From f9f03426de338ae1879e174f63adf698bbfc3a4b Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Fri, 19 Aug 2022 17:17:10 +0100 Subject: Implement MSC3852: Expose `last_seen_user_agent` to users for their own devices; also expose to Admin API (#13549) --- synapse/rest/client/devices.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) (limited to 'synapse/rest/client/devices.py') diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py index 6fab102437..ed6ce78d47 100644 --- a/synapse/rest/client/devices.py +++ b/synapse/rest/client/devices.py @@ -42,12 +42,26 @@ class DevicesRestServlet(RestServlet): self.hs = hs self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() + self._msc3852_enabled = hs.config.experimental.msc3852_enabled async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) devices = await self.device_handler.get_devices_by_user( requester.user.to_string() ) + + # If MSC3852 is disabled, then the "last_seen_user_agent" field will be + # removed from each device. If it is enabled, then the field name will + # be replaced by the unstable identifier. + # + # When MSC3852 is accepted, this block of code can just be removed to + # expose "last_seen_user_agent" to clients. + for device in devices: + last_seen_user_agent = device["last_seen_user_agent"] + del device["last_seen_user_agent"] + if self._msc3852_enabled: + device["org.matrix.msc3852.last_seen_user_agent"] = last_seen_user_agent + return 200, {"devices": devices} @@ -108,6 +122,7 @@ class DeviceRestServlet(RestServlet): self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() self.auth_handler = hs.get_auth_handler() + self._msc3852_enabled = hs.config.experimental.msc3852_enabled async def on_GET( self, request: SynapseRequest, device_id: str @@ -118,6 +133,18 @@ class DeviceRestServlet(RestServlet): ) if device is None: raise NotFoundError("No device found") + + # If MSC3852 is disabled, then the "last_seen_user_agent" field will be + # removed from each device. If it is enabled, then the field name will + # be replaced by the unstable identifier. + # + # When MSC3852 is accepted, this block of code can just be removed to + # expose "last_seen_user_agent" to clients. + last_seen_user_agent = device["last_seen_user_agent"] + del device["last_seen_user_agent"] + if self._msc3852_enabled: + device["org.matrix.msc3852.last_seen_user_agent"] = last_seen_user_agent + return 200, device @interactive_auth_handler -- cgit 1.5.1 From 2295095c97f3b4707f30ae8cb4562ebb799f7ac1 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 7 Oct 2022 13:54:07 +0100 Subject: Use Pydantic to validate /devices endpoints (#14054) --- changelog.d/14054.feature | 1 + synapse/rest/client/devices.py | 98 ++++++++++++++++++++++-------------------- 2 files changed, 53 insertions(+), 46 deletions(-) create mode 100644 changelog.d/14054.feature (limited to 'synapse/rest/client/devices.py') diff --git a/changelog.d/14054.feature b/changelog.d/14054.feature new file mode 100644 index 0000000000..9cf3f7a557 --- /dev/null +++ b/changelog.d/14054.feature @@ -0,0 +1 @@ +Improve validation of request bodies for the [Device Management](https://spec.matrix.org/v1.4/client-server-api/#device-management) and [MSC2697 Device Dehyrdation](https://github.com/matrix-org/matrix-spec-proposals/pull/2697) client-server API endpoints. diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py index ed6ce78d47..90828c95c4 100644 --- a/synapse/rest/client/devices.py +++ b/synapse/rest/client/devices.py @@ -14,18 +14,21 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Tuple +from typing import TYPE_CHECKING, List, Optional, Tuple + +from pydantic import Extra, StrictStr from synapse.api import errors from synapse.api.errors import NotFoundError from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, - assert_params_in_dict, - parse_json_object_from_request, + parse_and_validate_json_object_from_request, ) from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns, interactive_auth_handler +from synapse.rest.client.models import AuthenticationData +from synapse.rest.models import RequestBodyModel from synapse.types import JsonDict if TYPE_CHECKING: @@ -80,27 +83,29 @@ class DeleteDevicesRestServlet(RestServlet): self.device_handler = hs.get_device_handler() self.auth_handler = hs.get_auth_handler() + class PostBody(RequestBodyModel): + auth: Optional[AuthenticationData] + devices: List[StrictStr] + @interactive_auth_handler async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) try: - body = parse_json_object_from_request(request) + body = parse_and_validate_json_object_from_request(request, self.PostBody) except errors.SynapseError as e: if e.errcode == errors.Codes.NOT_JSON: - # DELETE + # TODO: Can/should we remove this fallback now? # deal with older clients which didn't pass a JSON dict # the same as those that pass an empty dict - body = {} + body = self.PostBody.parse_obj({}) else: raise e - assert_params_in_dict(body, ["devices"]) - await self.auth_handler.validate_user_via_ui_auth( requester, request, - body, + body.dict(exclude_unset=True), "remove device(s) from your account", # Users might call this multiple times in a row while cleaning up # devices, allow a single UI auth session to be re-used. @@ -108,7 +113,7 @@ class DeleteDevicesRestServlet(RestServlet): ) await self.device_handler.delete_devices( - requester.user.to_string(), body["devices"] + requester.user.to_string(), body.devices ) return 200, {} @@ -147,6 +152,9 @@ class DeviceRestServlet(RestServlet): return 200, device + class DeleteBody(RequestBodyModel): + auth: Optional[AuthenticationData] + @interactive_auth_handler async def on_DELETE( self, request: SynapseRequest, device_id: str @@ -154,20 +162,21 @@ class DeviceRestServlet(RestServlet): requester = await self.auth.get_user_by_req(request) try: - body = parse_json_object_from_request(request) + body = parse_and_validate_json_object_from_request(request, self.DeleteBody) except errors.SynapseError as e: if e.errcode == errors.Codes.NOT_JSON: + # TODO: can/should we remove this fallback now? # deal with older clients which didn't pass a JSON dict # the same as those that pass an empty dict - body = {} + body = self.DeleteBody.parse_obj({}) else: raise await self.auth_handler.validate_user_via_ui_auth( requester, request, - body, + body.dict(exclude_unset=True), "remove a device from your account", # Users might call this multiple times in a row while cleaning up # devices, allow a single UI auth session to be re-used. @@ -179,18 +188,33 @@ class DeviceRestServlet(RestServlet): ) return 200, {} + class PutBody(RequestBodyModel): + display_name: Optional[StrictStr] + async def on_PUT( self, request: SynapseRequest, device_id: str ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) - body = parse_json_object_from_request(request) + body = parse_and_validate_json_object_from_request(request, self.PutBody) await self.device_handler.update_device( - requester.user.to_string(), device_id, body + requester.user.to_string(), device_id, body.dict() ) return 200, {} +class DehydratedDeviceDataModel(RequestBodyModel): + """JSON blob describing a dehydrated device to be stored. + + Expects other freeform fields. Use .dict() to access them. + """ + + class Config: + extra = Extra.allow + + algorithm: StrictStr + + class DehydratedDeviceServlet(RestServlet): """Retrieve or store a dehydrated device. @@ -246,27 +270,19 @@ class DehydratedDeviceServlet(RestServlet): else: raise errors.NotFoundError("No dehydrated device available") + class PutBody(RequestBodyModel): + device_id: StrictStr + device_data: DehydratedDeviceDataModel + initial_device_display_name: Optional[StrictStr] + async def on_PUT(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - submission = parse_json_object_from_request(request) + submission = parse_and_validate_json_object_from_request(request, self.PutBody) requester = await self.auth.get_user_by_req(request) - if "device_data" not in submission: - raise errors.SynapseError( - 400, - "device_data missing", - errcode=errors.Codes.MISSING_PARAM, - ) - elif not isinstance(submission["device_data"], dict): - raise errors.SynapseError( - 400, - "device_data must be an object", - errcode=errors.Codes.INVALID_PARAM, - ) - device_id = await self.device_handler.store_dehydrated_device( requester.user.to_string(), - submission["device_data"], - submission.get("initial_device_display_name", None), + submission.device_data, + submission.initial_device_display_name, ) return 200, {"device_id": device_id} @@ -300,28 +316,18 @@ class ClaimDehydratedDeviceServlet(RestServlet): self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() + class PostBody(RequestBodyModel): + device_id: StrictStr + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) - submission = parse_json_object_from_request(request) - - if "device_id" not in submission: - raise errors.SynapseError( - 400, - "device_id missing", - errcode=errors.Codes.MISSING_PARAM, - ) - elif not isinstance(submission["device_id"], str): - raise errors.SynapseError( - 400, - "device_id must be a string", - errcode=errors.Codes.INVALID_PARAM, - ) + submission = parse_and_validate_json_object_from_request(request, self.PostBody) result = await self.device_handler.rehydrate_device( requester.user.to_string(), self.auth.get_access_token_from_request(request), - submission["device_id"], + submission.device_id, ) return 200, result -- cgit 1.5.1 From dbfc9b803ee32f7b31c2b5ccbc53a1bfcaa95983 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 31 Oct 2022 20:31:43 +0000 Subject: Fix dehydrated device REST checks (#14336) --- changelog.d/14336.bugfix | 1 + synapse/rest/client/devices.py | 5 ++--- tests/rest/client/test_devices.py | 34 ++++++++++++++++++++++++++++++++++ 3 files changed, 37 insertions(+), 3 deletions(-) create mode 100644 changelog.d/14336.bugfix (limited to 'synapse/rest/client/devices.py') diff --git a/changelog.d/14336.bugfix b/changelog.d/14336.bugfix new file mode 100644 index 0000000000..d44ff1bbc7 --- /dev/null +++ b/changelog.d/14336.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.70 where clients were unable to PUT new [dehydrated devices](https://github.com/matrix-org/matrix-spec-proposals/pull/2697). diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py index 90828c95c4..8f3cbd4ea2 100644 --- a/synapse/rest/client/devices.py +++ b/synapse/rest/client/devices.py @@ -231,7 +231,7 @@ class DehydratedDeviceServlet(RestServlet): } } - PUT /org.matrix.msc2697/dehydrated_device + PUT /org.matrix.msc2697.v2/dehydrated_device Content-Type: application/json { @@ -271,7 +271,6 @@ class DehydratedDeviceServlet(RestServlet): raise errors.NotFoundError("No dehydrated device available") class PutBody(RequestBodyModel): - device_id: StrictStr device_data: DehydratedDeviceDataModel initial_device_display_name: Optional[StrictStr] @@ -281,7 +280,7 @@ class DehydratedDeviceServlet(RestServlet): device_id = await self.device_handler.store_dehydrated_device( requester.user.to_string(), - submission.device_data, + submission.device_data.dict(), submission.initial_device_display_name, ) return 200, {"device_id": device_id} diff --git a/tests/rest/client/test_devices.py b/tests/rest/client/test_devices.py index aa98222434..d80eea17d3 100644 --- a/tests/rest/client/test_devices.py +++ b/tests/rest/client/test_devices.py @@ -200,3 +200,37 @@ class DevicesTestCase(unittest.HomeserverTestCase): self.reactor.advance(43200) self.get_success(self.handler.get_device(user_id, "abc")) self.get_failure(self.handler.get_device(user_id, "def"), NotFoundError) + + +class DehydratedDeviceTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets_for_client_rest_resource, + login.register_servlets, + register.register_servlets, + devices.register_servlets, + ] + + def test_PUT(self) -> None: + """Sanity-check that we can PUT a dehydrated device. + + Detects https://github.com/matrix-org/synapse/issues/14334. + """ + alice = self.register_user("alice", "correcthorse") + token = self.login(alice, "correcthorse") + + # Have alice update their device list + channel = self.make_request( + "PUT", + "_matrix/client/unstable/org.matrix.msc2697.v2/dehydrated_device", + { + "device_data": { + "algorithm": "org.matrix.msc2697.v1.dehydration.v1.olm", + "account": "dehydrated_device", + } + }, + access_token=token, + shorthand=False, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + device_id = channel.json_body.get("device_id") + self.assertIsInstance(device_id, str) -- cgit 1.5.1 From 6d47b7e32589e816eb766446cc1ff19ea73fc7c1 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 22 Nov 2022 14:08:04 -0500 Subject: Add a type hint for `get_device_handler()` and fix incorrect types. (#14055) This was the last untyped handler from the HomeServer object. Since it was being treated as Any (and thus unchecked) it was being used incorrectly in a few places. --- changelog.d/14055.misc | 1 + synapse/handlers/deactivate_account.py | 4 +++ synapse/handlers/device.py | 65 ++++++++++++++++++++++++++-------- synapse/handlers/e2e_keys.py | 61 ++++++++++++++++--------------- synapse/handlers/register.py | 4 +++ synapse/handlers/set_password.py | 6 +++- synapse/handlers/sso.py | 9 +++++ synapse/module_api/__init__.py | 10 +++++- synapse/replication/http/devices.py | 11 ++++-- synapse/rest/admin/__init__.py | 26 ++++++++------ synapse/rest/admin/devices.py | 13 +++++-- synapse/rest/client/devices.py | 17 ++++++--- synapse/rest/client/logout.py | 9 +++-- synapse/server.py | 2 +- tests/handlers/test_device.py | 19 ++++++---- tests/rest/admin/test_device.py | 5 ++- 16 files changed, 185 insertions(+), 77 deletions(-) create mode 100644 changelog.d/14055.misc (limited to 'synapse/rest/client/devices.py') diff --git a/changelog.d/14055.misc b/changelog.d/14055.misc new file mode 100644 index 0000000000..02980bc528 --- /dev/null +++ b/changelog.d/14055.misc @@ -0,0 +1 @@ +Add missing type hints to `HomeServer`. diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index 816e1a6d79..d74d135c0c 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -16,6 +16,7 @@ import logging from typing import TYPE_CHECKING, Optional from synapse.api.errors import SynapseError +from synapse.handlers.device import DeviceHandler from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import Codes, Requester, UserID, create_requester @@ -76,6 +77,9 @@ class DeactivateAccountHandler: True if identity server supports removing threepids, otherwise False. """ + # This can only be called on the main process. + assert isinstance(self._device_handler, DeviceHandler) + # Check if this user can be deactivated if not await self._third_party_rules.check_can_deactivate_user( user_id, by_admin diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index da3ddafeae..b1e55e1b9e 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -65,6 +65,8 @@ DELETE_STALE_DEVICES_INTERVAL_MS = 24 * 60 * 60 * 1000 class DeviceWorkerHandler: + device_list_updater: "DeviceListWorkerUpdater" + def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() self.hs = hs @@ -76,6 +78,8 @@ class DeviceWorkerHandler: self.server_name = hs.hostname self._msc3852_enabled = hs.config.experimental.msc3852_enabled + self.device_list_updater = DeviceListWorkerUpdater(hs) + @trace async def get_devices_by_user(self, user_id: str) -> List[JsonDict]: """ @@ -99,6 +103,19 @@ class DeviceWorkerHandler: log_kv(device_map) return devices + async def get_dehydrated_device( + self, user_id: str + ) -> Optional[Tuple[str, JsonDict]]: + """Retrieve the information for a dehydrated device. + + Args: + user_id: the user whose dehydrated device we are looking for + Returns: + a tuple whose first item is the device ID, and the second item is + the dehydrated device information + """ + return await self.store.get_dehydrated_device(user_id) + @trace async def get_device(self, user_id: str, device_id: str) -> JsonDict: """Retrieve the given device @@ -127,7 +144,7 @@ class DeviceWorkerHandler: @cancellable async def get_device_changes_in_shared_rooms( self, user_id: str, room_ids: Collection[str], from_token: StreamToken - ) -> Collection[str]: + ) -> Set[str]: """Get the set of users whose devices have changed who share a room with the given user. """ @@ -320,6 +337,8 @@ class DeviceWorkerHandler: class DeviceHandler(DeviceWorkerHandler): + device_list_updater: "DeviceListUpdater" + def __init__(self, hs: "HomeServer"): super().__init__(hs) @@ -606,19 +625,6 @@ class DeviceHandler(DeviceWorkerHandler): await self.delete_devices(user_id, [old_device_id]) return device_id - async def get_dehydrated_device( - self, user_id: str - ) -> Optional[Tuple[str, JsonDict]]: - """Retrieve the information for a dehydrated device. - - Args: - user_id: the user whose dehydrated device we are looking for - Returns: - a tuple whose first item is the device ID, and the second item is - the dehydrated device information - """ - return await self.store.get_dehydrated_device(user_id) - async def rehydrate_device( self, user_id: str, access_token: str, device_id: str ) -> dict: @@ -882,7 +888,36 @@ def _update_device_from_client_ips( ) -class DeviceListUpdater: +class DeviceListWorkerUpdater: + "Handles incoming device list updates from federation and contacts the main process over replication" + + def __init__(self, hs: "HomeServer"): + from synapse.replication.http.devices import ( + ReplicationUserDevicesResyncRestServlet, + ) + + self._user_device_resync_client = ( + ReplicationUserDevicesResyncRestServlet.make_client(hs) + ) + + async def user_device_resync( + self, user_id: str, mark_failed_as_stale: bool = True + ) -> Optional[JsonDict]: + """Fetches all devices for a user and updates the device cache with them. + + Args: + user_id: The user's id whose device_list will be updated. + mark_failed_as_stale: Whether to mark the user's device list as stale + if the attempt to resync failed. + Returns: + A dict with device info as under the "devices" in the result of this + request: + https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid + """ + return await self._user_device_resync_client(user_id=user_id) + + +class DeviceListUpdater(DeviceListWorkerUpdater): "Handles incoming device list updates from federation and updates the DB" def __init__(self, hs: "HomeServer", device_handler: DeviceHandler): diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index bf1221f523..5fe102e2f2 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -27,9 +27,9 @@ from twisted.internet import defer from synapse.api.constants import EduTypes from synapse.api.errors import CodeMessageException, Codes, NotFoundError, SynapseError +from synapse.handlers.device import DeviceHandler from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace -from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet from synapse.types import ( JsonDict, UserID, @@ -56,27 +56,23 @@ class E2eKeysHandler: self.is_mine = hs.is_mine self.clock = hs.get_clock() - self._edu_updater = SigningKeyEduUpdater(hs, self) - federation_registry = hs.get_federation_registry() - self._is_master = hs.config.worker.worker_app is None - if not self._is_master: - self._user_device_resync_client = ( - ReplicationUserDevicesResyncRestServlet.make_client(hs) - ) - else: + is_master = hs.config.worker.worker_app is None + if is_master: + edu_updater = SigningKeyEduUpdater(hs) + # Only register this edu handler on master as it requires writing # device updates to the db federation_registry.register_edu_handler( EduTypes.SIGNING_KEY_UPDATE, - self._edu_updater.incoming_signing_key_update, + edu_updater.incoming_signing_key_update, ) # also handle the unstable version # FIXME: remove this when enough servers have upgraded federation_registry.register_edu_handler( EduTypes.UNSTABLE_SIGNING_KEY_UPDATE, - self._edu_updater.incoming_signing_key_update, + edu_updater.incoming_signing_key_update, ) # doesn't really work as part of the generic query API, because the @@ -319,14 +315,13 @@ class E2eKeysHandler: # probably be tracking their device lists. However, we haven't # done an initial sync on the device list so we do it now. try: - if self._is_master: - resync_results = await self.device_handler.device_list_updater.user_device_resync( + resync_results = ( + await self.device_handler.device_list_updater.user_device_resync( user_id ) - else: - resync_results = await self._user_device_resync_client( - user_id=user_id - ) + ) + if resync_results is None: + raise ValueError("Device resync failed") # Add the device keys to the results. user_devices = resync_results["devices"] @@ -605,6 +600,8 @@ class E2eKeysHandler: async def upload_keys_for_user( self, user_id: str, device_id: str, keys: JsonDict ) -> JsonDict: + # This can only be called from the main process. + assert isinstance(self.device_handler, DeviceHandler) time_now = self.clock.time_msec() @@ -732,6 +729,8 @@ class E2eKeysHandler: user_id: the user uploading the keys keys: the signing keys """ + # This can only be called from the main process. + assert isinstance(self.device_handler, DeviceHandler) # if a master key is uploaded, then check it. Otherwise, load the # stored master key, to check signatures on other keys @@ -823,6 +822,9 @@ class E2eKeysHandler: Raises: SynapseError: if the signatures dict is not valid. """ + # This can only be called from the main process. + assert isinstance(self.device_handler, DeviceHandler) + failures = {} # signatures to be stored. Each item will be a SignatureListItem @@ -1200,6 +1202,9 @@ class E2eKeysHandler: A tuple of the retrieved key content, the key's ID and the matching VerifyKey. If the key cannot be retrieved, all values in the tuple will instead be None. """ + # This can only be called from the main process. + assert isinstance(self.device_handler, DeviceHandler) + try: remote_result = await self.federation.query_user_devices( user.domain, user.to_string() @@ -1396,11 +1401,14 @@ class SignatureListItem: class SigningKeyEduUpdater: """Handles incoming signing key updates from federation and updates the DB""" - def __init__(self, hs: "HomeServer", e2e_keys_handler: E2eKeysHandler): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main self.federation = hs.get_federation_client() self.clock = hs.get_clock() - self.e2e_keys_handler = e2e_keys_handler + + device_handler = hs.get_device_handler() + assert isinstance(device_handler, DeviceHandler) + self._device_handler = device_handler self._remote_edu_linearizer = Linearizer(name="remote_signing_key") @@ -1445,9 +1453,6 @@ class SigningKeyEduUpdater: user_id: the user whose updates we are processing """ - device_handler = self.e2e_keys_handler.device_handler - device_list_updater = device_handler.device_list_updater - async with self._remote_edu_linearizer.queue(user_id): pending_updates = self._pending_updates.pop(user_id, []) if not pending_updates: @@ -1459,13 +1464,11 @@ class SigningKeyEduUpdater: logger.info("pending updates: %r", pending_updates) for master_key, self_signing_key in pending_updates: - new_device_ids = ( - await device_list_updater.process_cross_signing_key_update( - user_id, - master_key, - self_signing_key, - ) + new_device_ids = await self._device_handler.device_list_updater.process_cross_signing_key_update( + user_id, + master_key, + self_signing_key, ) device_ids = device_ids + new_device_ids - await device_handler.notify_device_update(user_id, device_ids) + await self._device_handler.notify_device_update(user_id, device_ids) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index ca1c7a1866..6307fa9c5d 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -38,6 +38,7 @@ from synapse.api.errors import ( ) from synapse.appservice import ApplicationService from synapse.config.server import is_threepid_reserved +from synapse.handlers.device import DeviceHandler from synapse.http.servlet import assert_params_in_dict from synapse.replication.http.login import RegisterDeviceReplicationServlet from synapse.replication.http.register import ( @@ -841,6 +842,9 @@ class RegistrationHandler: refresh_token = None refresh_token_id = None + # This can only run on the main process. + assert isinstance(self.device_handler, DeviceHandler) + registered_device_id = await self.device_handler.check_device_registered( user_id, device_id, diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py index 73861bbd40..bd9d0bb34b 100644 --- a/synapse/handlers/set_password.py +++ b/synapse/handlers/set_password.py @@ -15,6 +15,7 @@ import logging from typing import TYPE_CHECKING, Optional from synapse.api.errors import Codes, StoreError, SynapseError +from synapse.handlers.device import DeviceHandler from synapse.types import Requester if TYPE_CHECKING: @@ -29,7 +30,10 @@ class SetPasswordHandler: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main self._auth_handler = hs.get_auth_handler() - self._device_handler = hs.get_device_handler() + # This can only be instantiated on the main process. + device_handler = hs.get_device_handler() + assert isinstance(device_handler, DeviceHandler) + self._device_handler = device_handler async def set_password( self, diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 749d7e93b0..e1c0bff1b2 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -37,6 +37,7 @@ from twisted.web.server import Request from synapse.api.constants import LoginType from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError from synapse.config.sso import SsoAttributeRequirement +from synapse.handlers.device import DeviceHandler from synapse.handlers.register import init_counters_for_auth_provider from synapse.handlers.ui_auth import UIAuthSessionDataConstants from synapse.http import get_request_user_agent @@ -1035,6 +1036,8 @@ class SsoHandler: ) -> None: """Revoke any devices and in-flight logins tied to a provider session. + Can only be called from the main process. + Args: auth_provider_id: A unique identifier for this SSO provider, e.g. "oidc" or "saml". @@ -1042,6 +1045,12 @@ class SsoHandler: expected_user_id: The user we're expecting to logout. If set, it will ignore sessions belonging to other users and log an error. """ + + # It is expected that this is the main process. + assert isinstance( + self._device_handler, DeviceHandler + ), "revoking SSO sessions can only be called on the main process" + # Invalidate any running user-mapping sessions to_delete = [] for session_id, session in self._username_mapping_sessions.items(): diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 1adc1fd64f..96a661177a 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -86,6 +86,7 @@ from synapse.handlers.auth import ( ON_LOGGED_OUT_CALLBACK, AuthHandler, ) +from synapse.handlers.device import DeviceHandler from synapse.handlers.push_rules import RuleSpec, check_actions from synapse.http.client import SimpleHttpClient from synapse.http.server import ( @@ -207,6 +208,7 @@ class ModuleApi: self._registration_handler = hs.get_registration_handler() self._send_email_handler = hs.get_send_email_handler() self._push_rules_handler = hs.get_push_rules_handler() + self._device_handler = hs.get_device_handler() self.custom_template_dir = hs.config.server.custom_template_directory try: @@ -784,6 +786,8 @@ class ModuleApi: ) -> Generator["defer.Deferred[Any]", Any, None]: """Invalidate an access token for a user + Can only be called from the main process. + Added in Synapse v0.25.0. Args: @@ -796,6 +800,10 @@ class ModuleApi: Raises: synapse.api.errors.AuthError: the access token is invalid """ + assert isinstance( + self._device_handler, DeviceHandler + ), "invalidate_access_token can only be called on the main process" + # see if the access token corresponds to a device user_info = yield defer.ensureDeferred( self._auth.get_user_by_access_token(access_token) @@ -805,7 +813,7 @@ class ModuleApi: if device_id: # delete the device, which will also delete its access tokens yield defer.ensureDeferred( - self._hs.get_device_handler().delete_devices(user_id, [device_id]) + self._device_handler.delete_devices(user_id, [device_id]) ) else: # no associated device. Just delete the access token. diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py index c21629def8..7c4941c3d3 100644 --- a/synapse/replication/http/devices.py +++ b/synapse/replication/http/devices.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Tuple +from typing import TYPE_CHECKING, Optional, Tuple from twisted.web.server import Request @@ -63,7 +63,12 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.device_list_updater = hs.get_device_handler().device_list_updater + from synapse.handlers.device import DeviceHandler + + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_list_updater = handler.device_list_updater + self.store = hs.get_datastores().main self.clock = hs.get_clock() @@ -73,7 +78,7 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint): async def _handle_request( # type: ignore[override] self, request: Request, user_id: str - ) -> Tuple[int, JsonDict]: + ) -> Tuple[int, Optional[JsonDict]]: user_devices = await self.device_list_updater.user_device_resync(user_id) return 200, user_devices diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index c62ea22116..fb73886df0 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -238,6 +238,10 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: """ Register all the admin servlets. """ + # Admin servlets aren't registered on workers. + if hs.config.worker.worker_app is not None: + return + register_servlets_for_client_rest_resource(hs, http_server) BlockRoomRestServlet(hs).register(http_server) ListRoomRestServlet(hs).register(http_server) @@ -254,9 +258,6 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: UserTokenRestServlet(hs).register(http_server) UserRestServletV2(hs).register(http_server) UsersRestServletV2(hs).register(http_server) - DeviceRestServlet(hs).register(http_server) - DevicesRestServlet(hs).register(http_server) - DeleteDevicesRestServlet(hs).register(http_server) UserMediaStatisticsRestServlet(hs).register(http_server) EventReportDetailRestServlet(hs).register(http_server) EventReportsRestServlet(hs).register(http_server) @@ -280,12 +281,13 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: UserByExternalId(hs).register(http_server) UserByThreePid(hs).register(http_server) - # Some servlets only get registered for the main process. - if hs.config.worker.worker_app is None: - SendServerNoticeServlet(hs).register(http_server) - BackgroundUpdateEnabledRestServlet(hs).register(http_server) - BackgroundUpdateRestServlet(hs).register(http_server) - BackgroundUpdateStartJobRestServlet(hs).register(http_server) + DeviceRestServlet(hs).register(http_server) + DevicesRestServlet(hs).register(http_server) + DeleteDevicesRestServlet(hs).register(http_server) + SendServerNoticeServlet(hs).register(http_server) + BackgroundUpdateEnabledRestServlet(hs).register(http_server) + BackgroundUpdateRestServlet(hs).register(http_server) + BackgroundUpdateStartJobRestServlet(hs).register(http_server) def register_servlets_for_client_rest_resource( @@ -294,9 +296,11 @@ def register_servlets_for_client_rest_resource( """Register only the servlets which need to be exposed on /_matrix/client/xxx""" WhoisRestServlet(hs).register(http_server) PurgeHistoryStatusRestServlet(hs).register(http_server) - DeactivateAccountRestServlet(hs).register(http_server) PurgeHistoryRestServlet(hs).register(http_server) - ResetPasswordRestServlet(hs).register(http_server) + # The following resources can only be run on the main process. + if hs.config.worker.worker_app is None: + DeactivateAccountRestServlet(hs).register(http_server) + ResetPasswordRestServlet(hs).register(http_server) SearchUsersRestServlet(hs).register(http_server) UserRegisterServlet(hs).register(http_server) AccountValidityRenewServlet(hs).register(http_server) diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py index d934880102..3b2f2d9abb 100644 --- a/synapse/rest/admin/devices.py +++ b/synapse/rest/admin/devices.py @@ -16,6 +16,7 @@ from http import HTTPStatus from typing import TYPE_CHECKING, Tuple from synapse.api.errors import NotFoundError, SynapseError +from synapse.handlers.device import DeviceHandler from synapse.http.servlet import ( RestServlet, assert_params_in_dict, @@ -43,7 +44,9 @@ class DeviceRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_handler = handler self.store = hs.get_datastores().main self.is_mine = hs.is_mine @@ -112,7 +115,9 @@ class DevicesRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_handler = handler self.store = hs.get_datastores().main self.is_mine = hs.is_mine @@ -143,7 +148,9 @@ class DeleteDevicesRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_handler = handler self.store = hs.get_datastores().main self.is_mine = hs.is_mine diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py index 8f3cbd4ea2..69b803f9f8 100644 --- a/synapse/rest/client/devices.py +++ b/synapse/rest/client/devices.py @@ -20,6 +20,7 @@ from pydantic import Extra, StrictStr from synapse.api import errors from synapse.api.errors import NotFoundError +from synapse.handlers.device import DeviceHandler from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, @@ -80,7 +81,9 @@ class DeleteDevicesRestServlet(RestServlet): super().__init__() self.hs = hs self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_handler = handler self.auth_handler = hs.get_auth_handler() class PostBody(RequestBodyModel): @@ -125,7 +128,9 @@ class DeviceRestServlet(RestServlet): super().__init__() self.hs = hs self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_handler = handler self.auth_handler = hs.get_auth_handler() self._msc3852_enabled = hs.config.experimental.msc3852_enabled @@ -256,7 +261,9 @@ class DehydratedDeviceServlet(RestServlet): super().__init__() self.hs = hs self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_handler = handler async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) @@ -313,7 +320,9 @@ class ClaimDehydratedDeviceServlet(RestServlet): super().__init__() self.hs = hs self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.device_handler = handler class PostBody(RequestBodyModel): device_id: StrictStr diff --git a/synapse/rest/client/logout.py b/synapse/rest/client/logout.py index 23dfa4518f..6d34625ad5 100644 --- a/synapse/rest/client/logout.py +++ b/synapse/rest/client/logout.py @@ -15,6 +15,7 @@ import logging from typing import TYPE_CHECKING, Tuple +from synapse.handlers.device import DeviceHandler from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet from synapse.http.site import SynapseRequest @@ -34,7 +35,9 @@ class LogoutRestServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() - self._device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self._device_handler = handler async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_expired=True) @@ -59,7 +62,9 @@ class LogoutAllRestServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() - self._device_handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self._device_handler = handler async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_expired=True) diff --git a/synapse/server.py b/synapse/server.py index f0a60d0056..5baae2325e 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -510,7 +510,7 @@ class HomeServer(metaclass=abc.ABCMeta): ) @cache_in_self - def get_device_handler(self): + def get_device_handler(self) -> DeviceWorkerHandler: if self.config.worker.worker_app: return DeviceWorkerHandler(self) else: diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index b8b465d35b..ce7525e29c 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -19,7 +19,7 @@ from typing import Optional from twisted.test.proto_helpers import MemoryReactor from synapse.api.errors import NotFoundError, SynapseError -from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN +from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceHandler from synapse.server import HomeServer from synapse.util import Clock @@ -32,7 +32,9 @@ user2 = "@theresa:bbb" class DeviceTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver("server", federation_http_client=None) - self.handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.handler = handler self.store = hs.get_datastores().main return hs @@ -61,6 +63,7 @@ class DeviceTestCase(unittest.HomeserverTestCase): self.assertEqual(res, "fco") dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco")) + assert dev is not None self.assertEqual(dev["display_name"], "display name") def test_device_is_preserved_if_exists(self) -> None: @@ -83,6 +86,7 @@ class DeviceTestCase(unittest.HomeserverTestCase): self.assertEqual(res2, "fco") dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco")) + assert dev is not None self.assertEqual(dev["display_name"], "display name") def test_device_id_is_made_up_if_unspecified(self) -> None: @@ -95,6 +99,7 @@ class DeviceTestCase(unittest.HomeserverTestCase): ) dev = self.get_success(self.handler.store.get_device("@theresa:foo", device_id)) + assert dev is not None self.assertEqual(dev["display_name"], "display") def test_get_devices_by_user(self) -> None: @@ -264,7 +269,9 @@ class DeviceTestCase(unittest.HomeserverTestCase): class DehydrationTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver("server", federation_http_client=None) - self.handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.handler = handler self.registration = hs.get_registration_handler() self.auth = hs.get_auth() self.store = hs.get_datastores().main @@ -284,9 +291,9 @@ class DehydrationTestCase(unittest.HomeserverTestCase): ) ) - retrieved_device_id, device_data = self.get_success( - self.handler.get_dehydrated_device(user_id=user_id) - ) + result = self.get_success(self.handler.get_dehydrated_device(user_id=user_id)) + assert result is not None + retrieved_device_id, device_data = result self.assertEqual(retrieved_device_id, stored_dehydrated_device_id) self.assertEqual(device_data, {"device_data": {"foo": "bar"}}) diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py index d52aee8f92..03f2112b07 100644 --- a/tests/rest/admin/test_device.py +++ b/tests/rest/admin/test_device.py @@ -19,6 +19,7 @@ from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.api.errors import Codes +from synapse.handlers.device import DeviceHandler from synapse.rest.client import login from synapse.server import HomeServer from synapse.util import Clock @@ -34,7 +35,9 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.handler = hs.get_device_handler() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.handler = handler self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") -- cgit 1.5.1