diff options
Diffstat (limited to 'synapse/handlers')
-rw-r--r-- | synapse/handlers/auth.py | 54 | ||||
-rw-r--r-- | synapse/handlers/cas_handler.py | 3 | ||||
-rw-r--r-- | synapse/handlers/device.py | 12 | ||||
-rw-r--r-- | synapse/handlers/e2e_keys.py | 223 | ||||
-rw-r--r-- | synapse/handlers/e2e_room_keys.py | 91 | ||||
-rw-r--r-- | synapse/handlers/federation.py | 6 | ||||
-rw-r--r-- | synapse/handlers/identity.py | 28 | ||||
-rw-r--r-- | synapse/handlers/message.py | 2 | ||||
-rw-r--r-- | synapse/handlers/oidc_handler.py | 103 | ||||
-rw-r--r-- | synapse/handlers/register.py | 7 | ||||
-rw-r--r-- | synapse/handlers/room.py | 7 | ||||
-rw-r--r-- | synapse/handlers/room_member.py | 31 | ||||
-rw-r--r-- | synapse/handlers/saml_handler.py | 5 | ||||
-rw-r--r-- | synapse/handlers/sso.py | 205 |
14 files changed, 555 insertions, 222 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 0e98db22b3..648fe91f53 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -61,6 +61,7 @@ from synapse.http.site import SynapseRequest from synapse.logging.context import defer_to_thread from synapse.metrics.background_process_metrics import run_as_background_process from synapse.module_api import ModuleApi +from synapse.storage.roommember import ProfileInfo from synapse.types import JsonDict, Requester, UserID from synapse.util import stringutils as stringutils from synapse.util.async_helpers import maybe_awaitable @@ -567,16 +568,6 @@ class AuthHandler(BaseHandler): session.session_id, login_type, result ) except LoginError as e: - if login_type == LoginType.EMAIL_IDENTITY: - # riot used to have a bug where it would request a new - # validation token (thus sending a new email) each time it - # got a 401 with a 'flows' field. - # (https://github.com/vector-im/vector-web/issues/2447). - # - # Grandfather in the old behaviour for now to avoid - # breaking old riot deployments. - raise - # this step failed. Merge the error dict into the response # so that the client can have another go. errordict = e.error_dict() @@ -1387,7 +1378,9 @@ class AuthHandler(BaseHandler): ) return self._sso_auth_confirm_template.render( - description=session.description, redirect_url=redirect_url, + description=session.description, + redirect_url=redirect_url, + idp=sso_auth_provider, ) async def complete_sso_login( @@ -1396,6 +1389,7 @@ class AuthHandler(BaseHandler): request: Request, client_redirect_url: str, extra_attributes: Optional[JsonDict] = None, + new_user: bool = False, ): """Having figured out a mxid for this user, complete the HTTP request @@ -1406,6 +1400,8 @@ class AuthHandler(BaseHandler): process. extra_attributes: Extra attributes which will be passed to the client during successful login. Must be JSON serializable. + new_user: True if we should use wording appropriate to a user who has just + registered. """ # If the account has been deactivated, do not proceed with the login # flow. @@ -1414,8 +1410,17 @@ class AuthHandler(BaseHandler): respond_with_html(request, 403, self._sso_account_deactivated_template) return + profile = await self.store.get_profileinfo( + UserID.from_string(registered_user_id).localpart + ) + self._complete_sso_login( - registered_user_id, request, client_redirect_url, extra_attributes + registered_user_id, + request, + client_redirect_url, + extra_attributes, + new_user=new_user, + user_profile_data=profile, ) def _complete_sso_login( @@ -1424,12 +1429,18 @@ class AuthHandler(BaseHandler): request: Request, client_redirect_url: str, extra_attributes: Optional[JsonDict] = None, + new_user: bool = False, + user_profile_data: Optional[ProfileInfo] = None, ): """ The synchronous portion of complete_sso_login. This exists purely for backwards compatibility of synapse.module_api.ModuleApi. """ + + if user_profile_data is None: + user_profile_data = ProfileInfo(None, None) + # Store any extra attributes which will be passed in the login response. # Note that this is per-user so it may overwrite a previous value, this # is considered OK since the newest SSO attributes should be most valid. @@ -1461,12 +1472,27 @@ class AuthHandler(BaseHandler): # Remove the query parameters from the redirect URL to get a shorter version of # it. This is only to display a human-readable URL in the template, but not the # URL we redirect users to. - redirect_url_no_params = client_redirect_url.split("?")[0] + url_parts = urllib.parse.urlsplit(client_redirect_url) + + if url_parts.scheme == "https": + # for an https uri, just show the netloc (ie, the hostname. Specifically, + # the bit between "//" and "/"; this includes any potential + # "username:password@" prefix.) + display_url = url_parts.netloc + else: + # for other uris, strip the query-params (including the login token) and + # fragment. + display_url = urllib.parse.urlunsplit( + (url_parts.scheme, url_parts.netloc, url_parts.path, "", "") + ) html = self._sso_redirect_confirm_template.render( - display_url=redirect_url_no_params, + display_url=display_url, redirect_url=redirect_url, server_name=self._server_name, + new_user=new_user, + user_id=registered_user_id, + user_profile=user_profile_data, ) respond_with_html(request, 200, html) diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py index 21b6bc4992..bd35d1fb87 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py @@ -80,9 +80,10 @@ class CasHandler: # user-facing name of this auth provider self.idp_name = "CAS" - # we do not currently support icons for CAS auth, but this is required by + # we do not currently support brands/icons for CAS auth, but this is required by # the SsoIdentityProvider protocol type. self.idp_icon = None + self.idp_brand = None self._sso_handler = hs.get_sso_handler() diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index debb1b4f29..0863154f7a 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple from synapse.api import errors from synapse.api.constants import EventTypes @@ -62,7 +62,7 @@ class DeviceWorkerHandler(BaseHandler): self._auth_handler = hs.get_auth_handler() @trace - async def get_devices_by_user(self, user_id: str) -> List[Dict[str, Any]]: + async def get_devices_by_user(self, user_id: str) -> List[JsonDict]: """ Retrieve the given user's devices @@ -85,7 +85,7 @@ class DeviceWorkerHandler(BaseHandler): return devices @trace - async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]: + async def get_device(self, user_id: str, device_id: str) -> JsonDict: """ Retrieve the given device Args: @@ -598,7 +598,7 @@ class DeviceHandler(DeviceWorkerHandler): def _update_device_from_client_ips( - device: Dict[str, Any], client_ips: Dict[Tuple[str, str], Dict[str, Any]] + device: JsonDict, client_ips: Dict[Tuple[str, str], JsonDict] ) -> None: ip = client_ips.get((device["user_id"], device["device_id"]), {}) device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")}) @@ -946,8 +946,8 @@ class DeviceListUpdater: async def process_cross_signing_key_update( self, user_id: str, - master_key: Optional[Dict[str, Any]], - self_signing_key: Optional[Dict[str, Any]], + master_key: Optional[JsonDict], + self_signing_key: Optional[JsonDict], ) -> List[str]: """Process the given new master and self-signing key for the given remote user. diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 929752150d..8f3a6b35a4 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -16,7 +16,7 @@ # limitations under the License. import logging -from typing import Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple import attr from canonicaljson import encode_canonical_json @@ -31,6 +31,7 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet from synapse.types import ( + JsonDict, UserID, get_domain_from_id, get_verify_key_from_cross_signing_key, @@ -40,11 +41,14 @@ from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.retryutils import NotRetryingDestination +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) class E2eKeysHandler: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.federation = hs.get_federation_client() self.device_handler = hs.get_device_handler() @@ -78,7 +82,9 @@ class E2eKeysHandler: ) @trace - async def query_devices(self, query_body, timeout, from_user_id): + async def query_devices( + self, query_body: JsonDict, timeout: int, from_user_id: str + ) -> JsonDict: """ Handle a device key query from a client { @@ -98,12 +104,14 @@ class E2eKeysHandler: } Args: - from_user_id (str): the user making the query. This is used when + from_user_id: the user making the query. This is used when adding cross-signing signatures to limit what signatures users can see. """ - device_keys_query = query_body.get("device_keys", {}) + device_keys_query = query_body.get( + "device_keys", {} + ) # type: Dict[str, Iterable[str]] # separate users by domain. # make a map from domain to user_id to device_ids @@ -121,7 +129,8 @@ class E2eKeysHandler: set_tag("remote_key_query", remote_queries) # First get local devices. - failures = {} + # A map of destination -> failure response. + failures = {} # type: Dict[str, JsonDict] results = {} if local_query: local_result = await self.query_local_devices(local_query) @@ -135,9 +144,10 @@ class E2eKeysHandler: ) # Now attempt to get any remote devices from our local cache. - remote_queries_not_in_cache = {} + # A map of destination -> user ID -> device IDs. + remote_queries_not_in_cache = {} # type: Dict[str, Dict[str, Iterable[str]]] if remote_queries: - query_list = [] + query_list = [] # type: List[Tuple[str, Optional[str]]] for user_id, device_ids in remote_queries.items(): if device_ids: query_list.extend((user_id, device_id) for device_id in device_ids) @@ -284,15 +294,15 @@ class E2eKeysHandler: return ret async def get_cross_signing_keys_from_cache( - self, query, from_user_id + self, query: Iterable[str], from_user_id: Optional[str] ) -> Dict[str, Dict[str, dict]]: """Get cross-signing keys for users from the database Args: - query (Iterable[string]) an iterable of user IDs. A dict whose keys + query: an iterable of user IDs. A dict whose keys are user IDs satisfies this, so the query format used for query_devices can be used here. - from_user_id (str): the user making the query. This is used when + from_user_id: the user making the query. This is used when adding cross-signing signatures to limit what signatures users can see. @@ -315,14 +325,12 @@ class E2eKeysHandler: if "self_signing" in user_info: self_signing_keys[user_id] = user_info["self_signing"] - if ( - from_user_id in keys - and keys[from_user_id] is not None - and "user_signing" in keys[from_user_id] - ): - # users can see other users' master and self-signing keys, but can - # only see their own user-signing keys - user_signing_keys[from_user_id] = keys[from_user_id]["user_signing"] + # users can see other users' master and self-signing keys, but can + # only see their own user-signing keys + if from_user_id: + from_user_key = keys.get(from_user_id) + if from_user_key and "user_signing" in from_user_key: + user_signing_keys[from_user_id] = from_user_key["user_signing"] return { "master_keys": master_keys, @@ -344,9 +352,9 @@ class E2eKeysHandler: A map from user_id -> device_id -> device details """ set_tag("local_query", query) - local_query = [] + local_query = [] # type: List[Tuple[str, Optional[str]]] - result_dict = {} + result_dict = {} # type: Dict[str, Dict[str, dict]] for user_id, device_ids in query.items(): # we use UserID.from_string to catch invalid user ids if not self.is_mine(UserID.from_string(user_id)): @@ -380,10 +388,14 @@ class E2eKeysHandler: log_kv(results) return result_dict - async def on_federation_query_client_keys(self, query_body): + async def on_federation_query_client_keys( + self, query_body: Dict[str, Dict[str, Optional[List[str]]]] + ) -> JsonDict: """ Handle a device key query from a federated server """ - device_keys_query = query_body.get("device_keys", {}) + device_keys_query = query_body.get( + "device_keys", {} + ) # type: Dict[str, Optional[List[str]]] res = await self.query_local_devices(device_keys_query) ret = {"device_keys": res} @@ -397,31 +409,34 @@ class E2eKeysHandler: return ret @trace - async def claim_one_time_keys(self, query, timeout): - local_query = [] - remote_queries = {} + async def claim_one_time_keys( + self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: int + ) -> JsonDict: + local_query = [] # type: List[Tuple[str, str, str]] + remote_queries = {} # type: Dict[str, Dict[str, Dict[str, str]]] - for user_id, device_keys in query.get("one_time_keys", {}).items(): + for user_id, one_time_keys in query.get("one_time_keys", {}).items(): # we use UserID.from_string to catch invalid user ids if self.is_mine(UserID.from_string(user_id)): - for device_id, algorithm in device_keys.items(): + for device_id, algorithm in one_time_keys.items(): local_query.append((user_id, device_id, algorithm)) else: domain = get_domain_from_id(user_id) - remote_queries.setdefault(domain, {})[user_id] = device_keys + remote_queries.setdefault(domain, {})[user_id] = one_time_keys set_tag("local_key_query", local_query) set_tag("remote_key_query", remote_queries) results = await self.store.claim_e2e_one_time_keys(local_query) - json_result = {} - failures = {} + # A map of user ID -> device ID -> key ID -> key. + json_result = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]] + failures = {} # type: Dict[str, JsonDict] for user_id, device_keys in results.items(): for device_id, keys in device_keys.items(): - for key_id, json_bytes in keys.items(): + for key_id, json_str in keys.items(): json_result.setdefault(user_id, {})[device_id] = { - key_id: json_decoder.decode(json_bytes) + key_id: json_decoder.decode(json_str) } @trace @@ -468,7 +483,9 @@ class E2eKeysHandler: return {"one_time_keys": json_result, "failures": failures} @tag_args - async def upload_keys_for_user(self, user_id, device_id, keys): + async def upload_keys_for_user( + self, user_id: str, device_id: str, keys: JsonDict + ) -> JsonDict: time_now = self.clock.time_msec() @@ -543,8 +560,8 @@ class E2eKeysHandler: return {"one_time_key_counts": result} async def _upload_one_time_keys_for_user( - self, user_id, device_id, time_now, one_time_keys - ): + self, user_id: str, device_id: str, time_now: int, one_time_keys: JsonDict + ) -> None: logger.info( "Adding one_time_keys %r for device %r for user %r at %d", one_time_keys.keys(), @@ -585,12 +602,14 @@ class E2eKeysHandler: log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys}) await self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys) - async def upload_signing_keys_for_user(self, user_id, keys): + async def upload_signing_keys_for_user( + self, user_id: str, keys: JsonDict + ) -> JsonDict: """Upload signing keys for cross-signing Args: - user_id (string): the user uploading the keys - keys (dict[string, dict]): the signing keys + user_id: the user uploading the keys + keys: the signing keys """ # if a master key is uploaded, then check it. Otherwise, load the @@ -667,16 +686,17 @@ class E2eKeysHandler: return {} - async def upload_signatures_for_device_keys(self, user_id, signatures): + async def upload_signatures_for_device_keys( + self, user_id: str, signatures: JsonDict + ) -> JsonDict: """Upload device signatures for cross-signing Args: - user_id (string): the user uploading the signatures - signatures (dict[string, dict[string, dict]]): map of users to - devices to signed keys. This is the submission from the user; an - exception will be raised if it is malformed. + user_id: the user uploading the signatures + signatures: map of users to devices to signed keys. This is the submission + from the user; an exception will be raised if it is malformed. Returns: - dict: response to be sent back to the client. The response will have + The response to be sent back to the client. The response will have a "failures" key, which will be a dict mapping users to devices to errors for the signatures that failed. Raises: @@ -719,7 +739,9 @@ class E2eKeysHandler: return {"failures": failures} - async def _process_self_signatures(self, user_id, signatures): + async def _process_self_signatures( + self, user_id: str, signatures: JsonDict + ) -> Tuple[List["SignatureListItem"], Dict[str, Dict[str, dict]]]: """Process uploaded signatures of the user's own keys. Signatures of the user's own keys from this API come in two forms: @@ -731,15 +753,14 @@ class E2eKeysHandler: signatures (dict[string, dict]): map of devices to signed keys Returns: - (list[SignatureListItem], dict[string, dict[string, dict]]): - a list of signatures to store, and a map of users to devices to failure - reasons + A tuple of a list of signatures to store, and a map of users to + devices to failure reasons Raises: SynapseError: if the input is malformed """ - signature_list = [] - failures = {} + signature_list = [] # type: List[SignatureListItem] + failures = {} # type: Dict[str, Dict[str, JsonDict]] if not signatures: return signature_list, failures @@ -834,19 +855,24 @@ class E2eKeysHandler: return signature_list, failures def _check_master_key_signature( - self, user_id, master_key_id, signed_master_key, stored_master_key, devices - ): + self, + user_id: str, + master_key_id: str, + signed_master_key: JsonDict, + stored_master_key: JsonDict, + devices: Dict[str, Dict[str, JsonDict]], + ) -> List["SignatureListItem"]: """Check signatures of a user's master key made by their devices. Args: - user_id (string): the user whose master key is being checked - master_key_id (string): the ID of the user's master key - signed_master_key (dict): the user's signed master key that was uploaded - stored_master_key (dict): our previously-stored copy of the user's master key - devices (iterable(dict)): the user's devices + user_id: the user whose master key is being checked + master_key_id: the ID of the user's master key + signed_master_key: the user's signed master key that was uploaded + stored_master_key: our previously-stored copy of the user's master key + devices: the user's devices Returns: - list[SignatureListItem]: a list of signatures to store + A list of signatures to store Raises: SynapseError: if a signature is invalid @@ -877,25 +903,26 @@ class E2eKeysHandler: return master_key_signature_list - async def _process_other_signatures(self, user_id, signatures): + async def _process_other_signatures( + self, user_id: str, signatures: Dict[str, dict] + ) -> Tuple[List["SignatureListItem"], Dict[str, Dict[str, dict]]]: """Process uploaded signatures of other users' keys. These will be the target user's master keys, signed by the uploading user's user-signing key. Args: - user_id (string): the user uploading the keys - signatures (dict[string, dict]): map of users to devices to signed keys + user_id: the user uploading the keys + signatures: map of users to devices to signed keys Returns: - (list[SignatureListItem], dict[string, dict[string, dict]]): - a list of signatures to store, and a map of users to devices to failure + A list of signatures to store, and a map of users to devices to failure reasons Raises: SynapseError: if the input is malformed """ - signature_list = [] - failures = {} + signature_list = [] # type: List[SignatureListItem] + failures = {} # type: Dict[str, Dict[str, JsonDict]] if not signatures: return signature_list, failures @@ -983,7 +1010,7 @@ class E2eKeysHandler: async def _get_e2e_cross_signing_verify_key( self, user_id: str, key_type: str, from_user_id: str = None - ): + ) -> Tuple[JsonDict, str, VerifyKey]: """Fetch locally or remotely query for a cross-signing public key. First, attempt to fetch the cross-signing public key from storage. @@ -997,8 +1024,7 @@ class E2eKeysHandler: This affects what signatures are fetched. Returns: - dict, str, VerifyKey: the raw key data, the key ID, and the - signedjson verify key + The raw key data, the key ID, and the signedjson verify key Raises: NotFoundError: if the key is not found @@ -1135,16 +1161,18 @@ class E2eKeysHandler: return desired_key, desired_key_id, desired_verify_key -def _check_cross_signing_key(key, user_id, key_type, signing_key=None): +def _check_cross_signing_key( + key: JsonDict, user_id: str, key_type: str, signing_key: Optional[VerifyKey] = None +) -> None: """Check a cross-signing key uploaded by a user. Performs some basic sanity checking, and ensures that it is signed, if a signature is required. Args: - key (dict): the key data to verify - user_id (str): the user whose key is being checked - key_type (str): the type of key that the key should be - signing_key (VerifyKey): (optional) the signing key that the key should - be signed with. If omitted, signatures will not be checked. + key: the key data to verify + user_id: the user whose key is being checked + key_type: the type of key that the key should be + signing_key: the signing key that the key should be signed with. If + omitted, signatures will not be checked. """ if ( key.get("user_id") != user_id @@ -1162,16 +1190,21 @@ def _check_cross_signing_key(key, user_id, key_type, signing_key=None): ) -def _check_device_signature(user_id, verify_key, signed_device, stored_device): +def _check_device_signature( + user_id: str, + verify_key: VerifyKey, + signed_device: JsonDict, + stored_device: JsonDict, +) -> None: """Check that a signature on a device or cross-signing key is correct and matches the copy of the device/key that we have stored. Throws an exception if an error is detected. Args: - user_id (str): the user ID whose signature is being checked - verify_key (VerifyKey): the key to verify the device with - signed_device (dict): the uploaded signed device data - stored_device (dict): our previously stored copy of the device + user_id: the user ID whose signature is being checked + verify_key: the key to verify the device with + signed_device: the uploaded signed device data + stored_device: our previously stored copy of the device Raises: SynapseError: if the signature was invalid or the sent device is not the @@ -1201,7 +1234,7 @@ def _check_device_signature(user_id, verify_key, signed_device, stored_device): raise SynapseError(400, "Invalid signature", Codes.INVALID_SIGNATURE) -def _exception_to_failure(e): +def _exception_to_failure(e: Exception) -> JsonDict: if isinstance(e, SynapseError): return {"status": e.code, "errcode": e.errcode, "message": str(e)} @@ -1218,7 +1251,7 @@ def _exception_to_failure(e): return {"status": 503, "message": str(e)} -def _one_time_keys_match(old_key_json, new_key): +def _one_time_keys_match(old_key_json: str, new_key: JsonDict) -> bool: old_key = json_decoder.decode(old_key_json) # if either is a string rather than an object, they must match exactly @@ -1239,16 +1272,16 @@ class SignatureListItem: """An item in the signature list as used by upload_signatures_for_device_keys. """ - signing_key_id = attr.ib() - target_user_id = attr.ib() - target_device_id = attr.ib() - signature = attr.ib() + signing_key_id = attr.ib(type=str) + target_user_id = attr.ib(type=str) + target_device_id = attr.ib(type=str) + signature = attr.ib(type=JsonDict) class SigningKeyEduUpdater: """Handles incoming signing key updates from federation and updates the DB""" - def __init__(self, hs, e2e_keys_handler): + def __init__(self, hs: "HomeServer", e2e_keys_handler: E2eKeysHandler): self.store = hs.get_datastore() self.federation = hs.get_federation_client() self.clock = hs.get_clock() @@ -1257,7 +1290,7 @@ class SigningKeyEduUpdater: self._remote_edu_linearizer = Linearizer(name="remote_signing_key") # user_id -> list of updates waiting to be handled. - self._pending_updates = {} + self._pending_updates = {} # type: Dict[str, List[Tuple[JsonDict, JsonDict]]] # Recently seen stream ids. We don't bother keeping these in the DB, # but they're useful to have them about to reduce the number of spurious @@ -1270,13 +1303,15 @@ class SigningKeyEduUpdater: iterable=True, ) - async def incoming_signing_key_update(self, origin, edu_content): + async def incoming_signing_key_update( + self, origin: str, edu_content: JsonDict + ) -> None: """Called on incoming signing key update from federation. Responsible for parsing the EDU and adding to pending updates list. Args: - origin (string): the server that sent the EDU - edu_content (dict): the contents of the EDU + origin: the server that sent the EDU + edu_content: the contents of the EDU """ user_id = edu_content.pop("user_id") @@ -1299,11 +1334,11 @@ class SigningKeyEduUpdater: await self._handle_signing_key_updates(user_id) - async def _handle_signing_key_updates(self, user_id): + async def _handle_signing_key_updates(self, user_id: str) -> None: """Actually handle pending updates. Args: - user_id (string): the user whose updates we are processing + user_id: the user whose updates we are processing """ device_handler = self.e2e_keys_handler.device_handler @@ -1315,7 +1350,7 @@ class SigningKeyEduUpdater: # This can happen since we batch updates return - device_ids = [] + device_ids = [] # type: List[str] logger.info("pending updates: %r", pending_updates) diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py index f01b090772..622cae23be 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py @@ -15,6 +15,7 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, List, Optional from synapse.api.errors import ( Codes, @@ -24,8 +25,12 @@ from synapse.api.errors import ( SynapseError, ) from synapse.logging.opentracing import log_kv, trace +from synapse.types import JsonDict from synapse.util.async_helpers import Linearizer +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) @@ -37,7 +42,7 @@ class E2eRoomKeysHandler: The actual payload of the encrypted keys is completely opaque to the handler. """ - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() # Used to lock whenever a client is uploading key data. This prevents collisions @@ -48,21 +53,27 @@ class E2eRoomKeysHandler: self._upload_linearizer = Linearizer("upload_room_keys_lock") @trace - async def get_room_keys(self, user_id, version, room_id=None, session_id=None): + async def get_room_keys( + self, + user_id: str, + version: str, + room_id: Optional[str] = None, + session_id: Optional[str] = None, + ) -> List[JsonDict]: """Bulk get the E2E room keys for a given backup, optionally filtered to a given room, or a given session. See EndToEndRoomKeyStore.get_e2e_room_keys for full details. Args: - user_id(str): the user whose keys we're getting - version(str): the version ID of the backup we're getting keys from - room_id(string): room ID to get keys for, for None to get keys for all rooms - session_id(string): session ID to get keys for, for None to get keys for all + user_id: the user whose keys we're getting + version: the version ID of the backup we're getting keys from + room_id: room ID to get keys for, for None to get keys for all rooms + session_id: session ID to get keys for, for None to get keys for all sessions Raises: NotFoundError: if the backup version does not exist Returns: - A deferred list of dicts giving the session_data and message metadata for + A list of dicts giving the session_data and message metadata for these room keys. """ @@ -86,17 +97,23 @@ class E2eRoomKeysHandler: return results @trace - async def delete_room_keys(self, user_id, version, room_id=None, session_id=None): + async def delete_room_keys( + self, + user_id: str, + version: str, + room_id: Optional[str] = None, + session_id: Optional[str] = None, + ) -> JsonDict: """Bulk delete the E2E room keys for a given backup, optionally filtered to a given room or a given session. See EndToEndRoomKeyStore.delete_e2e_room_keys for full details. Args: - user_id(str): the user whose backup we're deleting - version(str): the version ID of the backup we're deleting - room_id(string): room ID to delete keys for, for None to delete keys for all + user_id: the user whose backup we're deleting + version: the version ID of the backup we're deleting + room_id: room ID to delete keys for, for None to delete keys for all rooms - session_id(string): session ID to delete keys for, for None to delete keys + session_id: session ID to delete keys for, for None to delete keys for all sessions Raises: NotFoundError: if the backup version does not exist @@ -128,15 +145,17 @@ class E2eRoomKeysHandler: return {"etag": str(version_etag), "count": count} @trace - async def upload_room_keys(self, user_id, version, room_keys): + async def upload_room_keys( + self, user_id: str, version: str, room_keys: JsonDict + ) -> JsonDict: """Bulk upload a list of room keys into a given backup version, asserting that the given version is the current backup version. room_keys are merged into the current backup as described in RoomKeysServlet.on_PUT(). Args: - user_id(str): the user whose backup we're setting - version(str): the version ID of the backup we're updating - room_keys(dict): a nested dict describing the room_keys we're setting: + user_id: the user whose backup we're setting + version: the version ID of the backup we're updating + room_keys: a nested dict describing the room_keys we're setting: { "rooms": { @@ -254,14 +273,16 @@ class E2eRoomKeysHandler: return {"etag": str(version_etag), "count": count} @staticmethod - def _should_replace_room_key(current_room_key, room_key): + def _should_replace_room_key( + current_room_key: Optional[JsonDict], room_key: JsonDict + ) -> bool: """ Determine whether to replace a given current_room_key (if any) with a newly uploaded room_key backup Args: - current_room_key (dict): Optional, the current room_key dict if any - room_key (dict): The new room_key dict which may or may not be fit to + current_room_key: Optional, the current room_key dict if any + room_key : The new room_key dict which may or may not be fit to replace the current_room_key Returns: @@ -286,14 +307,14 @@ class E2eRoomKeysHandler: return True @trace - async def create_version(self, user_id, version_info): + async def create_version(self, user_id: str, version_info: JsonDict) -> str: """Create a new backup version. This automatically becomes the new backup version for the user's keys; previous backups will no longer be writeable to. Args: - user_id(str): the user whose backup version we're creating - version_info(dict): metadata about the new version being created + user_id: the user whose backup version we're creating + version_info: metadata about the new version being created { "algorithm": "m.megolm_backup.v1", @@ -301,7 +322,7 @@ class E2eRoomKeysHandler: } Returns: - A deferred of a string that gives the new version number. + The new version number. """ # TODO: Validate the JSON to make sure it has the right keys. @@ -313,17 +334,19 @@ class E2eRoomKeysHandler: ) return new_version - async def get_version_info(self, user_id, version=None): + async def get_version_info( + self, user_id: str, version: Optional[str] = None + ) -> JsonDict: """Get the info about a given version of the user's backup Args: - user_id(str): the user whose current backup version we're querying - version(str): Optional; if None gives the most recent version + user_id: the user whose current backup version we're querying + version: Optional; if None gives the most recent version otherwise a historical one. Raises: NotFoundError: if the requested backup version doesn't exist Returns: - A deferred of a info dict that gives the info about the new version. + A info dict that gives the info about the new version. { "version": "1234", @@ -346,7 +369,7 @@ class E2eRoomKeysHandler: return res @trace - async def delete_version(self, user_id, version=None): + async def delete_version(self, user_id: str, version: Optional[str] = None) -> None: """Deletes a given version of the user's e2e_room_keys backup Args: @@ -366,17 +389,19 @@ class E2eRoomKeysHandler: raise @trace - async def update_version(self, user_id, version, version_info): + async def update_version( + self, user_id: str, version: str, version_info: JsonDict + ) -> JsonDict: """Update the info about a given version of the user's backup Args: - user_id(str): the user whose current backup version we're updating - version(str): the backup version we're updating - version_info(dict): the new information about the backup + user_id: the user whose current backup version we're updating + version: the backup version we're updating + version_info: the new information about the backup Raises: NotFoundError: if the requested backup version doesn't exist Returns: - A deferred of an empty dict. + An empty dict. """ if "version" not in version_info: version_info["version"] = version diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index b6dc7f99b6..eddc7582d0 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1617,6 +1617,12 @@ class FederationHandler(BaseHandler): if event.state_key == self._server_notices_mxid: raise SynapseError(HTTPStatus.FORBIDDEN, "Cannot invite this user") + # We retrieve the room member handler here as to not cause a cyclic dependency + member_handler = self.hs.get_room_member_handler() + # We don't rate limit based on room ID, as that should be done by + # sending server. + member_handler.ratelimit_invite(None, event.state_key) + # keep a record of the room version, if we don't yet know it. # (this may get overwritten if we later get a different room version in a # join dance). diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index f61844d688..4f7137539b 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -27,9 +27,11 @@ from synapse.api.errors import ( HttpResponseException, SynapseError, ) +from synapse.api.ratelimiting import Ratelimiter from synapse.config.emailconfig import ThreepidBehaviour from synapse.http import RequestTimedOutError from synapse.http.client import SimpleHttpClient +from synapse.http.site import SynapseRequest from synapse.types import JsonDict, Requester from synapse.util import json_decoder from synapse.util.hash import sha256_and_url_safe_base64 @@ -57,6 +59,32 @@ class IdentityHandler(BaseHandler): self._web_client_location = hs.config.invite_client_location + # Ratelimiters for `/requestToken` endpoints. + self._3pid_validation_ratelimiter_ip = Ratelimiter( + clock=hs.get_clock(), + rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second, + burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count, + ) + self._3pid_validation_ratelimiter_address = Ratelimiter( + clock=hs.get_clock(), + rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second, + burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count, + ) + + def ratelimit_request_token_requests( + self, request: SynapseRequest, medium: str, address: str, + ): + """Used to ratelimit requests to `/requestToken` by IP and address. + + Args: + request: The associated request + medium: The type of threepid, e.g. "msisdn" or "email" + address: The actual threepid ID, e.g. the phone number or email address + """ + + self._3pid_validation_ratelimiter_ip.ratelimit((medium, request.getClientIP())) + self._3pid_validation_ratelimiter_address.ratelimit((medium, address)) + async def threepid_from_creds( self, id_server: str, creds: Dict[str, str] ) -> Optional[JsonDict]: diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index e2a7d567fa..a15336bf00 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -174,7 +174,7 @@ class MessageHandler: raise NotFoundError("Can't find event for token %s" % (at_token,)) visible_events = await filter_events_for_client( - self.storage, user_id, last_events, filter_send_to_client=False + self.storage, user_id, last_events, filter_send_to_client=False, ) event = last_events[0] diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index 1607e12935..3adc75fa4a 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -102,7 +102,7 @@ class OidcHandler: ) from e async def handle_oidc_callback(self, request: SynapseRequest) -> None: - """Handle an incoming request to /_synapse/oidc/callback + """Handle an incoming request to /_synapse/client/oidc/callback Since we might want to display OIDC-related errors in a user-friendly way, we don't raise SynapseError from here. Instead, we call @@ -123,7 +123,6 @@ class OidcHandler: Args: request: the incoming request from the browser. """ - # The provider might redirect with an error. # In that case, just display it as-is. if b"error" in request.args: @@ -137,8 +136,12 @@ class OidcHandler: # either the provider misbehaving or Synapse being misconfigured. # The only exception of that is "access_denied", where the user # probably cancelled the login flow. In other cases, log those errors. - if error != "access_denied": - logger.error("Error from the OIDC provider: %s %s", error, description) + logger.log( + logging.INFO if error == "access_denied" else logging.ERROR, + "Received OIDC callback with error: %s %s", + error, + description, + ) self._sso_handler.render_error(request, error, description) return @@ -149,7 +152,7 @@ class OidcHandler: # Fetch the session cookie session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes] if session is None: - logger.info("No session cookie found") + logger.info("Received OIDC callback, with no session cookie") self._sso_handler.render_error( request, "missing_session", "No session cookie found" ) @@ -169,7 +172,7 @@ class OidcHandler: # Check for the state query parameter if b"state" not in request.args: - logger.info("State parameter is missing") + logger.info("Received OIDC callback, with no state parameter") self._sso_handler.render_error( request, "invalid_request", "State parameter is missing" ) @@ -183,14 +186,16 @@ class OidcHandler: session, state ) except (MacaroonDeserializationException, ValueError) as e: - logger.exception("Invalid session") + logger.exception("Invalid session for OIDC callback") self._sso_handler.render_error(request, "invalid_session", str(e)) return except MacaroonInvalidSignatureException as e: - logger.exception("Could not verify session") + logger.exception("Could not verify session for OIDC callback") self._sso_handler.render_error(request, "mismatching_session", str(e)) return + logger.info("Received OIDC callback for IdP %s", session_data.idp_id) + oidc_provider = self._providers.get(session_data.idp_id) if not oidc_provider: logger.error("OIDC session uses unknown IdP %r", oidc_provider) @@ -274,6 +279,9 @@ class OidcProvider: # MXC URI for icon for this auth provider self.idp_icon = provider.idp_icon + # optional brand identifier for this auth provider + self.idp_brand = provider.idp_brand + self._sso_handler = hs.get_sso_handler() self._sso_handler.register_identity_provider(self) @@ -562,6 +570,7 @@ class OidcProvider: Returns: UserInfo: an object representing the user. """ + logger.debug("Using the OAuth2 access_token to request userinfo") metadata = await self.load_metadata() resp = await self._http_client.get_json( @@ -569,6 +578,8 @@ class OidcProvider: headers={"Authorization": ["Bearer {}".format(token["access_token"])]}, ) + logger.debug("Retrieved user info from userinfo endpoint: %r", resp) + return UserInfo(resp) async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo: @@ -597,17 +608,19 @@ class OidcProvider: claims_cls = ImplicitIDToken alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"]) - jwt = JsonWebToken(alg_values) claim_options = {"iss": {"values": [metadata["issuer"]]}} + id_token = token["id_token"] + logger.debug("Attempting to decode JWT id_token %r", id_token) + # Try to decode the keys in cache first, then retry by forcing the keys # to be reloaded jwk_set = await self.load_jwks() try: claims = jwt.decode( - token["id_token"], + id_token, key=jwk_set, claims_cls=claims_cls, claims_options=claim_options, @@ -617,13 +630,15 @@ class OidcProvider: logger.info("Reloading JWKS after decode error") jwk_set = await self.load_jwks(force=True) # try reloading the jwks claims = jwt.decode( - token["id_token"], + id_token, key=jwk_set, claims_cls=claims_cls, claims_options=claim_options, claims_params=claims_params, ) + logger.debug("Decoded id_token JWT %r; validating", claims) + claims.validate(leeway=120) # allows 2 min of clock skew return UserInfo(claims) @@ -640,7 +655,7 @@ class OidcProvider: - ``client_id``: the client ID set in ``oidc_config.client_id`` - ``response_type``: ``code`` - - ``redirect_uri``: the callback URL ; ``{base url}/_synapse/oidc/callback`` + - ``redirect_uri``: the callback URL ; ``{base url}/_synapse/client/oidc/callback`` - ``scope``: the list of scopes set in ``oidc_config.scopes`` - ``state``: a random string - ``nonce``: a random string @@ -681,7 +696,7 @@ class OidcProvider: request.addCookie( SESSION_COOKIE_NAME, cookie, - path="/_synapse/oidc", + path="/_synapse/client/oidc", max_age="3600", httpOnly=True, sameSite="lax", @@ -702,7 +717,7 @@ class OidcProvider: async def handle_oidc_callback( self, request: SynapseRequest, session_data: "OidcSessionData", code: str ) -> None: - """Handle an incoming request to /_synapse/oidc/callback + """Handle an incoming request to /_synapse/client/oidc/callback By this time we have already validated the session on the synapse side, and now need to do the provider-specific operations. This includes: @@ -723,19 +738,18 @@ class OidcProvider: """ # Exchange the code with the provider try: - logger.debug("Exchanging code") + logger.debug("Exchanging OAuth2 code for a token") token = await self._exchange_code(code) except OidcError as e: - logger.exception("Could not exchange code") + logger.exception("Could not exchange OAuth2 code") self._sso_handler.render_error(request, e.error, e.error_description) return - logger.debug("Successfully obtained OAuth2 access token") + logger.debug("Successfully obtained OAuth2 token data: %r", token) # Now that we have a token, get the userinfo, either by decoding the # `id_token` or by fetching the `userinfo_endpoint`. if self._uses_userinfo: - logger.debug("Fetching userinfo") try: userinfo = await self._fetch_userinfo(token) except Exception as e: @@ -743,7 +757,6 @@ class OidcProvider: self._sso_handler.render_error(request, "fetch_error", str(e)) return else: - logger.debug("Extracting userinfo from id_token") try: userinfo = await self._parse_id_token(token, nonce=session_data.nonce) except Exception as e: @@ -1056,7 +1069,8 @@ class OidcSessionData: UserAttributeDict = TypedDict( - "UserAttributeDict", {"localpart": Optional[str], "display_name": Optional[str]} + "UserAttributeDict", + {"localpart": Optional[str], "display_name": Optional[str], "emails": List[str]}, ) C = TypeVar("C") @@ -1135,11 +1149,12 @@ def jinja_finalize(thing): env = Environment(finalize=jinja_finalize) -@attr.s +@attr.s(slots=True, frozen=True) class JinjaOidcMappingConfig: subject_claim = attr.ib(type=str) localpart_template = attr.ib(type=Optional[Template]) display_name_template = attr.ib(type=Optional[Template]) + email_template = attr.ib(type=Optional[Template]) extra_attributes = attr.ib(type=Dict[str, Template]) @@ -1156,23 +1171,17 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): def parse_config(config: dict) -> JinjaOidcMappingConfig: subject_claim = config.get("subject_claim", "sub") - localpart_template = None # type: Optional[Template] - if "localpart_template" in config: + def parse_template_config(option_name: str) -> Optional[Template]: + if option_name not in config: + return None try: - localpart_template = env.from_string(config["localpart_template"]) + return env.from_string(config[option_name]) except Exception as e: - raise ConfigError( - "invalid jinja template", path=["localpart_template"] - ) from e + raise ConfigError("invalid jinja template", path=[option_name]) from e - display_name_template = None # type: Optional[Template] - if "display_name_template" in config: - try: - display_name_template = env.from_string(config["display_name_template"]) - except Exception as e: - raise ConfigError( - "invalid jinja template", path=["display_name_template"] - ) from e + localpart_template = parse_template_config("localpart_template") + display_name_template = parse_template_config("display_name_template") + email_template = parse_template_config("email_template") extra_attributes = {} # type Dict[str, Template] if "extra_attributes" in config: @@ -1192,6 +1201,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): subject_claim=subject_claim, localpart_template=localpart_template, display_name_template=display_name_template, + email_template=email_template, extra_attributes=extra_attributes, ) @@ -1213,16 +1223,23 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): # a usable mxid. localpart += str(failures) if failures else "" - display_name = None # type: Optional[str] - if self._config.display_name_template is not None: - display_name = self._config.display_name_template.render( - user=userinfo - ).strip() + def render_template_field(template: Optional[Template]) -> Optional[str]: + if template is None: + return None + return template.render(user=userinfo).strip() + + display_name = render_template_field(self._config.display_name_template) + if display_name == "": + display_name = None - if display_name == "": - display_name = None + emails = [] # type: List[str] + email = render_template_field(self._config.email_template) + if email: + emails.append(email) - return UserAttributeDict(localpart=localpart, display_name=display_name) + return UserAttributeDict( + localpart=localpart, display_name=display_name, emails=emails + ) async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict: extras = {} # type: Dict[str, str] diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index a2cf0f6f3e..49b085269b 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -14,8 +14,9 @@ # limitations under the License. """Contains functions for registering clients.""" + import logging -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple from synapse import types from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType @@ -152,7 +153,7 @@ class RegistrationHandler(BaseHandler): user_type: Optional[str] = None, default_display_name: Optional[str] = None, address: Optional[str] = None, - bind_emails: List[str] = [], + bind_emails: Iterable[str] = [], by_admin: bool = False, user_agent_ips: Optional[List[Tuple[str, str]]] = None, ) -> str: @@ -693,6 +694,8 @@ class RegistrationHandler(BaseHandler): access_token: The access token of the newly logged in device, or None if `inhibit_login` enabled. """ + # TODO: 3pid registration can actually happen on the workers. Consider + # refactoring it. if self.hs.config.worker_app: await self._post_registration_client( user_id=user_id, auth_result=auth_result, access_token=access_token diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 99fd063084..1336a23a3a 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -127,6 +127,10 @@ class RoomCreationHandler(BaseHandler): self.third_party_event_rules = hs.get_third_party_event_rules() + self._invite_burst_count = ( + hs.config.ratelimiting.rc_invites_per_room.burst_count + ) + async def upgrade_room( self, requester: Requester, old_room_id: str, new_version: RoomVersion ) -> str: @@ -663,6 +667,9 @@ class RoomCreationHandler(BaseHandler): invite_3pid_list = [] invite_list = [] + if len(invite_list) + len(invite_3pid_list) > self._invite_burst_count: + raise SynapseError(400, "Cannot invite so many users at once") + await self.event_creation_handler.assert_accepted_privacy_policy(requester) power_level_content_override = config.get("power_level_content_override") diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index e001e418f9..a5da97cfe0 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -85,6 +85,17 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count, ) + self._invites_per_room_limiter = Ratelimiter( + clock=self.clock, + rate_hz=hs.config.ratelimiting.rc_invites_per_room.per_second, + burst_count=hs.config.ratelimiting.rc_invites_per_room.burst_count, + ) + self._invites_per_user_limiter = Ratelimiter( + clock=self.clock, + rate_hz=hs.config.ratelimiting.rc_invites_per_user.per_second, + burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count, + ) + # This is only used to get at ratelimit function, and # maybe_kick_guest_users. It's fine there are multiple of these as # it doesn't store state. @@ -144,6 +155,16 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): """ raise NotImplementedError() + def ratelimit_invite(self, room_id: Optional[str], invitee_user_id: str): + """Ratelimit invites by room and by target user. + + If room ID is missing then we just rate limit by target user. + """ + if room_id: + self._invites_per_room_limiter.ratelimit(room_id) + + self._invites_per_user_limiter.ratelimit(invitee_user_id) + async def _local_membership_update( self, requester: Requester, @@ -387,8 +408,14 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): raise SynapseError(403, "This room has been blocked on this server") if effective_membership_state == Membership.INVITE: + target_id = target.to_string() + if ratelimit: + # Don't ratelimit application services. + if not requester.app_service or requester.app_service.is_rate_limited(): + self.ratelimit_invite(room_id, target_id) + # block any attempts to invite the server notices mxid - if target.to_string() == self._server_notices_mxid: + if target_id == self._server_notices_mxid: raise SynapseError(HTTPStatus.FORBIDDEN, "Cannot invite this user") block_invite = False @@ -412,7 +439,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): block_invite = True if not await self.spam_checker.user_may_invite( - requester.user.to_string(), target.to_string(), room_id + requester.user.to_string(), target_id, room_id ): logger.info("Blocking invite due to spam checker") block_invite = True diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index 38461cf79d..e88fd59749 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -78,9 +78,10 @@ class SamlHandler(BaseHandler): # user-facing name of this auth provider self.idp_name = "SAML" - # we do not currently support icons for SAML auth, but this is required by + # we do not currently support icons/brands for SAML auth, but this is required by # the SsoIdentityProvider protocol type. self.idp_icon = None + self.idp_brand = None # a map from saml session id to Saml2SessionData object self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData] @@ -132,7 +133,7 @@ class SamlHandler(BaseHandler): raise Exception("prepare_for_authenticate didn't return a Location header") async def handle_saml_response(self, request: SynapseRequest) -> None: - """Handle an incoming request to /_matrix/saml2/authn_response + """Handle an incoming request to /_synapse/client/saml2/authn_response Args: request: the incoming request from the browser. We'll diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index afc1341d09..96ccd991ed 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -14,21 +14,31 @@ # limitations under the License. import abc import logging -from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Mapping, Optional +from typing import ( + TYPE_CHECKING, + Awaitable, + Callable, + Dict, + Iterable, + Mapping, + Optional, + Set, +) from urllib.parse import urlencode import attr from typing_extensions import NoReturn, Protocol from twisted.web.http import Request +from twisted.web.iweb import IRequest from synapse.api.constants import LoginType from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError from synapse.handlers.ui_auth import UIAuthSessionDataConstants from synapse.http import get_request_user_agent -from synapse.http.server import respond_with_html +from synapse.http.server import respond_with_html, respond_with_redirect from synapse.http.site import SynapseRequest -from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters +from synapse.types import Collection, JsonDict, UserID, contains_invalid_mxid_characters from synapse.util.async_helpers import Linearizer from synapse.util.stringutils import random_string @@ -80,6 +90,11 @@ class SsoIdentityProvider(Protocol): """Optional MXC URI for user-facing icon""" return None + @property + def idp_brand(self) -> Optional[str]: + """Optional branding identifier""" + return None + @abc.abstractmethod async def handle_redirect_request( self, @@ -109,7 +124,7 @@ class UserAttributes: # enter one. localpart = attr.ib(type=Optional[str]) display_name = attr.ib(type=Optional[str], default=None) - emails = attr.ib(type=List[str], default=attr.Factory(list)) + emails = attr.ib(type=Collection[str], default=attr.Factory(list)) @attr.s(slots=True) @@ -124,7 +139,7 @@ class UsernameMappingSession: # attributes returned by the ID mapper display_name = attr.ib(type=Optional[str]) - emails = attr.ib(type=List[str]) + emails = attr.ib(type=Collection[str]) # An optional dictionary of extra attributes to be provided to the client in the # login response. @@ -136,6 +151,12 @@ class UsernameMappingSession: # expiry time for the session, in milliseconds expiry_time_ms = attr.ib(type=int) + # choices made by the user + chosen_localpart = attr.ib(type=Optional[str], default=None) + use_display_name = attr.ib(type=bool, default=True) + emails_to_use = attr.ib(type=Collection[str], default=()) + terms_accepted_version = attr.ib(type=Optional[str], default=None) + # the HTTP cookie used to track the mapping session id USERNAME_MAPPING_SESSION_COOKIE_NAME = b"username_mapping_session" @@ -170,6 +191,8 @@ class SsoHandler: # map from idp_id to SsoIdentityProvider self._identity_providers = {} # type: Dict[str, SsoIdentityProvider] + self._consent_at_registration = hs.config.consent.user_consent_at_registration + def register_identity_provider(self, p: SsoIdentityProvider): p_id = p.idp_id assert p_id not in self._identity_providers @@ -382,6 +405,8 @@ class SsoHandler: to an additional page. (e.g. to prompt for more information) """ + new_user = False + # grab a lock while we try to find a mapping for this user. This seems... # optimistic, especially for implementations that end up redirecting to # interstitial pages. @@ -422,9 +447,14 @@ class SsoHandler: get_request_user_agent(request), request.getClientIP(), ) + new_user = True await self._auth_handler.complete_sso_login( - user_id, request, client_redirect_url, extra_login_attributes + user_id, + request, + client_redirect_url, + extra_login_attributes, + new_user=new_user, ) async def _call_attribute_mapper( @@ -514,7 +544,7 @@ class SsoHandler: logger.info("Recorded registration session id %s", session_id) # Set the cookie and redirect to the username picker - e = RedirectException(b"/_synapse/client/pick_username") + e = RedirectException(b"/_synapse/client/pick_username/account_details") e.cookies.append( b"%s=%s; path=/" % (USERNAME_MAPPING_SESSION_COOKIE_NAME, session_id.encode("ascii")) @@ -642,6 +672,25 @@ class SsoHandler: ) respond_with_html(request, 200, html) + def get_mapping_session(self, session_id: str) -> UsernameMappingSession: + """Look up the given username mapping session + + If it is not found, raises a SynapseError with an http code of 400 + + Args: + session_id: session to look up + Returns: + active mapping session + Raises: + SynapseError if the session is not found/has expired + """ + self._expire_old_sessions() + session = self._username_mapping_sessions.get(session_id) + if session: + return session + logger.info("Couldn't find session id %s", session_id) + raise SynapseError(400, "unknown session") + async def check_username_availability( self, localpart: str, session_id: str, ) -> bool: @@ -658,12 +707,7 @@ class SsoHandler: # make sure that there is a valid mapping session, to stop people dictionary- # scanning for accounts - - self._expire_old_sessions() - session = self._username_mapping_sessions.get(session_id) - if not session: - logger.info("Couldn't find session id %s", session_id) - raise SynapseError(400, "unknown session") + self.get_mapping_session(session_id) logger.info( "[session %s] Checking for availability of username %s", @@ -680,7 +724,12 @@ class SsoHandler: return not user_infos async def handle_submit_username_request( - self, request: SynapseRequest, localpart: str, session_id: str + self, + request: SynapseRequest, + session_id: str, + localpart: str, + use_display_name: bool, + emails_to_use: Iterable[str], ) -> None: """Handle a request to the username-picker 'submit' endpoint @@ -690,21 +739,103 @@ class SsoHandler: request: HTTP request localpart: localpart requested by the user session_id: ID of the username mapping session, extracted from a cookie + use_display_name: whether the user wants to use the suggested display name + emails_to_use: emails that the user would like to use """ - self._expire_old_sessions() - session = self._username_mapping_sessions.get(session_id) - if not session: - logger.info("Couldn't find session id %s", session_id) - raise SynapseError(400, "unknown session") + try: + session = self.get_mapping_session(session_id) + except SynapseError as e: + self.render_error(request, "bad_session", e.msg, code=e.code) + return + + # update the session with the user's choices + session.chosen_localpart = localpart + session.use_display_name = use_display_name + + emails_from_idp = set(session.emails) + filtered_emails = set() # type: Set[str] + + # we iterate through the list rather than just building a set conjunction, so + # that we can log attempts to use unknown addresses + for email in emails_to_use: + if email in emails_from_idp: + filtered_emails.add(email) + else: + logger.warning( + "[session %s] ignoring user request to use unknown email address %r", + session_id, + email, + ) + session.emails_to_use = filtered_emails - logger.info("[session %s] Registering localpart %s", session_id, localpart) + # we may now need to collect consent from the user, in which case, redirect + # to the consent-extraction-unit + if self._consent_at_registration: + redirect_url = b"/_synapse/client/new_user_consent" + + # otherwise, redirect to the completion page + else: + redirect_url = b"/_synapse/client/sso_register" + + respond_with_redirect(request, redirect_url) + + async def handle_terms_accepted( + self, request: Request, session_id: str, terms_version: str + ): + """Handle a request to the new-user 'consent' endpoint + + Will serve an HTTP response to the request. + + Args: + request: HTTP request + session_id: ID of the username mapping session, extracted from a cookie + terms_version: the version of the terms which the user viewed and consented + to + """ + logger.info( + "[session %s] User consented to terms version %s", + session_id, + terms_version, + ) + try: + session = self.get_mapping_session(session_id) + except SynapseError as e: + self.render_error(request, "bad_session", e.msg, code=e.code) + return + + session.terms_accepted_version = terms_version + + # we're done; now we can register the user + respond_with_redirect(request, b"/_synapse/client/sso_register") + + async def register_sso_user(self, request: Request, session_id: str) -> None: + """Called once we have all the info we need to register a new user. + + Does so and serves an HTTP response + + Args: + request: HTTP request + session_id: ID of the username mapping session, extracted from a cookie + """ + try: + session = self.get_mapping_session(session_id) + except SynapseError as e: + self.render_error(request, "bad_session", e.msg, code=e.code) + return + + logger.info( + "[session %s] Registering localpart %s", + session_id, + session.chosen_localpart, + ) attributes = UserAttributes( - localpart=localpart, - display_name=session.display_name, - emails=session.emails, + localpart=session.chosen_localpart, emails=session.emails_to_use, ) + if session.use_display_name: + attributes.display_name = session.display_name + # the following will raise a 400 error if the username has been taken in the # meantime. user_id = await self._register_mapped_user( @@ -715,7 +846,12 @@ class SsoHandler: request.getClientIP(), ) - logger.info("[session %s] Registered userid %s", session_id, user_id) + logger.info( + "[session %s] Registered userid %s with attributes %s", + session_id, + user_id, + attributes, + ) # delete the mapping session and the cookie del self._username_mapping_sessions[session_id] @@ -728,11 +864,21 @@ class SsoHandler: path=b"/", ) + auth_result = {} + if session.terms_accepted_version: + # TODO: make this less awful. + auth_result[LoginType.TERMS] = True + + await self._registration_handler.post_registration_actions( + user_id, auth_result, access_token=None + ) + await self._auth_handler.complete_sso_login( user_id, request, session.client_redirect_url, session.extra_login_attributes, + new_user=True, ) def _expire_old_sessions(self): @@ -746,3 +892,14 @@ class SsoHandler: for session_id in to_expire: logger.info("Expiring mapping session %s", session_id) del self._username_mapping_sessions[session_id] + + +def get_username_mapping_session_cookie_from_request(request: IRequest) -> str: + """Extract the session ID from the cookie + + Raises a SynapseError if the cookie isn't found + """ + session_id = request.getCookie(USERNAME_MAPPING_SESSION_COOKIE_NAME) + if not session_id: + raise SynapseError(code=400, msg="missing session_id") + return session_id.decode("ascii", errors="replace") |