diff options
-rw-r--r-- | changelog.d/7955.feature | 1 | ||||
-rw-r--r-- | synapse/handlers/device.py | 135 | ||||
-rw-r--r-- | synapse/handlers/e2e_keys.py | 16 | ||||
-rw-r--r-- | synapse/handlers/sync.py | 8 | ||||
-rw-r--r-- | synapse/rest/client/v1/login.py | 116 | ||||
-rw-r--r-- | synapse/rest/client/v2_alpha/keys.py | 34 | ||||
-rw-r--r-- | synapse/rest/client/v2_alpha/sync.py | 1 | ||||
-rw-r--r-- | synapse/storage/databases/main/devices.py | 171 | ||||
-rw-r--r-- | synapse/storage/databases/main/end_to_end_keys.py | 85 | ||||
-rw-r--r-- | synapse/storage/databases/main/schema/delta/58/11dehydration.sql | 30 | ||||
-rw-r--r-- | synapse/storage/databases/main/schema/delta/58/11fallback.sql | 24 | ||||
-rw-r--r-- | tests/rest/client/v1/test_login.py | 65 |
12 files changed, 669 insertions, 17 deletions
diff --git a/changelog.d/7955.feature b/changelog.d/7955.feature new file mode 100644 index 0000000000..7d726046fe --- /dev/null +++ b/changelog.d/7955.feature @@ -0,0 +1 @@ +Add support for device dehydration. (MSC2697) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index db417d60de..7c809b27f0 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -14,8 +14,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import json import logging -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple from synapse.api import errors from synapse.api.constants import EventTypes @@ -28,6 +29,7 @@ from synapse.api.errors import ( from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import ( + JsonDict, RoomStreamToken, get_domain_from_id, get_verify_key_from_cross_signing_key, @@ -489,6 +491,137 @@ class DeviceHandler(DeviceWorkerHandler): # receive device updates. Mark this in DB. await self.store.mark_remote_user_device_list_as_unsubscribed(user_id) + async def store_dehydrated_device( + self, + user_id: str, + device_data: JsonDict, + initial_device_display_name: Optional[str] = None, + ) -> str: + """Store a dehydrated device for a user. If the user had a previous + dehydrated device, it is removed. + + Args: + user_id: the user that we are storing the device for + device_data: the dehydrated device information + initial_device_display_name: The display name to use for the device + Returns: + device id of the dehydrated device + """ + device_id = await self.check_device_registered( + user_id, None, initial_device_display_name, + ) + old_device_id = await self.store.store_dehydrated_device( + user_id, device_id, device_data + ) + if old_device_id is not None: + await self.delete_device(user_id, old_device_id) + return device_id + + async def get_dehydrated_device(self, user_id: str) -> Tuple[str, JsonDict]: + """Retrieve the information for a dehydrated device. + + Args: + user_id: the user whose dehydrated device we are looking for + Returns: + a tuple whose first item is the device ID, and the second item is + the dehydrated device information + """ + return await self.store.get_dehydrated_device(user_id) + + async def create_dehydration_token( + self, user_id: str, device_id: str, login_submission: JsonDict + ) -> str: + """Create a token for a client to fulfill a dehydration request. + + Args: + user_id: the user that we are creating the token for + device_id: the device ID for the dehydrated device. This is to + ensure that the device still exists when the user tells us + they want to use the dehydrated device. + login_submission: the contents of the login request. + Returns: + the dehydration token + """ + return await self.store.create_dehydration_token( + user_id, device_id, login_submission + ) + + async def rehydrate_device(self, token: str) -> dict: + """Process a rehydration request from the user. + + Args: + token: the dehydration token + Returns: + the login result, including the user's access token and device ID + """ + # FIXME: if can't find token, return 404 + token_info = await self.store.clear_dehydration_token(token, True) + + # normally, the constructor would do self.registration_handler = + # self.hs.get_registration_handler(), but doing that results in a + # circular dependency in the handlers. So do this for now + registration_handler = self.hs.get_registration_handler() + + if token_info["dehydrated"]: + # create access token for dehydrated device + initial_display_name = ( + None # FIXME: get display name from login submission? + ) + device_id, access_token = await registration_handler.register_device( + token_info.get("user_id"), + token_info.get("device_id"), + initial_display_name, + ) + + return { + "user_id": token_info["user_id"], + "access_token": access_token, + "home_server": self.hs.hostname, + "device_id": device_id, + } + + else: + # create device and access token from original login submission + login_submission = token_info["login_submission"] + device_id = login_submission.get("device_id") + initial_display_name = login_submission.get("initial_device_display_name") + device_id, access_token = await registration_handler.register_device( + token_info.get("user_id"), device_id, initial_display_name + ) + + return { + "user_id": token.info["user_id"], + "access_token": access_token, + "home_server": self.hs.hostname, + "device_id": device_id, + } + + async def cancel_rehydrate(self, token: str) -> dict: + """Cancel a rehydration request from the user and complete the user's login. + + Args: + token: the dehydration token + Returns: + the login result, including the user's access token and device ID + """ + # FIXME: if can't find token, return 404 + token_info = await self.store.clear_dehydration_token(token, False) + # create device and access token from original login submission + login_submission = token_info["login_submission"] + device_id = login_submission.get("device_id") + initial_display_name = login_submission.get("initial_device_display_name") + registration_handler = self.hs.get_registration_handler() + device_id, access_token = await registration_handler.register_device( + token_info.get("user_id"), device_id, initial_display_name + ) + + return { + "user_id": token_info.get("user_id"), + "access_token": access_token, + "home_server": self.hs.hostname, + "device_id": device_id, + } + def _update_device_from_client_ips(device, client_ips): ip = client_ips.get((device["user_id"], device["device_id"]), {}) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 84169c1022..0c37829afc 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -496,6 +496,22 @@ class E2eKeysHandler(object): log_kv( {"message": "Did not update one_time_keys", "reason": "no keys given"} ) + fallback_keys = keys.get("fallback_keys", None) + if fallback_keys and isinstance(fallback_keys, dict): + log_kv( + { + "message": "Updating fallback_keys for device.", + "user_id": user_id, + "device_id": device_id, + } + ) + await self.store.set_e2e_fallback_keys( + user_id, device_id, fallback_keys + ) + else: + log_kv( + {"message": "Did not update fallback_keys", "reason": "no keys given"} + ) # the device should have been registered already, but it may have been # deleted due to a race with a DELETE request. Or we may be using an diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index c42dac18f5..e340b1e615 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -203,6 +203,8 @@ class SyncResult: device_lists: List of user_ids whose devices have changed device_one_time_keys_count: Dict of algorithm to count for one time keys for this device + device_unused_fallback_keys: List of key types that have an unused fallback + key groups: Group updates, if any """ @@ -215,6 +217,7 @@ class SyncResult: to_device = attr.ib(type=List[JsonDict]) device_lists = attr.ib(type=DeviceLists) device_one_time_keys_count = attr.ib(type=JsonDict) + device_unused_fallback_keys = attr.ib(type=List[str]) groups = attr.ib(type=Optional[GroupsSyncResult]) def __nonzero__(self) -> bool: @@ -1024,10 +1027,14 @@ class SyncHandler(object): logger.debug("Fetching OTK data") device_id = sync_config.device_id one_time_key_counts = {} # type: JsonDict + unused_fallback_keys = [] # type: list if device_id: one_time_key_counts = await self.store.count_e2e_one_time_keys( user_id, device_id ) + unused_fallback_keys = await self.store.get_e2e_unused_fallback_keys( + user_id, device_id + ) logger.debug("Fetching group data") await self._generate_sync_entry_for_groups(sync_result_builder) @@ -1051,6 +1058,7 @@ class SyncHandler(object): device_lists=device_lists, groups=sync_result_builder.groups, device_one_time_keys_count=one_time_key_counts, + device_unused_fallback_keys=unused_fallback_keys, next_batch=sync_result_builder.now_token, ) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 379f668d6f..68fece986b 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -103,6 +103,7 @@ class LoginRestServlet(RestServlet): self.oidc_enabled = hs.config.oidc_enabled self.auth_handler = self.hs.get_auth_handler() + self.device_handler = hs.get_device_handler() self.registration_handler = hs.get_registration_handler() self.handlers = hs.get_handlers() self._well_known_builder = WellKnownBuilder(hs) @@ -339,6 +340,29 @@ class LoginRestServlet(RestServlet): ) user_id = canonical_uid + if login_submission.get("org.matrix.msc2697.restore_device"): + # user requested to rehydrate a device, so check if there they have + # a dehydrated device, and if so, allow them to try to rehydrate it + ( + device_id, + dehydrated_device, + ) = await self.device_handler.get_dehydrated_device(user_id) + if dehydrated_device: + token = await self.device_handler.create_dehydration_token( + user_id, device_id, login_submission + ) + result = { + "user_id": user_id, + "home_server": self.hs.hostname, + "device_data": dehydrated_device, + "device_id": device_id, + "dehydration_token": token, + } + + # FIXME: call callback? + + return result + device_id = login_submission.get("device_id") initial_display_name = login_submission.get("initial_device_display_name") device_id, access_token = await self.registration_handler.register_device( @@ -401,6 +425,96 @@ class LoginRestServlet(RestServlet): return result +class RestoreDeviceServlet(RestServlet): + """Complete a rehydration request, either by letting the client use the + dehydrated device, or by creating a new device for the user. + + POST /org.matrix.msc2697/restore_device + Content-Type: application/json + + { + "rehydrate": true, + "dehydration_token": "an_opaque_token" + } + + HTTP/1.1 200 OK + Content-Type: application/json + + { // same format as the result from a /login request + "user_id": "@alice:example.org", + "device_id": "dehydrated_device", + "access_token": "another_opaque_token" + } + + """ + + PATTERNS = client_patterns("/org.matrix.msc2697/restore_device") + + def __init__(self, hs): + super(RestoreDeviceServlet, self).__init__() + self.hs = hs + self.device_handler = hs.get_device_handler() + self._well_known_builder = WellKnownBuilder(hs) + + async def on_POST(self, request: SynapseRequest): + submission = parse_json_object_from_request(request) + + if submission.get("rehydrate"): + result = await self.device_handler.rehydrate_device( + submission["dehydration_token"] + ) + else: + result = await self.device_handler.cancel_rehydrate( + submission["dehydration_token"] + ) + well_known_data = self._well_known_builder.get_well_known() + if well_known_data: + result["well_known"] = well_known_data + return (200, result) + + +class StoreDeviceServlet(RestServlet): + """Store a dehydrated device. + + POST /org.matrix.msc2697/device/dehydrate + Content-Type: application/json + + { + "device_data": { + "algorithm": "m.dehydration.v1.olm", + "account": "dehydrated_device" + } + } + + HTTP/1.1 200 OK + Content-Type: application/json + + { + "device_id": "dehydrated_device_id" + } + + """ + + PATTERNS = client_patterns("/org.matrix.msc2697/device/dehydrate") + + def __init__(self, hs): + super(StoreDeviceServlet, self).__init__() + self.hs = hs + self.auth = hs.get_auth() + self.device_handler = hs.get_device_handler() + + async def on_POST(self, request: SynapseRequest): + submission = parse_json_object_from_request(request) + requester = await self.auth.get_user_by_req(request) + + device_id = await self.device_handler.store_dehydrated_device( + requester.user.to_string(), + submission["device_data"], + submission.get("initial_device_display_name", None) + ) + return 200, {"device_id": device_id} + + class BaseSSORedirectServlet(RestServlet): """Common base class for /login/sso/redirect impls""" @@ -499,6 +613,8 @@ class OIDCRedirectServlet(BaseSSORedirectServlet): def register_servlets(hs, http_server): LoginRestServlet(hs).register(http_server) + RestoreDeviceServlet(hs).register(http_server) + StoreDeviceServlet(hs).register(http_server) if hs.config.cas_enabled: CasRedirectServlet(hs).register(http_server) CasTicketServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 24bb090822..b86c8f598b 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -67,6 +67,7 @@ class KeyUploadServlet(RestServlet): super(KeyUploadServlet, self).__init__() self.auth = hs.get_auth() self.e2e_keys_handler = hs.get_e2e_keys_handler() + self.device_handler = hs.get_device_handler() @trace(opname="upload_keys") async def on_POST(self, request, device_id): @@ -78,20 +79,25 @@ class KeyUploadServlet(RestServlet): # passing the device_id here is deprecated; however, we allow it # for now for compatibility with older clients. if requester.device_id is not None and device_id != requester.device_id: - set_tag("error", True) - log_kv( - { - "message": "Client uploading keys for a different device", - "logged_in_id": requester.device_id, - "key_being_uploaded": device_id, - } - ) - logger.warning( - "Client uploading keys for a different device " - "(logged in as %s, uploading for %s)", - requester.device_id, - device_id, - ) + ( + dehydrated_device_id, + _, + ) = await self.device_handler.get_dehydrated_device(user_id) + if device_id != dehydrated_device_id: + set_tag("error", True) + log_kv( + { + "message": "Client uploading keys for a different device", + "logged_in_id": requester.device_id, + "key_being_uploaded": device_id, + } + ) + logger.warning( + "Client uploading keys for a different device " + "(logged in as %s, uploading for %s)", + requester.device_id, + device_id, + ) else: device_id = requester.device_id diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index a5c24fbd63..6f4b224454 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -237,6 +237,7 @@ class SyncRestServlet(RestServlet): "leave": sync_result.groups.leave, }, "device_one_time_keys_count": sync_result.device_one_time_keys_count, + "device_unused_fallback_keys": sync_result.device_unused_fallback_keys, "next_batch": sync_result.next_batch.to_string(), } diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 2b33060480..f9385a2c83 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -33,9 +33,14 @@ from synapse.storage.database import ( ) from synapse.types import Collection, JsonDict, get_verify_key_from_cross_signing_key from synapse.util import json_encoder -from synapse.util.caches.descriptors import Cache, cached, cachedList +from synapse.util.caches.descriptors import ( + Cache, + cached, + cachedInlineCallbacks, + cachedList, +) from synapse.util.iterutils import batch_iter -from synapse.util.stringutils import shortstr +from synapse.util.stringutils import random_string, shortstr logger = logging.getLogger(__name__) @@ -746,6 +751,168 @@ class DeviceWorkerStore(SQLBaseStore): _mark_remote_user_device_list_as_unsubscribed_txn, ) + async def get_dehydrated_device(self, user_id: str) -> Tuple[str, JsonDict]: + """Retrieve the information for a dehydrated device. + + Args: + user_id: the user whose dehydrated device we are looking for + Returns: + a tuple whose first item is the device ID, and the second item is + the dehydrated device information + """ + # FIXME: make sure device ID still exists in devices table + row = await self.db_pool.simple_select_one( + table="dehydrated_devices", + keyvalues={"user_id": user_id}, + retcols=["device_id", "device_data"], + allow_none=True, + ) + return (row["device_id"], json.loads(row["device_data"])) if row else (None, None) + + def _store_dehydrated_device_txn( + self, txn, user_id: str, device_id: str, device_data: str + ) -> Optional[str]: + old_device_id = self.db_pool.simple_select_one_onecol_txn( + txn, + table="dehydrated_devices", + keyvalues={"user_id": user_id}, + retcol="device_id", + allow_none=True, + ) + if old_device_id is None: + self.db_pool.simple_insert_txn( + txn, + table="dehydrated_devices", + values={ + "user_id": user_id, + "device_id": device_id, + "device_data": device_data, + }, + ) + else: + self.db_pool.simple_update_txn( + txn, + table="dehydrated_devices", + keyvalues={"user_id": user_id}, + updatevalues={"device_id": device_id, "device_data": device_data}, + ) + return old_device_id + + async def store_dehydrated_device( + self, user_id: str, device_id: str, device_data: JsonDict + ) -> Optional[str]: + """Store a dehydrated device for a user. + + Args: + user_id: the user that we are storing the device for + device_data: the dehydrated device information + initial_device_display_name: The display name to use for the device + Returns: + device id of the user's previous dehydrated device, if any + """ + return await self.db_pool.runInteraction( + "store_dehydrated_device_txn", + self._store_dehydrated_device_txn, + user_id, + device_id, + json.dumps(device_data), + ) + + async def create_dehydration_token( + self, user_id: str, device_id: str, login_submission: JsonDict + ) -> str: + """Create a token for a client to fulfill a dehydration request. + + Args: + user_id: the user that we are creating the token for + device_id: the device ID for the dehydrated device. This is to + ensure that the device still exists when the user tells us + they want to use the dehydrated device. + login_submission: the contents of the login request. + Returns: + the dehydration token + """ + # FIXME: expire any old tokens + + attempts = 0 + while attempts < 5: + token = random_string(24) + + try: + await self.db_pool.simple_insert( + table="dehydration_token", + values={ + "token": token, + "user_id": user_id, + "device_id": device_id, + "login_submission": json.dumps(login_submission), + "creation_time": self.hs.get_clock().time_msec(), + }, + desc="create_dehydration_token", + ) + return token + except self.db_pool.engine.module.IntegrityError: + attempts += 1 + raise StoreError(500, "Couldn't generate a token.") + + def _clear_dehydration_token_txn(self, txn, token: str, dehydrate: bool) -> dict: + token_info = self.db_pool.simple_select_one_txn( + txn, + "dehydration_token", + {"token": token}, + ["user_id", "device_id", "login_submission"], + ) + self.db_pool.simple_delete_one_txn( + txn, "dehydration_token", {"token": token}, + ) + token_info["login_submission"] = json.loads(token_info["login_submission"]) + + if dehydrate: + device_id = self.db_pool.simple_select_one_onecol_txn( + txn, + "dehydrated_devices", + keyvalues={"user_id": token_info["user_id"]}, + retcol="device_id", + allow_none=True, + ) + token_info["dehydrated"] = False + if device_id == token_info["device_id"]: + count = self.db_pool.simple_delete_txn( + txn, + "dehydrated_devices", + { + "user_id": token_info["user_id"], + "device_id": token_info["device_id"], + }, + ) + if count != 0: + token_info["dehydrated"] = True + + return token_info + + async def clear_dehydration_token(self, token: str, dehydrate: bool) -> dict: + """Use a dehydration token. If the client wishes to use the dehydrated + device, it will also remove the dehydrated device. + + Args: + token: the dehydration token + dehydrate: whether the client wishes to use the dehydrated device + Returns: + A dict giving the information related to the token. It will have + the following properties: + - user_id: the user associated from the token + - device_id: the ID of the dehydrated device + - login_submission: the original submission to /login + - dehydrated: (only present if the "dehydrate" parameter is True). + Whether the dehydrated device can be used by the client. + """ + return await self.db_pool.runInteraction( + "get_users_whose_devices_changed", + self._clear_dehydration_token_txn, + token, + dehydrate, + ) + class DeviceBackgroundUpdateStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index f93e0d320d..a1291b06ff 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -271,6 +271,46 @@ class EndToEndKeyWorkerStore(SQLBaseStore): "count_e2e_one_time_keys", _count_e2e_one_time_keys ) + async def set_e2e_fallback_keys( + self, user_id: str, device_id: str, fallback_keys: dict + ): + # fallback_keys will usually only have one item in it, so using a for + # loop (as opposed to calling simple_upsert_many_txn) won't be too bad + # FIXME: make sure that only one key per algorithm is uploaded + for key_id, fallback_key in fallback_keys.items(): + algorithm, key_id = key_id.split(":", 1) + await self.db_pool.simple_upsert( + "e2e_fallback_keys_json", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + "algorithm": algorithm + }, + values={ + "key_id": key_id, + "key_json": json.dumps(fallback_key), + "used": 0 + }, + desc="set_e2e_fallback_key" + ) + + @cached(max_entries=10000) + async def get_e2e_unused_fallback_keys( + self, user_id: str, device_id: str + ): + return await self.db_pool.simple_select_onecol( + "e2e_fallback_keys_json", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + "used": 0 + }, + retcol="algorithm", + desc="get_e2e_unused_fallback_keys" + ) + + # FIXME: delete fallbacks when user logs out + async def get_e2e_cross_signing_key( self, user_id: str, key_type: str, from_user_id: Optional[str] = None ) -> Optional[dict]: @@ -590,15 +630,29 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): " WHERE user_id = ? AND device_id = ? AND algorithm = ?" " LIMIT 1" ) + fallback_sql = ( + "SELECT key_id, key_json, used FROM e2e_fallback_keys_json" + " WHERE user_id = ? AND device_id = ? AND algorithm = ?" + " LIMIT 1" + ) result = {} delete = [] + used_fallbacks = [] for user_id, device_id, algorithm in query_list: user_result = result.setdefault(user_id, {}) device_result = user_result.setdefault(device_id, {}) txn.execute(sql, (user_id, device_id, algorithm)) + found = False for key_id, key_json in txn: + found = True device_result[algorithm + ":" + key_id] = key_json delete.append((user_id, device_id, algorithm, key_id)) + if not found: + txn.execute(fallback_sql, (user_id, device_id, algorithm)) + for key_id, key_json, used in txn: + device_result[algorithm + ":" + key_id] = key_json + if used == 0: + used_fallbacks.append((user_id, device_id, algorithm, key_id)) sql = ( "DELETE FROM e2e_one_time_keys_json" " WHERE user_id = ? AND device_id = ? AND algorithm = ?" @@ -615,6 +669,23 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): self._invalidate_cache_and_stream( txn, self.count_e2e_one_time_keys, (user_id, device_id) ) + for user_id, device_id, algorithm, key_id in used_fallbacks: + self.db_pool.simple_update_txn( + txn, + "e2e_fallback_keys_json", + { + "user_id": user_id, + "device_id": device_id, + "algorithm": algorithm, + "key_id": key_id + }, + { + "used": 1 + } + ) + self._invalidate_cache_and_stream( + txn, self.get_e2e_unused_fallback_keys, (user_id, device_id) + ) return result return self.db_pool.runInteraction( @@ -643,6 +714,20 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): self._invalidate_cache_and_stream( txn, self.count_e2e_one_time_keys, (user_id, device_id) ) + self.db_pool.simple_delete_txn( + txn, + table="dehydrated_devices", + keyvalues={"user_id": user_id, "device_id": device_id}, + ) + self.db_pool.simple_delete_txn( + txn, + table="e2e_fallback_keys_json", + keyvalues={"user_id": user_id, "device_id": device_id}, + ) + self._invalidate_cache_and_stream( + txn, self.get_e2e_unused_fallback_keys, (user_id, device_id) + ) + return self.db_pool.runInteraction( "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn diff --git a/synapse/storage/databases/main/schema/delta/58/11dehydration.sql b/synapse/storage/databases/main/schema/delta/58/11dehydration.sql new file mode 100644 index 0000000000..be5e8a4712 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/11dehydration.sql @@ -0,0 +1,30 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS dehydrated_devices( + user_id TEXT NOT NULL PRIMARY KEY, + device_id TEXT NOT NULL, + device_data TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS dehydration_token( + token TEXT NOT NULL PRIMARY KEY, + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + login_submission TEXT NOT NULL, + creation_time BIGINT NOT NULL +); + +-- FIXME: index on creation_time to expire old tokens diff --git a/synapse/storage/databases/main/schema/delta/58/11fallback.sql b/synapse/storage/databases/main/schema/delta/58/11fallback.sql new file mode 100644 index 0000000000..272314a4a8 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/11fallback.sql @@ -0,0 +1,24 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS e2e_fallback_keys_json ( + user_id TEXT NOT NULL, -- The user this fallback key is for. + device_id TEXT NOT NULL, -- The device this fallback key is for. + algorithm TEXT NOT NULL, -- Which algorithm this fallback key is for. + key_id TEXT NOT NULL, -- An id for suppressing duplicate uploads. + key_json TEXT NOT NULL, -- The key as a JSON blob. + used SMALLINT NOT NULL DEFAULT 0, -- Whether the key has been used or not. + CONSTRAINT e2e_fallback_keys_json_uniqueness UNIQUE (user_id, device_id, algorithm) +); diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index db52725cfe..d0c3f40e78 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -754,3 +754,68 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase): channel.json_body["error"], "JWT validation failed: Signature verification failed", ) + + +class DehydrationTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + logout.register_servlets, + devices.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + self.hs = self.setup_test_homeserver() + self.hs.config.enable_registration = True + self.hs.config.registrations_require_3pid = [] + self.hs.config.auto_join_rooms = [] + self.hs.config.enable_registration_captcha = False + + return self.hs + + def test_dehydrate_and_rehydrate_device(self): + self.register_user("kermit", "monkey") + access_token = self.login("kermit", "monkey") + + # dehydrate a device + params = json.dumps({"device_data": "foobar"}) + request, channel = self.make_request( + b"POST", + b"/_matrix/client/unstable/org.matrix.msc2697/device/dehydrate", + params, + access_token=access_token, + ) + self.render(request) + self.assertEquals(channel.code, 200, channel.result) + dehydrated_device_id = channel.json_body["device_id"] + + # Log out + request, channel = self.make_request( + b"POST", "/logout", access_token=access_token + ) + self.render(request) + + # log in, requesting a dehydrated device + params = json.dumps( + { + "type": "m.login.password", + "user": "kermit", + "password": "monkey", + "org.matrix.msc2697.restore_device": True, + } + ) + request, channel = self.make_request("POST", "/_matrix/client/r0/login", params) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.json_body["device_data"], "foobar") + self.assertEqual(channel.json_body["device_id"], dehydrated_device_id) + dehydration_token = channel.json_body["dehydration_token"] + + params = json.dumps({"rehydrate": True, "dehydration_token": dehydration_token}) + request, channel = self.make_request( + "POST", "/_matrix/client/unstable/org.matrix.msc2697/restore_device", params + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.json_body["device_id"], dehydrated_device_id) |