diff options
-rw-r--r-- | synapse/handlers/device.py | 78 | ||||
-rw-r--r-- | synapse/rest/client/v1/login.py | 55 | ||||
-rw-r--r-- | synapse/rest/client/v2_alpha/keys.py | 31 | ||||
-rw-r--r-- | synapse/storage/data_stores/main/devices.py | 132 | ||||
-rw-r--r-- | synapse/storage/data_stores/main/schema/delta/58/11dehydration.sql | 30 | ||||
-rw-r--r-- | tests/rest/client/v1/test_login.py | 69 |
6 files changed, 375 insertions, 20 deletions
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index db417d60de..614b10ca9e 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -14,8 +14,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import json import logging -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple from synapse.api import errors from synapse.api.constants import EventTypes @@ -28,6 +29,7 @@ from synapse.api.errors import ( from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import ( + JsonDict, RoomStreamToken, get_domain_from_id, get_verify_key_from_cross_signing_key, @@ -489,6 +491,80 @@ class DeviceHandler(DeviceWorkerHandler): # receive device updates. Mark this in DB. await self.store.mark_remote_user_device_list_as_unsubscribed(user_id) + async def store_dehydrated_device( + self, user_id: str, device_data: str, + initial_device_display_name: Optional[str] = None) -> str: + device_id = await self.check_device_registered( + user_id, None, initial_device_display_name, + ) + old_device_id = await self.store.store_dehydrated_device(user_id, device_id, device_data) + if old_device_id is not None: + await self.delete_device(user_id, old_device_id) + return device_id + + async def get_dehydrated_device(self, user_id: str) -> Tuple[str, str]: + return await self.store.get_dehydrated_device(user_id) + + async def get_dehydration_token(self, user_id: str, device_id: str, login_submission: JsonDict) -> str: + return await self.store.create_dehydration_token(user_id, device_id, json.dumps(login_submission)) + + async def rehydrate_device(self, token: str) -> dict: + # FIXME: if can't find token, return 404 + token_info = await self.store.clear_dehydration_token(token, True) + + # normally, the constructor would do self.registration_handler = + # self.hs.get_registration_handler(), but doing that results in a + # circular dependency in the handlers. So do this for now + registration_handler = self.hs.get_registration_handler() + + if token_info["dehydrated"]: + # create access token for dehydrated device + initial_display_name = None # FIXME: get display name from login submission? + device_id, access_token = await registration_handler.register_device( + token_info.get("user_id"), token_info.get("device_id"), initial_display_name + ) + + return { + "user_id": token_info.get("user_id"), + "access_token": access_token, + "home_server": self.hs.hostname, + "device_id": device_id, + } + + else: + # create device and access token from original login submission + login_submission = token_info.get("login_submission") + device_id = login_submission.get("device_id") + initial_display_name = login_submission.get("initial_device_display_name") + device_id, access_token = await registration_handler.register_device( + token_info.get("user_id"), device_id, initial_display_name + ) + + return { + "user_id": token.info.get("user_id"), + "access_token": access_token, + "home_server": self.hs.hostname, + "device_id": device_id, + } + + async def cancel_rehydrate(self, token: str) -> dict: + # FIXME: if can't find token, return 404 + token_info = await self.store.clear_dehydration_token(token) + # create device and access token from original login submission + login_submission = token_info.get("login_submission") + device_id = login_submission.get("device_id") + initial_display_name = login_submission.get("initial_device_display_name") + device_id, access_token = await self.registration_handler.register_device( + token_info.get("user_id"), device_id, initial_display_name + ) + + return { + "user_id": token_info.get("user_id"), + "access_token": access_token, + "home_server": self.hs.hostname, + "device_id": device_id, + } + def _update_device_from_client_ips(device, client_ips): ip = client_ips.get((device["user_id"], device["device_id"]), {}) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 379f668d6f..3e6da34de9 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -103,6 +103,7 @@ class LoginRestServlet(RestServlet): self.oidc_enabled = hs.config.oidc_enabled self.auth_handler = self.hs.get_auth_handler() + self.device_handler = hs.get_device_handler() self.registration_handler = hs.get_registration_handler() self.handlers = hs.get_handlers() self._well_known_builder = WellKnownBuilder(hs) @@ -339,6 +340,22 @@ class LoginRestServlet(RestServlet): ) user_id = canonical_uid + if login_submission.get("org.matrix.msc2697.restore_device"): + device_id, dehydrated_device = await self.device_handler.get_dehydrated_device(user_id) + if dehydrated_device: + token = await self.device_handler.get_dehydration_token(user_id, device_id, login_submission) + result = { + "user_id": user_id, + "home_server": self.hs.hostname, + "device_data": dehydrated_device, + "device_id": device_id, + "dehydration_token": token, + } + + # FIXME: call callback? + + return result + device_id = login_submission.get("device_id") initial_display_name = login_submission.get("initial_device_display_name") device_id, access_token = await self.registration_handler.register_device( @@ -401,6 +418,42 @@ class LoginRestServlet(RestServlet): return result +class RestoreDeviceServlet(RestServlet): + PATTERNS = client_patterns("/org.matrix.msc26997/restore_device") + + def __init__(self, hs): + super(RestoreDeviceServlet, self).__init__() + self.hs = hs + self.device_handler = hs.get_device_handler() + + async def on_POST(self, request: SynapseRequest): + submission = parse_json_object_from_request(request) + + if submission.get("rehydrate"): + return 200, await self.device_handler.rehydrate_device(submission.get("dehydration_token")) + else: + return 200, await self.device_handler.cancel_rehydrate(submission.get("dehydration_token")) + + +class StoreDeviceServlet(RestServlet): + PATTERNS = client_patterns("/org.matrix.msc2697/device/dehydrate") + + def __init__(self, hs): + super(StoreDeviceServlet, self).__init__() + self.hs = hs + self.auth = hs.get_auth() + self.device_handler = hs.get_device_handler() + + async def on_POST(self, request: SynapseRequest): + submission = parse_json_object_from_request(request) + requester = await self.auth.get_user_by_req(request) + + device_id = await self.device_handler.store_dehydrated_device( + requester.user.to_string(), submission.get("device_data") + ) + return 200, {"device_id": device_id} + + class BaseSSORedirectServlet(RestServlet): """Common base class for /login/sso/redirect impls""" @@ -499,6 +552,8 @@ class OIDCRedirectServlet(BaseSSORedirectServlet): def register_servlets(hs, http_server): LoginRestServlet(hs).register(http_server) + RestoreDeviceServlet(hs).register(http_server) + StoreDeviceServlet(hs).register(http_server) if hs.config.cas_enabled: CasRedirectServlet(hs).register(http_server) CasTicketServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 24bb090822..fa18db1946 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -67,6 +67,7 @@ class KeyUploadServlet(RestServlet): super(KeyUploadServlet, self).__init__() self.auth = hs.get_auth() self.e2e_keys_handler = hs.get_e2e_keys_handler() + self.device_handler = hs.get_device_handler() @trace(opname="upload_keys") async def on_POST(self, request, device_id): @@ -78,20 +79,22 @@ class KeyUploadServlet(RestServlet): # passing the device_id here is deprecated; however, we allow it # for now for compatibility with older clients. if requester.device_id is not None and device_id != requester.device_id: - set_tag("error", True) - log_kv( - { - "message": "Client uploading keys for a different device", - "logged_in_id": requester.device_id, - "key_being_uploaded": device_id, - } - ) - logger.warning( - "Client uploading keys for a different device " - "(logged in as %s, uploading for %s)", - requester.device_id, - device_id, - ) + dehydrated_device_id, _ = await self.device_handler.get_dehydrated_device(user_id) + if device_id != dehydrated_device_id: + set_tag("error", True) + log_kv( + { + "message": "Client uploading keys for a different device", + "logged_in_id": requester.device_id, + "key_being_uploaded": device_id, + } + ) + logger.warning( + "Client uploading keys for a different device " + "(logged in as %s, uploading for %s)", + requester.device_id, + device_id, + ) else: device_id = requester.device_id diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index 45581a6500..c16ed922ae 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -43,7 +43,7 @@ from synapse.util.caches.descriptors import ( cachedList, ) from synapse.util.iterutils import batch_iter -from synapse.util.stringutils import shortstr +from synapse.util.stringutils import random_string, shortstr logger = logging.getLogger(__name__) @@ -728,6 +728,129 @@ class DeviceWorkerStore(SQLBaseStore): _mark_remote_user_device_list_as_unsubscribed_txn, ) + async def get_dehydrated_device(self, user_id: str) -> Tuple[str, str]: + row = await self.db.simple_select_one( + table="dehydrated_devices", + keyvalues={"user_id": user_id}, + retcols=["device_id", "device_data"], + allow_none=True, + ) + return (row["device_id"], row["device_data"]) if row else (None, None) + + def _store_dehydrated_device_txn( + self, txn, user_id: str, device_id: str, device_data: str + ) -> Optional[str]: + old_device_id = self.db.simple_select_one_onecol_txn( + txn, + table="dehydrated_devices", + keyvalues={"user_id": user_id}, + retcol="device_id", + allow_none=True, + ) + if old_device_id is None: + self.db.simple_insert_txn( + txn, + table="dehydrated_devices", + values={ + "user_id": user_id, + "device_id": device_id, + "device_data": device_data, + }, + ) + else: + self.db.simple_update_txn( + txn, + table="dehydrated_devices", + keyvalues={"user_id", user_id}, + updatevalues={ + "device_id": device_id, + "device_data": device_data, + }, + ) + return old_device_id + + async def store_dehydrated_device( + self, user_id: str, device_id: str, device_data: str + ) -> Optional[str]: + return await self.db.runInteraction( + "store_dehydrated_device_txn", + self._store_dehydrated_device_txn, + user_id, device_id, device_data, + ) + + async def create_dehydration_token( + self, user_id: str, device_id: str, login_submission: str + ) -> str: + # FIXME: expire any old tokens + + attempts = 0 + while attempts < 5: + token = random_string(24) + + try: + await self.db.simple_insert( + table="dehydration_token", + values={ + "token": token, + "user_id": user_id, + "device_id": device_id, + "login_submission": login_submission, + "creation_time": self.hs.get_clock().time_msec(), + }, + desc="create_dehydration_token", + ) + return token + except self.db.engine.module.IntegrityError: + attempts += 1 + raise StoreError(500, "Couldn't generate a token.") + + def _clear_dehydration_token_txn(self, txn, token: str, dehydrate: bool) -> dict: + token_info = self.db.simple_select_one_txn( + txn, + "dehydration_token", + { + "token": token, + }, + ["user_id", "device_id", "login_submission"], + ) + self.db.simple_delete_one_txn( + txn, + "dehydration_token", + { + "token": token, + }, + ) + + if dehydrate: + device = self.db.simple_select_one_txn( + txn, + "dehydrated_devices", + {"user_id": token_info["user_id"]}, + ["device_id", "device_data"], + allow_none=True, + ) + if device and device["device_id"] == token_info["device_id"]: + count = self.db.simple_delete_txn( + txn, + "dehydrated_devices", + { + "user_id": token_info["user_id"], + "device_id": token_info["device_id"], + }, + ) + if count != 0: + token_info["dehydrated"] = True + + return token_info + + async def clear_dehydration_token(self, token: str, dehydrate: bool) -> dict: + return await self.db.runInteraction( + "get_users_whose_devices_changed", + self._clear_dehydration_token_txn, + token, + dehydrate, + ) + class DeviceBackgroundUpdateStore(SQLBaseStore): def __init__(self, database: Database, db_conn, hs): @@ -865,8 +988,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000) - @defer.inlineCallbacks - def store_device(self, user_id, device_id, initial_device_display_name): + async def store_device(self, user_id, device_id, initial_device_display_name): """Ensure the given device is known; add it to the store if not Args: @@ -885,7 +1007,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): return False try: - inserted = yield self.db.simple_insert( + inserted = await self.db.simple_insert( "devices", values={ "user_id": user_id, @@ -899,7 +1021,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): if not inserted: # if the device already exists, check if it's a real device, or # if the device ID is reserved by something else - hidden = yield self.db.simple_select_one_onecol( + hidden = await self.db.simple_select_one_onecol( "devices", keyvalues={"user_id": user_id, "device_id": device_id}, retcol="hidden", diff --git a/synapse/storage/data_stores/main/schema/delta/58/11dehydration.sql b/synapse/storage/data_stores/main/schema/delta/58/11dehydration.sql new file mode 100644 index 0000000000..be5e8a4712 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/58/11dehydration.sql @@ -0,0 +1,30 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS dehydrated_devices( + user_id TEXT NOT NULL PRIMARY KEY, + device_id TEXT NOT NULL, + device_data TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS dehydration_token( + token TEXT NOT NULL PRIMARY KEY, + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + login_submission TEXT NOT NULL, + creation_time BIGINT NOT NULL +); + +-- FIXME: index on creation_time to expire old tokens diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index db52725cfe..cafa581be7 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -754,3 +754,72 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase): channel.json_body["error"], "JWT validation failed: Signature verification failed", ) + + +class DehydrationTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + logout.register_servlets, + devices.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + self.hs = self.setup_test_homeserver() + self.hs.config.enable_registration = True + self.hs.config.registrations_require_3pid = [] + self.hs.config.auto_join_rooms = [] + self.hs.config.enable_registration_captcha = False + + return self.hs + + def test_dehydrate_and_rehydrate_device(self): + self.register_user("kermit", "monkey") + access_token = self.login("kermit", "monkey") + + # dehydrate a device + params = json.dumps({ + "device_data": "foobar" + }) + request, channel = self.make_request( + b"POST", b"/_matrix/client/unstable/org.matrix.msc2697/device/dehydrate", + params, + access_token=access_token + ) + self.render(request) + self.assertEquals(channel.code, 200, channel.result) + dehydrated_device_id = channel.json_body["device_id"] + + # Log out + request, channel = self.make_request( + b"POST", "/logout", access_token=access_token + ) + self.render(request) + + # log in, requesting a dehydrated device + params = json.dumps({ + "type": "m.login.password", + "user": "kermit", + "password": "monkey", + "org.matrix.msc2697.restore_device": True, + }) + request, channel = self.make_request( + "POST", "/_matrix/client/r0/login", params + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.json_body["device_data"], "foobar") + self.assertEqual(channel.json_body["device_id"], dehydrated_device_id) + dehydration_token = channel.json_body["dehydration_token"] + + params = json.dumps({ + "rehydrate": True, + "dehydration_token": dehydration_token + }) + request, channel = self.make_request( + "POST", "/_matrix/client/unstable/org.matrix.msc2697/restore_device", params + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.json_body["device_id"], dehydrated_device_id) |