summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/admin.py6
-rw-r--r--synapse/handlers/appservice.py4
-rw-r--r--synapse/handlers/auth.py34
-rw-r--r--synapse/handlers/cas_handler.py56
-rw-r--r--synapse/handlers/deactivate_account.py6
-rw-r--r--synapse/handlers/device.py30
-rw-r--r--synapse/handlers/devicemessage.py7
-rw-r--r--synapse/handlers/e2e_keys.py24
-rw-r--r--synapse/handlers/events.py3
-rw-r--r--synapse/handlers/federation.py109
-rw-r--r--synapse/handlers/groups_local.py24
-rw-r--r--synapse/handlers/identity.py5
-rw-r--r--synapse/handlers/initial_sync.py12
-rw-r--r--synapse/handlers/message.py34
-rw-r--r--synapse/handlers/oidc_handler.py193
-rw-r--r--synapse/handlers/pagination.py14
-rw-r--r--synapse/handlers/presence.py15
-rw-r--r--synapse/handlers/profile.py3
-rw-r--r--synapse/handlers/receipts.py9
-rw-r--r--synapse/handlers/register.py28
-rw-r--r--synapse/handlers/room.py44
-rw-r--r--synapse/handlers/room_member.py27
-rw-r--r--synapse/handlers/room_member_worker.py6
-rw-r--r--synapse/handlers/saml_handler.py32
-rw-r--r--synapse/handlers/sso.py93
-rw-r--r--synapse/handlers/stats.py3
-rw-r--r--synapse/handlers/sync.py37
-rw-r--r--synapse/handlers/typing.py9
-rw-r--r--synapse/handlers/user_directory.py9
29 files changed, 540 insertions, 336 deletions
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py

