diff options
author | Ben Banfield-Zanin <benbz@matrix.org> | 2021-03-01 10:06:09 +0000 |
---|---|---|
committer | Ben Banfield-Zanin <benbz@matrix.org> | 2021-03-01 10:06:09 +0000 |
commit | b26bee9faf957643cd34c4146b250b0009be205d (patch) | |
tree | a7a7e29f30acb437d010bdf6116c0f2729f21a1b /synapse/handlers | |
parent | Merge remote-tracking branch 'origin/release-v1.26.0' into toml/keycloak_hints (diff) | |
parent | Fixup changelog (diff) | |
download | synapse-toml/keycloak_hints.tar.xz |
Merge remote-tracking branch 'origin/release-v1.28.0' into toml/keycloak_hints github/toml/keycloak_hints toml/keycloak_hints
Diffstat (limited to '')
35 files changed, 1354 insertions, 700 deletions
diff --git a/synapse/handlers/acme.py b/synapse/handlers/acme.py index 8476256a59..5ecb2da1ac 100644 --- a/synapse/handlers/acme.py +++ b/synapse/handlers/acme.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING import twisted import twisted.internet.error @@ -22,6 +23,9 @@ from twisted.web.resource import Resource from synapse.app import check_bind_error +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) ACME_REGISTER_FAIL_ERROR = """ @@ -35,12 +39,12 @@ solutions, please read https://github.com/matrix-org/synapse/blob/master/docs/AC class AcmeHandler: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.reactor = hs.get_reactor() self._acme_domain = hs.config.acme_domain - async def start_listening(self): + async def start_listening(self) -> None: from synapse.handlers import acme_issuing_service # Configure logging for txacme, if you need to debug @@ -85,7 +89,7 @@ class AcmeHandler: logger.error(ACME_REGISTER_FAIL_ERROR) raise - async def provision_certificate(self): + async def provision_certificate(self) -> None: logger.warning("Reprovisioning %s", self._acme_domain) @@ -110,5 +114,3 @@ class AcmeHandler: except Exception: logger.exception("Failed saving!") raise - - return True diff --git a/synapse/handlers/acme_issuing_service.py b/synapse/handlers/acme_issuing_service.py index 7294649d71..ae2a9dd9c2 100644 --- a/synapse/handlers/acme_issuing_service.py +++ b/synapse/handlers/acme_issuing_service.py @@ -22,8 +22,10 @@ only need (and may only have available) if we are doing ACME, so is designed to imported conditionally. """ import logging +from typing import Dict, Iterable, List import attr +import pem from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization from josepy import JWKRSA @@ -36,20 +38,27 @@ from txacme.util import generate_private_key from zope.interface import implementer from twisted.internet import defer +from twisted.internet.interfaces import IReactorTCP from twisted.python.filepath import FilePath from twisted.python.url import URL +from twisted.web.resource import IResource logger = logging.getLogger(__name__) -def create_issuing_service(reactor, acme_url, account_key_file, well_known_resource): +def create_issuing_service( + reactor: IReactorTCP, + acme_url: str, + account_key_file: str, + well_known_resource: IResource, +) -> AcmeIssuingService: """Create an ACME issuing service, and attach it to a web Resource Args: reactor: twisted reactor - acme_url (str): URL to use to request certificates - account_key_file (str): where to store the account key - well_known_resource (twisted.web.IResource): web resource for .well-known. + acme_url: URL to use to request certificates + account_key_file: where to store the account key + well_known_resource: web resource for .well-known. we will attach a child resource for "acme-challenge". Returns: @@ -83,18 +92,20 @@ class ErsatzStore: A store that only stores in memory. """ - certs = attr.ib(default=attr.Factory(dict)) + certs = attr.ib(type=Dict[bytes, List[bytes]], default=attr.Factory(dict)) - def store(self, server_name, pem_objects): + def store( + self, server_name: bytes, pem_objects: Iterable[pem.AbstractPEMObject] + ) -> defer.Deferred: self.certs[server_name] = [o.as_bytes() for o in pem_objects] return defer.succeed(None) -def load_or_create_client_key(key_file): +def load_or_create_client_key(key_file: str) -> JWKRSA: """Load the ACME account key from a file, creating it if it does not exist. Args: - key_file (str): name of the file to use as the account key + key_file: name of the file to use as the account key """ # this is based on txacme.endpoint.load_or_create_client_key, but doesn't # hardcode the 'client.key' filename diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 37e63da9b1..db68c94c50 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -203,13 +203,11 @@ class AdminHandler(BaseHandler): class ExfiltrationWriter(metaclass=abc.ABCMeta): - """Interface used to specify how to write exported data. - """ + """Interface used to specify how to write exported data.""" @abc.abstractmethod def write_events(self, room_id: str, events: List[EventBase]) -> None: - """Write a batch of events for a room. - """ + """Write a batch of events for a room.""" raise NotImplementedError() @abc.abstractmethod diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 5c6458eb52..deab8ff2d0 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -290,7 +290,9 @@ class ApplicationServicesHandler: if not interested: continue presence_events, _ = await presence_source.get_new_events( - user=user, service=service, from_key=from_key, + user=user, + service=service, + from_key=from_key, ) time_now = self.clock.time_msec() events.extend( diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 0e98db22b3..9ba9f591d9 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -61,6 +61,7 @@ from synapse.http.site import SynapseRequest from synapse.logging.context import defer_to_thread from synapse.metrics.background_process_metrics import run_as_background_process from synapse.module_api import ModuleApi +from synapse.storage.roommember import ProfileInfo from synapse.types import JsonDict, Requester, UserID from synapse.util import stringutils as stringutils from synapse.util.async_helpers import maybe_awaitable @@ -119,7 +120,9 @@ def convert_client_dict_legacy_fields_to_identifier( # Ensure the identifier has a type if "type" not in identifier: raise SynapseError( - 400, "'identifier' dict has no key 'type'", errcode=Codes.MISSING_PARAM, + 400, + "'identifier' dict has no key 'type'", + errcode=Codes.MISSING_PARAM, ) return identifier @@ -350,7 +353,11 @@ class AuthHandler(BaseHandler): try: result, params, session_id = await self.check_ui_auth( - flows, request, request_body, description, get_new_session_data, + flows, + request, + request_body, + description, + get_new_session_data, ) except LoginError: # Update the ratelimiter to say we failed (`can_do_action` doesn't raise). @@ -378,8 +385,7 @@ class AuthHandler(BaseHandler): return params, session_id async def _get_available_ui_auth_types(self, user: UserID) -> Iterable[str]: - """Get a list of the authentication types this user can use - """ + """Get a list of the authentication types this user can use""" ui_auth_types = set() @@ -567,16 +573,6 @@ class AuthHandler(BaseHandler): session.session_id, login_type, result ) except LoginError as e: - if login_type == LoginType.EMAIL_IDENTITY: - # riot used to have a bug where it would request a new - # validation token (thus sending a new email) each time it - # got a 401 with a 'flows' field. - # (https://github.com/vector-im/vector-web/issues/2447). - # - # Grandfather in the old behaviour for now to avoid - # breaking old riot deployments. - raise - # this step failed. Merge the error dict into the response # so that the client can have another go. errordict = e.error_dict() @@ -732,7 +728,9 @@ class AuthHandler(BaseHandler): } def _auth_dict_for_flows( - self, flows: List[List[str]], session_id: str, + self, + flows: List[List[str]], + session_id: str, ) -> Dict[str, Any]: public_flows = [] for f in flows: @@ -889,7 +887,9 @@ class AuthHandler(BaseHandler): return self._supported_login_types async def validate_login( - self, login_submission: Dict[str, Any], ratelimit: bool = False, + self, + login_submission: Dict[str, Any], + ratelimit: bool = False, ) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]: """Authenticates the user for the /login API @@ -1032,7 +1032,9 @@ class AuthHandler(BaseHandler): raise async def _validate_userid_login( - self, username: str, login_submission: Dict[str, Any], + self, + username: str, + login_submission: Dict[str, Any], ) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]: """Helper for validate_login @@ -1387,7 +1389,9 @@ class AuthHandler(BaseHandler): ) return self._sso_auth_confirm_template.render( - description=session.description, redirect_url=redirect_url, + description=session.description, + redirect_url=redirect_url, + idp=sso_auth_provider, ) async def complete_sso_login( @@ -1396,6 +1400,7 @@ class AuthHandler(BaseHandler): request: Request, client_redirect_url: str, extra_attributes: Optional[JsonDict] = None, + new_user: bool = False, ): """Having figured out a mxid for this user, complete the HTTP request @@ -1406,6 +1411,8 @@ class AuthHandler(BaseHandler): process. extra_attributes: Extra attributes which will be passed to the client during successful login. Must be JSON serializable. + new_user: True if we should use wording appropriate to a user who has just + registered. """ # If the account has been deactivated, do not proceed with the login # flow. @@ -1414,8 +1421,17 @@ class AuthHandler(BaseHandler): respond_with_html(request, 403, self._sso_account_deactivated_template) return + profile = await self.store.get_profileinfo( + UserID.from_string(registered_user_id).localpart + ) + self._complete_sso_login( - registered_user_id, request, client_redirect_url, extra_attributes + registered_user_id, + request, + client_redirect_url, + extra_attributes, + new_user=new_user, + user_profile_data=profile, ) def _complete_sso_login( @@ -1424,18 +1440,25 @@ class AuthHandler(BaseHandler): request: Request, client_redirect_url: str, extra_attributes: Optional[JsonDict] = None, + new_user: bool = False, + user_profile_data: Optional[ProfileInfo] = None, ): """ The synchronous portion of complete_sso_login. This exists purely for backwards compatibility of synapse.module_api.ModuleApi. """ + + if user_profile_data is None: + user_profile_data = ProfileInfo(None, None) + # Store any extra attributes which will be passed in the login response. # Note that this is per-user so it may overwrite a previous value, this # is considered OK since the newest SSO attributes should be most valid. if extra_attributes: self._extra_attributes[registered_user_id] = SsoLoginExtraAttributes( - self._clock.time_msec(), extra_attributes, + self._clock.time_msec(), + extra_attributes, ) # Create a login token @@ -1461,12 +1484,27 @@ class AuthHandler(BaseHandler): # Remove the query parameters from the redirect URL to get a shorter version of # it. This is only to display a human-readable URL in the template, but not the # URL we redirect users to. - redirect_url_no_params = client_redirect_url.split("?")[0] + url_parts = urllib.parse.urlsplit(client_redirect_url) + + if url_parts.scheme == "https": + # for an https uri, just show the netloc (ie, the hostname. Specifically, + # the bit between "//" and "/"; this includes any potential + # "username:password@" prefix.) + display_url = url_parts.netloc + else: + # for other uris, strip the query-params (including the login token) and + # fragment. + display_url = urllib.parse.urlunsplit( + (url_parts.scheme, url_parts.netloc, url_parts.path, "", "") + ) html = self._sso_redirect_confirm_template.render( - display_url=redirect_url_no_params, + display_url=display_url, redirect_url=redirect_url, server_name=self._server_name, + new_user=new_user, + user_id=registered_user_id, + user_profile=user_profile_data, ) respond_with_html(request, 200, html) @@ -1676,5 +1714,9 @@ class PasswordProvider: # This might return an awaitable, if it does block the log out # until it completes. await maybe_awaitable( - g(user_id=user_id, device_id=device_id, access_token=access_token,) + g( + user_id=user_id, + device_id=device_id, + access_token=access_token, + ) ) diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py index 0f342c607b..04972f9cf0 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py @@ -14,7 +14,7 @@ # limitations under the License. import logging import urllib.parse -from typing import TYPE_CHECKING, Dict, Optional +from typing import TYPE_CHECKING, Dict, List, Optional from xml.etree import ElementTree as ET import attr @@ -33,8 +33,7 @@ logger = logging.getLogger(__name__) class CasError(Exception): - """Used to catch errors when validating the CAS ticket. - """ + """Used to catch errors when validating the CAS ticket.""" def __init__(self, error, error_description=None): self.error = error @@ -49,7 +48,7 @@ class CasError(Exception): @attr.s(slots=True, frozen=True) class CasResponse: username = attr.ib(type=str) - attributes = attr.ib(type=Dict[str, Optional[str]]) + attributes = attr.ib(type=Dict[str, List[Optional[str]]]) class CasHandler: @@ -80,9 +79,10 @@ class CasHandler: # user-facing name of this auth provider self.idp_name = "CAS" - # we do not currently support icons for CAS auth, but this is required by + # we do not currently support brands/icons for CAS auth, but this is required by # the SsoIdentityProvider protocol type. self.idp_icon = None + self.idp_brand = None self._sso_handler = hs.get_sso_handler() @@ -99,9 +99,8 @@ class CasHandler: Returns: The URL to use as a "service" parameter. """ - return "%s%s?%s" % ( + return "%s?%s" % ( self._cas_service_url, - "/_matrix/client/r0/login/cas/ticket", urllib.parse.urlencode(args), ) @@ -172,7 +171,7 @@ class CasHandler: # Iterate through the nodes and pull out the user and any extra attributes. user = None - attributes = {} + attributes = {} # type: Dict[str, List[Optional[str]]] for child in root[0]: if child.tag.endswith("user"): user = child.text @@ -185,7 +184,7 @@ class CasHandler: tag = attribute.tag if "}" in tag: tag = tag.split("}")[1] - attributes[tag] = attribute.text + attributes.setdefault(tag, []).append(attribute.text) # Ensure a user was found. if user is None: @@ -299,36 +298,20 @@ class CasHandler: # first check if we're doing a UIA if session: return await self._sso_handler.complete_sso_ui_auth_request( - self.idp_id, cas_response.username, session, request, + self.idp_id, + cas_response.username, + session, + request, ) # otherwise, we're handling a login request. # Ensure that the attributes of the logged in user meet the required # attributes. - for required_attribute, required_value in self._cas_required_attributes.items(): - # If required attribute was not in CAS Response - Forbidden - if required_attribute not in cas_response.attributes: - self._sso_handler.render_error( - request, - "unauthorised", - "You are not authorised to log in here.", - 401, - ) - return - - # Also need to check value - if required_value is not None: - actual_value = cas_response.attributes[required_attribute] - # If required attribute value does not match expected - Forbidden - if required_value != actual_value: - self._sso_handler.render_error( - request, - "unauthorised", - "You are not authorised to log in here.", - 401, - ) - return + if not self._sso_handler.check_required_attributes( + request, cas_response.attributes, self._cas_required_attributes + ): + return # Call the mapper to register/login the user @@ -375,9 +358,10 @@ class CasHandler: if failures: raise RuntimeError("CAS is not expected to de-duplicate Matrix IDs") + # Arbitrarily use the first attribute found. display_name = cas_response.attributes.get( - self._cas_displayname_attribute, None - ) + self._cas_displayname_attribute, [None] + )[0] return UserAttributes(localpart=localpart, display_name=display_name) @@ -387,7 +371,8 @@ class CasHandler: user_id = UserID(localpart, self._hostname).to_string() logger.debug( - "Looking for existing account based on mapped %s", user_id, + "Looking for existing account based on mapped %s", + user_id, ) users = await self._store.get_users_by_id_case_insensitive(user_id) diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index c4a3b26a84..94f3f3163f 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -196,8 +196,7 @@ class DeactivateAccountHandler(BaseHandler): run_as_background_process("user_parter_loop", self._user_parter_loop) async def _user_parter_loop(self) -> None: - """Loop that parts deactivated users from rooms - """ + """Loop that parts deactivated users from rooms""" self._user_parter_running = True logger.info("Starting user parter") try: @@ -214,8 +213,7 @@ class DeactivateAccountHandler(BaseHandler): self._user_parter_running = False async def _part_user(self, user_id: str) -> None: - """Causes the given user_id to leave all the rooms they're joined to - """ + """Causes the given user_id to leave all the rooms they're joined to""" user = UserID.from_string(user_id) rooms_for_user = await self.store.get_rooms_for_user(user_id) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index debb1b4f29..df3cdc8fba 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple from synapse.api import errors from synapse.api.constants import EventTypes @@ -62,7 +62,7 @@ class DeviceWorkerHandler(BaseHandler): self._auth_handler = hs.get_auth_handler() @trace - async def get_devices_by_user(self, user_id: str) -> List[Dict[str, Any]]: + async def get_devices_by_user(self, user_id: str) -> List[JsonDict]: """ Retrieve the given user's devices @@ -85,8 +85,8 @@ class DeviceWorkerHandler(BaseHandler): return devices @trace - async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]: - """ Retrieve the given device + async def get_device(self, user_id: str, device_id: str) -> JsonDict: + """Retrieve the given device Args: user_id: The user to get the device from @@ -341,7 +341,7 @@ class DeviceHandler(DeviceWorkerHandler): @trace async def delete_device(self, user_id: str, device_id: str) -> None: - """ Delete the given device + """Delete the given device Args: user_id: The user to delete the device from. @@ -386,7 +386,7 @@ class DeviceHandler(DeviceWorkerHandler): await self.delete_devices(user_id, device_ids) async def delete_devices(self, user_id: str, device_ids: List[str]) -> None: - """ Delete several devices + """Delete several devices Args: user_id: The user to delete devices from. @@ -417,7 +417,7 @@ class DeviceHandler(DeviceWorkerHandler): await self.notify_device_update(user_id, device_ids) async def update_device(self, user_id: str, device_id: str, content: dict) -> None: - """ Update the given device + """Update the given device Args: user_id: The user to update devices of. @@ -534,7 +534,9 @@ class DeviceHandler(DeviceWorkerHandler): device id of the dehydrated device """ device_id = await self.check_device_registered( - user_id, None, initial_device_display_name, + user_id, + None, + initial_device_display_name, ) old_device_id = await self.store.store_dehydrated_device( user_id, device_id, device_data @@ -598,7 +600,7 @@ class DeviceHandler(DeviceWorkerHandler): def _update_device_from_client_ips( - device: Dict[str, Any], client_ips: Dict[Tuple[str, str], Dict[str, Any]] + device: JsonDict, client_ips: Dict[Tuple[str, str], JsonDict] ) -> None: ip = client_ips.get((device["user_id"], device["device_id"]), {}) device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")}) @@ -803,7 +805,8 @@ class DeviceListUpdater: try: # Try to resync the current user's devices list. result = await self.user_device_resync( - user_id=user_id, mark_failed_as_stale=False, + user_id=user_id, + mark_failed_as_stale=False, ) # user_device_resync only returns a result if it managed to @@ -813,14 +816,17 @@ class DeviceListUpdater: # self.store.update_remote_device_list_cache). if result: logger.debug( - "Successfully resynced the device list for %s", user_id, + "Successfully resynced the device list for %s", + user_id, ) except Exception as e: # If there was an issue resyncing this user, e.g. if the remote # server sent a malformed result, just log the error instead of # aborting all the subsequent resyncs. logger.debug( - "Could not resync the device list for %s: %s", user_id, e, + "Could not resync the device list for %s: %s", + user_id, + e, ) finally: # Allow future calls to retry resyncinc out of sync device lists. @@ -855,7 +861,9 @@ class DeviceListUpdater: return None except (RequestSendFailed, HttpResponseException) as e: logger.warning( - "Failed to handle device list update for %s: %s", user_id, e, + "Failed to handle device list update for %s: %s", + user_id, + e, ) if mark_failed_as_stale: @@ -931,7 +939,9 @@ class DeviceListUpdater: # Handle cross-signing keys. cross_signing_device_ids = await self.process_cross_signing_key_update( - user_id, master_key, self_signing_key, + user_id, + master_key, + self_signing_key, ) device_ids = device_ids + cross_signing_device_ids @@ -946,8 +956,8 @@ class DeviceListUpdater: async def process_cross_signing_key_update( self, user_id: str, - master_key: Optional[Dict[str, Any]], - self_signing_key: Optional[Dict[str, Any]], + master_key: Optional[JsonDict], + self_signing_key: Optional[JsonDict], ) -> List[str]: """Process the given new master and self-signing key for the given remote user. diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index 0c7737e09d..1aa7d803b5 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -62,7 +62,8 @@ class DeviceMessageHandler: ) else: hs.get_federation_registry().register_instances_for_edu( - "m.direct_to_device", hs.config.worker.writers.to_device, + "m.direct_to_device", + hs.config.worker.writers.to_device, ) # The handler to call when we think a user's device list might be out of @@ -73,8 +74,8 @@ class DeviceMessageHandler: hs.get_device_handler().device_list_updater.user_device_resync ) else: - self._user_device_resync = ReplicationUserDevicesResyncRestServlet.make_client( - hs + self._user_device_resync = ( + ReplicationUserDevicesResyncRestServlet.make_client(hs) ) async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None: diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 929752150d..9a946a3cfe 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -16,7 +16,7 @@ # limitations under the License. import logging -from typing import Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple import attr from canonicaljson import encode_canonical_json @@ -31,6 +31,7 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet from synapse.types import ( + JsonDict, UserID, get_domain_from_id, get_verify_key_from_cross_signing_key, @@ -40,11 +41,14 @@ from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.retryutils import NotRetryingDestination +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) class E2eKeysHandler: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.federation = hs.get_federation_client() self.device_handler = hs.get_device_handler() @@ -57,8 +61,8 @@ class E2eKeysHandler: self._is_master = hs.config.worker_app is None if not self._is_master: - self._user_device_resync_client = ReplicationUserDevicesResyncRestServlet.make_client( - hs + self._user_device_resync_client = ( + ReplicationUserDevicesResyncRestServlet.make_client(hs) ) else: # Only register this edu handler on master as it requires writing @@ -78,8 +82,10 @@ class E2eKeysHandler: ) @trace - async def query_devices(self, query_body, timeout, from_user_id): - """ Handle a device key query from a client + async def query_devices( + self, query_body: JsonDict, timeout: int, from_user_id: str + ) -> JsonDict: + """Handle a device key query from a client { "device_keys": { @@ -98,12 +104,14 @@ class E2eKeysHandler: } Args: - from_user_id (str): the user making the query. This is used when + from_user_id: the user making the query. This is used when adding cross-signing signatures to limit what signatures users can see. """ - device_keys_query = query_body.get("device_keys", {}) + device_keys_query = query_body.get( + "device_keys", {} + ) # type: Dict[str, Iterable[str]] # separate users by domain. # make a map from domain to user_id to device_ids @@ -121,7 +129,8 @@ class E2eKeysHandler: set_tag("remote_key_query", remote_queries) # First get local devices. - failures = {} + # A map of destination -> failure response. + failures = {} # type: Dict[str, JsonDict] results = {} if local_query: local_result = await self.query_local_devices(local_query) @@ -135,9 +144,10 @@ class E2eKeysHandler: ) # Now attempt to get any remote devices from our local cache. - remote_queries_not_in_cache = {} + # A map of destination -> user ID -> device IDs. + remote_queries_not_in_cache = {} # type: Dict[str, Dict[str, Iterable[str]]] if remote_queries: - query_list = [] + query_list = [] # type: List[Tuple[str, Optional[str]]] for user_id, device_ids in remote_queries.items(): if device_ids: query_list.extend((user_id, device_id) for device_id in device_ids) @@ -284,15 +294,15 @@ class E2eKeysHandler: return ret async def get_cross_signing_keys_from_cache( - self, query, from_user_id + self, query: Iterable[str], from_user_id: Optional[str] ) -> Dict[str, Dict[str, dict]]: """Get cross-signing keys for users from the database Args: - query (Iterable[string]) an iterable of user IDs. A dict whose keys + query: an iterable of user IDs. A dict whose keys are user IDs satisfies this, so the query format used for query_devices can be used here. - from_user_id (str): the user making the query. This is used when + from_user_id: the user making the query. This is used when adding cross-signing signatures to limit what signatures users can see. @@ -315,14 +325,12 @@ class E2eKeysHandler: if "self_signing" in user_info: self_signing_keys[user_id] = user_info["self_signing"] - if ( - from_user_id in keys - and keys[from_user_id] is not None - and "user_signing" in keys[from_user_id] - ): - # users can see other users' master and self-signing keys, but can - # only see their own user-signing keys - user_signing_keys[from_user_id] = keys[from_user_id]["user_signing"] + # users can see other users' master and self-signing keys, but can + # only see their own user-signing keys + if from_user_id: + from_user_key = keys.get(from_user_id) + if from_user_key and "user_signing" in from_user_key: + user_signing_keys[from_user_id] = from_user_key["user_signing"] return { "master_keys": master_keys, @@ -344,9 +352,9 @@ class E2eKeysHandler: A map from user_id -> device_id -> device details """ set_tag("local_query", query) - local_query = [] + local_query = [] # type: List[Tuple[str, Optional[str]]] - result_dict = {} + result_dict = {} # type: Dict[str, Dict[str, dict]] for user_id, device_ids in query.items(): # we use UserID.from_string to catch invalid user ids if not self.is_mine(UserID.from_string(user_id)): @@ -380,10 +388,13 @@ class E2eKeysHandler: log_kv(results) return result_dict - async def on_federation_query_client_keys(self, query_body): - """ Handle a device key query from a federated server - """ - device_keys_query = query_body.get("device_keys", {}) + async def on_federation_query_client_keys( + self, query_body: Dict[str, Dict[str, Optional[List[str]]]] + ) -> JsonDict: + """Handle a device key query from a federated server""" + device_keys_query = query_body.get( + "device_keys", {} + ) # type: Dict[str, Optional[List[str]]] res = await self.query_local_devices(device_keys_query) ret = {"device_keys": res} @@ -397,31 +408,34 @@ class E2eKeysHandler: return ret @trace - async def claim_one_time_keys(self, query, timeout): - local_query = [] - remote_queries = {} + async def claim_one_time_keys( + self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: int + ) -> JsonDict: + local_query = [] # type: List[Tuple[str, str, str]] + remote_queries = {} # type: Dict[str, Dict[str, Dict[str, str]]] - for user_id, device_keys in query.get("one_time_keys", {}).items(): + for user_id, one_time_keys in query.get("one_time_keys", {}).items(): # we use UserID.from_string to catch invalid user ids if self.is_mine(UserID.from_string(user_id)): - for device_id, algorithm in device_keys.items(): + for device_id, algorithm in one_time_keys.items(): local_query.append((user_id, device_id, algorithm)) else: domain = get_domain_from_id(user_id) - remote_queries.setdefault(domain, {})[user_id] = device_keys + remote_queries.setdefault(domain, {})[user_id] = one_time_keys set_tag("local_key_query", local_query) set_tag("remote_key_query", remote_queries) results = await self.store.claim_e2e_one_time_keys(local_query) - json_result = {} - failures = {} + # A map of user ID -> device ID -> key ID -> key. + json_result = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]] + failures = {} # type: Dict[str, JsonDict] for user_id, device_keys in results.items(): for device_id, keys in device_keys.items(): - for key_id, json_bytes in keys.items(): + for key_id, json_str in keys.items(): json_result.setdefault(user_id, {})[device_id] = { - key_id: json_decoder.decode(json_bytes) + key_id: json_decoder.decode(json_str) } @trace @@ -468,7 +482,9 @@ class E2eKeysHandler: return {"one_time_keys": json_result, "failures": failures} @tag_args - async def upload_keys_for_user(self, user_id, device_id, keys): + async def upload_keys_for_user( + self, user_id: str, device_id: str, keys: JsonDict + ) -> JsonDict: time_now = self.clock.time_msec() @@ -543,8 +559,8 @@ class E2eKeysHandler: return {"one_time_key_counts": result} async def _upload_one_time_keys_for_user( - self, user_id, device_id, time_now, one_time_keys - ): + self, user_id: str, device_id: str, time_now: int, one_time_keys: JsonDict + ) -> None: logger.info( "Adding one_time_keys %r for device %r for user %r at %d", one_time_keys.keys(), @@ -585,12 +601,14 @@ class E2eKeysHandler: log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys}) await self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys) - async def upload_signing_keys_for_user(self, user_id, keys): + async def upload_signing_keys_for_user( + self, user_id: str, keys: JsonDict + ) -> JsonDict: """Upload signing keys for cross-signing Args: - user_id (string): the user uploading the keys - keys (dict[string, dict]): the signing keys + user_id: the user uploading the keys + keys: the signing keys """ # if a master key is uploaded, then check it. Otherwise, load the @@ -667,16 +685,17 @@ class E2eKeysHandler: return {} - async def upload_signatures_for_device_keys(self, user_id, signatures): + async def upload_signatures_for_device_keys( + self, user_id: str, signatures: JsonDict + ) -> JsonDict: """Upload device signatures for cross-signing Args: - user_id (string): the user uploading the signatures - signatures (dict[string, dict[string, dict]]): map of users to - devices to signed keys. This is the submission from the user; an - exception will be raised if it is malformed. + user_id: the user uploading the signatures + signatures: map of users to devices to signed keys. This is the submission + from the user; an exception will be raised if it is malformed. Returns: - dict: response to be sent back to the client. The response will have + The response to be sent back to the client. The response will have a "failures" key, which will be a dict mapping users to devices to errors for the signatures that failed. Raises: @@ -719,7 +738,9 @@ class E2eKeysHandler: return {"failures": failures} - async def _process_self_signatures(self, user_id, signatures): + async def _process_self_signatures( + self, user_id: str, signatures: JsonDict + ) -> Tuple[List["SignatureListItem"], Dict[str, Dict[str, dict]]]: """Process uploaded signatures of the user's own keys. Signatures of the user's own keys from this API come in two forms: @@ -731,15 +752,14 @@ class E2eKeysHandler: signatures (dict[string, dict]): map of devices to signed keys Returns: - (list[SignatureListItem], dict[string, dict[string, dict]]): - a list of signatures to store, and a map of users to devices to failure - reasons + A tuple of a list of signatures to store, and a map of users to + devices to failure reasons Raises: SynapseError: if the input is malformed """ - signature_list = [] - failures = {} + signature_list = [] # type: List[SignatureListItem] + failures = {} # type: Dict[str, Dict[str, JsonDict]] if not signatures: return signature_list, failures @@ -834,19 +854,24 @@ class E2eKeysHandler: return signature_list, failures def _check_master_key_signature( - self, user_id, master_key_id, signed_master_key, stored_master_key, devices - ): + self, + user_id: str, + master_key_id: str, + signed_master_key: JsonDict, + stored_master_key: JsonDict, + devices: Dict[str, Dict[str, JsonDict]], + ) -> List["SignatureListItem"]: """Check signatures of a user's master key made by their devices. Args: - user_id (string): the user whose master key is being checked - master_key_id (string): the ID of the user's master key - signed_master_key (dict): the user's signed master key that was uploaded - stored_master_key (dict): our previously-stored copy of the user's master key - devices (iterable(dict)): the user's devices + user_id: the user whose master key is being checked + master_key_id: the ID of the user's master key + signed_master_key: the user's signed master key that was uploaded + stored_master_key: our previously-stored copy of the user's master key + devices: the user's devices Returns: - list[SignatureListItem]: a list of signatures to store + A list of signatures to store Raises: SynapseError: if a signature is invalid @@ -877,25 +902,26 @@ class E2eKeysHandler: return master_key_signature_list - async def _process_other_signatures(self, user_id, signatures): + async def _process_other_signatures( + self, user_id: str, signatures: Dict[str, dict] + ) -> Tuple[List["SignatureListItem"], Dict[str, Dict[str, dict]]]: """Process uploaded signatures of other users' keys. These will be the target user's master keys, signed by the uploading user's user-signing key. Args: - user_id (string): the user uploading the keys - signatures (dict[string, dict]): map of users to devices to signed keys + user_id: the user uploading the keys + signatures: map of users to devices to signed keys Returns: - (list[SignatureListItem], dict[string, dict[string, dict]]): - a list of signatures to store, and a map of users to devices to failure + A list of signatures to store, and a map of users to devices to failure reasons Raises: SynapseError: if the input is malformed """ - signature_list = [] - failures = {} + signature_list = [] # type: List[SignatureListItem] + failures = {} # type: Dict[str, Dict[str, JsonDict]] if not signatures: return signature_list, failures @@ -983,7 +1009,7 @@ class E2eKeysHandler: async def _get_e2e_cross_signing_verify_key( self, user_id: str, key_type: str, from_user_id: str = None - ): + ) -> Tuple[JsonDict, str, VerifyKey]: """Fetch locally or remotely query for a cross-signing public key. First, attempt to fetch the cross-signing public key from storage. @@ -997,8 +1023,7 @@ class E2eKeysHandler: This affects what signatures are fetched. Returns: - dict, str, VerifyKey: the raw key data, the key ID, and the - signedjson verify key + The raw key data, the key ID, and the signedjson verify key Raises: NotFoundError: if the key is not found @@ -1039,7 +1064,9 @@ class E2eKeysHandler: return key, key_id, verify_key async def _retrieve_cross_signing_keys_for_remote_user( - self, user: UserID, desired_key_type: str, + self, + user: UserID, + desired_key_type: str, ) -> Tuple[Optional[dict], Optional[str], Optional[VerifyKey]]: """Queries cross-signing keys for a remote user and saves them to the database @@ -1135,16 +1162,18 @@ class E2eKeysHandler: return desired_key, desired_key_id, desired_verify_key -def _check_cross_signing_key(key, user_id, key_type, signing_key=None): +def _check_cross_signing_key( + key: JsonDict, user_id: str, key_type: str, signing_key: Optional[VerifyKey] = None +) -> None: """Check a cross-signing key uploaded by a user. Performs some basic sanity checking, and ensures that it is signed, if a signature is required. Args: - key (dict): the key data to verify - user_id (str): the user whose key is being checked - key_type (str): the type of key that the key should be - signing_key (VerifyKey): (optional) the signing key that the key should - be signed with. If omitted, signatures will not be checked. + key: the key data to verify + user_id: the user whose key is being checked + key_type: the type of key that the key should be + signing_key: the signing key that the key should be signed with. If + omitted, signatures will not be checked. """ if ( key.get("user_id") != user_id @@ -1162,16 +1191,21 @@ def _check_cross_signing_key(key, user_id, key_type, signing_key=None): ) -def _check_device_signature(user_id, verify_key, signed_device, stored_device): +def _check_device_signature( + user_id: str, + verify_key: VerifyKey, + signed_device: JsonDict, + stored_device: JsonDict, +) -> None: """Check that a signature on a device or cross-signing key is correct and matches the copy of the device/key that we have stored. Throws an exception if an error is detected. Args: - user_id (str): the user ID whose signature is being checked - verify_key (VerifyKey): the key to verify the device with - signed_device (dict): the uploaded signed device data - stored_device (dict): our previously stored copy of the device + user_id: the user ID whose signature is being checked + verify_key: the key to verify the device with + signed_device: the uploaded signed device data + stored_device: our previously stored copy of the device Raises: SynapseError: if the signature was invalid or the sent device is not the @@ -1201,7 +1235,7 @@ def _check_device_signature(user_id, verify_key, signed_device, stored_device): raise SynapseError(400, "Invalid signature", Codes.INVALID_SIGNATURE) -def _exception_to_failure(e): +def _exception_to_failure(e: Exception) -> JsonDict: if isinstance(e, SynapseError): return {"status": e.code, "errcode": e.errcode, "message": str(e)} @@ -1218,7 +1252,7 @@ def _exception_to_failure(e): return {"status": 503, "message": str(e)} -def _one_time_keys_match(old_key_json, new_key): +def _one_time_keys_match(old_key_json: str, new_key: JsonDict) -> bool: old_key = json_decoder.decode(old_key_json) # if either is a string rather than an object, they must match exactly @@ -1236,19 +1270,18 @@ def _one_time_keys_match(old_key_json, new_key): @attr.s(slots=True) class SignatureListItem: - """An item in the signature list as used by upload_signatures_for_device_keys. - """ + """An item in the signature list as used by upload_signatures_for_device_keys.""" - signing_key_id = attr.ib() - target_user_id = attr.ib() - target_device_id = attr.ib() - signature = attr.ib() + signing_key_id = attr.ib(type=str) + target_user_id = attr.ib(type=str) + target_device_id = attr.ib(type=str) + signature = attr.ib(type=JsonDict) class SigningKeyEduUpdater: """Handles incoming signing key updates from federation and updates the DB""" - def __init__(self, hs, e2e_keys_handler): + def __init__(self, hs: "HomeServer", e2e_keys_handler: E2eKeysHandler): self.store = hs.get_datastore() self.federation = hs.get_federation_client() self.clock = hs.get_clock() @@ -1257,7 +1290,7 @@ class SigningKeyEduUpdater: self._remote_edu_linearizer = Linearizer(name="remote_signing_key") # user_id -> list of updates waiting to be handled. - self._pending_updates = {} + self._pending_updates = {} # type: Dict[str, List[Tuple[JsonDict, JsonDict]]] # Recently seen stream ids. We don't bother keeping these in the DB, # but they're useful to have them about to reduce the number of spurious @@ -1270,13 +1303,15 @@ class SigningKeyEduUpdater: iterable=True, ) - async def incoming_signing_key_update(self, origin, edu_content): + async def incoming_signing_key_update( + self, origin: str, edu_content: JsonDict + ) -> None: """Called on incoming signing key update from federation. Responsible for parsing the EDU and adding to pending updates list. Args: - origin (string): the server that sent the EDU - edu_content (dict): the contents of the EDU + origin: the server that sent the EDU + edu_content: the contents of the EDU """ user_id = edu_content.pop("user_id") @@ -1299,11 +1334,11 @@ class SigningKeyEduUpdater: await self._handle_signing_key_updates(user_id) - async def _handle_signing_key_updates(self, user_id): + async def _handle_signing_key_updates(self, user_id: str) -> None: """Actually handle pending updates. Args: - user_id (string): the user whose updates we are processing + user_id: the user whose updates we are processing """ device_handler = self.e2e_keys_handler.device_handler @@ -1315,13 +1350,17 @@ class SigningKeyEduUpdater: # This can happen since we batch updates return - device_ids = [] + device_ids = [] # type: List[str] logger.info("pending updates: %r", pending_updates) for master_key, self_signing_key in pending_updates: - new_device_ids = await device_list_updater.process_cross_signing_key_update( - user_id, master_key, self_signing_key, + new_device_ids = ( + await device_list_updater.process_cross_signing_key_update( + user_id, + master_key, + self_signing_key, + ) ) device_ids = device_ids + new_device_ids diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py index f01b090772..622cae23be 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py @@ -15,6 +15,7 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, List, Optional from synapse.api.errors import ( Codes, @@ -24,8 +25,12 @@ from synapse.api.errors import ( SynapseError, ) from synapse.logging.opentracing import log_kv, trace +from synapse.types import JsonDict from synapse.util.async_helpers import Linearizer +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) @@ -37,7 +42,7 @@ class E2eRoomKeysHandler: The actual payload of the encrypted keys is completely opaque to the handler. """ - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() # Used to lock whenever a client is uploading key data. This prevents collisions @@ -48,21 +53,27 @@ class E2eRoomKeysHandler: self._upload_linearizer = Linearizer("upload_room_keys_lock") @trace - async def get_room_keys(self, user_id, version, room_id=None, session_id=None): + async def get_room_keys( + self, + user_id: str, + version: str, + room_id: Optional[str] = None, + session_id: Optional[str] = None, + ) -> List[JsonDict]: """Bulk get the E2E room keys for a given backup, optionally filtered to a given room, or a given session. See EndToEndRoomKeyStore.get_e2e_room_keys for full details. Args: - user_id(str): the user whose keys we're getting - version(str): the version ID of the backup we're getting keys from - room_id(string): room ID to get keys for, for None to get keys for all rooms - session_id(string): session ID to get keys for, for None to get keys for all + user_id: the user whose keys we're getting + version: the version ID of the backup we're getting keys from + room_id: room ID to get keys for, for None to get keys for all rooms + session_id: session ID to get keys for, for None to get keys for all sessions Raises: NotFoundError: if the backup version does not exist Returns: - A deferred list of dicts giving the session_data and message metadata for + A list of dicts giving the session_data and message metadata for these room keys. """ @@ -86,17 +97,23 @@ class E2eRoomKeysHandler: return results @trace - async def delete_room_keys(self, user_id, version, room_id=None, session_id=None): + async def delete_room_keys( + self, + user_id: str, + version: str, + room_id: Optional[str] = None, + session_id: Optional[str] = None, + ) -> JsonDict: """Bulk delete the E2E room keys for a given backup, optionally filtered to a given room or a given session. See EndToEndRoomKeyStore.delete_e2e_room_keys for full details. Args: - user_id(str): the user whose backup we're deleting - version(str): the version ID of the backup we're deleting - room_id(string): room ID to delete keys for, for None to delete keys for all + user_id: the user whose backup we're deleting + version: the version ID of the backup we're deleting + room_id: room ID to delete keys for, for None to delete keys for all rooms - session_id(string): session ID to delete keys for, for None to delete keys + session_id: session ID to delete keys for, for None to delete keys for all sessions Raises: NotFoundError: if the backup version does not exist @@ -128,15 +145,17 @@ class E2eRoomKeysHandler: return {"etag": str(version_etag), "count": count} @trace - async def upload_room_keys(self, user_id, version, room_keys): + async def upload_room_keys( + self, user_id: str, version: str, room_keys: JsonDict + ) -> JsonDict: """Bulk upload a list of room keys into a given backup version, asserting that the given version is the current backup version. room_keys are merged into the current backup as described in RoomKeysServlet.on_PUT(). Args: - user_id(str): the user whose backup we're setting - version(str): the version ID of the backup we're updating - room_keys(dict): a nested dict describing the room_keys we're setting: + user_id: the user whose backup we're setting + version: the version ID of the backup we're updating + room_keys: a nested dict describing the room_keys we're setting: { "rooms": { @@ -254,14 +273,16 @@ class E2eRoomKeysHandler: return {"etag": str(version_etag), "count": count} @staticmethod - def _should_replace_room_key(current_room_key, room_key): + def _should_replace_room_key( + current_room_key: Optional[JsonDict], room_key: JsonDict + ) -> bool: """ Determine whether to replace a given current_room_key (if any) with a newly uploaded room_key backup Args: - current_room_key (dict): Optional, the current room_key dict if any - room_key (dict): The new room_key dict which may or may not be fit to + current_room_key: Optional, the current room_key dict if any + room_key : The new room_key dict which may or may not be fit to replace the current_room_key Returns: @@ -286,14 +307,14 @@ class E2eRoomKeysHandler: return True @trace - async def create_version(self, user_id, version_info): + async def create_version(self, user_id: str, version_info: JsonDict) -> str: """Create a new backup version. This automatically becomes the new backup version for the user's keys; previous backups will no longer be writeable to. Args: - user_id(str): the user whose backup version we're creating - version_info(dict): metadata about the new version being created + user_id: the user whose backup version we're creating + version_info: metadata about the new version being created { "algorithm": "m.megolm_backup.v1", @@ -301,7 +322,7 @@ class E2eRoomKeysHandler: } Returns: - A deferred of a string that gives the new version number. + The new version number. """ # TODO: Validate the JSON to make sure it has the right keys. @@ -313,17 +334,19 @@ class E2eRoomKeysHandler: ) return new_version - async def get_version_info(self, user_id, version=None): + async def get_version_info( + self, user_id: str, version: Optional[str] = None + ) -> JsonDict: """Get the info about a given version of the user's backup Args: - user_id(str): the user whose current backup version we're querying - version(str): Optional; if None gives the most recent version + user_id: the user whose current backup version we're querying + version: Optional; if None gives the most recent version otherwise a historical one. Raises: NotFoundError: if the requested backup version doesn't exist Returns: - A deferred of a info dict that gives the info about the new version. + A info dict that gives the info about the new version. { "version": "1234", @@ -346,7 +369,7 @@ class E2eRoomKeysHandler: return res @trace - async def delete_version(self, user_id, version=None): + async def delete_version(self, user_id: str, version: Optional[str] = None) -> None: """Deletes a given version of the user's e2e_room_keys backup Args: @@ -366,17 +389,19 @@ class E2eRoomKeysHandler: raise @trace - async def update_version(self, user_id, version, version_info): + async def update_version( + self, user_id: str, version: str, version_info: JsonDict + ) -> JsonDict: """Update the info about a given version of the user's backup Args: - user_id(str): the user whose current backup version we're updating - version(str): the backup version we're updating - version_info(dict): the new information about the backup + user_id: the user whose current backup version we're updating + version: the backup version we're updating + version_info: the new information about the backup Raises: NotFoundError: if the requested backup version doesn't exist Returns: - A deferred of an empty dict. + An empty dict. """ if "version" not in version_info: version_info["version"] = version diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 539b4fc32e..3e23f82cf7 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -57,8 +57,7 @@ class EventStreamHandler(BaseHandler): room_id: Optional[str] = None, is_guest: bool = False, ) -> JsonDict: - """Fetches the events stream for a given user. - """ + """Fetches the events stream for a given user.""" if room_id: blocked = await self.store.is_room_blocked(room_id) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index fd8de8696d..2ead626a4d 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -111,13 +111,13 @@ class _NewEventInfo: class FederationHandler(BaseHandler): """Handles events that originated from federation. - Responsible for: - a) handling received Pdus before handing them on as Events to the rest - of the homeserver (including auth and state conflict resolutions) - b) converting events that were produced by local clients that may need - to be sent to remote homeservers. - c) doing the necessary dances to invite remote users and join remote - rooms. + Responsible for: + a) handling received Pdus before handing them on as Events to the rest + of the homeserver (including auth and state conflict resolutions) + b) converting events that were produced by local clients that may need + to be sent to remote homeservers. + c) doing the necessary dances to invite remote users and join remote + rooms. """ def __init__(self, hs: "HomeServer"): @@ -150,11 +150,11 @@ class FederationHandler(BaseHandler): ) if hs.config.worker_app: - self._user_device_resync = ReplicationUserDevicesResyncRestServlet.make_client( - hs + self._user_device_resync = ( + ReplicationUserDevicesResyncRestServlet.make_client(hs) ) - self._maybe_store_room_on_outlier_membership = ReplicationStoreRoomOnOutlierMembershipRestServlet.make_client( - hs + self._maybe_store_room_on_outlier_membership = ( + ReplicationStoreRoomOnOutlierMembershipRestServlet.make_client(hs) ) else: self._device_list_updater = hs.get_device_handler().device_list_updater @@ -172,7 +172,7 @@ class FederationHandler(BaseHandler): self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None: - """ Process a PDU received via a federation /send/ transaction, or + """Process a PDU received via a federation /send/ transaction, or via backfill of missing prev_events Args: @@ -368,7 +368,8 @@ class FederationHandler(BaseHandler): # know about for p in prevs - seen: logger.info( - "Requesting state at missing prev_event %s", event_id, + "Requesting state at missing prev_event %s", + event_id, ) with nested_logging_context(p): @@ -388,12 +389,14 @@ class FederationHandler(BaseHandler): event_map[x.event_id] = x room_version = await self.store.get_room_version_id(room_id) - state_map = await self._state_resolution_handler.resolve_events_with_store( - room_id, - room_version, - state_maps, - event_map, - state_res_store=StateResolutionStore(self.store), + state_map = ( + await self._state_resolution_handler.resolve_events_with_store( + room_id, + room_version, + state_maps, + event_map, + state_res_store=StateResolutionStore(self.store), + ) ) # We need to give _process_received_pdu the actual state events @@ -687,9 +690,12 @@ class FederationHandler(BaseHandler): return fetched_events async def _process_received_pdu( - self, origin: str, event: EventBase, state: Optional[Iterable[EventBase]], + self, + origin: str, + event: EventBase, + state: Optional[Iterable[EventBase]], ): - """ Called when we have a new pdu. We need to do auth checks and put it + """Called when we have a new pdu. We need to do auth checks and put it through the StateHandler. Args: @@ -801,7 +807,7 @@ class FederationHandler(BaseHandler): @log_function async def backfill(self, dest, room_id, limit, extremities): - """ Trigger a backfill request to `dest` for the given `room_id` + """Trigger a backfill request to `dest` for the given `room_id` This will attempt to get more events from the remote. If the other side has no new events to offer, this will return an empty list. @@ -1204,11 +1210,16 @@ class FederationHandler(BaseHandler): with nested_logging_context(event_id): try: event = await self.federation_client.get_pdu( - [destination], event_id, room_version, outlier=True, + [destination], + event_id, + room_version, + outlier=True, ) if event is None: logger.warning( - "Server %s didn't return event %s", destination, event_id, + "Server %s didn't return event %s", + destination, + event_id, ) return @@ -1235,7 +1246,8 @@ class FederationHandler(BaseHandler): if aid not in event_map ] persisted_events = await self.store.get_events( - auth_events, allow_rejected=True, + auth_events, + allow_rejected=True, ) event_infos = [] @@ -1251,7 +1263,9 @@ class FederationHandler(BaseHandler): event_infos.append(_NewEventInfo(event, None, auth)) await self._handle_new_events( - destination, room_id, event_infos, + destination, + room_id, + event_infos, ) def _sanity_check_event(self, ev): @@ -1287,7 +1301,7 @@ class FederationHandler(BaseHandler): raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events") async def send_invite(self, target_host, event): - """ Sends the invite to the remote server for signing. + """Sends the invite to the remote server for signing. Invites must be signed by the invitee's server before distribution. """ @@ -1310,7 +1324,7 @@ class FederationHandler(BaseHandler): async def do_invite_join( self, target_hosts: Iterable[str], room_id: str, joinee: str, content: JsonDict ) -> Tuple[str, int]: - """ Attempts to join the `joinee` to the room `room_id` via the + """Attempts to join the `joinee` to the room `room_id` via the servers contained in `target_hosts`. This first triggers a /make_join/ request that returns a partial @@ -1354,8 +1368,6 @@ class FederationHandler(BaseHandler): await self._clean_room_for_join(room_id) - handled_events = set() - try: # Try the host we successfully got a response to /make_join/ # request first. @@ -1375,10 +1387,6 @@ class FederationHandler(BaseHandler): auth_chain = ret["auth_chain"] auth_chain.sort(key=lambda e: e.depth) - handled_events.update([s.event_id for s in state]) - handled_events.update([a.event_id for a in auth_chain]) - handled_events.add(event.event_id) - logger.debug("do_invite_join auth_chain: %s", auth_chain) logger.debug("do_invite_join state: %s", state) @@ -1394,7 +1402,8 @@ class FederationHandler(BaseHandler): # so we can rely on it now. # await self.store.upsert_room_on_join( - room_id=room_id, room_version=room_version_obj, + room_id=room_id, + room_version=room_version_obj, ) max_stream_id = await self._persist_auth_tree( @@ -1464,7 +1473,7 @@ class FederationHandler(BaseHandler): async def on_make_join_request( self, origin: str, room_id: str, user_id: str ) -> EventBase: - """ We've received a /make_join/ request, so we create a partial + """We've received a /make_join/ request, so we create a partial join event for the room and return that. We do *not* persist or process it until the other server has signed it and sent it back. @@ -1489,7 +1498,8 @@ class FederationHandler(BaseHandler): is_in_room = await self.auth.check_host_in_room(room_id, self.server_name) if not is_in_room: logger.info( - "Got /make_join request for room %s we are no longer in", room_id, + "Got /make_join request for room %s we are no longer in", + room_id, ) raise NotFoundError("Not an active room on this server") @@ -1523,7 +1533,7 @@ class FederationHandler(BaseHandler): return event async def on_send_join_request(self, origin, pdu): - """ We have received a join event for a room. Fully process it and + """We have received a join event for a room. Fully process it and respond with the current state and auth chains. """ event = pdu @@ -1579,7 +1589,7 @@ class FederationHandler(BaseHandler): async def on_invite_request( self, origin: str, event: EventBase, room_version: RoomVersion ): - """ We've got an invite event. Process and persist it. Sign it. + """We've got an invite event. Process and persist it. Sign it. Respond with the now signed event. """ @@ -1617,6 +1627,12 @@ class FederationHandler(BaseHandler): if event.state_key == self._server_notices_mxid: raise SynapseError(HTTPStatus.FORBIDDEN, "Cannot invite this user") + # We retrieve the room member handler here as to not cause a cyclic dependency + member_handler = self.hs.get_room_member_handler() + # We don't rate limit based on room ID, as that should be done by + # sending server. + member_handler.ratelimit_invite(None, event.state_key) + # keep a record of the room version, if we don't yet know it. # (this may get overwritten if we later get a different room version in a # join dance). @@ -1700,7 +1716,7 @@ class FederationHandler(BaseHandler): async def on_make_leave_request( self, origin: str, room_id: str, user_id: str ) -> EventBase: - """ We've received a /make_leave/ request, so we create a partial + """We've received a /make_leave/ request, so we create a partial leave event for the room and return that. We do *not* persist or process it until the other server has signed it and sent it back. @@ -1776,8 +1792,7 @@ class FederationHandler(BaseHandler): return None async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]: - """Returns the state at the event. i.e. not including said event. - """ + """Returns the state at the event. i.e. not including said event.""" event = await self.store.get_event(event_id, check_room_id=room_id) @@ -1803,8 +1818,7 @@ class FederationHandler(BaseHandler): return [] async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]: - """Returns the state at the event. i.e. not including said event. - """ + """Returns the state at the event. i.e. not including said event.""" event = await self.store.get_event(event_id, check_room_id=room_id) state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id]) @@ -2010,7 +2024,11 @@ class FederationHandler(BaseHandler): for e_id in missing_auth_events: m_ev = await self.federation_client.get_pdu( - [origin], e_id, room_version=room_version, outlier=True, timeout=10000, + [origin], + e_id, + room_version=room_version, + outlier=True, + timeout=10000, ) if m_ev and m_ev.event_id == e_id: event_map[e_id] = m_ev @@ -2093,6 +2111,11 @@ class FederationHandler(BaseHandler): if event.type == EventTypes.GuestAccess and not context.rejected: await self.maybe_kick_guest_users(event) + # If we are going to send this event over federation we precaclculate + # the joined hosts. + if event.internal_metadata.get_send_on_behalf_of(): + await self.event_creation_handler.cache_joined_hosts_for_event(event) + return context async def _check_for_soft_fail( @@ -2155,7 +2178,9 @@ class FederationHandler(BaseHandler): ) logger.debug( - "Doing soft-fail check for %s: state %s", event.event_id, current_state_ids, + "Doing soft-fail check for %s: state %s", + event.event_id, + current_state_ids, ) # Now check if event pass auth against said current state @@ -2508,7 +2533,7 @@ class FederationHandler(BaseHandler): async def construct_auth_difference( self, local_auth: Iterable[EventBase], remote_auth: Iterable[EventBase] ) -> Dict: - """ Given a local and remote auth chain, find the differences. This + """Given a local and remote auth chain, find the differences. This assumes that we have already processed all events in remote_auth Params: diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py index df29edeb83..bfb95e3eee 100644 --- a/synapse/handlers/groups_local.py +++ b/synapse/handlers/groups_local.py @@ -15,9 +15,13 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Dict, Iterable, List, Set from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError -from synapse.types import GroupID, get_domain_from_id +from synapse.types import GroupID, JsonDict, get_domain_from_id + +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer logger = logging.getLogger(__name__) @@ -56,7 +60,7 @@ def _create_rerouter(func_name): class GroupsLocalWorkerHandler: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastore() self.room_list_handler = hs.get_room_list_handler() @@ -84,7 +88,9 @@ class GroupsLocalWorkerHandler: get_group_role = _create_rerouter("get_group_role") get_group_roles = _create_rerouter("get_group_roles") - async def get_group_summary(self, group_id, requester_user_id): + async def get_group_summary( + self, group_id: str, requester_user_id: str + ) -> JsonDict: """Get the group summary for a group. If the group is remote we check that the users have valid attestations. @@ -137,14 +143,14 @@ class GroupsLocalWorkerHandler: return res - async def get_users_in_group(self, group_id, requester_user_id): - """Get users in a group - """ + async def get_users_in_group( + self, group_id: str, requester_user_id: str + ) -> JsonDict: + """Get users in a group""" if self.is_mine_id(group_id): - res = await self.groups_server_handler.get_users_in_group( + return await self.groups_server_handler.get_users_in_group( group_id, requester_user_id ) - return res group_server_name = get_domain_from_id(group_id) @@ -178,11 +184,11 @@ class GroupsLocalWorkerHandler: return res - async def get_joined_groups(self, user_id): + async def get_joined_groups(self, user_id: str) -> JsonDict: group_ids = await self.store.get_joined_groups(user_id) return {"groups": group_ids} - async def get_publicised_groups_for_user(self, user_id): + async def get_publicised_groups_for_user(self, user_id: str) -> JsonDict: if self.hs.is_mine_id(user_id): result = await self.store.get_publicised_groups_for_user(user_id) @@ -206,8 +212,10 @@ class GroupsLocalWorkerHandler: # TODO: Verify attestations return {"groups": result} - async def bulk_get_publicised_groups(self, user_ids, proxy=True): - destinations = {} + async def bulk_get_publicised_groups( + self, user_ids: Iterable[str], proxy: bool = True + ) -> JsonDict: + destinations = {} # type: Dict[str, Set[str]] local_users = set() for user_id in user_ids: @@ -220,7 +228,7 @@ class GroupsLocalWorkerHandler: raise SynapseError(400, "Some user_ids are not local") results = {} - failed_results = [] + failed_results = [] # type: List[str] for destination, dest_user_ids in destinations.items(): try: r = await self.transport_client.bulk_get_publicised_groups( @@ -242,7 +250,7 @@ class GroupsLocalWorkerHandler: class GroupsLocalHandler(GroupsLocalWorkerHandler): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) # Ensure attestations get renewed @@ -271,9 +279,10 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): set_group_join_policy = _create_rerouter("set_group_join_policy") - async def create_group(self, group_id, user_id, content): - """Create a group - """ + async def create_group( + self, group_id: str, user_id: str, content: JsonDict + ) -> JsonDict: + """Create a group""" logger.info("Asking to create group with ID: %r", group_id) @@ -284,27 +293,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): local_attestation = None remote_attestation = None else: - local_attestation = self.attestations.create_attestation(group_id, user_id) - content["attestation"] = local_attestation - - content["user_profile"] = await self.profile_handler.get_profile(user_id) - - try: - res = await self.transport_client.create_group( - get_domain_from_id(group_id), group_id, user_id, content - ) - except HttpResponseException as e: - raise e.to_synapse_error() - except RequestSendFailed: - raise SynapseError(502, "Failed to contact group server") - - remote_attestation = res["attestation"] - await self.attestations.verify_attestation( - remote_attestation, - group_id=group_id, - user_id=user_id, - server_name=get_domain_from_id(group_id), - ) + raise SynapseError(400, "Unable to create remote groups") is_publicised = content.get("publicise", False) token = await self.store.register_user_group_membership( @@ -320,9 +309,10 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): return res - async def join_group(self, group_id, user_id, content): - """Request to join a group - """ + async def join_group( + self, group_id: str, user_id: str, content: JsonDict + ) -> JsonDict: + """Request to join a group""" if self.is_mine_id(group_id): await self.groups_server_handler.join_group(group_id, user_id, content) local_attestation = None @@ -365,9 +355,10 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): return {} - async def accept_invite(self, group_id, user_id, content): - """Accept an invite to a group - """ + async def accept_invite( + self, group_id: str, user_id: str, content: JsonDict + ) -> JsonDict: + """Accept an invite to a group""" if self.is_mine_id(group_id): await self.groups_server_handler.accept_invite(group_id, user_id, content) local_attestation = None @@ -410,9 +401,10 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): return {} - async def invite(self, group_id, user_id, requester_user_id, config): - """Invite a user to a group - """ + async def invite( + self, group_id: str, user_id: str, requester_user_id: str, config: JsonDict + ) -> JsonDict: + """Invite a user to a group""" content = {"requester_user_id": requester_user_id, "config": config} if self.is_mine_id(group_id): res = await self.groups_server_handler.invite_to_group( @@ -434,9 +426,10 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): return res - async def on_invite(self, group_id, user_id, content): - """One of our users were invited to a group - """ + async def on_invite( + self, group_id: str, user_id: str, content: JsonDict + ) -> JsonDict: + """One of our users were invited to a group""" # TODO: Support auto join and rejection if not self.is_mine_id(user_id): @@ -465,10 +458,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): return {"state": "invite", "user_profile": user_profile} async def remove_user_from_group( - self, group_id, user_id, requester_user_id, content - ): - """Remove a user from a group - """ + self, group_id: str, user_id: str, requester_user_id: str, content: JsonDict + ) -> JsonDict: + """Remove a user from a group""" if user_id == requester_user_id: token = await self.store.register_user_group_membership( group_id, user_id, membership="leave" @@ -499,9 +491,10 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): return res - async def user_removed_from_group(self, group_id, user_id, content): - """One of our users was removed/kicked from a group - """ + async def user_removed_from_group( + self, group_id: str, user_id: str, content: JsonDict + ) -> None: + """One of our users was removed/kicked from a group""" # TODO: Check if user in group token = await self.store.register_user_group_membership( group_id, user_id, membership="leave" diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index f61844d688..5f346f6d6d 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -27,9 +27,11 @@ from synapse.api.errors import ( HttpResponseException, SynapseError, ) +from synapse.api.ratelimiting import Ratelimiter from synapse.config.emailconfig import ThreepidBehaviour from synapse.http import RequestTimedOutError from synapse.http.client import SimpleHttpClient +from synapse.http.site import SynapseRequest from synapse.types import JsonDict, Requester from synapse.util import json_decoder from synapse.util.hash import sha256_and_url_safe_base64 @@ -57,6 +59,35 @@ class IdentityHandler(BaseHandler): self._web_client_location = hs.config.invite_client_location + # Ratelimiters for `/requestToken` endpoints. + self._3pid_validation_ratelimiter_ip = Ratelimiter( + clock=hs.get_clock(), + rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second, + burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count, + ) + self._3pid_validation_ratelimiter_address = Ratelimiter( + clock=hs.get_clock(), + rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second, + burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count, + ) + + def ratelimit_request_token_requests( + self, + request: SynapseRequest, + medium: str, + address: str, + ): + """Used to ratelimit requests to `/requestToken` by IP and address. + + Args: + request: The associated request + medium: The type of threepid, e.g. "msisdn" or "email" + address: The actual threepid ID, e.g. the phone number or email address + """ + + self._3pid_validation_ratelimiter_ip.ratelimit((medium, request.getClientIP())) + self._3pid_validation_ratelimiter_address.ratelimit((medium, address)) + async def threepid_from_creds( self, id_server: str, creds: Dict[str, str] ) -> Optional[JsonDict]: @@ -476,6 +507,10 @@ class IdentityHandler(BaseHandler): except RequestTimedOutError: raise SynapseError(500, "Timed out contacting identity server") + # It is already checked that public_baseurl is configured since this code + # should only be used if account_threepid_delegate_msisdn is true. + assert self.hs.config.public_baseurl + # we need to tell the client to send the token back to us, since it doesn't # otherwise know where to send it, so add submit_url response parameter # (see also MSC2078) diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index fbd8df9dcc..78c3e5a10b 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -124,7 +124,8 @@ class InitialSyncHandler(BaseHandler): joined_rooms = [r.room_id for r in room_list if r.membership == Membership.JOIN] receipt = await self.store.get_linearized_receipts_for_rooms( - joined_rooms, to_key=int(now_token.receipt_key), + joined_rooms, + to_key=int(now_token.receipt_key), ) tags_by_room = await self.store.get_tags_for_user(user_id) @@ -169,7 +170,10 @@ class InitialSyncHandler(BaseHandler): self.state_handler.get_current_state, event.room_id ) elif event.membership == Membership.LEAVE: - room_end_token = RoomStreamToken(None, event.stream_ordering,) + room_end_token = RoomStreamToken( + None, + event.stream_ordering, + ) deferred_room_state = run_in_background( self.state_store.get_state_for_events, [event.event_id] ) @@ -284,7 +288,9 @@ class InitialSyncHandler(BaseHandler): membership, member_event_id, ) = await self.auth.check_user_in_room_or_world_readable( - room_id, user_id, allow_departed_users=True, + room_id, + user_id, + allow_departed_users=True, ) is_peeking = member_event_id is None diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 9dfeab09cd..c03f6c997b 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -65,8 +65,7 @@ logger = logging.getLogger(__name__) class MessageHandler: - """Contains some read only APIs to get state about a room - """ + """Contains some read only APIs to get state about a room""" def __init__(self, hs): self.auth = hs.get_auth() @@ -88,9 +87,13 @@ class MessageHandler: ) async def get_room_data( - self, user_id: str, room_id: str, event_type: str, state_key: str, + self, + user_id: str, + room_id: str, + event_type: str, + state_key: str, ) -> dict: - """ Get data from a room. + """Get data from a room. Args: user_id @@ -174,7 +177,10 @@ class MessageHandler: raise NotFoundError("Can't find event for token %s" % (at_token,)) visible_events = await filter_events_for_client( - self.storage, user_id, last_events, filter_send_to_client=False + self.storage, + user_id, + last_events, + filter_send_to_client=False, ) event = last_events[0] @@ -432,6 +438,8 @@ class EventCreationHandler: self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages + self._external_cache = hs.get_external_cache() + async def create_event( self, requester: Requester, @@ -569,7 +577,7 @@ class EventCreationHandler: async def _is_exempt_from_privacy_policy( self, builder: EventBuilder, requester: Requester ) -> bool: - """"Determine if an event to be sent is exempt from having to consent + """ "Determine if an event to be sent is exempt from having to consent to the privacy policy Args: @@ -791,9 +799,10 @@ class EventCreationHandler: """ if prev_event_ids is not None: - assert len(prev_event_ids) <= 10, ( - "Attempting to create an event with %i prev_events" - % (len(prev_event_ids),) + assert ( + len(prev_event_ids) <= 10 + ), "Attempting to create an event with %i prev_events" % ( + len(prev_event_ids), ) else: prev_event_ids = await self.store.get_prev_events_for_room(builder.room_id) @@ -819,7 +828,8 @@ class EventCreationHandler: ) if not third_party_result: logger.info( - "Event %s forbidden by third-party rules", event, + "Event %s forbidden by third-party rules", + event, ) raise SynapseError( 403, "This event is not allowed in this context", Codes.FORBIDDEN @@ -939,6 +949,8 @@ class EventCreationHandler: await self.action_generator.handle_push_actions_for_event(event, context) + await self.cache_joined_hosts_for_event(event) + try: # If we're a worker we need to hit out to the master. writer_instance = self._events_shard_config.get_instance(event.room_id) @@ -978,6 +990,44 @@ class EventCreationHandler: await self.store.remove_push_actions_from_staging(event.event_id) raise + async def cache_joined_hosts_for_event(self, event: EventBase) -> None: + """Precalculate the joined hosts at the event, when using Redis, so that + external federation senders don't have to recalculate it themselves. + """ + + if not self._external_cache.is_enabled(): + return + + # We actually store two mappings, event ID -> prev state group, + # state group -> joined hosts, which is much more space efficient + # than event ID -> joined hosts. + # + # Note: We have to cache event ID -> prev state group, as we don't + # store that in the DB. + # + # Note: We always set the state group -> joined hosts cache, even if + # we already set it, so that the expiry time is reset. + + state_entry = await self.state.resolve_state_groups_for_events( + event.room_id, event_ids=event.prev_event_ids() + ) + + if state_entry.state_group: + joined_hosts = await self.store.get_joined_hosts(event.room_id, state_entry) + + await self._external_cache.set( + "event_to_prev_state_group", + event.event_id, + state_entry.state_group, + expiry_ms=60 * 60 * 1000, + ) + await self._external_cache.set( + "get_joined_hosts", + str(state_entry.state_group), + list(joined_hosts), + expiry_ms=60 * 60 * 1000, + ) + async def _validate_canonical_alias( self, directory_handler, room_alias_str: str, expected_room_id: str ) -> None: @@ -1274,7 +1324,11 @@ class EventCreationHandler: # Since this is a dummy-event it is OK if it is sent by a # shadow-banned user. await self.handle_new_client_event( - requester, event, context, ratelimit=False, ignore_shadow_ban=True, + requester, + event, + context, + ratelimit=False, + ignore_shadow_ban=True, ) return True except AuthError: diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index 81cb2ffc6b..f73cbe2af3 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -41,13 +41,33 @@ from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart from synapse.util import json_decoder +from synapse.util.caches.cached_call import RetryOnExceptionCachedCall if TYPE_CHECKING: from synapse.server import HomeServer logger = logging.getLogger(__name__) -SESSION_COOKIE_NAME = b"oidc_session" +# we want the cookie to be returned to us even when the request is the POSTed +# result of a form on another domain, as is used with `response_mode=form_post`. +# +# Modern browsers will not do so unless we set SameSite=None; however *older* +# browsers (including all versions of Safari on iOS 12?) don't support +# SameSite=None, and interpret it as SameSite=Strict: +# https://bugs.webkit.org/show_bug.cgi?id=198181 +# +# As a rather painful workaround, we set *two* cookies, one with SameSite=None +# and one with no SameSite, in the hope that at least one of them will get +# back to us. +# +# Secure is necessary for SameSite=None (and, empirically, also breaks things +# on iOS 12.) +# +# Here we have the names of the cookies, and the options we use to set them. +_SESSION_COOKIES = [ + (b"oidc_session", b"Path=/_synapse/client/oidc; HttpOnly; Secure; SameSite=None"), + (b"oidc_session_no_samesite", b"Path=/_synapse/client/oidc; HttpOnly"), +] #: A token exchanged from the token endpoint, as per RFC6749 sec 5.1. and #: OpenID.Core sec 3.1.3.3. @@ -72,8 +92,7 @@ JWKS = TypedDict("JWKS", {"keys": List[JWK]}) class OidcHandler: - """Handles requests related to the OpenID Connect login flow. - """ + """Handles requests related to the OpenID Connect login flow.""" def __init__(self, hs: "HomeServer"): self._sso_handler = hs.get_sso_handler() @@ -102,7 +121,7 @@ class OidcHandler: ) from e async def handle_oidc_callback(self, request: SynapseRequest) -> None: - """Handle an incoming request to /_synapse/oidc/callback + """Handle an incoming request to /_synapse/client/oidc/callback Since we might want to display OIDC-related errors in a user-friendly way, we don't raise SynapseError from here. Instead, we call @@ -123,7 +142,6 @@ class OidcHandler: Args: request: the incoming request from the browser. """ - # The provider might redirect with an error. # In that case, just display it as-is. if b"error" in request.args: @@ -137,8 +155,12 @@ class OidcHandler: # either the provider misbehaving or Synapse being misconfigured. # The only exception of that is "access_denied", where the user # probably cancelled the login flow. In other cases, log those errors. - if error != "access_denied": - logger.error("Error from the OIDC provider: %s %s", error, description) + logger.log( + logging.INFO if error == "access_denied" else logging.ERROR, + "Received OIDC callback with error: %s %s", + error, + description, + ) self._sso_handler.render_error(request, error, description) return @@ -146,30 +168,37 @@ class OidcHandler: # otherwise, it is presumably a successful response. see: # https://tools.ietf.org/html/rfc6749#section-4.1.2 - # Fetch the session cookie - session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes] - if session is None: - logger.info("No session cookie found") + # Fetch the session cookie. See the comments on SESSION_COOKIES for why there + # are two. + + for cookie_name, _ in _SESSION_COOKIES: + session = request.getCookie(cookie_name) # type: Optional[bytes] + if session is not None: + break + else: + logger.info("Received OIDC callback, with no session cookie") self._sso_handler.render_error( request, "missing_session", "No session cookie found" ) return - # Remove the cookie. There is a good chance that if the callback failed + # Remove the cookies. There is a good chance that if the callback failed # once, it will fail next time and the code will already be exchanged. - # Removing it early avoids spamming the provider with token requests. - request.addCookie( - SESSION_COOKIE_NAME, - b"", - path="/_synapse/oidc", - expires="Thu, Jan 01 1970 00:00:00 UTC", - httpOnly=True, - sameSite="lax", - ) + # Removing the cookies early avoids spamming the provider with token requests. + # + # we have to build the header by hand rather than calling request.addCookie + # because the latter does not support SameSite=None + # (https://twistedmatrix.com/trac/ticket/10088) + + for cookie_name, options in _SESSION_COOKIES: + request.cookies.append( + b"%s=; Expires=Thu, Jan 01 1970 00:00:00 UTC; %s" + % (cookie_name, options) + ) # Check for the state query parameter if b"state" not in request.args: - logger.info("State parameter is missing") + logger.info("Received OIDC callback, with no state parameter") self._sso_handler.render_error( request, "invalid_request", "State parameter is missing" ) @@ -183,14 +212,16 @@ class OidcHandler: session, state ) except (MacaroonDeserializationException, ValueError) as e: - logger.exception("Invalid session") + logger.exception("Invalid session for OIDC callback") self._sso_handler.render_error(request, "invalid_session", str(e)) return except MacaroonInvalidSignatureException as e: - logger.exception("Could not verify session") + logger.exception("Could not verify session for OIDC callback") self._sso_handler.render_error(request, "mismatching_session", str(e)) return + logger.info("Received OIDC callback for IdP %s", session_data.idp_id) + oidc_provider = self._providers.get(session_data.idp_id) if not oidc_provider: logger.error("OIDC session uses unknown IdP %r", oidc_provider) @@ -210,8 +241,7 @@ class OidcHandler: class OidcError(Exception): - """Used to catch errors when calling the token_endpoint - """ + """Used to catch errors when calling the token_endpoint""" def __init__(self, error, error_description=None): self.error = error @@ -240,22 +270,27 @@ class OidcProvider: self._token_generator = token_generator + self._config = provider self._callback_url = hs.config.oidc_callback_url # type: str self._scopes = provider.scopes self._user_profile_method = provider.user_profile_method self._client_auth = ClientAuth( - provider.client_id, provider.client_secret, provider.client_auth_method, + provider.client_id, + provider.client_secret, + provider.client_auth_method, ) # type: ClientAuth self._client_auth_method = provider.client_auth_method - self._provider_metadata = OpenIDProviderMetadata( - issuer=provider.issuer, - authorization_endpoint=provider.authorization_endpoint, - token_endpoint=provider.token_endpoint, - userinfo_endpoint=provider.userinfo_endpoint, - jwks_uri=provider.jwks_uri, - ) # type: OpenIDProviderMetadata - self._provider_needs_discovery = provider.discover + + # cache of metadata for the identity provider (endpoint uris, mostly). This is + # loaded on-demand from the discovery endpoint (if discovery is enabled), with + # possible overrides from the config. Access via `load_metadata`. + self._provider_metadata = RetryOnExceptionCachedCall(self._load_metadata) + + # cache of JWKs used by the identity provider to sign tokens. Loaded on demand + # from the IdP's jwks_uri, if required. + self._jwks = RetryOnExceptionCachedCall(self._load_jwks) + self._user_mapping_provider = provider.user_mapping_provider_class( provider.user_mapping_provider_config ) @@ -274,11 +309,14 @@ class OidcProvider: # MXC URI for icon for this auth provider self.idp_icon = provider.idp_icon + # optional brand identifier for this auth provider + self.idp_brand = provider.idp_brand + self._sso_handler = hs.get_sso_handler() self._sso_handler.register_identity_provider(self) - def _validate_metadata(self): + def _validate_metadata(self, m: OpenIDProviderMetadata) -> None: """Verifies the provider metadata. This checks the validity of the currently loaded provider. Not @@ -297,7 +335,6 @@ class OidcProvider: if self._skip_verification is True: return - m = self._provider_metadata m.validate_issuer() m.validate_authorization_endpoint() m.validate_token_endpoint() @@ -332,11 +369,7 @@ class OidcProvider: ) else: # If we're not using userinfo, we need a valid jwks to validate the ID token - if m.get("jwks") is None: - if m.get("jwks_uri") is not None: - m.validate_jwks_uri() - else: - raise ValueError('"jwks_uri" must be set') + m.validate_jwks_uri() @property def _uses_userinfo(self) -> bool: @@ -353,11 +386,15 @@ class OidcProvider: or self._user_profile_method == "userinfo_endpoint" ) - async def load_metadata(self) -> OpenIDProviderMetadata: - """Load and validate the provider metadata. + async def load_metadata(self, force: bool = False) -> OpenIDProviderMetadata: + """Return the provider metadata. - The values metadatas are discovered if ``oidc_config.discovery`` is - ``True`` and then cached. + If this is the first call, the metadata is built from the config and from the + metadata discovery endpoint (if enabled), and then validated. If the metadata + is successfully validated, it is then cached for future use. + + Args: + force: If true, any cached metadata is discarded to force a reload. Raises: ValueError: if something in the provider is not valid @@ -365,18 +402,41 @@ class OidcProvider: Returns: The provider's metadata. """ - # If we are using the OpenID Discovery documents, it needs to be loaded once - # FIXME: should there be a lock here? - if self._provider_needs_discovery: - url = get_well_known_url(self._provider_metadata["issuer"], external=True) + if force: + # reset the cached call to ensure we get a new result + self._provider_metadata = RetryOnExceptionCachedCall(self._load_metadata) + + return await self._provider_metadata.get() + + async def _load_metadata(self) -> OpenIDProviderMetadata: + # start out with just the issuer (unlike the other settings, discovered issuer + # takes precedence over configured issuer, because configured issuer is + # required for discovery to take place.) + # + metadata = OpenIDProviderMetadata(issuer=self._config.issuer) + + # load any data from the discovery endpoint, if enabled + if self._config.discover: + url = get_well_known_url(self._config.issuer, external=True) metadata_response = await self._http_client.get_json(url) - # TODO: maybe update the other way around to let user override some values? - self._provider_metadata.update(metadata_response) - self._provider_needs_discovery = False + metadata.update(metadata_response) - self._validate_metadata() + # override any discovered data with any settings in our config + if self._config.authorization_endpoint: + metadata["authorization_endpoint"] = self._config.authorization_endpoint - return self._provider_metadata + if self._config.token_endpoint: + metadata["token_endpoint"] = self._config.token_endpoint + + if self._config.userinfo_endpoint: + metadata["userinfo_endpoint"] = self._config.userinfo_endpoint + + if self._config.jwks_uri: + metadata["jwks_uri"] = self._config.jwks_uri + + self._validate_metadata(metadata) + + return metadata async def load_jwks(self, force: bool = False) -> JWKS: """Load the JSON Web Key Set used to sign ID tokens. @@ -406,27 +466,27 @@ class OidcProvider: ] } """ + if force: + # reset the cached call to ensure we get a new result + self._jwks = RetryOnExceptionCachedCall(self._load_jwks) + return await self._jwks.get() + + async def _load_jwks(self) -> JWKS: if self._uses_userinfo: # We're not using jwt signing, return an empty jwk set return {"keys": []} - # First check if the JWKS are loaded in the provider metadata. - # It can happen either if the provider gives its JWKS in the discovery - # document directly or if it was already loaded once. metadata = await self.load_metadata() - jwk_set = metadata.get("jwks") - if jwk_set is not None and not force: - return jwk_set - # Loading the JWKS using the `jwks_uri` metadata + # Load the JWKS using the `jwks_uri` metadata. uri = metadata.get("jwks_uri") if not uri: + # this should be unreachable: load_metadata validates that + # there is a jwks_uri in the metadata if _uses_userinfo is unset raise RuntimeError('Missing "jwks_uri" in metadata') jwk_set = await self._http_client.get_json(uri) - # Caching the JWKS in the provider's metadata - self._provider_metadata["jwks"] = jwk_set return jwk_set async def _exchange_code(self, code: str) -> Token: @@ -484,7 +544,10 @@ class OidcProvider: # We're not using the SimpleHttpClient util methods as we don't want to # check the HTTP status code and we do the body encoding ourself. response = await self._http_client.request( - method="POST", uri=uri, data=body.encode("utf-8"), headers=headers, + method="POST", + uri=uri, + data=body.encode("utf-8"), + headers=headers, ) # This is used in multiple error messages below @@ -562,6 +625,7 @@ class OidcProvider: Returns: UserInfo: an object representing the user. """ + logger.debug("Using the OAuth2 access_token to request userinfo") metadata = await self.load_metadata() resp = await self._http_client.get_json( @@ -569,6 +633,8 @@ class OidcProvider: headers={"Authorization": ["Bearer {}".format(token["access_token"])]}, ) + logger.debug("Retrieved user info from userinfo endpoint: %r", resp) + return UserInfo(resp) async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo: @@ -597,17 +663,19 @@ class OidcProvider: claims_cls = ImplicitIDToken alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"]) - jwt = JsonWebToken(alg_values) claim_options = {"iss": {"values": [metadata["issuer"]]}} + id_token = token["id_token"] + logger.debug("Attempting to decode JWT id_token %r", id_token) + # Try to decode the keys in cache first, then retry by forcing the keys # to be reloaded jwk_set = await self.load_jwks() try: claims = jwt.decode( - token["id_token"], + id_token, key=jwk_set, claims_cls=claims_cls, claims_options=claim_options, @@ -617,13 +685,15 @@ class OidcProvider: logger.info("Reloading JWKS after decode error") jwk_set = await self.load_jwks(force=True) # try reloading the jwks claims = jwt.decode( - token["id_token"], + id_token, key=jwk_set, claims_cls=claims_cls, claims_options=claim_options, claims_params=claims_params, ) + logger.debug("Decoded id_token JWT %r; validating", claims) + claims.validate(leeway=120) # allows 2 min of clock skew return UserInfo(claims) @@ -640,7 +710,7 @@ class OidcProvider: - ``client_id``: the client ID set in ``oidc_config.client_id`` - ``response_type``: ``code`` - - ``redirect_uri``: the callback URL ; ``{base url}/_synapse/oidc/callback`` + - ``redirect_uri``: the callback URL ; ``{base url}/_synapse/client/oidc/callback`` - ``scope``: the list of scopes set in ``oidc_config.scopes`` - ``state``: a random string - ``nonce``: a random string @@ -678,14 +748,18 @@ class OidcProvider: ui_auth_session_id=ui_auth_session_id, ), ) - request.addCookie( - SESSION_COOKIE_NAME, - cookie, - path="/_synapse/oidc", - max_age="3600", - httpOnly=True, - sameSite="lax", - ) + + # Set the cookies. See the comments on _SESSION_COOKIES for why there are two. + # + # we have to build the header by hand rather than calling request.addCookie + # because the latter does not support SameSite=None + # (https://twistedmatrix.com/trac/ticket/10088) + + for cookie_name, options in _SESSION_COOKIES: + request.cookies.append( + b"%s=%s; Max-Age=3600; %s" + % (cookie_name, cookie.encode("utf-8"), options) + ) metadata = await self.load_metadata() authorization_endpoint = metadata.get("authorization_endpoint") @@ -720,7 +794,7 @@ class OidcProvider: async def handle_oidc_callback( self, request: SynapseRequest, session_data: "OidcSessionData", code: str ) -> None: - """Handle an incoming request to /_synapse/oidc/callback + """Handle an incoming request to /_synapse/client/oidc/callback By this time we have already validated the session on the synapse side, and now need to do the provider-specific operations. This includes: @@ -741,19 +815,18 @@ class OidcProvider: """ # Exchange the code with the provider try: - logger.debug("Exchanging code") + logger.debug("Exchanging OAuth2 code for a token") token = await self._exchange_code(code) except OidcError as e: - logger.exception("Could not exchange code") + logger.exception("Could not exchange OAuth2 code") self._sso_handler.render_error(request, e.error, e.error_description) return - logger.debug("Successfully obtained OAuth2 access token") + logger.debug("Successfully obtained OAuth2 token data: %r", token) # Now that we have a token, get the userinfo, either by decoding the # `id_token` or by fetching the `userinfo_endpoint`. if self._uses_userinfo: - logger.debug("Fetching userinfo") try: userinfo = await self._fetch_userinfo(token) except Exception as e: @@ -761,7 +834,6 @@ class OidcProvider: self._sso_handler.render_error(request, "fetch_error", str(e)) return else: - logger.debug("Extracting userinfo from id_token") try: userinfo = await self._parse_id_token(token, nonce=session_data.nonce) except Exception as e: @@ -954,7 +1026,9 @@ class OidcSessionTokenGenerator: A signed macaroon token with the session information. """ macaroon = pymacaroons.Macaroon( - location=self._server_name, identifier="key", key=self._macaroon_secret_key, + location=self._server_name, + identifier="key", + key=self._macaroon_secret_key, ) macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("type = session") @@ -1074,7 +1148,8 @@ class OidcSessionData: UserAttributeDict = TypedDict( - "UserAttributeDict", {"localpart": Optional[str], "display_name": Optional[str]} + "UserAttributeDict", + {"localpart": Optional[str], "display_name": Optional[str], "emails": List[str]}, ) C = TypeVar("C") @@ -1153,11 +1228,12 @@ def jinja_finalize(thing): env = Environment(finalize=jinja_finalize) -@attr.s +@attr.s(slots=True, frozen=True) class JinjaOidcMappingConfig: subject_claim = attr.ib(type=str) localpart_template = attr.ib(type=Optional[Template]) display_name_template = attr.ib(type=Optional[Template]) + email_template = attr.ib(type=Optional[Template]) extra_attributes = attr.ib(type=Dict[str, Template]) @@ -1174,23 +1250,17 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): def parse_config(config: dict) -> JinjaOidcMappingConfig: subject_claim = config.get("subject_claim", "sub") - localpart_template = None # type: Optional[Template] - if "localpart_template" in config: + def parse_template_config(option_name: str) -> Optional[Template]: + if option_name not in config: + return None try: - localpart_template = env.from_string(config["localpart_template"]) + return env.from_string(config[option_name]) except Exception as e: - raise ConfigError( - "invalid jinja template", path=["localpart_template"] - ) from e + raise ConfigError("invalid jinja template", path=[option_name]) from e - display_name_template = None # type: Optional[Template] - if "display_name_template" in config: - try: - display_name_template = env.from_string(config["display_name_template"]) - except Exception as e: - raise ConfigError( - "invalid jinja template", path=["display_name_template"] - ) from e + localpart_template = parse_template_config("localpart_template") + display_name_template = parse_template_config("display_name_template") + email_template = parse_template_config("email_template") extra_attributes = {} # type Dict[str, Template] if "extra_attributes" in config: @@ -1210,6 +1280,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): subject_claim=subject_claim, localpart_template=localpart_template, display_name_template=display_name_template, + email_template=email_template, extra_attributes=extra_attributes, ) @@ -1231,16 +1302,23 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): # a usable mxid. localpart += str(failures) if failures else "" - display_name = None # type: Optional[str] - if self._config.display_name_template is not None: - display_name = self._config.display_name_template.render( - user=userinfo - ).strip() + def render_template_field(template: Optional[Template]) -> Optional[str]: + if template is None: + return None + return template.render(user=userinfo).strip() - if display_name == "": - display_name = None + display_name = render_template_field(self._config.display_name_template) + if display_name == "": + display_name = None - return UserAttributeDict(localpart=localpart, display_name=display_name) + emails = [] # type: List[str] + email = render_template_field(self._config.email_template) + if email: + emails.append(email) + + return UserAttributeDict( + localpart=localpart, display_name=display_name, emails=emails + ) async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict: extras = {} # type: Dict[str, str] diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 5372753707..059064a4eb 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -197,7 +197,8 @@ class PaginationHandler: stream_ordering = await self.store.find_first_stream_ordering_after_ts(ts) r = await self.store.get_room_event_before_stream_ordering( - room_id, stream_ordering, + room_id, + stream_ordering, ) if not r: logger.warning( @@ -223,7 +224,12 @@ class PaginationHandler: # the background so that it's not blocking any other operation apart from # other purges in the same room. run_as_background_process( - "_purge_history", self._purge_history, purge_id, room_id, token, True, + "_purge_history", + self._purge_history, + purge_id, + room_id, + token, + True, ) def start_purge_history( @@ -389,7 +395,9 @@ class PaginationHandler: ) await self.hs.get_federation_handler().maybe_backfill( - room_id, curr_topo, limit=pagin_config.limit, + room_id, + curr_topo, + limit=pagin_config.limit, ) to_room_key = None diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 22d1e9d35c..fb85b19770 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -349,10 +349,13 @@ class PresenceHandler(BasePresenceHandler): [self.user_to_current_state[user_id] for user_id in unpersisted] ) - async def _update_states(self, new_states): + async def _update_states(self, new_states: Iterable[UserPresenceState]) -> None: """Updates presence of users. Sets the appropriate timeouts. Pokes the notifier and federation if and only if the changed presence state should be sent to clients/servers. + + Args: + new_states: The new user presence state updates to process. """ now = self.clock.time_msec() @@ -368,7 +371,7 @@ class PresenceHandler(BasePresenceHandler): new_states_dict = {} for new_state in new_states: new_states_dict[new_state.user_id] = new_state - new_state = new_states_dict.values() + new_states = new_states_dict.values() for new_state in new_states: user_id = new_state.user_id @@ -635,8 +638,7 @@ class PresenceHandler(BasePresenceHandler): self.external_process_last_updated_ms.pop(process_id, None) async def current_state_for_user(self, user_id): - """Get the current presence state for a user. - """ + """Get the current presence state for a user.""" res = await self.current_state_for_users([user_id]) return res[user_id] @@ -658,17 +660,6 @@ class PresenceHandler(BasePresenceHandler): self._push_to_remotes(states) - async def notify_for_states(self, state, stream_id): - parties = await get_interested_parties(self.store, [state]) - room_ids_to_states, users_to_states = parties - - self.notifier.on_new_event( - "presence_key", - stream_id, - rooms=room_ids_to_states.keys(), - users=[UserID.from_string(u) for u in users_to_states], - ) - def _push_to_remotes(self, states): """Sends state updates to remote servers. @@ -678,8 +669,7 @@ class PresenceHandler(BasePresenceHandler): self.federation.send_presence(states) async def incoming_presence(self, origin, content): - """Called when we receive a `m.presence` EDU from a remote server. - """ + """Called when we receive a `m.presence` EDU from a remote server.""" if not self._presence_enabled: return @@ -729,8 +719,7 @@ class PresenceHandler(BasePresenceHandler): await self._update_states(updates) async def set_state(self, target_user, state, ignore_status_msg=False): - """Set the presence state of the user. - """ + """Set the presence state of the user.""" status_msg = state.get("status_msg", None) presence = state["presence"] @@ -758,8 +747,7 @@ class PresenceHandler(BasePresenceHandler): await self._update_states([prev_state.copy_and_replace(**new_fields)]) async def is_visible(self, observed_user, observer_user): - """Returns whether a user can see another user's presence. - """ + """Returns whether a user can see another user's presence.""" observer_room_ids = await self.store.get_rooms_for_user( observer_user.to_string() ) @@ -953,8 +941,7 @@ class PresenceHandler(BasePresenceHandler): def should_notify(old_state, new_state): - """Decides if a presence state change should be sent to interested parties. - """ + """Decides if a presence state change should be sent to interested parties.""" if old_state == new_state: return False diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index c02b951031..2f62d84fb5 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -207,7 +207,8 @@ class ProfileHandler(BaseHandler): # This must be done by the target user himself. if by_admin: requester = create_requester( - target_user, authenticated_entity=requester.authenticated_entity, + target_user, + authenticated_entity=requester.authenticated_entity, ) await self.store.set_profile_displayname( diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index cc21fc2284..6a6c528849 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -49,15 +49,15 @@ class ReceiptsHandler(BaseHandler): ) else: hs.get_federation_registry().register_instances_for_edu( - "m.receipt", hs.config.worker.writers.receipts, + "m.receipt", + hs.config.worker.writers.receipts, ) self.clock = self.hs.get_clock() self.state = hs.get_state_handler() async def _received_remote_receipt(self, origin: str, content: JsonDict) -> None: - """Called when we receive an EDU of type m.receipt from a remote HS. - """ + """Called when we receive an EDU of type m.receipt from a remote HS.""" receipts = [] for room_id, room_values in content.items(): for receipt_type, users in room_values.items(): @@ -83,8 +83,7 @@ class ReceiptsHandler(BaseHandler): await self._handle_new_receipts(receipts) async def _handle_new_receipts(self, receipts: List[ReadReceipt]) -> bool: - """Takes a list of receipts, stores them and informs the notifier. - """ + """Takes a list of receipts, stores them and informs the notifier.""" min_batch_id = None # type: Optional[int] max_batch_id = None # type: Optional[int] diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index a2cf0f6f3e..3cda89657e 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -14,8 +14,9 @@ # limitations under the License. """Contains functions for registering clients.""" + import logging -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple from synapse import types from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType @@ -61,8 +62,8 @@ class RegistrationHandler(BaseHandler): self._register_device_client = RegisterDeviceReplicationServlet.make_client( hs ) - self._post_registration_client = ReplicationPostRegisterActionsServlet.make_client( - hs + self._post_registration_client = ( + ReplicationPostRegisterActionsServlet.make_client(hs) ) else: self.device_handler = hs.get_device_handler() @@ -152,7 +153,7 @@ class RegistrationHandler(BaseHandler): user_type: Optional[str] = None, default_display_name: Optional[str] = None, address: Optional[str] = None, - bind_emails: List[str] = [], + bind_emails: Iterable[str] = [], by_admin: bool = False, user_agent_ips: Optional[List[Tuple[str, str]]] = None, ) -> str: @@ -188,12 +189,15 @@ class RegistrationHandler(BaseHandler): self.check_registration_ratelimit(address) result = await self.spam_checker.check_registration_for_spam( - threepid, localpart, user_agent_ips or [], + threepid, + localpart, + user_agent_ips or [], ) if result == RegistrationBehaviour.DENY: logger.info( - "Blocked registration of %r", localpart, + "Blocked registration of %r", + localpart, ) # We return a 429 to make it not obvious that they've been # denied. @@ -202,7 +206,8 @@ class RegistrationHandler(BaseHandler): shadow_banned = result == RegistrationBehaviour.SHADOW_BAN if shadow_banned: logger.info( - "Shadow banning registration of %r", localpart, + "Shadow banning registration of %r", + localpart, ) # do not check_auth_blocking if the call is coming through the Admin API @@ -368,7 +373,9 @@ class RegistrationHandler(BaseHandler): config["room_alias_name"] = room_alias.localpart info, _ = await room_creation_handler.create_room( - fake_requester, config=config, ratelimit=False, + fake_requester, + config=config, + ratelimit=False, ) # If the room does not require an invite, but another user @@ -693,6 +700,8 @@ class RegistrationHandler(BaseHandler): access_token: The access token of the newly logged in device, or None if `inhibit_login` enabled. """ + # TODO: 3pid registration can actually happen on the workers. Consider + # refactoring it. if self.hs.config.worker_app: await self._post_registration_client( user_id=user_id, auth_result=auth_result, access_token=access_token @@ -750,7 +759,10 @@ class RegistrationHandler(BaseHandler): return await self._auth_handler.add_threepid( - user_id, threepid["medium"], threepid["address"], threepid["validated_at"], + user_id, + threepid["medium"], + threepid["address"], + threepid["validated_at"], ) # And we add an email pusher for them by default, but only @@ -802,5 +814,8 @@ class RegistrationHandler(BaseHandler): raise await self._auth_handler.add_threepid( - user_id, threepid["medium"], threepid["address"], threepid["validated_at"], + user_id, + threepid["medium"], + threepid["address"], + threepid["validated_at"], ) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index ee27d99135..a488df10d6 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -38,6 +38,7 @@ from synapse.api.filtering import Filter from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.events import EventBase from synapse.events.utils import copy_power_levels_contents +from synapse.rest.admin._base import assert_user_is_admin from synapse.storage.state import StateFilter from synapse.types import ( JsonDict, @@ -126,6 +127,10 @@ class RoomCreationHandler(BaseHandler): self.third_party_event_rules = hs.get_third_party_event_rules() + self._invite_burst_count = ( + hs.config.ratelimiting.rc_invites_per_room.burst_count + ) + async def upgrade_room( self, requester: Requester, old_room_id: str, new_version: RoomVersion ) -> str: @@ -193,7 +198,9 @@ class RoomCreationHandler(BaseHandler): if r is None: raise NotFoundError("Unknown room id %s" % (old_room_id,)) new_room_id = await self._generate_room_id( - creator_id=user_id, is_public=r["is_public"], room_version=new_version, + creator_id=user_id, + is_public=r["is_public"], + room_version=new_version, ) logger.info("Creating new room %s to replace %s", new_room_id, old_room_id) @@ -231,7 +238,9 @@ class RoomCreationHandler(BaseHandler): # now send the tombstone await self.event_creation_handler.handle_new_client_event( - requester=requester, event=tombstone_event, context=tombstone_context, + requester=requester, + event=tombstone_event, + context=tombstone_context, ) old_room_state = await tombstone_context.get_current_state_ids() @@ -252,7 +261,10 @@ class RoomCreationHandler(BaseHandler): # finally, shut down the PLs in the old room, and update them in the new # room. await self._update_upgraded_room_pls( - requester, old_room_id, new_room_id, old_room_state, + requester, + old_room_id, + new_room_id, + old_room_state, ) return new_room_id @@ -420,17 +432,20 @@ class RoomCreationHandler(BaseHandler): # Copy over user power levels now as this will not be possible with >100PL users once # the room has been created - # Calculate the minimum power level needed to clone the room event_power_levels = power_levels.get("events", {}) - state_default = power_levels.get("state_default", 0) - ban = power_levels.get("ban") + state_default = power_levels.get("state_default", 50) + ban = power_levels.get("ban", 50) needed_power_level = max(state_default, ban, max(event_power_levels.values())) + # Get the user's current power level, this matches the logic in get_user_power_level, + # but without the entire state map. + user_power_levels = power_levels.setdefault("users", {}) + users_default = power_levels.get("users_default", 0) + current_power_level = user_power_levels.get(user_id, users_default) # Raise the requester's power level in the new room if necessary - current_power_level = power_levels["users"][user_id] if current_power_level < needed_power_level: - power_levels["users"][user_id] = needed_power_level + user_power_levels[user_id] = needed_power_level await self._send_events_for_new_room( requester, @@ -562,7 +577,7 @@ class RoomCreationHandler(BaseHandler): ratelimit: bool = True, creator_join_profile: Optional[JsonDict] = None, ) -> Tuple[dict, int]: - """ Creates a new room. + """Creates a new room. Args: requester: @@ -662,6 +677,9 @@ class RoomCreationHandler(BaseHandler): invite_3pid_list = [] invite_list = [] + if len(invite_list) + len(invite_3pid_list) > self._invite_burst_count: + raise SynapseError(400, "Cannot invite so many users at once") + await self.event_creation_handler.assert_accepted_privacy_policy(requester) power_level_content_override = config.get("power_level_content_override") @@ -680,7 +698,9 @@ class RoomCreationHandler(BaseHandler): is_public = visibility == "public" room_id = await self._generate_room_id( - creator_id=user_id, is_public=is_public, room_version=room_version, + creator_id=user_id, + is_public=is_public, + room_version=room_version, ) # Check whether this visibility value is blocked by a third party module @@ -821,7 +841,7 @@ class RoomCreationHandler(BaseHandler): if room_alias: result["room_alias"] = room_alias.to_string() - # Always wait for room creation to progate before returning + # Always wait for room creation to propagate before returning await self._replication.wait_for_stream_position( self.hs.config.worker.events_shard_config.get_instance(room_id), "events", @@ -873,7 +893,10 @@ class RoomCreationHandler(BaseHandler): _, last_stream_id, ) = await self.event_creation_handler.create_and_send_nonmember_event( - creator, event, ratelimit=False, ignore_shadow_ban=True, + creator, + event, + ratelimit=False, + ignore_shadow_ban=True, ) return last_stream_id @@ -973,7 +996,10 @@ class RoomCreationHandler(BaseHandler): return last_sent_stream_id async def _generate_room_id( - self, creator_id: str, is_public: bool, room_version: RoomVersion, + self, + creator_id: str, + is_public: bool, + room_version: RoomVersion, ): # autogen room IDs and try to create it. We may clash, so just # try a few times till one goes through, giving up eventually. @@ -997,41 +1023,51 @@ class RoomCreationHandler(BaseHandler): class RoomContextHandler: def __init__(self, hs: "HomeServer"): self.hs = hs + self.auth = hs.get_auth() self.store = hs.get_datastore() self.storage = hs.get_storage() self.state_store = self.storage.state async def get_event_context( self, - user: UserID, + requester: Requester, room_id: str, event_id: str, limit: int, event_filter: Optional[Filter], + use_admin_priviledge: bool = False, ) -> Optional[JsonDict]: """Retrieves events, pagination tokens and state around a given event in a room. Args: - user + requester room_id event_id limit: The maximum number of events to return in total (excluding state). event_filter: the filter to apply to the events returned (excluding the target event_id) - + use_admin_priviledge: if `True`, return all events, regardless + of whether `user` has access to them. To be used **ONLY** + from the admin API. Returns: dict, or None if the event isn't found """ + user = requester.user + if use_admin_priviledge: + await assert_user_is_admin(self.auth, requester.user) + before_limit = math.floor(limit / 2.0) after_limit = limit - before_limit users = await self.store.get_users_in_room(room_id) is_peeking = user.to_string() not in users - def filter_evts(events): - return filter_events_for_client( + async def filter_evts(events): + if use_admin_priviledge: + return events + return await filter_events_for_client( self.storage, user.to_string(), events, is_peeking=is_peeking ) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index e001e418f9..1660921306 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -85,6 +85,17 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count, ) + self._invites_per_room_limiter = Ratelimiter( + clock=self.clock, + rate_hz=hs.config.ratelimiting.rc_invites_per_room.per_second, + burst_count=hs.config.ratelimiting.rc_invites_per_room.burst_count, + ) + self._invites_per_user_limiter = Ratelimiter( + clock=self.clock, + rate_hz=hs.config.ratelimiting.rc_invites_per_user.per_second, + burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count, + ) + # This is only used to get at ratelimit function, and # maybe_kick_guest_users. It's fine there are multiple of these as # it doesn't store state. @@ -144,6 +155,16 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): """ raise NotImplementedError() + def ratelimit_invite(self, room_id: Optional[str], invitee_user_id: str): + """Ratelimit invites by room and by target user. + + If room ID is missing then we just rate limit by target user. + """ + if room_id: + self._invites_per_room_limiter.ratelimit(room_id) + + self._invites_per_user_limiter.ratelimit(invitee_user_id) + async def _local_membership_update( self, requester: Requester, @@ -170,7 +191,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # do it up front for efficiency.) if txn_id and requester.access_token_id: existing_event_id = await self.store.get_event_id_from_transaction_id( - room_id, requester.user.to_string(), requester.access_token_id, txn_id, + room_id, + requester.user.to_string(), + requester.access_token_id, + txn_id, ) if existing_event_id: event_pos = await self.store.get_position_for_event(existing_event_id) @@ -217,7 +241,11 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): ) result_event = await self.event_creation_handler.handle_new_client_event( - requester, event, context, extra_users=[target], ratelimit=ratelimit, + requester, + event, + context, + extra_users=[target], + ratelimit=ratelimit, ) if event.membership == Membership.LEAVE: @@ -387,8 +415,14 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): raise SynapseError(403, "This room has been blocked on this server") if effective_membership_state == Membership.INVITE: + target_id = target.to_string() + if ratelimit: + # Don't ratelimit application services. + if not requester.app_service or requester.app_service.is_rate_limited(): + self.ratelimit_invite(room_id, target_id) + # block any attempts to invite the server notices mxid - if target.to_string() == self._server_notices_mxid: + if target_id == self._server_notices_mxid: raise SynapseError(HTTPStatus.FORBIDDEN, "Cannot invite this user") block_invite = False @@ -412,7 +446,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): block_invite = True if not await self.spam_checker.user_may_invite( - requester.user.to_string(), target.to_string(), room_id + requester.user.to_string(), target_id, room_id ): logger.info("Blocking invite due to spam checker") block_invite = True @@ -556,7 +590,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # send the rejection to the inviter's HS (with fallback to # local event) return await self.remote_reject_invite( - invite.event_id, txn_id, requester, content, + invite.event_id, + txn_id, + requester, + content, ) # the inviter was on our server, but has now left. Carry on @@ -1029,8 +1066,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): user: UserID, content: dict, ) -> Tuple[str, int]: - """Implements RoomMemberHandler._remote_join - """ + """Implements RoomMemberHandler._remote_join""" # filter ourselves out of remote_room_hosts: do_invite_join ignores it # and if it is the only entry we'd like to return a 404 rather than a # 500. @@ -1184,7 +1220,10 @@ class RoomMemberMasterHandler(RoomMemberHandler): event.internal_metadata.out_of_band_membership = True result_event = await self.event_creation_handler.handle_new_client_event( - requester, event, context, extra_users=[UserID.from_string(target_user)], + requester, + event, + context, + extra_users=[UserID.from_string(target_user)], ) # we know it was persisted, so must have a stream ordering assert result_event.internal_metadata.stream_ordering @@ -1192,8 +1231,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): return result_event.event_id, result_event.internal_metadata.stream_ordering async def _user_left_room(self, target: UserID, room_id: str) -> None: - """Implements RoomMemberHandler._user_left_room - """ + """Implements RoomMemberHandler._user_left_room""" user_left_room(self.distributor, target, room_id) async def forget(self, user: UserID, room_id: str) -> None: diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py index f2e88f6a5b..108730a7a1 100644 --- a/synapse/handlers/room_member_worker.py +++ b/synapse/handlers/room_member_worker.py @@ -44,8 +44,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler): user: UserID, content: dict, ) -> Tuple[str, int]: - """Implements RoomMemberHandler._remote_join - """ + """Implements RoomMemberHandler._remote_join""" if len(remote_room_hosts) == 0: raise SynapseError(404, "No known servers") @@ -80,8 +79,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler): return ret["event_id"], ret["stream_id"] async def _user_left_room(self, target: UserID, room_id: str) -> None: - """Implements RoomMemberHandler._user_left_room - """ + """Implements RoomMemberHandler._user_left_room""" await self._notify_change_client( user_id=target.to_string(), room_id=room_id, change="left" ) diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index 38461cf79d..a9645b77d8 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -23,7 +23,6 @@ from saml2.client import Saml2Client from synapse.api.errors import SynapseError from synapse.config import ConfigError -from synapse.config.saml2_config import SamlAttributeRequirement from synapse.handlers._base import BaseHandler from synapse.handlers.sso import MappingException, UserAttributes from synapse.http.servlet import parse_string @@ -78,9 +77,10 @@ class SamlHandler(BaseHandler): # user-facing name of this auth provider self.idp_name = "SAML" - # we do not currently support icons for SAML auth, but this is required by + # we do not currently support icons/brands for SAML auth, but this is required by # the SsoIdentityProvider protocol type. self.idp_icon = None + self.idp_brand = None # a map from saml session id to Saml2SessionData object self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData] @@ -121,7 +121,8 @@ class SamlHandler(BaseHandler): now = self.clock.time_msec() self._outstanding_requests_dict[reqid] = Saml2SessionData( - creation_time=now, ui_auth_session_id=ui_auth_session_id, + creation_time=now, + ui_auth_session_id=ui_auth_session_id, ) for key, value in info["headers"]: @@ -132,7 +133,7 @@ class SamlHandler(BaseHandler): raise Exception("prepare_for_authenticate didn't return a Location header") async def handle_saml_response(self, request: SynapseRequest) -> None: - """Handle an incoming request to /_matrix/saml2/authn_response + """Handle an incoming request to /_synapse/client/saml2/authn_response Args: request: the incoming request from the browser. We'll @@ -238,12 +239,10 @@ class SamlHandler(BaseHandler): # Ensure that the attributes of the logged in user meet the required # attributes. - for requirement in self._saml2_attribute_requirements: - if not _check_attribute_requirement(saml2_auth.ava, requirement): - self._sso_handler.render_error( - request, "unauthorised", "You are not authorised to log in here." - ) - return + if not self._sso_handler.check_required_attributes( + request, saml2_auth.ava, self._saml2_attribute_requirements + ): + return # Call the mapper to register/login the user try: @@ -372,21 +371,6 @@ class SamlHandler(BaseHandler): del self._outstanding_requests_dict[reqid] -def _check_attribute_requirement(ava: dict, req: SamlAttributeRequirement) -> bool: - values = ava.get(req.attribute, []) - for v in values: - if v == req.value: - return True - - logger.info( - "SAML2 attribute %s did not match required value '%s' (was '%s')", - req.attribute, - req.value, - values, - ) - return False - - DOT_REPLACE_PATTERN = re.compile( ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),)) ) @@ -467,7 +451,8 @@ class DefaultSamlMappingProvider: mxid_source = saml_response.ava[self._mxid_source_attribute][0] except KeyError: logger.warning( - "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute, + "SAML2 response lacks a '%s' attestation", + self._mxid_source_attribute, ) raise SynapseError( 400, "%s not in SAML2 response" % (self._mxid_source_attribute,) diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 66f1bbcfc4..94062e79cb 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -15,23 +15,28 @@ import itertools import logging -from typing import Iterable +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional from unpaddedbase64 import decode_base64, encode_base64 from synapse.api.constants import EventTypes, Membership from synapse.api.errors import NotFoundError, SynapseError from synapse.api.filtering import Filter +from synapse.events import EventBase from synapse.storage.state import StateFilter +from synapse.types import JsonDict, UserID from synapse.visibility import filter_events_for_client from ._base import BaseHandler +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) class SearchHandler(BaseHandler): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self._event_serializer = hs.get_event_client_serializer() self.storage = hs.get_storage() @@ -87,13 +92,15 @@ class SearchHandler(BaseHandler): return historical_room_ids - async def search(self, user, content, batch=None): + async def search( + self, user: UserID, content: JsonDict, batch: Optional[str] = None + ) -> JsonDict: """Performs a full text search for a user. Args: - user (UserID) - content (dict): Search parameters - batch (str): The next_batch parameter. Used for pagination. + user + content: Search parameters + batch: The next_batch parameter. Used for pagination. Returns: dict to be returned to the client with results of search @@ -186,7 +193,7 @@ class SearchHandler(BaseHandler): # If doing a subset of all rooms seearch, check if any of the rooms # are from an upgraded room, and search their contents as well if search_filter.rooms: - historical_room_ids = [] + historical_room_ids = [] # type: List[str] for room_id in search_filter.rooms: # Add any previous rooms to the search if they exist ids = await self.get_old_rooms_from_upgraded_room(room_id) @@ -209,8 +216,10 @@ class SearchHandler(BaseHandler): rank_map = {} # event_id -> rank of event allowed_events = [] - room_groups = {} # Holds result of grouping by room, if applicable - sender_group = {} # Holds result of grouping by sender, if applicable + # Holds result of grouping by room, if applicable + room_groups = {} # type: Dict[str, JsonDict] + # Holds result of grouping by sender, if applicable + sender_group = {} # type: Dict[str, JsonDict] # Holds the next_batch for the entire result set if one of those exists global_next_batch = None @@ -254,7 +263,7 @@ class SearchHandler(BaseHandler): s["results"].append(e.event_id) elif order_by == "recent": - room_events = [] + room_events = [] # type: List[EventBase] i = 0 pagination_token = batch_token @@ -418,13 +427,10 @@ class SearchHandler(BaseHandler): state_results = {} if include_state: - rooms = {e.room_id for e in allowed_events} - for room_id in rooms: + for room_id in {e.room_id for e in allowed_events}: state = await self.state_handler.get_current_state(room_id) state_results[room_id] = list(state.values()) - state_results.values() - # We're now about to serialize the events. We should not make any # blocking calls after this. Otherwise the 'age' will be wrong @@ -448,9 +454,9 @@ class SearchHandler(BaseHandler): if state_results: s = {} - for room_id, state in state_results.items(): + for room_id, state_events in state_results.items(): s[room_id] = await self._event_serializer.serialize_events( - state, time_now + state_events, time_now ) rooms_cat_res["state"] = s diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py index a5d67f828f..84af2dde7e 100644 --- a/synapse/handlers/set_password.py +++ b/synapse/handlers/set_password.py @@ -13,24 +13,26 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Optional +from typing import TYPE_CHECKING, Optional from synapse.api.errors import Codes, StoreError, SynapseError from synapse.types import Requester from ._base import BaseHandler +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) class SetPasswordHandler(BaseHandler): """Handler which deals with changing user account passwords""" - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() - self._password_policy_handler = hs.get_password_policy_handler() async def set_password( self, @@ -38,7 +40,7 @@ class SetPasswordHandler(BaseHandler): password_hash: str, logout_devices: bool, requester: Optional[Requester] = None, - ): + ) -> None: if not self.hs.config.password_localdb_enabled: raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN) diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index d493327a10..514b1f69d8 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -14,21 +14,34 @@ # limitations under the License. import abc import logging -from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Mapping, Optional +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Dict, + Iterable, + List, + Mapping, + Optional, + Set, +) from urllib.parse import urlencode import attr from typing_extensions import NoReturn, Protocol from twisted.web.http import Request +from twisted.web.iweb import IRequest from synapse.api.constants import LoginType -from synapse.api.errors import Codes, RedirectException, SynapseError +from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError +from synapse.config.sso import SsoAttributeRequirement from synapse.handlers.ui_auth import UIAuthSessionDataConstants from synapse.http import get_request_user_agent -from synapse.http.server import respond_with_html +from synapse.http.server import respond_with_html, respond_with_redirect from synapse.http.site import SynapseRequest -from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters +from synapse.types import Collection, JsonDict, UserID, contains_invalid_mxid_characters from synapse.util.async_helpers import Linearizer from synapse.util.stringutils import random_string @@ -80,6 +93,11 @@ class SsoIdentityProvider(Protocol): """Optional MXC URI for user-facing icon""" return None + @property + def idp_brand(self) -> Optional[str]: + """Optional branding identifier""" + return None + @abc.abstractmethod async def handle_redirect_request( self, @@ -109,7 +127,7 @@ class UserAttributes: # enter one. localpart = attr.ib(type=Optional[str]) display_name = attr.ib(type=Optional[str], default=None) - emails = attr.ib(type=List[str], default=attr.Factory(list)) + emails = attr.ib(type=Collection[str], default=attr.Factory(list)) @attr.s(slots=True) @@ -124,7 +142,7 @@ class UsernameMappingSession: # attributes returned by the ID mapper display_name = attr.ib(type=Optional[str]) - emails = attr.ib(type=List[str]) + emails = attr.ib(type=Collection[str]) # An optional dictionary of extra attributes to be provided to the client in the # login response. @@ -136,6 +154,12 @@ class UsernameMappingSession: # expiry time for the session, in milliseconds expiry_time_ms = attr.ib(type=int) + # choices made by the user + chosen_localpart = attr.ib(type=Optional[str], default=None) + use_display_name = attr.ib(type=bool, default=True) + emails_to_use = attr.ib(type=Collection[str], default=()) + terms_accepted_version = attr.ib(type=Optional[str], default=None) + # the HTTP cookie used to track the mapping session id USERNAME_MAPPING_SESSION_COOKIE_NAME = b"username_mapping_session" @@ -170,6 +194,8 @@ class SsoHandler: # map from idp_id to SsoIdentityProvider self._identity_providers = {} # type: Dict[str, SsoIdentityProvider] + self._consent_at_registration = hs.config.consent.user_consent_at_registration + def register_identity_provider(self, p: SsoIdentityProvider): p_id = p.idp_id assert p_id not in self._identity_providers @@ -235,7 +261,10 @@ class SsoHandler: respond_with_html(request, code, html) async def handle_redirect_request( - self, request: SynapseRequest, client_redirect_url: bytes, + self, + request: SynapseRequest, + client_redirect_url: bytes, + idp_id: Optional[str], ) -> str: """Handle a request to /login/sso/redirect @@ -243,6 +272,7 @@ class SsoHandler: request: incoming HTTP request client_redirect_url: the URL that we should redirect the client to after login. + idp_id: optional identity provider chosen by the client Returns: the URI to redirect to @@ -252,10 +282,19 @@ class SsoHandler: 400, "Homeserver not configured for SSO.", errcode=Codes.UNRECOGNIZED ) + # if the client chose an IdP, use that + idp = None # type: Optional[SsoIdentityProvider] + if idp_id: + idp = self._identity_providers.get(idp_id) + if not idp: + raise NotFoundError("Unknown identity provider") + # if we only have one auth provider, redirect to it directly - if len(self._identity_providers) == 1: - ap = next(iter(self._identity_providers.values())) - return await ap.handle_redirect_request(request, client_redirect_url) + elif len(self._identity_providers) == 1: + idp = next(iter(self._identity_providers.values())) + + if idp: + return await idp.handle_redirect_request(request, client_redirect_url) # otherwise, redirect to the IDP picker return "/_synapse/client/pick_idp?" + urlencode( @@ -288,7 +327,8 @@ class SsoHandler: # Check if we already have a mapping for this user. previously_registered_user_id = await self._store.get_user_by_external_id( - auth_provider_id, remote_user_id, + auth_provider_id, + remote_user_id, ) # A match was found, return the user ID. @@ -369,13 +409,16 @@ class SsoHandler: to an additional page. (e.g. to prompt for more information) """ + new_user = False + # grab a lock while we try to find a mapping for this user. This seems... # optimistic, especially for implementations that end up redirecting to # interstitial pages. with await self._mapping_lock.queue(auth_provider_id): # first of all, check if we already have a mapping for this user user_id = await self.get_sso_user_by_remote_user_id( - auth_provider_id, remote_user_id, + auth_provider_id, + remote_user_id, ) # Check for grandfathering of users. @@ -409,13 +452,19 @@ class SsoHandler: get_request_user_agent(request), request.getClientIP(), ) + new_user = True await self._auth_handler.complete_sso_login( - user_id, request, client_redirect_url, extra_login_attributes + user_id, + request, + client_redirect_url, + extra_login_attributes, + new_user=new_user, ) async def _call_attribute_mapper( - self, sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]], + self, + sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]], ) -> UserAttributes: """Call the attribute mapper function in a loop, until we get a unique userid""" for i in range(self._MAP_USERNAME_RETRIES): @@ -501,7 +550,7 @@ class SsoHandler: logger.info("Recorded registration session id %s", session_id) # Set the cookie and redirect to the username picker - e = RedirectException(b"/_synapse/client/pick_username") + e = RedirectException(b"/_synapse/client/pick_username/account_details") e.cookies.append( b"%s=%s; path=/" % (USERNAME_MAPPING_SESSION_COOKIE_NAME, session_id.encode("ascii")) @@ -586,7 +635,8 @@ class SsoHandler: """ user_id = await self.get_sso_user_by_remote_user_id( - auth_provider_id, remote_user_id, + auth_provider_id, + remote_user_id, ) user_id_to_verify = await self._auth_handler.get_session_data( @@ -625,12 +675,34 @@ class SsoHandler: # render an error page. html = self._bad_user_template.render( - server_name=self._server_name, user_id_to_verify=user_id_to_verify, + server_name=self._server_name, + user_id_to_verify=user_id_to_verify, ) respond_with_html(request, 200, html) + def get_mapping_session(self, session_id: str) -> UsernameMappingSession: + """Look up the given username mapping session + + If it is not found, raises a SynapseError with an http code of 400 + + Args: + session_id: session to look up + Returns: + active mapping session + Raises: + SynapseError if the session is not found/has expired + """ + self._expire_old_sessions() + session = self._username_mapping_sessions.get(session_id) + if session: + return session + logger.info("Couldn't find session id %s", session_id) + raise SynapseError(400, "unknown session") + async def check_username_availability( - self, localpart: str, session_id: str, + self, + localpart: str, + session_id: str, ) -> bool: """Handle an "is username available" callback check @@ -645,12 +717,7 @@ class SsoHandler: # make sure that there is a valid mapping session, to stop people dictionary- # scanning for accounts - - self._expire_old_sessions() - session = self._username_mapping_sessions.get(session_id) - if not session: - logger.info("Couldn't find session id %s", session_id) - raise SynapseError(400, "unknown session") + self.get_mapping_session(session_id) logger.info( "[session %s] Checking for availability of username %s", @@ -667,7 +734,12 @@ class SsoHandler: return not user_infos async def handle_submit_username_request( - self, request: SynapseRequest, localpart: str, session_id: str + self, + request: SynapseRequest, + session_id: str, + localpart: str, + use_display_name: bool, + emails_to_use: Iterable[str], ) -> None: """Handle a request to the username-picker 'submit' endpoint @@ -677,21 +749,104 @@ class SsoHandler: request: HTTP request localpart: localpart requested by the user session_id: ID of the username mapping session, extracted from a cookie + use_display_name: whether the user wants to use the suggested display name + emails_to_use: emails that the user would like to use """ - self._expire_old_sessions() - session = self._username_mapping_sessions.get(session_id) - if not session: - logger.info("Couldn't find session id %s", session_id) - raise SynapseError(400, "unknown session") + try: + session = self.get_mapping_session(session_id) + except SynapseError as e: + self.render_error(request, "bad_session", e.msg, code=e.code) + return + + # update the session with the user's choices + session.chosen_localpart = localpart + session.use_display_name = use_display_name + + emails_from_idp = set(session.emails) + filtered_emails = set() # type: Set[str] + + # we iterate through the list rather than just building a set conjunction, so + # that we can log attempts to use unknown addresses + for email in emails_to_use: + if email in emails_from_idp: + filtered_emails.add(email) + else: + logger.warning( + "[session %s] ignoring user request to use unknown email address %r", + session_id, + email, + ) + session.emails_to_use = filtered_emails + + # we may now need to collect consent from the user, in which case, redirect + # to the consent-extraction-unit + if self._consent_at_registration: + redirect_url = b"/_synapse/client/new_user_consent" + + # otherwise, redirect to the completion page + else: + redirect_url = b"/_synapse/client/sso_register" + + respond_with_redirect(request, redirect_url) + + async def handle_terms_accepted( + self, request: Request, session_id: str, terms_version: str + ): + """Handle a request to the new-user 'consent' endpoint + + Will serve an HTTP response to the request. + + Args: + request: HTTP request + session_id: ID of the username mapping session, extracted from a cookie + terms_version: the version of the terms which the user viewed and consented + to + """ + logger.info( + "[session %s] User consented to terms version %s", + session_id, + terms_version, + ) + try: + session = self.get_mapping_session(session_id) + except SynapseError as e: + self.render_error(request, "bad_session", e.msg, code=e.code) + return + + session.terms_accepted_version = terms_version + + # we're done; now we can register the user + respond_with_redirect(request, b"/_synapse/client/sso_register") + + async def register_sso_user(self, request: Request, session_id: str) -> None: + """Called once we have all the info we need to register a new user. - logger.info("[session %s] Registering localpart %s", session_id, localpart) + Does so and serves an HTTP response + + Args: + request: HTTP request + session_id: ID of the username mapping session, extracted from a cookie + """ + try: + session = self.get_mapping_session(session_id) + except SynapseError as e: + self.render_error(request, "bad_session", e.msg, code=e.code) + return + + logger.info( + "[session %s] Registering localpart %s", + session_id, + session.chosen_localpart, + ) attributes = UserAttributes( - localpart=localpart, - display_name=session.display_name, - emails=session.emails, + localpart=session.chosen_localpart, + emails=session.emails_to_use, ) + if session.use_display_name: + attributes.display_name = session.display_name + # the following will raise a 400 error if the username has been taken in the # meantime. user_id = await self._register_mapped_user( @@ -702,7 +857,12 @@ class SsoHandler: request.getClientIP(), ) - logger.info("[session %s] Registered userid %s", session_id, user_id) + logger.info( + "[session %s] Registered userid %s with attributes %s", + session_id, + user_id, + attributes, + ) # delete the mapping session and the cookie del self._username_mapping_sessions[session_id] @@ -715,11 +875,21 @@ class SsoHandler: path=b"/", ) + auth_result = {} + if session.terms_accepted_version: + # TODO: make this less awful. + auth_result[LoginType.TERMS] = True + + await self._registration_handler.post_registration_actions( + user_id, auth_result, access_token=None + ) + await self._auth_handler.complete_sso_login( user_id, request, session.client_redirect_url, session.extra_login_attributes, + new_user=True, ) def _expire_old_sessions(self): @@ -733,3 +903,82 @@ class SsoHandler: for session_id in to_expire: logger.info("Expiring mapping session %s", session_id) del self._username_mapping_sessions[session_id] + + def check_required_attributes( + self, + request: SynapseRequest, + attributes: Mapping[str, List[Any]], + attribute_requirements: Iterable[SsoAttributeRequirement], + ) -> bool: + """ + Confirm that the required attributes were present in the SSO response. + + If all requirements are met, this will return True. + + If any requirement is not met, then the request will be finalized by + showing an error page to the user and False will be returned. + + Args: + request: The request to (potentially) respond to. + attributes: The attributes from the SSO IdP. + attribute_requirements: The requirements that attributes must meet. + + Returns: + True if all requirements are met, False if any attribute fails to + meet the requirement. + + """ + # Ensure that the attributes of the logged in user meet the required + # attributes. + for requirement in attribute_requirements: + if not _check_attribute_requirement(attributes, requirement): + self.render_error( + request, "unauthorised", "You are not authorised to log in here." + ) + return False + + return True + + +def get_username_mapping_session_cookie_from_request(request: IRequest) -> str: + """Extract the session ID from the cookie + + Raises a SynapseError if the cookie isn't found + """ + session_id = request.getCookie(USERNAME_MAPPING_SESSION_COOKIE_NAME) + if not session_id: + raise SynapseError(code=400, msg="missing session_id") + return session_id.decode("ascii", errors="replace") + + +def _check_attribute_requirement( + attributes: Mapping[str, List[Any]], req: SsoAttributeRequirement +) -> bool: + """Check if SSO attributes meet the proper requirements. + + Args: + attributes: A mapping of attributes to an iterable of one or more values. + requirement: The configured requirement to check. + + Returns: + True if the required attribute was found and had a proper value. + """ + if req.attribute not in attributes: + logger.info("SSO attribute missing: %s", req.attribute) + return False + + # If the requirement is None, the attribute existing is enough. + if req.value is None: + return True + + values = attributes[req.attribute] + if req.value in values: + return True + + logger.info( + "SSO attribute %s did not match required value '%s' (was '%s')", + req.attribute, + req.value, + values, + ) + return False diff --git a/synapse/handlers/state_deltas.py b/synapse/handlers/state_deltas.py index fb4f70e8e2..b3f9875358 100644 --- a/synapse/handlers/state_deltas.py +++ b/synapse/handlers/state_deltas.py @@ -14,15 +14,25 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer logger = logging.getLogger(__name__) class StateDeltasHandler: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() - async def _get_key_change(self, prev_event_id, event_id, key_name, public_value): + async def _get_key_change( + self, + prev_event_id: Optional[str], + event_id: Optional[str], + key_name: str, + public_value: str, + ) -> Optional[bool]: """Given two events check if the `key_name` field in content changed from not matching `public_value` to doing so. diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py index dc62b21c06..924281144c 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py @@ -12,13 +12,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging from collections import Counter +from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Tuple + +from typing_extensions import Counter as CounterType from synapse.api.constants import EventTypes, Membership from synapse.metrics import event_processing_positions from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer logger = logging.getLogger(__name__) @@ -31,7 +37,7 @@ class StatsHandler: Heavily derived from UserDirectoryHandler """ - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastore() self.state = hs.get_state_handler() @@ -44,7 +50,7 @@ class StatsHandler: self.stats_enabled = hs.config.stats_enabled # The current position in the current_state_delta stream - self.pos = None + self.pos = None # type: Optional[int] # Guard to ensure we only process deltas one at a time self._is_processing = False @@ -56,9 +62,8 @@ class StatsHandler: # we start populating stats self.clock.call_later(0, self.notify_new_event) - def notify_new_event(self): - """Called when there may be more deltas to process - """ + def notify_new_event(self) -> None: + """Called when there may be more deltas to process""" if not self.stats_enabled or self._is_processing: return @@ -72,7 +77,7 @@ class StatsHandler: run_as_background_process("stats.notify_new_event", process) - async def _unsafe_process(self): + async def _unsafe_process(self) -> None: # If self.pos is None then means we haven't fetched it from DB if self.pos is None: self.pos = await self.store.get_stats_positions() @@ -110,10 +115,10 @@ class StatsHandler: ) for room_id, fields in room_count.items(): - room_deltas.setdefault(room_id, {}).update(fields) + room_deltas.setdefault(room_id, Counter()).update(fields) for user_id, fields in user_count.items(): - user_deltas.setdefault(user_id, {}).update(fields) + user_deltas.setdefault(user_id, Counter()).update(fields) logger.debug("room_deltas: %s", room_deltas) logger.debug("user_deltas: %s", user_deltas) @@ -131,19 +136,20 @@ class StatsHandler: self.pos = max_pos - async def _handle_deltas(self, deltas): + async def _handle_deltas( + self, deltas: Iterable[JsonDict] + ) -> Tuple[Dict[str, CounterType[str]], Dict[str, CounterType[str]]]: """Called with the state deltas to process Returns: - tuple[dict[str, Counter], dict[str, counter]] Two dicts: the room deltas and the user deltas, mapping from room/user ID to changes in the various fields. """ - room_to_stats_deltas = {} - user_to_stats_deltas = {} + room_to_stats_deltas = {} # type: Dict[str, CounterType[str]] + user_to_stats_deltas = {} # type: Dict[str, CounterType[str]] - room_to_state_updates = {} + room_to_state_updates = {} # type: Dict[str, Dict[str, Any]] for delta in deltas: typ = delta["type"] @@ -173,7 +179,7 @@ class StatsHandler: ) continue - event_content = {} + event_content = {} # type: JsonDict sender = None if event_id is not None: @@ -257,13 +263,13 @@ class StatsHandler: ) if has_changed_joinedness: - delta = +1 if membership == Membership.JOIN else -1 + membership_delta = +1 if membership == Membership.JOIN else -1 user_to_stats_deltas.setdefault(user_id, Counter())[ "joined_rooms" - ] += delta + ] += membership_delta - room_stats_delta["local_users_in_room"] += delta + room_stats_delta["local_users_in_room"] += membership_delta elif typ == EventTypes.Create: room_state["is_federatable"] = ( diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 5c7590f38e..4e8ed7b33f 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -339,8 +339,7 @@ class SyncHandler: since_token: Optional[StreamToken] = None, full_state: bool = False, ) -> SyncResult: - """Get the sync for client needed to match what the server has now. - """ + """Get the sync for client needed to match what the server has now.""" return await self.generate_sync_result(sync_config, since_token, full_state) async def push_rules_for_user(self, user: UserID) -> JsonDict: @@ -564,7 +563,7 @@ class SyncHandler: stream_position: StreamToken, state_filter: StateFilter = StateFilter.all(), ) -> StateMap[str]: - """ Get the room state at a particular stream position + """Get the room state at a particular stream position Args: room_id: room for which to get state @@ -598,7 +597,7 @@ class SyncHandler: state: MutableStateMap[EventBase], now_token: StreamToken, ) -> Optional[JsonDict]: - """ Works out a room summary block for this room, summarising the number + """Works out a room summary block for this room, summarising the number of joined members in the room, and providing the 'hero' members if the room has no name so clients can consistently name rooms. Also adds state events to 'state' if needed to describe the heroes. @@ -743,7 +742,7 @@ class SyncHandler: now_token: StreamToken, full_state: bool, ) -> MutableStateMap[EventBase]: - """ Works out the difference in state between the start of the timeline + """Works out the difference in state between the start of the timeline and the previous sync. Args: @@ -820,8 +819,10 @@ class SyncHandler: ) elif batch.limited: if batch: - state_at_timeline_start = await self.state_store.get_state_ids_for_event( - batch.events[0].event_id, state_filter=state_filter + state_at_timeline_start = ( + await self.state_store.get_state_ids_for_event( + batch.events[0].event_id, state_filter=state_filter + ) ) else: # We can get here if the user has ignored the senders of all @@ -955,8 +956,7 @@ class SyncHandler: since_token: Optional[StreamToken] = None, full_state: bool = False, ) -> SyncResult: - """Generates a sync result. - """ + """Generates a sync result.""" # NB: The now_token gets changed by some of the generate_sync_* methods, # this is due to some of the underlying streams not supporting the ability # to query up to a given point. @@ -1030,8 +1030,8 @@ class SyncHandler: one_time_key_counts = await self.store.count_e2e_one_time_keys( user_id, device_id ) - unused_fallback_key_types = await self.store.get_e2e_unused_fallback_key_types( - user_id, device_id + unused_fallback_key_types = ( + await self.store.get_e2e_unused_fallback_key_types(user_id, device_id) ) logger.debug("Fetching group data") @@ -1176,8 +1176,10 @@ class SyncHandler: # weren't in the previous sync *or* they left and rejoined. users_that_have_changed.update(newly_joined_or_invited_users) - user_signatures_changed = await self.store.get_users_whose_signatures_changed( - user_id, since_token.device_list_key + user_signatures_changed = ( + await self.store.get_users_whose_signatures_changed( + user_id, since_token.device_list_key + ) ) users_that_have_changed.update(user_signatures_changed) @@ -1393,8 +1395,10 @@ class SyncHandler: logger.debug("no-oping sync") return set(), set(), set(), set() - ignored_account_data = await self.store.get_global_account_data_by_type_for_user( - AccountDataTypes.IGNORED_USER_LIST, user_id=user_id + ignored_account_data = ( + await self.store.get_global_account_data_by_type_for_user( + AccountDataTypes.IGNORED_USER_LIST, user_id=user_id + ) ) # If there is ignored users account data and it matches the proper type, @@ -1499,8 +1503,7 @@ class SyncHandler: async def _get_rooms_changed( self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str] ) -> _RoomChanges: - """Gets the the changes that have happened since the last sync. - """ + """Gets the the changes that have happened since the last sync.""" user_id = sync_result_builder.sync_config.user.to_string() since_token = sync_result_builder.since_token now_token = sync_result_builder.now_token diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index e919a8f9ed..096d199f4c 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -15,13 +15,13 @@ import logging import random from collections import namedtuple -from typing import TYPE_CHECKING, List, Set, Tuple +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple from synapse.api.errors import AuthError, ShadowBanError, SynapseError from synapse.appservice import ApplicationService from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.tcp.streams import TypingStream -from synapse.types import JsonDict, UserID, get_domain_from_id +from synapse.types import JsonDict, Requester, UserID, get_domain_from_id from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.metrics import Measure from synapse.util.wheel_timer import WheelTimer @@ -61,23 +61,23 @@ class FollowerTypingHandler: if hs.config.worker.writers.typing != hs.get_instance_name(): hs.get_federation_registry().register_instance_for_edu( - "m.typing", hs.config.worker.writers.typing, + "m.typing", + hs.config.worker.writers.typing, ) # map room IDs to serial numbers - self._room_serials = {} + self._room_serials = {} # type: Dict[str, int] # map room IDs to sets of users currently typing - self._room_typing = {} + self._room_typing = {} # type: Dict[str, Set[str]] - self._member_last_federation_poke = {} + self._member_last_federation_poke = {} # type: Dict[RoomMember, int] self.wheel_timer = WheelTimer(bucket_size=5000) self._latest_room_serial = 0 self.clock.looping_call(self._handle_timeouts, 5000) - def _reset(self): - """Reset the typing handler's data caches. - """ + def _reset(self) -> None: + """Reset the typing handler's data caches.""" # map room IDs to serial numbers self._room_serials = {} # map room IDs to sets of users currently typing @@ -86,7 +86,7 @@ class FollowerTypingHandler: self._member_last_federation_poke = {} self.wheel_timer = WheelTimer(bucket_size=5000) - def _handle_timeouts(self): + def _handle_timeouts(self) -> None: logger.debug("Checking for typing timeouts") now = self.clock.time_msec() @@ -96,7 +96,7 @@ class FollowerTypingHandler: for member in members: self._handle_timeout_for_member(now, member) - def _handle_timeout_for_member(self, now: int, member: RoomMember): + def _handle_timeout_for_member(self, now: int, member: RoomMember) -> None: if not self.is_typing(member): # Nothing to do if they're no longer typing return @@ -114,10 +114,10 @@ class FollowerTypingHandler: # each person typing. self.wheel_timer.insert(now=now, obj=member, then=now + 60 * 1000) - def is_typing(self, member): + def is_typing(self, member: RoomMember) -> bool: return member.user_id in self._room_typing.get(member.room_id, []) - async def _push_remote(self, member, typing): + async def _push_remote(self, member: RoomMember, typing: bool) -> None: if not self.federation: return @@ -148,9 +148,8 @@ class FollowerTypingHandler: def process_replication_rows( self, token: int, rows: List[TypingStream.TypingStreamRow] - ): - """Should be called whenever we receive updates for typing stream. - """ + ) -> None: + """Should be called whenever we receive updates for typing stream.""" if self._latest_room_serial > token: # The master has gone backwards. To prevent inconsistent data, just @@ -178,7 +177,7 @@ class FollowerTypingHandler: async def _send_changes_in_typing_to_remotes( self, room_id: str, prev_typing: Set[str], now_typing: Set[str] - ): + ) -> None: """Process a change in typing of a room from replication, sending EDUs for any local users. """ @@ -194,12 +193,12 @@ class FollowerTypingHandler: if self.is_mine_id(user_id): await self._push_remote(RoomMember(room_id, user_id), False) - def get_current_token(self): + def get_current_token(self) -> int: return self._latest_room_serial class TypingWriterHandler(FollowerTypingHandler): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) assert hs.config.worker.writers.typing == hs.get_instance_name() @@ -213,14 +212,15 @@ class TypingWriterHandler(FollowerTypingHandler): hs.get_distributor().observe("user_left_room", self.user_left_room) - self._member_typing_until = {} # clock time we expect to stop + # clock time we expect to stop + self._member_typing_until = {} # type: Dict[RoomMember, int] # caches which room_ids changed at which serials self._typing_stream_change_cache = StreamChangeCache( "TypingStreamChangeCache", self._latest_room_serial ) - def _handle_timeout_for_member(self, now: int, member: RoomMember): + def _handle_timeout_for_member(self, now: int, member: RoomMember) -> None: super()._handle_timeout_for_member(now, member) if not self.is_typing(member): @@ -233,7 +233,9 @@ class TypingWriterHandler(FollowerTypingHandler): self._stopped_typing(member) return - async def started_typing(self, target_user, requester, room_id, timeout): + async def started_typing( + self, target_user: UserID, requester: Requester, room_id: str, timeout: int + ) -> None: target_user_id = target_user.to_string() auth_user_id = requester.user.to_string() @@ -263,11 +265,13 @@ class TypingWriterHandler(FollowerTypingHandler): if was_present: # No point sending another notification - return None + return self._push_update(member=member, typing=True) - async def stopped_typing(self, target_user, requester, room_id): + async def stopped_typing( + self, target_user: UserID, requester: Requester, room_id: str + ) -> None: target_user_id = target_user.to_string() auth_user_id = requester.user.to_string() @@ -290,23 +294,23 @@ class TypingWriterHandler(FollowerTypingHandler): self._stopped_typing(member) - def user_left_room(self, user, room_id): + def user_left_room(self, user: UserID, room_id: str) -> None: user_id = user.to_string() if self.is_mine_id(user_id): member = RoomMember(room_id=room_id, user_id=user_id) self._stopped_typing(member) - def _stopped_typing(self, member): + def _stopped_typing(self, member: RoomMember) -> None: if member.user_id not in self._room_typing.get(member.room_id, set()): # No point - return None + return self._member_typing_until.pop(member, None) self._member_last_federation_poke.pop(member, None) self._push_update(member=member, typing=False) - def _push_update(self, member, typing): + def _push_update(self, member: RoomMember, typing: bool) -> None: if self.hs.is_mine_id(member.user_id): # Only send updates for changes to our own users. run_as_background_process( @@ -315,7 +319,7 @@ class TypingWriterHandler(FollowerTypingHandler): self._push_update_local(member=member, typing=typing) - async def _recv_edu(self, origin, content): + async def _recv_edu(self, origin: str, content: JsonDict) -> None: room_id = content["room_id"] user_id = content["user_id"] @@ -340,7 +344,7 @@ class TypingWriterHandler(FollowerTypingHandler): self.wheel_timer.insert(now=now, obj=member, then=now + FEDERATION_TIMEOUT) self._push_update_local(member=member, typing=content["typing"]) - def _push_update_local(self, member, typing): + def _push_update_local(self, member: RoomMember, typing: bool) -> None: room_set = self._room_typing.setdefault(member.room_id, set()) if typing: room_set.add(member.user_id) @@ -386,7 +390,7 @@ class TypingWriterHandler(FollowerTypingHandler): changed_rooms = self._typing_stream_change_cache.get_all_entities_changed( last_id - ) + ) # type: Optional[Iterable[str]] if changed_rooms is None: changed_rooms = self._room_serials @@ -412,13 +416,13 @@ class TypingWriterHandler(FollowerTypingHandler): def process_replication_rows( self, token: int, rows: List[TypingStream.TypingStreamRow] - ): + ) -> None: # The writing process should never get updates from replication. raise Exception("Typing writer instance got typing info over replication") class TypingNotificationEventSource: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.clock = hs.get_clock() # We can't call get_typing_handler here because there's a cycle: @@ -427,7 +431,7 @@ class TypingNotificationEventSource: # self.get_typing_handler = hs.get_typing_handler - def _make_event_for(self, room_id): + def _make_event_for(self, room_id: str) -> JsonDict: typing = self.get_typing_handler()._room_typing[room_id] return { "type": "m.typing", @@ -462,7 +466,9 @@ class TypingNotificationEventSource: return (events, handler._latest_room_serial) - async def get_new_events(self, from_key, room_ids, **kwargs): + async def get_new_events( + self, from_key: int, room_ids: Iterable[str], **kwargs + ) -> Tuple[List[JsonDict], int]: with Measure(self.clock, "typing.get_new_events"): from_key = int(from_key) handler = self.get_typing_handler() @@ -478,5 +484,5 @@ class TypingNotificationEventSource: return (events, handler._latest_room_serial) - def get_current_key(self): + def get_current_key(self) -> int: return self.get_typing_handler()._latest_room_serial diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index d4651c8348..1a8340000a 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -97,8 +97,7 @@ class UserDirectoryHandler(StateDeltasHandler): return results def notify_new_event(self) -> None: - """Called when there may be more deltas to process - """ + """Called when there may be more deltas to process""" if not self.update_user_directory: return @@ -134,8 +133,7 @@ class UserDirectoryHandler(StateDeltasHandler): ) async def handle_user_deactivated(self, user_id: str) -> None: - """Called when a user ID is deactivated - """ + """Called when a user ID is deactivated""" # FIXME(#3714): We should probably do this in the same worker as all # the other changes. await self.store.remove_from_user_dir(user_id) @@ -145,7 +143,7 @@ class UserDirectoryHandler(StateDeltasHandler): if self.pos is None: self.pos = await self.store.get_user_directory_stream_pos() - # If still None then the initial background update hasn't happened yet + # If still None then the initial background update hasn't happened yet. if self.pos is None: return None @@ -176,8 +174,7 @@ class UserDirectoryHandler(StateDeltasHandler): await self.store.update_user_directory_stream_pos(max_pos) async def _handle_deltas(self, deltas: List[Dict[str, Any]]) -> None: - """Called with the state deltas to process - """ + """Called with the state deltas to process""" for delta in deltas: typ = delta["type"] state_key = delta["state_key"] @@ -233,6 +230,11 @@ class UserDirectoryHandler(StateDeltasHandler): if change: # The user joined event = await self.store.get_event(event_id, allow_none=True) + # It isn't expected for this event to not exist, but we + # don't want the entire background process to break. + if event is None: + continue + profile = ProfileInfo( avatar_url=event.content.get("avatar_url"), display_name=event.content.get("displayname"), |