summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/8312.feature1
-rwxr-xr-xscripts/synapse_port_db1
-rw-r--r--synapse/handlers/e2e_keys.py16
-rw-r--r--synapse/handlers/sync.py8
-rw-r--r--synapse/rest/client/v2_alpha/sync.py1
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py100
-rw-r--r--synapse/storage/databases/main/schema/delta/58/11fallback.sql24
-rw-r--r--tests/handlers/test_e2e_keys.py65
8 files changed, 215 insertions, 1 deletions
diff --git a/changelog.d/8312.feature b/changelog.d/8312.feature
new file mode 100644
index 0000000000..222a1b032a
--- /dev/null
+++ b/changelog.d/8312.feature
@@ -0,0 +1 @@
+Add support for olm fallback keys ([MSC2732](https://github.com/matrix-org/matrix-doc/pull/2732)).
\ No newline at end of file
diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db
index 7e12f5440c..2d0b59ab53 100755
--- a/scripts/synapse_port_db
+++ b/scripts/synapse_port_db
@@ -90,6 +90,7 @@ BOOLEAN_COLUMNS = {
     "room_stats_state": ["is_federatable"],
     "local_media_repository": ["safe_from_quarantine"],
     "users": ["shadow_banned"],
+    "e2e_fallback_keys_json": ["used"],
 }
 
 
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index dd40fd1299..611742ae72 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -496,6 +496,22 @@ class E2eKeysHandler:
             log_kv(
                 {"message": "Did not update one_time_keys", "reason": "no keys given"}
             )
+        fallback_keys = keys.get("org.matrix.msc2732.fallback_keys", None)
+        if fallback_keys and isinstance(fallback_keys, dict):
+            log_kv(
+                {
+                    "message": "Updating fallback_keys for device.",
+                    "user_id": user_id,
+                    "device_id": device_id,
+                }
+            )
+            await self.store.set_e2e_fallback_keys(user_id, device_id, fallback_keys)
+        elif fallback_keys:
+            log_kv({"message": "Did not update fallback_keys", "reason": "not a dict"})
+        else:
+            log_kv(
+                {"message": "Did not update fallback_keys", "reason": "no keys given"}
+            )
 
         # the device should have been registered already, but it may have been
         # deleted due to a race with a DELETE request. Or we may be using an
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index a998e6b7f6..dd1f90e359 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -201,6 +201,8 @@ class SyncResult:
         device_lists: List of user_ids whose devices have changed
         device_one_time_keys_count: Dict of algorithm to count for one time keys
             for this device
+        device_unused_fallback_key_types: List of key types that have an unused fallback
+            key
         groups: Group updates, if any
     """
 
@@ -213,6 +215,7 @@ class SyncResult:
     to_device = attr.ib(type=List[JsonDict])
     device_lists = attr.ib(type=DeviceLists)
     device_one_time_keys_count = attr.ib(type=JsonDict)
+    device_unused_fallback_key_types = attr.ib(type=List[str])
     groups = attr.ib(type=Optional[GroupsSyncResult])
 
     def __bool__(self) -> bool:
@@ -1014,10 +1017,14 @@ class SyncHandler:
         logger.debug("Fetching OTK data")
         device_id = sync_config.device_id
         one_time_key_counts = {}  # type: JsonDict
+        unused_fallback_key_types = []  # type: List[str]
         if device_id:
             one_time_key_counts = await self.store.count_e2e_one_time_keys(
                 user_id, device_id
             )
+            unused_fallback_key_types = await self.store.get_e2e_unused_fallback_key_types(
+                user_id, device_id
+            )
 
         logger.debug("Fetching group data")
         await self._generate_sync_entry_for_groups(sync_result_builder)
@@ -1041,6 +1048,7 @@ class SyncHandler:
             device_lists=device_lists,
             groups=sync_result_builder.groups,
             device_one_time_keys_count=one_time_key_counts,
+            device_unused_fallback_key_types=unused_fallback_key_types,
             next_batch=sync_result_builder.now_token,
         )
 
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index 6779df952f..2b84eb89c0 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -236,6 +236,7 @@ class SyncRestServlet(RestServlet):
                 "leave": sync_result.groups.leave,
             },
             "device_one_time_keys_count": sync_result.device_one_time_keys_count,
+            "org.matrix.msc2732.device_unused_fallback_key_types": sync_result.device_unused_fallback_key_types,
             "next_batch": await sync_result.next_batch.to_string(self.store),
         }
 
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 22e1ed15d0..8c97f2af5c 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -367,6 +367,57 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
             "count_e2e_one_time_keys", _count_e2e_one_time_keys
         )
 
+    async def set_e2e_fallback_keys(
+        self, user_id: str, device_id: str, fallback_keys: JsonDict
+    ) -> None:
+        """Set the user's e2e fallback keys.
+
+        Args:
+            user_id: the user whose keys are being set
+            device_id: the device whose keys are being set
+            fallback_keys: the keys to set.  This is a map from key ID (which is
+                of the form "algorithm:id") to key data.
+        """
+        # fallback_keys will usually only have one item in it, so using a for
+        # loop (as opposed to calling simple_upsert_many_txn) won't be too bad
+        # FIXME: make sure that only one key per algorithm is uploaded
+        for key_id, fallback_key in fallback_keys.items():
+            algorithm, key_id = key_id.split(":", 1)
+            await self.db_pool.simple_upsert(
+                "e2e_fallback_keys_json",
+                keyvalues={
+                    "user_id": user_id,
+                    "device_id": device_id,
+                    "algorithm": algorithm,
+                },
+                values={
+                    "key_id": key_id,
+                    "key_json": json_encoder.encode(fallback_key),
+                    "used": False,
+                },
+                desc="set_e2e_fallback_key",
+            )
+
+    @cached(max_entries=10000)
+    async def get_e2e_unused_fallback_key_types(
+        self, user_id: str, device_id: str
+    ) -> List[str]:
+        """Returns the fallback key types that have an unused key.
+
+        Args:
+            user_id: the user whose keys are being queried
+            device_id: the device whose keys are being queried
+
+        Returns:
+            a list of key types
+        """
+        return await self.db_pool.simple_select_onecol(
+            "e2e_fallback_keys_json",
+            keyvalues={"user_id": user_id, "device_id": device_id, "used": False},
+            retcol="algorithm",
+            desc="get_e2e_unused_fallback_key_types",
+        )
+
     async def get_e2e_cross_signing_key(
         self, user_id: str, key_type: str, from_user_id: Optional[str] = None
     ) -> Optional[dict]:
@@ -701,15 +752,37 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
                 " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
                 " LIMIT 1"
             )
+            fallback_sql = (
+                "SELECT key_id, key_json, used FROM e2e_fallback_keys_json"
+                " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
+                " LIMIT 1"
+            )
             result = {}
             delete = []
+            used_fallbacks = []
             for user_id, device_id, algorithm in query_list:
                 user_result = result.setdefault(user_id, {})
                 device_result = user_result.setdefault(device_id, {})
                 txn.execute(sql, (user_id, device_id, algorithm))
-                for key_id, key_json in txn:
+                otk_row = txn.fetchone()
+                if otk_row is not None:
+                    key_id, key_json = otk_row
                     device_result[algorithm + ":" + key_id] = key_json
                     delete.append((user_id, device_id, algorithm, key_id))
+                else:
+                    # no one-time key available, so see if there's a fallback
+                    # key
+                    txn.execute(fallback_sql, (user_id, device_id, algorithm))
+                    fallback_row = txn.fetchone()
+                    if fallback_row is not None:
+                        key_id, key_json, used = fallback_row
+                        device_result[algorithm + ":" + key_id] = key_json
+                        if not used:
+                            used_fallbacks.append(
+                                (user_id, device_id, algorithm, key_id)
+                            )
+
+            # drop any one-time keys that were claimed
             sql = (
                 "DELETE FROM e2e_one_time_keys_json"
                 " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
@@ -726,6 +799,23 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
                 self._invalidate_cache_and_stream(
                     txn, self.count_e2e_one_time_keys, (user_id, device_id)
                 )
+            # mark fallback keys as used
+            for user_id, device_id, algorithm, key_id in used_fallbacks:
+                self.db_pool.simple_update_txn(
+                    txn,
+                    "e2e_fallback_keys_json",
+                    {
+                        "user_id": user_id,
+                        "device_id": device_id,
+                        "algorithm": algorithm,
+                        "key_id": key_id,
+                    },
+                    {"used": True},
+                )
+                self._invalidate_cache_and_stream(
+                    txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
+                )
+
             return result
 
         return await self.db_pool.runInteraction(
@@ -754,6 +844,14 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
             self._invalidate_cache_and_stream(
                 txn, self.count_e2e_one_time_keys, (user_id, device_id)
             )
+            self.db_pool.simple_delete_txn(
+                txn,
+                table="e2e_fallback_keys_json",
+                keyvalues={"user_id": user_id, "device_id": device_id},
+            )
+            self._invalidate_cache_and_stream(
+                txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
+            )
 
         await self.db_pool.runInteraction(
             "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
diff --git a/synapse/storage/databases/main/schema/delta/58/11fallback.sql b/synapse/storage/databases/main/schema/delta/58/11fallback.sql
new file mode 100644
index 0000000000..4ed981dbf8
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/11fallback.sql
@@ -0,0 +1,24 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS e2e_fallback_keys_json (
+    user_id TEXT NOT NULL, -- The user this fallback key is for.
+    device_id TEXT NOT NULL, -- The device this fallback key is for.
+    algorithm TEXT NOT NULL, -- Which algorithm this fallback key is for.
+    key_id TEXT NOT NULL, -- An id for suppressing duplicate uploads.
+    key_json TEXT NOT NULL, -- The key as a JSON blob.
+    used BOOLEAN NOT NULL DEFAULT FALSE, -- Whether the key has been used or not.
+    CONSTRAINT e2e_fallback_keys_json_uniqueness UNIQUE (user_id, device_id, algorithm)
+);
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 366dcfb670..4e9e3dcbc2 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -172,6 +172,71 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         )
 
     @defer.inlineCallbacks
+    def test_fallback_key(self):
+        local_user = "@boris:" + self.hs.hostname
+        device_id = "xyz"
+        fallback_key = {"alg1:k1": "key1"}
+        otk = {"alg1:k2": "key2"}
+
+        yield defer.ensureDeferred(
+            self.handler.upload_keys_for_user(
+                local_user,
+                device_id,
+                {"org.matrix.msc2732.fallback_keys": fallback_key},
+            )
+        )
+
+        # claiming an OTK when no OTKs are available should return the fallback
+        # key
+        res = yield defer.ensureDeferred(
+            self.handler.claim_one_time_keys(
+                {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+            )
+        )
+        self.assertEqual(
+            res,
+            {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
+        )
+
+        # claiming an OTK again should return the same fallback key
+        res = yield defer.ensureDeferred(
+            self.handler.claim_one_time_keys(
+                {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+            )
+        )
+        self.assertEqual(
+            res,
+            {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
+        )
+
+        # if the user uploads a one-time key, the next claim should fetch the
+        # one-time key, and then go back to the fallback
+        yield defer.ensureDeferred(
+            self.handler.upload_keys_for_user(
+                local_user, device_id, {"one_time_keys": otk}
+            )
+        )
+
+        res = yield defer.ensureDeferred(
+            self.handler.claim_one_time_keys(
+                {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+            )
+        )
+        self.assertEqual(
+            res, {"failures": {}, "one_time_keys": {local_user: {device_id: otk}}},
+        )
+
+        res = yield defer.ensureDeferred(
+            self.handler.claim_one_time_keys(
+                {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+            )
+        )
+        self.assertEqual(
+            res,
+            {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
+        )
+
+    @defer.inlineCallbacks
     def test_replace_master_key(self):
         """uploading a new signing key should make the old signing key unavailable"""
         local_user = "@boris:" + self.hs.hostname