index 37e63da9b1..db68c94c50 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py
@@ -203,13 +203,11 @@ class AdminHandler(BaseHandler): class ExfiltrationWriter(metaclass=abc.ABCMeta): - """Interface used to specify how to write exported data. - """ + """Interface used to specify how to write exported data.""" @abc.abstractmethod def write_events(self, room_id: str, events: List[EventBase]) -> None: - """Write a batch of events for a room. - """ + """Write a batch of events for a room.""" raise NotImplementedError() @abc.abstractmethod diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 5c6458eb52..deab8ff2d0 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py
@@ -290,7 +290,9 @@ class ApplicationServicesHandler: if not interested: continue presence_events, _ = await presence_source.get_new_events( - user=user, service=service, from_key=from_key, + user=user, + service=service, + from_key=from_key, ) time_now = self.clock.time_msec() events.extend( diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 648fe91f53..9ba9f591d9 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py
@@ -120,7 +120,9 @@ def convert_client_dict_legacy_fields_to_identifier( # Ensure the identifier has a type if "type" not in identifier: raise SynapseError( - 400, "'identifier' dict has no key 'type'", errcode=Codes.MISSING_PARAM, + 400, + "'identifier' dict has no key 'type'", + errcode=Codes.MISSING_PARAM, ) return identifier @@ -351,7 +353,11 @@ class AuthHandler(BaseHandler): try: result, params, session_id = await self.check_ui_auth( - flows, request, request_body, description, get_new_session_data, + flows, + request, + request_body, + description, + get_new_session_data, ) except LoginError: # Update the ratelimiter to say we failed (`can_do_action` doesn't raise). @@ -379,8 +385,7 @@ class AuthHandler(BaseHandler): return params, session_id async def _get_available_ui_auth_types(self, user: UserID) -> Iterable[str]: - """Get a list of the authentication types this user can use - """ + """Get a list of the authentication types this user can use""" ui_auth_types = set() @@ -723,7 +728,9 @@ class AuthHandler(BaseHandler): } def _auth_dict_for_flows( - self, flows: List[List[str]], session_id: str, + self, + flows: List[List[str]], + session_id: str, ) -> Dict[str, Any]: public_flows = [] for f in flows: @@ -880,7 +887,9 @@ class AuthHandler(BaseHandler): return self._supported_login_types async def validate_login( - self, login_submission: Dict[str, Any], ratelimit: bool = False, + self, + login_submission: Dict[str, Any], + ratelimit: bool = False, ) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]: """Authenticates the user for the /login API @@ -1023,7 +1032,9 @@ class AuthHandler(BaseHandler): raise async def _validate_userid_login( - self, username: str, login_submission: Dict[str, Any], + self, + username: str, + login_submission: Dict[str, Any], ) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]: """Helper for validate_login @@ -1446,7 +1457,8 @@ class AuthHandler(BaseHandler): # is considered OK since the newest SSO attributes should be most valid. if extra_attributes: self._extra_attributes[registered_user_id] = SsoLoginExtraAttributes( - self._clock.time_msec(), extra_attributes, + self._clock.time_msec(), + extra_attributes, ) # Create a login token @@ -1702,5 +1714,9 @@ class PasswordProvider: # This might return an awaitable, if it does block the log out # until it completes. await maybe_awaitable( - g(user_id=user_id, device_id=device_id, access_token=access_token,) + g( + user_id=user_id, + device_id=device_id, + access_token=access_token, + ) ) diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index bd35d1fb87..04972f9cf0 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py
@@ -14,7 +14,7 @@ # limitations under the License. import logging import urllib.parse -from typing import TYPE_CHECKING, Dict, Optional +from typing import TYPE_CHECKING, Dict, List, Optional from xml.etree import ElementTree as ET import attr @@ -33,8 +33,7 @@ logger = logging.getLogger(__name__) class CasError(Exception): - """Used to catch errors when validating the CAS ticket. - """ + """Used to catch errors when validating the CAS ticket.""" def __init__(self, error, error_description=None): self.error = error @@ -49,7 +48,7 @@ class CasError(Exception): @attr.s(slots=True, frozen=True) class CasResponse: username = attr.ib(type=str) - attributes = attr.ib(type=Dict[str, Optional[str]]) + attributes = attr.ib(type=Dict[str, List[Optional[str]]]) class CasHandler: @@ -100,7 +99,10 @@ class CasHandler: Returns: The URL to use as a "service" parameter. """ - return "%s?%s" % (self._cas_service_url, urllib.parse.urlencode(args),) + return "%s?%s" % ( + self._cas_service_url, + urllib.parse.urlencode(args), + ) async def _validate_ticket( self, ticket: str, service_args: Dict[str, str] @@ -169,7 +171,7 @@ class CasHandler: # Iterate through the nodes and pull out the user and any extra attributes. user = None - attributes = {} + attributes = {} # type: Dict[str, List[Optional[str]]] for child in root[0]: if child.tag.endswith("user"): user = child.text @@ -182,7 +184,7 @@ class CasHandler: tag = attribute.tag if "}" in tag: tag = tag.split("}")[1] - attributes[tag] = attribute.text + attributes.setdefault(tag, []).append(attribute.text) # Ensure a user was found. if user is None: @@ -296,36 +298,20 @@ class CasHandler: # first check if we're doing a UIA if session: return await self._sso_handler.complete_sso_ui_auth_request( - self.idp_id, cas_response.username, session, request, + self.idp_id, + cas_response.username, + session, + request, ) # otherwise, we're handling a login request. # Ensure that the attributes of the logged in user meet the required # attributes. - for required_attribute, required_value in self._cas_required_attributes.items(): - # If required attribute was not in CAS Response - Forbidden - if required_attribute not in cas_response.attributes: - self._sso_handler.render_error( - request, - "unauthorised", - "You are not authorised to log in here.", - 401, - ) - return - - # Also need to check value - if required_value is not None: - actual_value = cas_response.attributes[required_attribute] - # If required attribute value does not match expected - Forbidden - if required_value != actual_value: - self._sso_handler.render_error( - request, - "unauthorised", - "You are not authorised to log in here.", - 401, - ) - return + if not self._sso_handler.check_required_attributes( + request, cas_response.attributes, self._cas_required_attributes + ): + return # Call the mapper to register/login the user @@ -372,9 +358,10 @@ class CasHandler: if failures: raise RuntimeError("CAS is not expected to de-duplicate Matrix IDs") + # Arbitrarily use the first attribute found. display_name = cas_response.attributes.get( - self._cas_displayname_attribute, None - ) + self._cas_displayname_attribute, [None] + )[0] return UserAttributes(localpart=localpart, display_name=display_name) @@ -384,7 +371,8 @@ class CasHandler: user_id = UserID(localpart, self._hostname).to_string() logger.debug( - "Looking for existing account based on mapped %s", user_id, + "Looking for existing account based on mapped %s", + user_id, ) users = await self._store.get_users_by_id_case_insensitive(user_id) diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index ac25e3e94f..7911d126f5 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py
@@ -199,8 +199,7 @@ class DeactivateAccountHandler(BaseHandler): run_as_background_process("user_parter_loop", self._user_parter_loop) async def _user_parter_loop(self) -> None: - """Loop that parts deactivated users from rooms - """ + """Loop that parts deactivated users from rooms""" self._user_parter_running = True logger.info("Starting user parter") try: @@ -217,8 +216,7 @@ class DeactivateAccountHandler(BaseHandler): self._user_parter_running = False async def _part_user(self, user_id: str) -> None: - """Causes the given user_id to leave all the rooms they're joined to - """ + """Causes the given user_id to leave all the rooms they're joined to""" user = UserID.from_string(user_id) rooms_for_user = await self.store.get_rooms_for_user(user_id) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 0863154f7a..df3cdc8fba 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py
@@ -86,7 +86,7 @@ class DeviceWorkerHandler(BaseHandler): @trace async def get_device(self, user_id: str, device_id: str) -> JsonDict: - """ Retrieve the given device + """Retrieve the given device Args: user_id: The user to get the device from @@ -341,7 +341,7 @@ class DeviceHandler(DeviceWorkerHandler): @trace async def delete_device(self, user_id: str, device_id: str) -> None: - """ Delete the given device + """Delete the given device Args: user_id: The user to delete the device from. @@ -386,7 +386,7 @@ class DeviceHandler(DeviceWorkerHandler): await self.delete_devices(user_id, device_ids) async def delete_devices(self, user_id: str, device_ids: List[str]) -> None: - """ Delete several devices + """Delete several devices Args: user_id: The user to delete devices from. @@ -417,7 +417,7 @@ class DeviceHandler(DeviceWorkerHandler): await self.notify_device_update(user_id, device_ids) async def update_device(self, user_id: str, device_id: str, content: dict) -> None: - """ Update the given device + """Update the given device Args: user_id: The user to update devices of. @@ -534,7 +534,9 @@ class DeviceHandler(DeviceWorkerHandler): device id of the dehydrated device """ device_id = await self.check_device_registered( - user_id, None, initial_device_display_name, + user_id, + None, + initial_device_display_name, ) old_device_id = await self.store.store_dehydrated_device( user_id, device_id, device_data @@ -803,7 +805,8 @@ class DeviceListUpdater: try: # Try to resync the current user's devices list. result = await self.user_device_resync( - user_id=user_id, mark_failed_as_stale=False, + user_id=user_id, + mark_failed_as_stale=False, ) # user_device_resync only returns a result if it managed to @@ -813,14 +816,17 @@ class DeviceListUpdater: # self.store.update_remote_device_list_cache). if result: logger.debug( - "Successfully resynced the device list for %s", user_id, + "Successfully resynced the device list for %s", + user_id, ) except Exception as e: # If there was an issue resyncing this user, e.g. if the remote # server sent a malformed result, just log the error instead of # aborting all the subsequent resyncs. logger.debug( - "Could not resync the device list for %s: %s", user_id, e, + "Could not resync the device list for %s: %s", + user_id, + e, ) finally: # Allow future calls to retry resyncinc out of sync device lists. @@ -855,7 +861,9 @@ class DeviceListUpdater: return None except (RequestSendFailed, HttpResponseException) as e: logger.warning( - "Failed to handle device list update for %s: %s", user_id, e, + "Failed to handle device list update for %s: %s", + user_id, + e, ) if mark_failed_as_stale: @@ -931,7 +939,9 @@ class DeviceListUpdater: # Handle cross-signing keys. cross_signing_device_ids = await self.process_cross_signing_key_update( - user_id, master_key, self_signing_key, + user_id, + master_key, + self_signing_key, ) device_ids = device_ids + cross_signing_device_ids diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 0c7737e09d..1aa7d803b5 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py
@@ -62,7 +62,8 @@ class DeviceMessageHandler: ) else: hs.get_federation_registry().register_instances_for_edu( - "m.direct_to_device", hs.config.worker.writers.to_device, + "m.direct_to_device", + hs.config.worker.writers.to_device, ) # The handler to call when we think a user's device list might be out of @@ -73,8 +74,8 @@ class DeviceMessageHandler: hs.get_device_handler().device_list_updater.user_device_resync ) else: - self._user_device_resync = ReplicationUserDevicesResyncRestServlet.make_client( - hs + self._user_device_resync = ( + ReplicationUserDevicesResyncRestServlet.make_client(hs) ) async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None: diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 8f3a6b35a4..9a946a3cfe 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py
@@ -61,8 +61,8 @@ class E2eKeysHandler: self._is_master = hs.config.worker_app is None if not self._is_master: - self._user_device_resync_client = ReplicationUserDevicesResyncRestServlet.make_client( - hs + self._user_device_resync_client = ( + ReplicationUserDevicesResyncRestServlet.make_client(hs) ) else: # Only register this edu handler on master as it requires writing @@ -85,7 +85,7 @@ class E2eKeysHandler: async def query_devices( self, query_body: JsonDict, timeout: int, from_user_id: str ) -> JsonDict: - """ Handle a device key query from a client + """Handle a device key query from a client { "device_keys": { @@ -391,8 +391,7 @@ class E2eKeysHandler: async def on_federation_query_client_keys( self, query_body: Dict[str, Dict[str, Optional[List[str]]]] ) -> JsonDict: - """ Handle a device key query from a federated server - """ + """Handle a device key query from a federated server""" device_keys_query = query_body.get( "device_keys", {} ) # type: Dict[str, Optional[List[str]]] @@ -1065,7 +1064,9 @@ class E2eKeysHandler: return key, key_id, verify_key async def _retrieve_cross_signing_keys_for_remote_user( - self, user: UserID, desired_key_type: str, + self, + user: UserID, + desired_key_type: str, ) -> Tuple[Optional[dict], Optional[str], Optional[VerifyKey]]: """Queries cross-signing keys for a remote user and saves them to the database @@ -1269,8 +1270,7 @@ def _one_time_keys_match(old_key_json: str, new_key: JsonDict) -> bool: @attr.s(slots=True) class SignatureListItem: - """An item in the signature list as used by upload_signatures_for_device_keys. - """ + """An item in the signature list as used by upload_signatures_for_device_keys.""" signing_key_id = attr.ib(type=str) target_user_id = attr.ib(type=str) @@ -1355,8 +1355,12 @@ class SigningKeyEduUpdater: logger.info("pending updates: %r", pending_updates) for master_key, self_signing_key in pending_updates: - new_device_ids = await device_list_updater.process_cross_signing_key_update( - user_id, master_key, self_signing_key, + new_device_ids = ( + await device_list_updater.process_cross_signing_key_update( + user_id, + master_key, + self_signing_key, + ) ) device_ids = device_ids + new_device_ids diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index 539b4fc32e..3e23f82cf7 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py
@@ -57,8 +57,7 @@ class EventStreamHandler(BaseHandler): room_id: Optional[str] = None, is_guest: bool = False, ) -> JsonDict: - """Fetches the events stream for a given user. - """ + """Fetches the events stream for a given user.""" if room_id: blocked = await self.store.is_room_blocked(room_id) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 61bc0c8bc6..51bdf97920 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py
@@ -112,13 +112,13 @@ class _NewEventInfo: class FederationHandler(BaseHandler): """Handles events that originated from federation. - Responsible for: - a) handling received Pdus before handing them on as Events to the rest - of the homeserver (including auth and state conflict resolutions) - b) converting events that were produced by local clients that may need - to be sent to remote homeservers. - c) doing the necessary dances to invite remote users and join remote - rooms. + Responsible for: + a) handling received Pdus before handing them on as Events to the rest + of the homeserver (including auth and state conflict resolutions) + b) converting events that were produced by local clients that may need + to be sent to remote homeservers. + c) doing the necessary dances to invite remote users and join remote + rooms. """ def __init__(self, hs: "HomeServer"): @@ -151,11 +151,11 @@ class FederationHandler(BaseHandler): ) if hs.config.worker_app: - self._user_device_resync = ReplicationUserDevicesResyncRestServlet.make_client( - hs + self._user_device_resync = ( + ReplicationUserDevicesResyncRestServlet.make_client(hs) ) - self._maybe_store_room_on_outlier_membership = ReplicationStoreRoomOnOutlierMembershipRestServlet.make_client( - hs + self._maybe_store_room_on_outlier_membership = ( + ReplicationStoreRoomOnOutlierMembershipRestServlet.make_client(hs) ) else: self._device_list_updater = hs.get_device_handler().device_list_updater @@ -173,7 +173,7 @@ class FederationHandler(BaseHandler): self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None: - """ Process a PDU received via a federation /send/ transaction, or + """Process a PDU received via a federation /send/ transaction, or via backfill of missing prev_events Args: @@ -371,10 +371,8 @@ class FederationHandler(BaseHandler): # know about for p in prevs - seen: logger.info( - "[%s %s] Requesting state at missing prev_event %s", - room_id, + "Requesting state at missing prev_event %s", event_id, - p, ) with nested_logging_context(p): @@ -394,12 +392,14 @@ class FederationHandler(BaseHandler): event_map[x.event_id] = x room_version = await self.store.get_room_version_id(room_id) - state_map = await self._state_resolution_handler.resolve_events_with_store( - room_id, - room_version, - state_maps, - event_map, - state_res_store=StateResolutionStore(self.store), + state_map = ( + await self._state_resolution_handler.resolve_events_with_store( + room_id, + room_version, + state_maps, + event_map, + state_res_store=StateResolutionStore(self.store), + ) ) # We need to give _process_received_pdu the actual state events @@ -691,9 +691,12 @@ class FederationHandler(BaseHandler): return fetched_events async def _process_received_pdu( - self, origin: str, event: EventBase, state: Optional[Iterable[EventBase]], + self, + origin: str, + event: EventBase, + state: Optional[Iterable[EventBase]], ): - """ Called when we have a new pdu. We need to do auth checks and put it + """Called when we have a new pdu. We need to do auth checks and put it through the StateHandler. Args: @@ -805,7 +808,7 @@ class FederationHandler(BaseHandler): @log_function async def backfill(self, dest, room_id, limit, extremities): - """ Trigger a backfill request to `dest` for the given `room_id` + """Trigger a backfill request to `dest` for the given `room_id` This will attempt to get more events from the remote. If the other side has no new events to offer, this will return an empty list. @@ -1208,11 +1211,16 @@ class FederationHandler(BaseHandler): with nested_logging_context(event_id): try: event = await self.federation_client.get_pdu( - [destination], event_id, room_version, outlier=True, + [destination], + event_id, + room_version, + outlier=True, ) if event is None: logger.warning( - "Server %s didn't return event %s", destination, event_id, + "Server %s didn't return event %s", + destination, + event_id, ) return @@ -1239,7 +1247,8 @@ class FederationHandler(BaseHandler): if aid not in event_map ] persisted_events = await self.store.get_events( - auth_events, allow_rejected=True, + auth_events, + allow_rejected=True, ) event_infos = [] @@ -1255,7 +1264,9 @@ class FederationHandler(BaseHandler): event_infos.append(_NewEventInfo(event, None, auth)) await self._handle_new_events( - destination, room_id, event_infos, + destination, + room_id, + event_infos, ) def _sanity_check_event(self, ev): @@ -1291,7 +1302,7 @@ class FederationHandler(BaseHandler): raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events") async def send_invite(self, target_host, event): - """ Sends the invite to the remote server for signing. + """Sends the invite to the remote server for signing. Invites must be signed by the invitee's server before distribution. """ @@ -1314,7 +1325,7 @@ class FederationHandler(BaseHandler): async def do_invite_join( self, target_hosts: Iterable[str], room_id: str, joinee: str, content: JsonDict ) -> Tuple[str, int]: - """ Attempts to join the `joinee` to the room `room_id` via the + """Attempts to join the `joinee` to the room `room_id` via the servers contained in `target_hosts`. This first triggers a /make_join/ request that returns a partial @@ -1358,8 +1369,6 @@ class FederationHandler(BaseHandler): await self._clean_room_for_join(room_id) - handled_events = set() - try: # Try the host we successfully got a response to /make_join/ # request first. @@ -1379,10 +1388,6 @@ class FederationHandler(BaseHandler): auth_chain = ret["auth_chain"] auth_chain.sort(key=lambda e: e.depth) - handled_events.update([s.event_id for s in state]) - handled_events.update([a.event_id for a in auth_chain]) - handled_events.add(event.event_id) - logger.debug("do_invite_join auth_chain: %s", auth_chain) logger.debug("do_invite_join state: %s", state) @@ -1398,7 +1403,8 @@ class FederationHandler(BaseHandler): # so we can rely on it now. # await self.store.upsert_room_on_join( - room_id=room_id, room_version=room_version_obj, + room_id=room_id, + room_version=room_version_obj, ) max_stream_id = await self._persist_auth_tree( @@ -1535,7 +1541,7 @@ class FederationHandler(BaseHandler): async def on_make_join_request( self, origin: str, room_id: str, user_id: str ) -> EventBase: - """ We've received a /make_join/ request, so we create a partial + """We've received a /make_join/ request, so we create a partial join event for the room and return that. We do *not* persist or process it until the other server has signed it and sent it back. @@ -1560,7 +1566,8 @@ class FederationHandler(BaseHandler): is_in_room = await self.auth.check_host_in_room(room_id, self.server_name) if not is_in_room: logger.info( - "Got /make_join request for room %s we are no longer in", room_id, + "Got /make_join request for room %s we are no longer in", + room_id, ) raise NotFoundError("Not an active room on this server") @@ -1594,7 +1601,7 @@ class FederationHandler(BaseHandler): return event async def on_send_join_request(self, origin, pdu): - """ We have received a join event for a room. Fully process it and + """We have received a join event for a room. Fully process it and respond with the current state and auth chains. """ event = pdu @@ -1650,7 +1657,7 @@ class FederationHandler(BaseHandler): async def on_invite_request( self, origin: str, event: EventBase, room_version: RoomVersion ): - """ We've got an invite event. Process and persist it. Sign it. + """We've got an invite event. Process and persist it. Sign it. Respond with the now signed event. """ @@ -1784,7 +1791,7 @@ class FederationHandler(BaseHandler): async def on_make_leave_request( self, origin: str, room_id: str, user_id: str ) -> EventBase: - """ We've received a /make_leave/ request, so we create a partial + """We've received a /make_leave/ request, so we create a partial leave event for the room and return that. We do *not* persist or process it until the other server has signed it and sent it back. @@ -1974,8 +1981,7 @@ class FederationHandler(BaseHandler): return context async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]: - """Returns the state at the event. i.e. not including said event. - """ + """Returns the state at the event. i.e. not including said event.""" event = await self.store.get_event(event_id, check_room_id=room_id) @@ -2001,8 +2007,7 @@ class FederationHandler(BaseHandler): return [] async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]: - """Returns the state at the event. i.e. not including said event. - """ + """Returns the state at the event. i.e. not including said event.""" event = await self.store.get_event(event_id, check_room_id=room_id) state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id]) @@ -2208,7 +2213,11 @@ class FederationHandler(BaseHandler): for e_id in missing_auth_events: m_ev = await self.federation_client.get_pdu( - [origin], e_id, room_version=room_version, outlier=True, timeout=10000, + [origin], + e_id, + room_version=room_version, + outlier=True, + timeout=10000, ) if m_ev and m_ev.event_id == e_id: event_map[e_id] = m_ev @@ -2358,7 +2367,9 @@ class FederationHandler(BaseHandler): ) logger.debug( - "Doing soft-fail check for %s: state %s", event.event_id, current_state_ids, + "Doing soft-fail check for %s: state %s", + event.event_id, + current_state_ids, ) # Now check if event pass auth against said current state @@ -2711,7 +2722,7 @@ class FederationHandler(BaseHandler): async def construct_auth_difference( self, local_auth: Iterable[EventBase], remote_auth: Iterable[EventBase] ) -> Dict: - """ Given a local and remote auth chain, find the differences. This + """Given a local and remote auth chain, find the differences. This assumes that we have already processed all events in remote_auth Params: diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index 71f11ef94a..bfb95e3eee 100644 --- a/synapse/handlers/groups_local.py +++ b/synapse/handlers/groups_local.py
@@ -146,8 +146,7 @@ class GroupsLocalWorkerHandler: async def get_users_in_group( self, group_id: str, requester_user_id: str ) -> JsonDict: - """Get users in a group - """ + """Get users in a group""" if self.is_mine_id(group_id): return await self.groups_server_handler.get_users_in_group( group_id, requester_user_id @@ -283,8 +282,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): async def create_group( self, group_id: str, user_id: str, content: JsonDict ) -> JsonDict: - """Create a group - """ + """Create a group""" logger.info("Asking to create group with ID: %r", group_id) @@ -314,8 +312,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): async def join_group( self, group_id: str, user_id: str, content: JsonDict ) -> JsonDict: - """Request to join a group - """ + """Request to join a group""" if self.is_mine_id(group_id): await self.groups_server_handler.join_group(group_id, user_id, content) local_attestation = None @@ -361,8 +358,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): async def accept_invite( self, group_id: str, user_id: str, content: JsonDict ) -> JsonDict: - """Accept an invite to a group - """ + """Accept an invite to a group""" if self.is_mine_id(group_id): await self.groups_server_handler.accept_invite(group_id, user_id, content) local_attestation = None @@ -408,8 +404,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): async def invite( self, group_id: str, user_id: str, requester_user_id: str, config: JsonDict ) -> JsonDict: - """Invite a user to a group - """ + """Invite a user to a group""" content = {"requester_user_id": requester_user_id, "config": config} if self.is_mine_id(group_id): res = await self.groups_server_handler.invite_to_group( @@ -434,8 +429,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): async def on_invite( self, group_id: str, user_id: str, content: JsonDict ) -> JsonDict: - """One of our users were invited to a group - """ + """One of our users were invited to a group""" # TODO: Support auto join and rejection if not self.is_mine_id(user_id): @@ -466,8 +460,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): async def remove_user_from_group( self, group_id: str, user_id: str, requester_user_id: str, content: JsonDict ) -> JsonDict: - """Remove a user from a group - """ + """Remove a user from a group""" if user_id == requester_user_id: token = await self.store.register_user_group_membership( group_id, user_id, membership="leave" @@ -501,8 +494,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): async def user_removed_from_group( self, group_id: str, user_id: str, content: JsonDict ) -> None: - """One of our users was removed/kicked from a group - """ + """One of our users was removed/kicked from a group""" # TODO: Check if user in group token = await self.store.register_user_group_membership( group_id, user_id, membership="leave" diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 4eb0036edd..ac81fa3678 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py
@@ -75,7 +75,10 @@ class IdentityHandler(BaseHandler): ) def ratelimit_request_token_requests( - self, request: SynapseRequest, medium: str, address: str, + self, + request: SynapseRequest, + medium: str, + address: str, ): """Used to ratelimit requests to `/requestToken` by IP and address. diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index fbd8df9dcc..78c3e5a10b 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py
@@ -124,7 +124,8 @@ class InitialSyncHandler(BaseHandler): joined_rooms = [r.room_id for r in room_list if r.membership == Membership.JOIN] receipt = await self.store.get_linearized_receipts_for_rooms( - joined_rooms, to_key=int(now_token.receipt_key), + joined_rooms, + to_key=int(now_token.receipt_key), ) tags_by_room = await self.store.get_tags_for_user(user_id) @@ -169,7 +170,10 @@ class InitialSyncHandler(BaseHandler): self.state_handler.get_current_state, event.room_id ) elif event.membership == Membership.LEAVE: - room_end_token = RoomStreamToken(None, event.stream_ordering,) + room_end_token = RoomStreamToken( + None, + event.stream_ordering, + ) deferred_room_state = run_in_background( self.state_store.get_state_for_events, [event.event_id] ) @@ -284,7 +288,9 @@ class InitialSyncHandler(BaseHandler): membership, member_event_id, ) = await self.auth.check_user_in_room_or_world_readable( - room_id, user_id, allow_departed_users=True, + room_id, + user_id, + allow_departed_users=True, ) is_peeking = member_event_id is None diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 3f9f594be6..1aded280c7 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py
@@ -67,8 +67,7 @@ logger = logging.getLogger(__name__) class MessageHandler: - """Contains some read only APIs to get state about a room - """ + """Contains some read only APIs to get state about a room""" def __init__(self, hs): self.auth = hs.get_auth() @@ -90,9 +89,13 @@ class MessageHandler: ) async def get_room_data( - self, user_id: str, room_id: str, event_type: str, state_key: str, + self, + user_id: str, + room_id: str, + event_type: str, + state_key: str, ) -> dict: - """ Get data from a room. + """Get data from a room. Args: user_id @@ -176,7 +179,10 @@ class MessageHandler: raise NotFoundError("Can't find event for token %s" % (at_token,)) visible_events = await filter_events_for_client( - self.storage, user_id, last_events, filter_send_to_client=False, + self.storage, + user_id, + last_events, + filter_send_to_client=False, ) event = last_events[0] @@ -573,7 +579,7 @@ class EventCreationHandler: async def _is_exempt_from_privacy_policy( self, builder: EventBuilder, requester: Requester ) -> bool: - """"Determine if an event to be sent is exempt from having to consent + """ "Determine if an event to be sent is exempt from having to consent to the privacy policy Args: @@ -795,9 +801,10 @@ class EventCreationHandler: """ if prev_event_ids is not None: - assert len(prev_event_ids) <= 10, ( - "Attempting to create an event with %i prev_events" - % (len(prev_event_ids),) + assert ( + len(prev_event_ids) <= 10 + ), "Attempting to create an event with %i prev_events" % ( + len(prev_event_ids), ) else: prev_event_ids = await self.store.get_prev_events_for_room(builder.room_id) @@ -823,7 +830,8 @@ class EventCreationHandler: ) if not third_party_result: logger.info( - "Event %s forbidden by third-party rules", event, + "Event %s forbidden by third-party rules", + event, ) raise SynapseError( 403, "This event is not allowed in this context", Codes.FORBIDDEN @@ -1325,7 +1333,11 @@ class EventCreationHandler: # Since this is a dummy-event it is OK if it is sent by a # shadow-banned user. await self.handle_new_client_event( - requester, event, context, ratelimit=False, ignore_shadow_ban=True, + requester, + event, + context, + ratelimit=False, + ignore_shadow_ban=True, ) return True except AuthError: diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 3adc75fa4a..07db1e31e4 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py
@@ -41,13 +41,33 @@ from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart from synapse.util import json_decoder +from synapse.util.caches.cached_call import RetryOnExceptionCachedCall if TYPE_CHECKING: from synapse.server import HomeServer logger = logging.getLogger(__name__) -SESSION_COOKIE_NAME = b"oidc_session" +# we want the cookie to be returned to us even when the request is the POSTed +# result of a form on another domain, as is used with `response_mode=form_post`. +# +# Modern browsers will not do so unless we set SameSite=None; however *older* +# browsers (including all versions of Safari on iOS 12?) don't support +# SameSite=None, and interpret it as SameSite=Strict: +# https://bugs.webkit.org/show_bug.cgi?id=198181 +# +# As a rather painful workaround, we set *two* cookies, one with SameSite=None +# and one with no SameSite, in the hope that at least one of them will get +# back to us. +# +# Secure is necessary for SameSite=None (and, empirically, also breaks things +# on iOS 12.) +# +# Here we have the names of the cookies, and the options we use to set them. +_SESSION_COOKIES = [ + (b"oidc_session", b"Path=/_synapse/client/oidc; HttpOnly; Secure; SameSite=None"), + (b"oidc_session_no_samesite", b"Path=/_synapse/client/oidc; HttpOnly"), +] #: A token exchanged from the token endpoint, as per RFC6749 sec 5.1. and #: OpenID.Core sec 3.1.3.3. @@ -72,8 +92,7 @@ JWKS = TypedDict("JWKS", {"keys": List[JWK]}) class OidcHandler: - """Handles requests related to the OpenID Connect login flow. - """ + """Handles requests related to the OpenID Connect login flow.""" def __init__(self, hs: "HomeServer"): self._sso_handler = hs.get_sso_handler() @@ -149,26 +168,33 @@ class OidcHandler: # otherwise, it is presumably a successful response. see: # https://tools.ietf.org/html/rfc6749#section-4.1.2 - # Fetch the session cookie - session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes] - if session is None: + # Fetch the session cookie. See the comments on SESSION_COOKIES for why there + # are two. + + for cookie_name, _ in _SESSION_COOKIES: + session = request.getCookie(cookie_name) # type: Optional[bytes] + if session is not None: + break + else: logger.info("Received OIDC callback, with no session cookie") self._sso_handler.render_error( request, "missing_session", "No session cookie found" ) return - # Remove the cookie. There is a good chance that if the callback failed + # Remove the cookies. There is a good chance that if the callback failed # once, it will fail next time and the code will already be exchanged. - # Removing it early avoids spamming the provider with token requests. - request.addCookie( - SESSION_COOKIE_NAME, - b"", - path="/_synapse/oidc", - expires="Thu, Jan 01 1970 00:00:00 UTC", - httpOnly=True, - sameSite="lax", - ) + # Removing the cookies early avoids spamming the provider with token requests. + # + # we have to build the header by hand rather than calling request.addCookie + # because the latter does not support SameSite=None + # (https://twistedmatrix.com/trac/ticket/10088) + + for cookie_name, options in _SESSION_COOKIES: + request.cookies.append( + b"%s=; Expires=Thu, Jan 01 1970 00:00:00 UTC; %s" + % (cookie_name, options) + ) # Check for the state query parameter if b"state" not in request.args: @@ -215,8 +241,7 @@ class OidcHandler: class OidcError(Exception): - """Used to catch errors when calling the token_endpoint - """ + """Used to catch errors when calling the token_endpoint""" def __init__(self, error, error_description=None): self.error = error @@ -245,22 +270,27 @@ class OidcProvider: self._token_generator = token_generator + self._config = provider self._callback_url = hs.config.oidc_callback_url # type: str self._scopes = provider.scopes self._user_profile_method = provider.user_profile_method self._client_auth = ClientAuth( - provider.client_id, provider.client_secret, provider.client_auth_method, + provider.client_id, + provider.client_secret, + provider.client_auth_method, ) # type: ClientAuth self._client_auth_method = provider.client_auth_method - self._provider_metadata = OpenIDProviderMetadata( - issuer=provider.issuer, - authorization_endpoint=provider.authorization_endpoint, - token_endpoint=provider.token_endpoint, - userinfo_endpoint=provider.userinfo_endpoint, - jwks_uri=provider.jwks_uri, - ) # type: OpenIDProviderMetadata - self._provider_needs_discovery = provider.discover + + # cache of metadata for the identity provider (endpoint uris, mostly). This is + # loaded on-demand from the discovery endpoint (if discovery is enabled), with + # possible overrides from the config. Access via `load_metadata`. + self._provider_metadata = RetryOnExceptionCachedCall(self._load_metadata) + + # cache of JWKs used by the identity provider to sign tokens. Loaded on demand + # from the IdP's jwks_uri, if required. + self._jwks = RetryOnExceptionCachedCall(self._load_jwks) + self._user_mapping_provider = provider.user_mapping_provider_class( provider.user_mapping_provider_config ) @@ -286,7 +316,7 @@ class OidcProvider: self._sso_handler.register_identity_provider(self) - def _validate_metadata(self): + def _validate_metadata(self, m: OpenIDProviderMetadata) -> None: """Verifies the provider metadata. This checks the validity of the currently loaded provider. Not @@ -305,7 +335,6 @@ class OidcProvider: if self._skip_verification is True: return - m = self._provider_metadata m.validate_issuer() m.validate_authorization_endpoint() m.validate_token_endpoint() @@ -340,11 +369,7 @@ class OidcProvider: ) else: # If we're not using userinfo, we need a valid jwks to validate the ID token - if m.get("jwks") is None: - if m.get("jwks_uri") is not None: - m.validate_jwks_uri() - else: - raise ValueError('"jwks_uri" must be set') + m.validate_jwks_uri() @property def _uses_userinfo(self) -> bool: @@ -361,11 +386,15 @@ class OidcProvider: or self._user_profile_method == "userinfo_endpoint" ) - async def load_metadata(self) -> OpenIDProviderMetadata: - """Load and validate the provider metadata. + async def load_metadata(self, force: bool = False) -> OpenIDProviderMetadata: + """Return the provider metadata. + + If this is the first call, the metadata is built from the config and from the + metadata discovery endpoint (if enabled), and then validated. If the metadata + is successfully validated, it is then cached for future use. - The values metadatas are discovered if ``oidc_config.discovery`` is - ``True`` and then cached. + Args: + force: If true, any cached metadata is discarded to force a reload. Raises: ValueError: if something in the provider is not valid @@ -373,18 +402,41 @@ class OidcProvider: Returns: The provider's metadata. """ - # If we are using the OpenID Discovery documents, it needs to be loaded once - # FIXME: should there be a lock here? - if self._provider_needs_discovery: - url = get_well_known_url(self._provider_metadata["issuer"], external=True) + if force: + # reset the cached call to ensure we get a new result + self._provider_metadata = RetryOnExceptionCachedCall(self._load_metadata) + + return await self._provider_metadata.get() + + async def _load_metadata(self) -> OpenIDProviderMetadata: + # start out with just the issuer (unlike the other settings, discovered issuer + # takes precedence over configured issuer, because configured issuer is + # required for discovery to take place.) + # + metadata = OpenIDProviderMetadata(issuer=self._config.issuer) + + # load any data from the discovery endpoint, if enabled + if self._config.discover: + url = get_well_known_url(self._config.issuer, external=True) metadata_response = await self._http_client.get_json(url) - # TODO: maybe update the other way around to let user override some values? - self._provider_metadata.update(metadata_response) - self._provider_needs_discovery = False + metadata.update(metadata_response) + + # override any discovered data with any settings in our config + if self._config.authorization_endpoint: + metadata["authorization_endpoint"] = self._config.authorization_endpoint + + if self._config.token_endpoint: + metadata["token_endpoint"] = self._config.token_endpoint + + if self._config.userinfo_endpoint: + metadata["userinfo_endpoint"] = self._config.userinfo_endpoint - self._validate_metadata() + if self._config.jwks_uri: + metadata["jwks_uri"] = self._config.jwks_uri - return self._provider_metadata + self._validate_metadata(metadata) + + return metadata async def load_jwks(self, force: bool = False) -> JWKS: """Load the JSON Web Key Set used to sign ID tokens. @@ -414,27 +466,27 @@ class OidcProvider: ] } """ + if force: + # reset the cached call to ensure we get a new result + self._jwks = RetryOnExceptionCachedCall(self._load_jwks) + return await self._jwks.get() + + async def _load_jwks(self) -> JWKS: if self._uses_userinfo: # We're not using jwt signing, return an empty jwk set return {"keys": []} - # First check if the JWKS are loaded in the provider metadata. - # It can happen either if the provider gives its JWKS in the discovery - # document directly or if it was already loaded once. metadata = await self.load_metadata() - jwk_set = metadata.get("jwks") - if jwk_set is not None and not force: - return jwk_set - # Loading the JWKS using the `jwks_uri` metadata + # Load the JWKS using the `jwks_uri` metadata. uri = metadata.get("jwks_uri") if not uri: + # this should be unreachable: load_metadata validates that + # there is a jwks_uri in the metadata if _uses_userinfo is unset raise RuntimeError('Missing "jwks_uri" in metadata') jwk_set = await self._http_client.get_json(uri) - # Caching the JWKS in the provider's metadata - self._provider_metadata["jwks"] = jwk_set return jwk_set async def _exchange_code(self, code: str) -> Token: @@ -492,7 +544,10 @@ class OidcProvider: # We're not using the SimpleHttpClient util methods as we don't want to # check the HTTP status code and we do the body encoding ourself. response = await self._http_client.request( - method="POST", uri=uri, data=body.encode("utf-8"), headers=headers, + method="POST", + uri=uri, + data=body.encode("utf-8"), + headers=headers, ) # This is used in multiple error messages below @@ -693,14 +748,18 @@ class OidcProvider: ui_auth_session_id=ui_auth_session_id, ), ) - request.addCookie( - SESSION_COOKIE_NAME, - cookie, - path="/_synapse/client/oidc", - max_age="3600", - httpOnly=True, - sameSite="lax", - ) + + # Set the cookies. See the comments on _SESSION_COOKIES for why there are two. + # + # we have to build the header by hand rather than calling request.addCookie + # because the latter does not support SameSite=None + # (https://twistedmatrix.com/trac/ticket/10088) + + for cookie_name, options in _SESSION_COOKIES: + request.cookies.append( + b"%s=%s; Max-Age=3600; %s" + % (cookie_name, cookie.encode("utf-8"), options) + ) metadata = await self.load_metadata() authorization_endpoint = metadata.get("authorization_endpoint") @@ -949,7 +1008,9 @@ class OidcSessionTokenGenerator: A signed macaroon token with the session information. """ macaroon = pymacaroons.Macaroon( - location=self._server_name, identifier="key", key=self._macaroon_secret_key, + location=self._server_name, + identifier="key", + key=self._macaroon_secret_key, ) macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("type = session") diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 5372753707..059064a4eb 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py
@@ -197,7 +197,8 @@ class PaginationHandler: stream_ordering = await self.store.find_first_stream_ordering_after_ts(ts) r = await self.store.get_room_event_before_stream_ordering( - room_id, stream_ordering, + room_id, + stream_ordering, ) if not r: logger.warning( @@ -223,7 +224,12 @@ class PaginationHandler: # the background so that it's not blocking any other operation apart from # other purges in the same room. run_as_background_process( - "_purge_history", self._purge_history, purge_id, room_id, token, True, + "_purge_history", + self._purge_history, + purge_id, + room_id, + token, + True, ) def start_purge_history( @@ -389,7 +395,9 @@ class PaginationHandler: ) await self.hs.get_federation_handler().maybe_backfill( - room_id, curr_topo, limit=pagin_config.limit, + room_id, + curr_topo, + limit=pagin_config.limit, ) to_room_key = None diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 22d1e9d35c..7ba22d511f 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py
@@ -635,8 +635,7 @@ class PresenceHandler(BasePresenceHandler): self.external_process_last_updated_ms.pop(process_id, None) async def current_state_for_user(self, user_id): - """Get the current presence state for a user. - """ + """Get the current presence state for a user.""" res = await self.current_state_for_users([user_id]) return res[user_id] @@ -678,8 +677,7 @@ class PresenceHandler(BasePresenceHandler): self.federation.send_presence(states) async def incoming_presence(self, origin, content): - """Called when we receive a `m.presence` EDU from a remote server. - """ + """Called when we receive a `m.presence` EDU from a remote server.""" if not self._presence_enabled: return @@ -729,8 +727,7 @@ class PresenceHandler(BasePresenceHandler): await self._update_states(updates) async def set_state(self, target_user, state, ignore_status_msg=False): - """Set the presence state of the user. - """ + """Set the presence state of the user.""" status_msg = state.get("status_msg", None) presence = state["presence"] @@ -758,8 +755,7 @@ class PresenceHandler(BasePresenceHandler): await self._update_states([prev_state.copy_and_replace(**new_fields)]) async def is_visible(self, observed_user, observer_user): - """Returns whether a user can see another user's presence. - """ + """Returns whether a user can see another user's presence.""" observer_room_ids = await self.store.get_rooms_for_user( observer_user.to_string() ) @@ -953,8 +949,7 @@ class PresenceHandler(BasePresenceHandler): def should_notify(old_state, new_state): - """Decides if a presence state change should be sent to interested parties. - """ + """Decides if a presence state change should be sent to interested parties.""" if old_state == new_state: return False diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 4b102ff9a9..b04ee5f430 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py
@@ -305,7 +305,8 @@ class ProfileHandler(BaseHandler): # This must be done by the target user himself. if by_admin: requester = create_requester( - target_user, authenticated_entity=requester.authenticated_entity, + target_user, + authenticated_entity=requester.authenticated_entity, ) if len(self.hs.config.replicate_user_profiles_to) > 0: diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index cc21fc2284..6a6c528849 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py
@@ -49,15 +49,15 @@ class ReceiptsHandler(BaseHandler): ) else: hs.get_federation_registry().register_instances_for_edu( - "m.receipt", hs.config.worker.writers.receipts, + "m.receipt", + hs.config.worker.writers.receipts, ) self.clock = self.hs.get_clock() self.state = hs.get_state_handler() async def _received_remote_receipt(self, origin: str, content: JsonDict) -> None: - """Called when we receive an EDU of type m.receipt from a remote HS. - """ + """Called when we receive an EDU of type m.receipt from a remote HS.""" receipts = [] for room_id, room_values in content.items(): for receipt_type, users in room_values.items(): @@ -83,8 +83,7 @@ class ReceiptsHandler(BaseHandler): await self._handle_new_receipts(receipts) async def _handle_new_receipts(self, receipts: List[ReadReceipt]) -> bool: - """Takes a list of receipts, stores them and informs the notifier. - """ + """Takes a list of receipts, stores them and informs the notifier.""" min_batch_id = None # type: Optional[int] max_batch_id = None # type: Optional[int] diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index ab4d5ccc1c..553fcb5b66 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py
@@ -65,8 +65,8 @@ class RegistrationHandler(BaseHandler): self._register_device_client = RegisterDeviceReplicationServlet.make_client( hs ) - self._post_registration_client = ReplicationPostRegisterActionsServlet.make_client( - hs + self._post_registration_client = ( + ReplicationPostRegisterActionsServlet.make_client(hs) ) else: self.device_handler = hs.get_device_handler() @@ -204,12 +204,15 @@ class RegistrationHandler(BaseHandler): self.check_registration_ratelimit(address) result = await self.spam_checker.check_registration_for_spam( - threepid, localpart, user_agent_ips or [], + threepid, + localpart, + user_agent_ips or [], ) if result == RegistrationBehaviour.DENY: logger.info( - "Blocked registration of %r", localpart, + "Blocked registration of %r", + localpart, ) # We return a 429 to make it not obvious that they've been # denied. @@ -218,7 +221,8 @@ class RegistrationHandler(BaseHandler): shadow_banned = result == RegistrationBehaviour.SHADOW_BAN if shadow_banned: logger.info( - "Shadow banning registration of %r", localpart, + "Shadow banning registration of %r", + localpart, ) # do not check_auth_blocking if the call is coming through the Admin API @@ -401,7 +405,9 @@ class RegistrationHandler(BaseHandler): config["room_alias_name"] = room_alias.localpart info, _ = await room_creation_handler.create_room( - fake_requester, config=config, ratelimit=False, + fake_requester, + config=config, + ratelimit=False, ) # If the room does not require an invite, but another user @@ -859,7 +865,10 @@ class RegistrationHandler(BaseHandler): return await self._auth_handler.add_threepid( - user_id, threepid["medium"], threepid["address"], threepid["validated_at"], + user_id, + threepid["medium"], + threepid["address"], + threepid["validated_at"], ) # And we add an email pusher for them by default, but only @@ -911,5 +920,8 @@ class RegistrationHandler(BaseHandler): raise await self._auth_handler.add_threepid( - user_id, threepid["medium"], threepid["address"], threepid["validated_at"], + user_id, + threepid["medium"], + threepid["address"], + threepid["validated_at"], ) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 736070d574..2271c60afc 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py
@@ -198,7 +198,9 @@ class RoomCreationHandler(BaseHandler): if r is None: raise NotFoundError("Unknown room id %s" % (old_room_id,)) new_room_id = await self._generate_room_id( - creator_id=user_id, is_public=r["is_public"], room_version=new_version, + creator_id=user_id, + is_public=r["is_public"], + room_version=new_version, ) logger.info("Creating new room %s to replace %s", new_room_id, old_room_id) @@ -236,7 +238,9 @@ class RoomCreationHandler(BaseHandler): # now send the tombstone await self.event_creation_handler.handle_new_client_event( - requester=requester, event=tombstone_event, context=tombstone_context, + requester=requester, + event=tombstone_event, + context=tombstone_context, ) old_room_state = await tombstone_context.get_current_state_ids() @@ -257,7 +261,10 @@ class RoomCreationHandler(BaseHandler): # finally, shut down the PLs in the old room, and update them in the new # room. await self._update_upgraded_room_pls( - requester, old_room_id, new_room_id, old_room_state, + requester, + old_room_id, + new_room_id, + old_room_state, ) return new_room_id @@ -437,17 +444,20 @@ class RoomCreationHandler(BaseHandler): # Copy over user power levels now as this will not be possible with >100PL users once # the room has been created - # Calculate the minimum power level needed to clone the room event_power_levels = power_levels.get("events", {}) - state_default = power_levels.get("state_default", 0) - ban = power_levels.get("ban") + state_default = power_levels.get("state_default", 50) + ban = power_levels.get("ban", 50) needed_power_level = max(state_default, ban, max(event_power_levels.values())) + # Get the user's current power level, this matches the logic in get_user_power_level, + # but without the entire state map. + user_power_levels = power_levels.setdefault("users", {}) + users_default = power_levels.get("users_default", 0) + current_power_level = user_power_levels.get(user_id, users_default) # Raise the requester's power level in the new room if necessary - current_power_level = power_levels["users"][user_id] if current_power_level < needed_power_level: - power_levels["users"][user_id] = needed_power_level + user_power_levels[user_id] = needed_power_level await self._send_events_for_new_room( requester, @@ -579,7 +589,7 @@ class RoomCreationHandler(BaseHandler): ratelimit: bool = True, creator_join_profile: Optional[JsonDict] = None, ) -> Tuple[dict, int]: - """ Creates a new room. + """Creates a new room. Args: requester: @@ -706,7 +716,9 @@ class RoomCreationHandler(BaseHandler): is_public = visibility == "public" room_id = await self._generate_room_id( - creator_id=user_id, is_public=is_public, room_version=room_version, + creator_id=user_id, + is_public=is_public, + room_version=room_version, ) # Check whether this visibility value is blocked by a third party module @@ -849,7 +861,7 @@ class RoomCreationHandler(BaseHandler): if room_alias: result["room_alias"] = room_alias.to_string() - # Always wait for room creation to progate before returning + # Always wait for room creation to propagate before returning await self._replication.wait_for_stream_position( self.hs.config.worker.events_shard_config.get_instance(room_id), "events", @@ -901,7 +913,10 @@ class RoomCreationHandler(BaseHandler): _, last_stream_id, ) = await self.event_creation_handler.create_and_send_nonmember_event( - creator, event, ratelimit=False, ignore_shadow_ban=True, + creator, + event, + ratelimit=False, + ignore_shadow_ban=True, ) return last_stream_id @@ -1002,7 +1017,10 @@ class RoomCreationHandler(BaseHandler): return last_sent_stream_id async def _generate_room_id( - self, creator_id: str, is_public: bool, room_version: RoomVersion, + self, + creator_id: str, + is_public: bool, + room_version: RoomVersion, ): # autogen room IDs and try to create it. We may clash, so just # try a few times till one goes through, giving up eventually. diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index eb3193e554..312ebc139c 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py
@@ -234,7 +234,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # do it up front for efficiency.) if txn_id and requester.access_token_id: existing_event_id = await self.store.get_event_id_from_transaction_id( - room_id, requester.user.to_string(), requester.access_token_id, txn_id, + room_id, + requester.user.to_string(), + requester.access_token_id, + txn_id, ) if existing_event_id: event_pos = await self.store.get_position_for_event(existing_event_id) @@ -281,7 +284,11 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): ) result_event = await self.event_creation_handler.handle_new_client_event( - requester, event, context, extra_users=[target], ratelimit=ratelimit, + requester, + event, + context, + extra_users=[target], + ratelimit=ratelimit, ) if event.membership == Membership.LEAVE: @@ -657,7 +664,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # send the rejection to the inviter's HS (with fallback to # local event) return await self.remote_reject_invite( - invite.event_id, txn_id, requester, content, + invite.event_id, + txn_id, + requester, + content, ) # the inviter was on our server, but has now left. Carry on @@ -1178,8 +1188,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): user: UserID, content: dict, ) -> Tuple[str, int]: - """Implements RoomMemberHandler._remote_join - """ + """Implements RoomMemberHandler._remote_join""" # filter ourselves out of remote_room_hosts: do_invite_join ignores it # and if it is the only entry we'd like to return a 404 rather than a # 500. @@ -1362,7 +1371,10 @@ class RoomMemberMasterHandler(RoomMemberHandler): event.internal_metadata.out_of_band_membership = True result_event = await self.event_creation_handler.handle_new_client_event( - requester, event, context, extra_users=[UserID.from_string(target_user)], + requester, + event, + context, + extra_users=[UserID.from_string(target_user)], ) # we know it was persisted, so must have a stream ordering assert result_event.internal_metadata.stream_ordering @@ -1396,8 +1408,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): ) async def _user_left_room(self, target: UserID, room_id: str) -> None: - """Implements RoomMemberHandler._user_left_room - """ + """Implements RoomMemberHandler._user_left_room""" user_left_room(self.distributor, target, room_id) async def forget(self, user: UserID, room_id: str) -> None: diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py
index 3de63e885e..926d09f40c 100644 --- a/synapse/handlers/room_member_worker.py +++ b/synapse/handlers/room_member_worker.py
@@ -49,8 +49,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler): user: UserID, content: dict, ) -> Tuple[str, int]: - """Implements RoomMemberHandler._remote_join - """ + """Implements RoomMemberHandler._remote_join""" if len(remote_room_hosts) == 0: raise SynapseError(404, "No known servers") @@ -128,8 +127,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler): return ret["event_id"], ret["stream_id"] async def _user_left_room(self, target: UserID, room_id: str) -> None: - """Implements RoomMemberHandler._user_left_room - """ + """Implements RoomMemberHandler._user_left_room""" await self._notify_change_client( user_id=target.to_string(), room_id=room_id, change="left" ) diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index e88fd59749..a9645b77d8 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py
@@ -23,7 +23,6 @@ from saml2.client import Saml2Client from synapse.api.errors import SynapseError from synapse.config import ConfigError -from synapse.config.saml2_config import SamlAttributeRequirement from synapse.handlers._base import BaseHandler from synapse.handlers.sso import MappingException, UserAttributes from synapse.http.servlet import parse_string @@ -122,7 +121,8 @@ class SamlHandler(BaseHandler): now = self.clock.time_msec() self._outstanding_requests_dict[reqid] = Saml2SessionData( - creation_time=now, ui_auth_session_id=ui_auth_session_id, + creation_time=now, + ui_auth_session_id=ui_auth_session_id, ) for key, value in info["headers"]: @@ -239,12 +239,10 @@ class SamlHandler(BaseHandler): # Ensure that the attributes of the logged in user meet the required # attributes. - for requirement in self._saml2_attribute_requirements: - if not _check_attribute_requirement(saml2_auth.ava, requirement): - self._sso_handler.render_error( - request, "unauthorised", "You are not authorised to log in here." - ) - return + if not self._sso_handler.check_required_attributes( + request, saml2_auth.ava, self._saml2_attribute_requirements + ): + return # Call the mapper to register/login the user try: @@ -373,21 +371,6 @@ class SamlHandler(BaseHandler): del self._outstanding_requests_dict[reqid] -def _check_attribute_requirement(ava: dict, req: SamlAttributeRequirement) -> bool: - values = ava.get(req.attribute, []) - for v in values: - if v == req.value: - return True - - logger.info( - "SAML2 attribute %s did not match required value '%s' (was '%s')", - req.attribute, - req.value, - values, - ) - return False - - DOT_REPLACE_PATTERN = re.compile( ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),)) ) @@ -468,7 +451,8 @@ class DefaultSamlMappingProvider: mxid_source = saml_response.ava[self._mxid_source_attribute][0] except KeyError: logger.warning( - "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute, + "SAML2 response lacks a '%s' attestation", + self._mxid_source_attribute, ) raise SynapseError( 400, "%s not in SAML2 response" % (self._mxid_source_attribute,) diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 96ccd991ed..514b1f69d8 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py
@@ -16,10 +16,12 @@ import abc import logging from typing import ( TYPE_CHECKING, + Any, Awaitable, Callable, Dict, Iterable, + List, Mapping, Optional, Set, @@ -34,6 +36,7 @@ from twisted.web.iweb import IRequest from synapse.api.constants import LoginType from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError +from synapse.config.sso import SsoAttributeRequirement from synapse.handlers.ui_auth import UIAuthSessionDataConstants from synapse.http import get_request_user_agent from synapse.http.server import respond_with_html, respond_with_redirect @@ -324,7 +327,8 @@ class SsoHandler: # Check if we already have a mapping for this user. previously_registered_user_id = await self._store.get_user_by_external_id( - auth_provider_id, remote_user_id, + auth_provider_id, + remote_user_id, ) # A match was found, return the user ID. @@ -413,7 +417,8 @@ class SsoHandler: with await self._mapping_lock.queue(auth_provider_id): # first of all, check if we already have a mapping for this user user_id = await self.get_sso_user_by_remote_user_id( - auth_provider_id, remote_user_id, + auth_provider_id, + remote_user_id, ) # Check for grandfathering of users. @@ -458,7 +463,8 @@ class SsoHandler: ) async def _call_attribute_mapper( - self, sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]], + self, + sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]], ) -> UserAttributes: """Call the attribute mapper function in a loop, until we get a unique userid""" for i in range(self._MAP_USERNAME_RETRIES): @@ -629,7 +635,8 @@ class SsoHandler: """ user_id = await self.get_sso_user_by_remote_user_id( - auth_provider_id, remote_user_id, + auth_provider_id, + remote_user_id, ) user_id_to_verify = await self._auth_handler.get_session_data( @@ -668,7 +675,8 @@ class SsoHandler: # render an error page. html = self._bad_user_template.render( - server_name=self._server_name, user_id_to_verify=user_id_to_verify, + server_name=self._server_name, + user_id_to_verify=user_id_to_verify, ) respond_with_html(request, 200, html) @@ -692,7 +700,9 @@ class SsoHandler: raise SynapseError(400, "unknown session") async def check_username_availability( - self, localpart: str, session_id: str, + self, + localpart: str, + session_id: str, ) -> bool: """Handle an "is username available" callback check @@ -830,7 +840,8 @@ class SsoHandler: ) attributes = UserAttributes( - localpart=session.chosen_localpart, emails=session.emails_to_use, + localpart=session.chosen_localpart, + emails=session.emails_to_use, ) if session.use_display_name: @@ -893,6 +904,41 @@ class SsoHandler: logger.info("Expiring mapping session %s", session_id) del self._username_mapping_sessions[session_id] + def check_required_attributes( + self, + request: SynapseRequest, + attributes: Mapping[str, List[Any]], + attribute_requirements: Iterable[SsoAttributeRequirement], + ) -> bool: + """ + Confirm that the required attributes were present in the SSO response. + + If all requirements are met, this will return True. + + If any requirement is not met, then the request will be finalized by + showing an error page to the user and False will be returned. + + Args: + request: The request to (potentially) respond to. + attributes: The attributes from the SSO IdP. + attribute_requirements: The requirements that attributes must meet. + + Returns: + True if all requirements are met, False if any attribute fails to + meet the requirement. + + """ + # Ensure that the attributes of the logged in user meet the required + # attributes. + for requirement in attribute_requirements: + if not _check_attribute_requirement(attributes, requirement): + self.render_error( + request, "unauthorised", "You are not authorised to log in here." + ) + return False + + return True + def get_username_mapping_session_cookie_from_request(request: IRequest) -> str: """Extract the session ID from the cookie @@ -903,3 +949,36 @@ def get_username_mapping_session_cookie_from_request(request: IRequest) -> str: if not session_id: raise SynapseError(code=400, msg="missing session_id") return session_id.decode("ascii", errors="replace") + + +def _check_attribute_requirement( + attributes: Mapping[str, List[Any]], req: SsoAttributeRequirement +) -> bool: + """Check if SSO attributes meet the proper requirements. + + Args: + attributes: A mapping of attributes to an iterable of one or more values. + requirement: The configured requirement to check. + + Returns: + True if the required attribute was found and had a proper value. + """ + if req.attribute not in attributes: + logger.info("SSO attribute missing: %s", req.attribute) + return False + + # If the requirement is None, the attribute existing is enough. + if req.value is None: + return True + + values = attributes[req.attribute] + if req.value in values: + return True + + logger.info( + "SSO attribute %s did not match required value '%s' (was '%s')", + req.attribute, + req.value, + values, + ) + return False diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
index 0b5e62da1b..388dec5831 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py
@@ -65,8 +65,7 @@ class StatsHandler: self.clock.call_later(0, self.notify_new_event) def notify_new_event(self) -> None: - """Called when there may be more deltas to process - """ + """Called when there may be more deltas to process""" if not self.stats_enabled or self._is_processing: return diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index e8947e0f9b..fa6794734b 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py
@@ -353,8 +353,7 @@ class SyncHandler: since_token: Optional[StreamToken] = None, full_state: bool = False, ) -> SyncResult: - """Get the sync for client needed to match what the server has now. - """ + """Get the sync for client needed to match what the server has now.""" return await self.generate_sync_result(sync_config, since_token, full_state) async def push_rules_for_user(self, user: UserID) -> JsonDict: @@ -578,7 +577,7 @@ class SyncHandler: stream_position: StreamToken, state_filter: StateFilter = StateFilter.all(), ) -> StateMap[str]: - """ Get the room state at a particular stream position + """Get the room state at a particular stream position Args: room_id: room for which to get state @@ -612,7 +611,7 @@ class SyncHandler: state: MutableStateMap[EventBase], now_token: StreamToken, ) -> Optional[JsonDict]: - """ Works out a room summary block for this room, summarising the number + """Works out a room summary block for this room, summarising the number of joined members in the room, and providing the 'hero' members if the room has no name so clients can consistently name rooms. Also adds state events to 'state' if needed to describe the heroes. @@ -757,7 +756,7 @@ class SyncHandler: now_token: StreamToken, full_state: bool, ) -> MutableStateMap[EventBase]: - """ Works out the difference in state between the start of the timeline + """Works out the difference in state between the start of the timeline and the previous sync. Args: @@ -834,8 +833,10 @@ class SyncHandler: ) elif batch.limited: if batch: - state_at_timeline_start = await self.state_store.get_state_ids_for_event( - batch.events[0].event_id, state_filter=state_filter + state_at_timeline_start = ( + await self.state_store.get_state_ids_for_event( + batch.events[0].event_id, state_filter=state_filter + ) ) else: # We can get here if the user has ignored the senders of all @@ -969,8 +970,7 @@ class SyncHandler: since_token: Optional[StreamToken] = None, full_state: bool = False, ) -> SyncResult: - """Generates a sync result. - """ + """Generates a sync result.""" # NB: The now_token gets changed by some of the generate_sync_* methods, # this is due to some of the underlying streams not supporting the ability # to query up to a given point. @@ -1046,8 +1046,8 @@ class SyncHandler: one_time_key_counts = await self.store.count_e2e_one_time_keys( user_id, device_id ) - unused_fallback_key_types = await self.store.get_e2e_unused_fallback_key_types( - user_id, device_id + unused_fallback_key_types = ( + await self.store.get_e2e_unused_fallback_key_types(user_id, device_id) ) logger.debug("Fetching group data") @@ -1196,8 +1196,10 @@ class SyncHandler: # weren't in the previous sync *or* they left and rejoined. users_that_have_changed.update(newly_joined_or_invited_or_knocked_users) - user_signatures_changed = await self.store.get_users_whose_signatures_changed( - user_id, since_token.device_list_key + user_signatures_changed = ( + await self.store.get_users_whose_signatures_changed( + user_id, since_token.device_list_key + ) ) users_that_have_changed.update(user_signatures_changed) @@ -1413,8 +1415,10 @@ class SyncHandler: logger.debug("no-oping sync") return set(), set(), set(), set() - ignored_account_data = await self.store.get_global_account_data_by_type_for_user( - AccountDataTypes.IGNORED_USER_LIST, user_id=user_id + ignored_account_data = ( + await self.store.get_global_account_data_by_type_for_user( + AccountDataTypes.IGNORED_USER_LIST, user_id=user_id + ) ) # If there is ignored users account data and it matches the proper type, @@ -1524,8 +1528,7 @@ class SyncHandler: async def _get_rooms_changed( self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str] ) -> _RoomChanges: - """Gets the the changes that have happened since the last sync. - """ + """Gets the the changes that have happened since the last sync.""" user_id = sync_result_builder.sync_config.user.to_string() since_token = sync_result_builder.since_token now_token = sync_result_builder.now_token diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 3f0dfc7a74..096d199f4c 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py
@@ -61,7 +61,8 @@ class FollowerTypingHandler: if hs.config.worker.writers.typing != hs.get_instance_name(): hs.get_federation_registry().register_instance_for_edu( - "m.typing", hs.config.worker.writers.typing, + "m.typing", + hs.config.worker.writers.typing, ) # map room IDs to serial numbers @@ -76,8 +77,7 @@ class FollowerTypingHandler: self.clock.looping_call(self._handle_timeouts, 5000) def _reset(self) -> None: - """Reset the typing handler's data caches. - """ + """Reset the typing handler's data caches.""" # map room IDs to serial numbers self._room_serials = {} # map room IDs to sets of users currently typing @@ -149,8 +149,7 @@ class FollowerTypingHandler: def process_replication_rows( self, token: int, rows: List[TypingStream.TypingStreamRow] ) -> None: - """Should be called whenever we receive updates for typing stream. - """ + """Should be called whenever we receive updates for typing stream.""" if self._latest_room_serial > token: # The master has gone backwards. To prevent inconsistent data, just diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index 8aedf5072e..3dfb0a26c2 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py
@@ -97,8 +97,7 @@ class UserDirectoryHandler(StateDeltasHandler): return results def notify_new_event(self) -> None: - """Called when there may be more deltas to process - """ + """Called when there may be more deltas to process""" if not self.update_user_directory: return @@ -134,8 +133,7 @@ class UserDirectoryHandler(StateDeltasHandler): ) async def handle_user_deactivated(self, user_id: str) -> None: - """Called when a user ID is deactivated - """ + """Called when a user ID is deactivated""" # FIXME(#3714): We should probably do this in the same worker as all # the other changes. await self.store.remove_from_user_dir(user_id) @@ -172,8 +170,7 @@ class UserDirectoryHandler(StateDeltasHandler): await self.store.update_user_directory_stream_pos(max_pos) async def _handle_deltas(self, deltas: List[Dict[str, Any]]) -> None: - """Called with the state deltas to process - """ + """Called with the state deltas to process""" for delta in deltas: typ = delta["type"] state_key = delta["state_key"]