summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/7955.feature1
-rw-r--r--synapse/handlers/device.py135
-rw-r--r--synapse/handlers/e2e_keys.py16
-rw-r--r--synapse/handlers/sync.py8
-rw-r--r--synapse/rest/client/v1/login.py116
-rw-r--r--synapse/rest/client/v2_alpha/keys.py34
-rw-r--r--synapse/rest/client/v2_alpha/sync.py1
-rw-r--r--synapse/storage/databases/main/devices.py171
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py85
-rw-r--r--synapse/storage/databases/main/schema/delta/58/11dehydration.sql30
-rw-r--r--synapse/storage/databases/main/schema/delta/58/11fallback.sql24
-rw-r--r--tests/rest/client/v1/test_login.py65
12 files changed, 669 insertions, 17 deletions
diff --git a/changelog.d/7955.feature b/changelog.d/7955.feature
new file mode 100644
index 0000000000..7d726046fe
--- /dev/null
+++ b/changelog.d/7955.feature
@@ -0,0 +1 @@
+Add support for device dehydration. (MSC2697)
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index db417d60de..7c809b27f0 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -14,8 +14,9 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import json
 import logging
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List, Optional, Tuple
 
 from synapse.api import errors
 from synapse.api.constants import EventTypes
@@ -28,6 +29,7 @@ from synapse.api.errors import (
 from synapse.logging.opentracing import log_kv, set_tag, trace
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.types import (
+    JsonDict,
     RoomStreamToken,
     get_domain_from_id,
     get_verify_key_from_cross_signing_key,
@@ -489,6 +491,137 @@ class DeviceHandler(DeviceWorkerHandler):
             # receive device updates. Mark this in DB.
             await self.store.mark_remote_user_device_list_as_unsubscribed(user_id)
 
+    async def store_dehydrated_device(
+        self,
+        user_id: str,
+        device_data: JsonDict,
+        initial_device_display_name: Optional[str] = None,
+    ) -> str:
+        """Store a dehydrated device for a user.  If the user had a previous
+        dehydrated device, it is removed.
+
+        Args:
+            user_id: the user that we are storing the device for
+            device_data: the dehydrated device information
+            initial_device_display_name: The display name to use for the device
+        Returns:
+            device id of the dehydrated device
+        """
+        device_id = await self.check_device_registered(
+            user_id, None, initial_device_display_name,
+        )
+        old_device_id = await self.store.store_dehydrated_device(
+            user_id, device_id, device_data
+        )
+        if old_device_id is not None:
+            await self.delete_device(user_id, old_device_id)
+        return device_id
+
+    async def get_dehydrated_device(self, user_id: str) -> Tuple[str, JsonDict]:
+        """Retrieve the information for a dehydrated device.
+
+        Args:
+            user_id: the user whose dehydrated device we are looking for
+        Returns:
+            a tuple whose first item is the device ID, and the second item is
+            the dehydrated device information
+        """
+        return await self.store.get_dehydrated_device(user_id)
+
+    async def create_dehydration_token(
+        self, user_id: str, device_id: str, login_submission: JsonDict
+    ) -> str:
+        """Create a token for a client to fulfill a dehydration request.
+
+        Args:
+            user_id: the user that we are creating the token for
+            device_id: the device ID for the dehydrated device.  This is to
+                ensure that the device still exists when the user tells us
+                they want to use the dehydrated device.
+            login_submission: the contents of the login request.
+        Returns:
+            the dehydration token
+        """
+        return await self.store.create_dehydration_token(
+            user_id, device_id, login_submission
+        )
+
+    async def rehydrate_device(self, token: str) -> dict:
+        """Process a rehydration request from the user.
+
+        Args:
+            token: the dehydration token
+        Returns:
+            the login result, including the user's access token and device ID
+        """
+        # FIXME: if can't find token, return 404
+        token_info = await self.store.clear_dehydration_token(token, True)
+
+        # normally, the constructor would do self.registration_handler =
+        # self.hs.get_registration_handler(), but doing that results in a
+        # circular dependency in the handlers.  So do this for now
+        registration_handler = self.hs.get_registration_handler()
+
+        if token_info["dehydrated"]:
+            # create access token for dehydrated device
+            initial_display_name = (
+                None  # FIXME: get display name from login submission?
+            )
+            device_id, access_token = await registration_handler.register_device(
+                token_info.get("user_id"),
+                token_info.get("device_id"),
+                initial_display_name,
+            )
+
+            return {
+                "user_id": token_info["user_id"],
+                "access_token": access_token,
+                "home_server": self.hs.hostname,
+                "device_id": device_id,
+            }
+
+        else:
+            # create device and access token from original login submission
+            login_submission = token_info["login_submission"]
+            device_id = login_submission.get("device_id")
+            initial_display_name = login_submission.get("initial_device_display_name")
+            device_id, access_token = await registration_handler.register_device(
+                token_info.get("user_id"), device_id, initial_display_name
+            )
+
+            return {
+                "user_id": token.info["user_id"],
+                "access_token": access_token,
+                "home_server": self.hs.hostname,
+                "device_id": device_id,
+            }
+
+    async def cancel_rehydrate(self, token: str) -> dict:
+        """Cancel a rehydration request from the user and complete the user's login.
+
+        Args:
+            token: the dehydration token
+        Returns:
+            the login result, including the user's access token and device ID
+        """
+        # FIXME: if can't find token, return 404
+        token_info = await self.store.clear_dehydration_token(token, False)
+        # create device and access token from original login submission
+        login_submission = token_info["login_submission"]
+        device_id = login_submission.get("device_id")
+        initial_display_name = login_submission.get("initial_device_display_name")
+        registration_handler = self.hs.get_registration_handler()
+        device_id, access_token = await registration_handler.register_device(
+            token_info.get("user_id"), device_id, initial_display_name
+        )
+
+        return {
+            "user_id": token_info.get("user_id"),
+            "access_token": access_token,
+            "home_server": self.hs.hostname,
+            "device_id": device_id,
+        }
+
 
 def _update_device_from_client_ips(device, client_ips):
     ip = client_ips.get((device["user_id"], device["device_id"]), {})
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 84169c1022..0c37829afc 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -496,6 +496,22 @@ class E2eKeysHandler(object):
             log_kv(
                 {"message": "Did not update one_time_keys", "reason": "no keys given"}
             )
+        fallback_keys = keys.get("fallback_keys", None)
+        if fallback_keys and isinstance(fallback_keys, dict):
+            log_kv(
+                {
+                    "message": "Updating fallback_keys for device.",
+                    "user_id": user_id,
+                    "device_id": device_id,
+                }
+            )
+            await self.store.set_e2e_fallback_keys(
+                user_id, device_id, fallback_keys
+            )
+        else:
+            log_kv(
+                {"message": "Did not update fallback_keys", "reason": "no keys given"}
+            )
 
         # the device should have been registered already, but it may have been
         # deleted due to a race with a DELETE request. Or we may be using an
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index c42dac18f5..e340b1e615 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -203,6 +203,8 @@ class SyncResult:
         device_lists: List of user_ids whose devices have changed
         device_one_time_keys_count: Dict of algorithm to count for one time keys
             for this device
+        device_unused_fallback_keys: List of key types that have an unused fallback
+            key
         groups: Group updates, if any
     """
 
@@ -215,6 +217,7 @@ class SyncResult:
     to_device = attr.ib(type=List[JsonDict])
     device_lists = attr.ib(type=DeviceLists)
     device_one_time_keys_count = attr.ib(type=JsonDict)
+    device_unused_fallback_keys = attr.ib(type=List[str])
     groups = attr.ib(type=Optional[GroupsSyncResult])
 
     def __nonzero__(self) -> bool:
@@ -1024,10 +1027,14 @@ class SyncHandler(object):
         logger.debug("Fetching OTK data")
         device_id = sync_config.device_id
         one_time_key_counts = {}  # type: JsonDict
+        unused_fallback_keys = []  # type: list
         if device_id:
             one_time_key_counts = await self.store.count_e2e_one_time_keys(
                 user_id, device_id
             )
+            unused_fallback_keys = await self.store.get_e2e_unused_fallback_keys(
+                user_id, device_id
+            )
 
         logger.debug("Fetching group data")
         await self._generate_sync_entry_for_groups(sync_result_builder)
@@ -1051,6 +1058,7 @@ class SyncHandler(object):
             device_lists=device_lists,
             groups=sync_result_builder.groups,
             device_one_time_keys_count=one_time_key_counts,
+            device_unused_fallback_keys=unused_fallback_keys,
             next_batch=sync_result_builder.now_token,
         )
 
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 379f668d6f..68fece986b 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -103,6 +103,7 @@ class LoginRestServlet(RestServlet):
         self.oidc_enabled = hs.config.oidc_enabled
 
         self.auth_handler = self.hs.get_auth_handler()
+        self.device_handler = hs.get_device_handler()
         self.registration_handler = hs.get_registration_handler()
         self.handlers = hs.get_handlers()
         self._well_known_builder = WellKnownBuilder(hs)
@@ -339,6 +340,29 @@ class LoginRestServlet(RestServlet):
                 )
             user_id = canonical_uid
 
+        if login_submission.get("org.matrix.msc2697.restore_device"):
+            # user requested to rehydrate a device, so check if there they have
+            # a dehydrated device, and if so, allow them to try to rehydrate it
+            (
+                device_id,
+                dehydrated_device,
+            ) = await self.device_handler.get_dehydrated_device(user_id)
+            if dehydrated_device:
+                token = await self.device_handler.create_dehydration_token(
+                    user_id, device_id, login_submission
+                )
+                result = {
+                    "user_id": user_id,
+                    "home_server": self.hs.hostname,
+                    "device_data": dehydrated_device,
+                    "device_id": device_id,
+                    "dehydration_token": token,
+                }
+
+                # FIXME: call callback?
+
+                return result
+
         device_id = login_submission.get("device_id")
         initial_display_name = login_submission.get("initial_device_display_name")
         device_id, access_token = await self.registration_handler.register_device(
@@ -401,6 +425,96 @@ class LoginRestServlet(RestServlet):
         return result
 
 
+class RestoreDeviceServlet(RestServlet):
+    """Complete a rehydration request, either by letting the client use the
+    dehydrated device, or by creating a new device for the user.
+
+    POST /org.matrix.msc2697/restore_device
+    Content-Type: application/json
+
+    {
+      "rehydrate": true,
+      "dehydration_token": "an_opaque_token"
+    }
+
+    HTTP/1.1 200 OK
+    Content-Type: application/json
+
+    { // same format as the result from a /login request
+      "user_id": "@alice:example.org",
+      "device_id": "dehydrated_device",
+      "access_token": "another_opaque_token"
+    }
+
+    """
+
+    PATTERNS = client_patterns("/org.matrix.msc2697/restore_device")
+
+    def __init__(self, hs):
+        super(RestoreDeviceServlet, self).__init__()
+        self.hs = hs
+        self.device_handler = hs.get_device_handler()
+        self._well_known_builder = WellKnownBuilder(hs)
+
+    async def on_POST(self, request: SynapseRequest):
+        submission = parse_json_object_from_request(request)
+
+        if submission.get("rehydrate"):
+            result = await self.device_handler.rehydrate_device(
+                submission["dehydration_token"]
+            )
+        else:
+            result = await self.device_handler.cancel_rehydrate(
+                submission["dehydration_token"]
+            )
+        well_known_data = self._well_known_builder.get_well_known()
+        if well_known_data:
+            result["well_known"] = well_known_data
+        return (200, result)
+
+
+class StoreDeviceServlet(RestServlet):
+    """Store a dehydrated device.
+
+    POST /org.matrix.msc2697/device/dehydrate
+    Content-Type: application/json
+
+    {
+      "device_data": {
+        "algorithm": "m.dehydration.v1.olm",
+        "account": "dehydrated_device"
+      }
+    }
+
+    HTTP/1.1 200 OK
+    Content-Type: application/json
+
+    {
+      "device_id": "dehydrated_device_id"
+    }
+
+    """
+
+    PATTERNS = client_patterns("/org.matrix.msc2697/device/dehydrate")
+
+    def __init__(self, hs):
+        super(StoreDeviceServlet, self).__init__()
+        self.hs = hs
+        self.auth = hs.get_auth()
+        self.device_handler = hs.get_device_handler()
+
+    async def on_POST(self, request: SynapseRequest):
+        submission = parse_json_object_from_request(request)
+        requester = await self.auth.get_user_by_req(request)
+
+        device_id = await self.device_handler.store_dehydrated_device(
+            requester.user.to_string(),
+            submission["device_data"],
+            submission.get("initial_device_display_name", None)
+        )
+        return 200, {"device_id": device_id}
+
+
 class BaseSSORedirectServlet(RestServlet):
     """Common base class for /login/sso/redirect impls"""
 
@@ -499,6 +613,8 @@ class OIDCRedirectServlet(BaseSSORedirectServlet):
 
 def register_servlets(hs, http_server):
     LoginRestServlet(hs).register(http_server)
+    RestoreDeviceServlet(hs).register(http_server)
+    StoreDeviceServlet(hs).register(http_server)
     if hs.config.cas_enabled:
         CasRedirectServlet(hs).register(http_server)
         CasTicketServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index 24bb090822..b86c8f598b 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -67,6 +67,7 @@ class KeyUploadServlet(RestServlet):
         super(KeyUploadServlet, self).__init__()
         self.auth = hs.get_auth()
         self.e2e_keys_handler = hs.get_e2e_keys_handler()
+        self.device_handler = hs.get_device_handler()
 
     @trace(opname="upload_keys")
     async def on_POST(self, request, device_id):
@@ -78,20 +79,25 @@ class KeyUploadServlet(RestServlet):
             # passing the device_id here is deprecated; however, we allow it
             # for now for compatibility with older clients.
             if requester.device_id is not None and device_id != requester.device_id:
-                set_tag("error", True)
-                log_kv(
-                    {
-                        "message": "Client uploading keys for a different device",
-                        "logged_in_id": requester.device_id,
-                        "key_being_uploaded": device_id,
-                    }
-                )
-                logger.warning(
-                    "Client uploading keys for a different device "
-                    "(logged in as %s, uploading for %s)",
-                    requester.device_id,
-                    device_id,
-                )
+                (
+                    dehydrated_device_id,
+                    _,
+                ) = await self.device_handler.get_dehydrated_device(user_id)
+                if device_id != dehydrated_device_id:
+                    set_tag("error", True)
+                    log_kv(
+                        {
+                            "message": "Client uploading keys for a different device",
+                            "logged_in_id": requester.device_id,
+                            "key_being_uploaded": device_id,
+                        }
+                    )
+                    logger.warning(
+                        "Client uploading keys for a different device "
+                        "(logged in as %s, uploading for %s)",
+                        requester.device_id,
+                        device_id,
+                    )
         else:
             device_id = requester.device_id
 
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index a5c24fbd63..6f4b224454 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -237,6 +237,7 @@ class SyncRestServlet(RestServlet):
                 "leave": sync_result.groups.leave,
             },
             "device_one_time_keys_count": sync_result.device_one_time_keys_count,
+            "device_unused_fallback_keys": sync_result.device_unused_fallback_keys,
             "next_batch": sync_result.next_batch.to_string(),
         }
 
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 2b33060480..f9385a2c83 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -33,9 +33,14 @@ from synapse.storage.database import (
 )
 from synapse.types import Collection, JsonDict, get_verify_key_from_cross_signing_key
 from synapse.util import json_encoder
-from synapse.util.caches.descriptors import Cache, cached, cachedList
+from synapse.util.caches.descriptors import (
+    Cache,
+    cached,
+    cachedInlineCallbacks,
+    cachedList,
+)
 from synapse.util.iterutils import batch_iter
-from synapse.util.stringutils import shortstr
+from synapse.util.stringutils import random_string, shortstr
 
 logger = logging.getLogger(__name__)
 
@@ -746,6 +751,168 @@ class DeviceWorkerStore(SQLBaseStore):
             _mark_remote_user_device_list_as_unsubscribed_txn,
         )
 
+    async def get_dehydrated_device(self, user_id: str) -> Tuple[str, JsonDict]:
+        """Retrieve the information for a dehydrated device.
+
+        Args:
+            user_id: the user whose dehydrated device we are looking for
+        Returns:
+            a tuple whose first item is the device ID, and the second item is
+            the dehydrated device information
+        """
+        # FIXME: make sure device ID still exists in devices table
+        row = await self.db_pool.simple_select_one(
+            table="dehydrated_devices",
+            keyvalues={"user_id": user_id},
+            retcols=["device_id", "device_data"],
+            allow_none=True,
+        )
+        return (row["device_id"], json.loads(row["device_data"])) if row else (None, None)
+
+    def _store_dehydrated_device_txn(
+        self, txn, user_id: str, device_id: str, device_data: str
+    ) -> Optional[str]:
+        old_device_id = self.db_pool.simple_select_one_onecol_txn(
+            txn,
+            table="dehydrated_devices",
+            keyvalues={"user_id": user_id},
+            retcol="device_id",
+            allow_none=True,
+        )
+        if old_device_id is None:
+            self.db_pool.simple_insert_txn(
+                txn,
+                table="dehydrated_devices",
+                values={
+                    "user_id": user_id,
+                    "device_id": device_id,
+                    "device_data": device_data,
+                },
+            )
+        else:
+            self.db_pool.simple_update_txn(
+                txn,
+                table="dehydrated_devices",
+                keyvalues={"user_id": user_id},
+                updatevalues={"device_id": device_id, "device_data": device_data},
+            )
+        return old_device_id
+
+    async def store_dehydrated_device(
+        self, user_id: str, device_id: str, device_data: JsonDict
+    ) -> Optional[str]:
+        """Store a dehydrated device for a user.
+
+        Args:
+            user_id: the user that we are storing the device for
+            device_data: the dehydrated device information
+            initial_device_display_name: The display name to use for the device
+        Returns:
+            device id of the user's previous dehydrated device, if any
+        """
+        return await self.db_pool.runInteraction(
+            "store_dehydrated_device_txn",
+            self._store_dehydrated_device_txn,
+            user_id,
+            device_id,
+            json.dumps(device_data),
+        )
+
+    async def create_dehydration_token(
+        self, user_id: str, device_id: str, login_submission: JsonDict
+    ) -> str:
+        """Create a token for a client to fulfill a dehydration request.
+
+        Args:
+            user_id: the user that we are creating the token for
+            device_id: the device ID for the dehydrated device.  This is to
+                ensure that the device still exists when the user tells us
+                they want to use the dehydrated device.
+            login_submission: the contents of the login request.
+        Returns:
+            the dehydration token
+        """
+        # FIXME: expire any old tokens
+
+        attempts = 0
+        while attempts < 5:
+            token = random_string(24)
+
+            try:
+                await self.db_pool.simple_insert(
+                    table="dehydration_token",
+                    values={
+                        "token": token,
+                        "user_id": user_id,
+                        "device_id": device_id,
+                        "login_submission": json.dumps(login_submission),
+                        "creation_time": self.hs.get_clock().time_msec(),
+                    },
+                    desc="create_dehydration_token",
+                )
+                return token
+            except self.db_pool.engine.module.IntegrityError:
+                attempts += 1
+        raise StoreError(500, "Couldn't generate a token.")
+
+    def _clear_dehydration_token_txn(self, txn, token: str, dehydrate: bool) -> dict:
+        token_info = self.db_pool.simple_select_one_txn(
+            txn,
+            "dehydration_token",
+            {"token": token},
+            ["user_id", "device_id", "login_submission"],
+        )
+        self.db_pool.simple_delete_one_txn(
+            txn, "dehydration_token", {"token": token},
+        )
+        token_info["login_submission"] = json.loads(token_info["login_submission"])
+
+        if dehydrate:
+            device_id = self.db_pool.simple_select_one_onecol_txn(
+                txn,
+                "dehydrated_devices",
+                keyvalues={"user_id": token_info["user_id"]},
+                retcol="device_id",
+                allow_none=True,
+            )
+            token_info["dehydrated"] = False
+            if device_id == token_info["device_id"]:
+                count = self.db_pool.simple_delete_txn(
+                    txn,
+                    "dehydrated_devices",
+                    {
+                        "user_id": token_info["user_id"],
+                        "device_id": token_info["device_id"],
+                    },
+                )
+                if count != 0:
+                    token_info["dehydrated"] = True
+
+        return token_info
+
+    async def clear_dehydration_token(self, token: str, dehydrate: bool) -> dict:
+        """Use a dehydration token.  If the client wishes to use the dehydrated
+        device, it will also remove the dehydrated device.
+
+        Args:
+            token: the dehydration token
+            dehydrate: whether the client wishes to use the dehydrated device
+        Returns:
+            A dict giving the information related to the token.  It will have
+            the following properties:
+            - user_id: the user associated from the token
+            - device_id: the ID of the dehydrated device
+            - login_submission: the original submission to /login
+            - dehydrated: (only present if the "dehydrate" parameter is True).
+              Whether the dehydrated device can be used by the client.
+        """
+        return await self.db_pool.runInteraction(
+            "get_users_whose_devices_changed",
+            self._clear_dehydration_token_txn,
+            token,
+            dehydrate,
+        )
+
 
 class DeviceBackgroundUpdateStore(SQLBaseStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index f93e0d320d..a1291b06ff 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -271,6 +271,46 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
             "count_e2e_one_time_keys", _count_e2e_one_time_keys
         )
 
+    async def set_e2e_fallback_keys(
+            self, user_id: str, device_id: str, fallback_keys: dict
+    ):
+        # fallback_keys will usually only have one item in it, so using a for
+        # loop (as opposed to calling simple_upsert_many_txn) won't be too bad
+        # FIXME: make sure that only one key per algorithm is uploaded
+        for key_id, fallback_key in fallback_keys.items():
+            algorithm, key_id = key_id.split(":", 1)
+            await self.db_pool.simple_upsert(
+                "e2e_fallback_keys_json",
+                keyvalues={
+                    "user_id": user_id,
+                    "device_id": device_id,
+                    "algorithm": algorithm
+                },
+                values={
+                    "key_id": key_id,
+                    "key_json": json.dumps(fallback_key),
+                    "used": 0
+                },
+                desc="set_e2e_fallback_key"
+            )
+
+    @cached(max_entries=10000)
+    async def get_e2e_unused_fallback_keys(
+            self, user_id: str, device_id: str
+    ):
+        return await self.db_pool.simple_select_onecol(
+            "e2e_fallback_keys_json",
+            keyvalues={
+                "user_id": user_id,
+                "device_id": device_id,
+                "used": 0
+            },
+            retcol="algorithm",
+            desc="get_e2e_unused_fallback_keys"
+        )
+
+    # FIXME: delete fallbacks when user logs out
+
     async def get_e2e_cross_signing_key(
         self, user_id: str, key_type: str, from_user_id: Optional[str] = None
     ) -> Optional[dict]:
@@ -590,15 +630,29 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
                 " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
                 " LIMIT 1"
             )
+            fallback_sql = (
+                "SELECT key_id, key_json, used FROM e2e_fallback_keys_json"
+                " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
+                " LIMIT 1"
+            )
             result = {}
             delete = []
+            used_fallbacks = []
             for user_id, device_id, algorithm in query_list:
                 user_result = result.setdefault(user_id, {})
                 device_result = user_result.setdefault(device_id, {})
                 txn.execute(sql, (user_id, device_id, algorithm))
+                found = False
                 for key_id, key_json in txn:
+                    found = True
                     device_result[algorithm + ":" + key_id] = key_json
                     delete.append((user_id, device_id, algorithm, key_id))
+                if not found:
+                    txn.execute(fallback_sql, (user_id, device_id, algorithm))
+                    for key_id, key_json, used in txn:
+                        device_result[algorithm + ":" + key_id] = key_json
+                        if used == 0:
+                            used_fallbacks.append((user_id, device_id, algorithm, key_id))
             sql = (
                 "DELETE FROM e2e_one_time_keys_json"
                 " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
@@ -615,6 +669,23 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
                 self._invalidate_cache_and_stream(
                     txn, self.count_e2e_one_time_keys, (user_id, device_id)
                 )
+            for user_id, device_id, algorithm, key_id in used_fallbacks:
+                self.db_pool.simple_update_txn(
+                    txn,
+                    "e2e_fallback_keys_json",
+                    {
+                        "user_id": user_id,
+                        "device_id": device_id,
+                        "algorithm": algorithm,
+                        "key_id": key_id
+                    },
+                    {
+                        "used": 1
+                    }
+                )
+                self._invalidate_cache_and_stream(
+                    txn, self.get_e2e_unused_fallback_keys, (user_id, device_id)
+                )
             return result
 
         return self.db_pool.runInteraction(
@@ -643,6 +714,20 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
             self._invalidate_cache_and_stream(
                 txn, self.count_e2e_one_time_keys, (user_id, device_id)
             )
+            self.db_pool.simple_delete_txn(
+                txn,
+                table="dehydrated_devices",
+                keyvalues={"user_id": user_id, "device_id": device_id},
+            )
+            self.db_pool.simple_delete_txn(
+                txn,
+                table="e2e_fallback_keys_json",
+                keyvalues={"user_id": user_id, "device_id": device_id},
+            )
+            self._invalidate_cache_and_stream(
+                txn, self.get_e2e_unused_fallback_keys, (user_id, device_id)
+            )
+
 
         return self.db_pool.runInteraction(
             "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
diff --git a/synapse/storage/databases/main/schema/delta/58/11dehydration.sql b/synapse/storage/databases/main/schema/delta/58/11dehydration.sql
new file mode 100644
index 0000000000..be5e8a4712
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/11dehydration.sql
@@ -0,0 +1,30 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS dehydrated_devices(
+    user_id TEXT NOT NULL PRIMARY KEY,
+    device_id TEXT NOT NULL,
+    device_data TEXT NOT NULL
+);
+
+CREATE TABLE IF NOT EXISTS dehydration_token(
+    token TEXT NOT NULL PRIMARY KEY,
+    user_id TEXT NOT NULL,
+    device_id TEXT NOT NULL,
+    login_submission TEXT NOT NULL,
+    creation_time BIGINT NOT NULL
+);
+
+-- FIXME: index on creation_time to expire old tokens
diff --git a/synapse/storage/databases/main/schema/delta/58/11fallback.sql b/synapse/storage/databases/main/schema/delta/58/11fallback.sql
new file mode 100644
index 0000000000..272314a4a8
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/11fallback.sql
@@ -0,0 +1,24 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS e2e_fallback_keys_json (
+    user_id TEXT NOT NULL, -- The user this fallback key is for.
+    device_id TEXT NOT NULL, -- The device this fallback key is for.
+    algorithm TEXT NOT NULL, -- Which algorithm this fallback key is for.
+    key_id TEXT NOT NULL, -- An id for suppressing duplicate uploads.
+    key_json TEXT NOT NULL, -- The key as a JSON blob.
+    used SMALLINT NOT NULL DEFAULT 0, -- Whether the key has been used or not.
+    CONSTRAINT e2e_fallback_keys_json_uniqueness UNIQUE (user_id, device_id, algorithm)
+);
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index db52725cfe..d0c3f40e78 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -754,3 +754,68 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
             channel.json_body["error"],
             "JWT validation failed: Signature verification failed",
         )
+
+
+class DehydrationTestCase(unittest.HomeserverTestCase):
+
+    servlets = [
+        synapse.rest.admin.register_servlets_for_client_rest_resource,
+        login.register_servlets,
+        logout.register_servlets,
+        devices.register_servlets,
+    ]
+
+    def make_homeserver(self, reactor, clock):
+        self.hs = self.setup_test_homeserver()
+        self.hs.config.enable_registration = True
+        self.hs.config.registrations_require_3pid = []
+        self.hs.config.auto_join_rooms = []
+        self.hs.config.enable_registration_captcha = False
+
+        return self.hs
+
+    def test_dehydrate_and_rehydrate_device(self):
+        self.register_user("kermit", "monkey")
+        access_token = self.login("kermit", "monkey")
+
+        # dehydrate a device
+        params = json.dumps({"device_data": "foobar"})
+        request, channel = self.make_request(
+            b"POST",
+            b"/_matrix/client/unstable/org.matrix.msc2697/device/dehydrate",
+            params,
+            access_token=access_token,
+        )
+        self.render(request)
+        self.assertEquals(channel.code, 200, channel.result)
+        dehydrated_device_id = channel.json_body["device_id"]
+
+        # Log out
+        request, channel = self.make_request(
+            b"POST", "/logout", access_token=access_token
+        )
+        self.render(request)
+
+        # log in, requesting a dehydrated device
+        params = json.dumps(
+            {
+                "type": "m.login.password",
+                "user": "kermit",
+                "password": "monkey",
+                "org.matrix.msc2697.restore_device": True,
+            }
+        )
+        request, channel = self.make_request("POST", "/_matrix/client/r0/login", params)
+        self.render(request)
+        self.assertEqual(channel.code, 200, channel.result)
+        self.assertEqual(channel.json_body["device_data"], "foobar")
+        self.assertEqual(channel.json_body["device_id"], dehydrated_device_id)
+        dehydration_token = channel.json_body["dehydration_token"]
+
+        params = json.dumps({"rehydrate": True, "dehydration_token": dehydration_token})
+        request, channel = self.make_request(
+            "POST", "/_matrix/client/unstable/org.matrix.msc2697/restore_device", params
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200, channel.result)
+        self.assertEqual(channel.json_body["device_id"], dehydrated_device_id)