summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/11382.misc1
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py51
-rw-r--r--tests/handlers/test_e2e_keys.py32
3 files changed, 72 insertions, 12 deletions
diff --git a/changelog.d/11382.misc b/changelog.d/11382.misc
new file mode 100644
index 0000000000..d812ef309e
--- /dev/null
+++ b/changelog.d/11382.misc
@@ -0,0 +1 @@
+Keep fallback key marked as used if it's re-uploaded.
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index a95ac34f09..b06c1dc45b 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -408,29 +408,58 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
             fallback_keys: the keys to set.  This is a map from key ID (which is
                 of the form "algorithm:id") to key data.
         """
+        await self.db_pool.runInteraction(
+            "set_e2e_fallback_keys_txn",
+            self._set_e2e_fallback_keys_txn,
+            user_id,
+            device_id,
+            fallback_keys,
+        )
+
+        await self.invalidate_cache_and_stream(
+            "get_e2e_unused_fallback_key_types", (user_id, device_id)
+        )
+
+    def _set_e2e_fallback_keys_txn(
+        self, txn: Connection, user_id: str, device_id: str, fallback_keys: JsonDict
+    ) -> None:
         # fallback_keys will usually only have one item in it, so using a for
         # loop (as opposed to calling simple_upsert_many_txn) won't be too bad
         # FIXME: make sure that only one key per algorithm is uploaded
         for key_id, fallback_key in fallback_keys.items():
             algorithm, key_id = key_id.split(":", 1)
-            await self.db_pool.simple_upsert(
-                "e2e_fallback_keys_json",
+            old_key_json = self.db_pool.simple_select_one_onecol_txn(
+                txn,
+                table="e2e_fallback_keys_json",
                 keyvalues={
                     "user_id": user_id,
                     "device_id": device_id,
                     "algorithm": algorithm,
                 },
-                values={
-                    "key_id": key_id,
-                    "key_json": json_encoder.encode(fallback_key),
-                    "used": False,
-                },
-                desc="set_e2e_fallback_key",
+                retcol="key_json",
+                allow_none=True,
             )
 
-        await self.invalidate_cache_and_stream(
-            "get_e2e_unused_fallback_key_types", (user_id, device_id)
-        )
+            new_key_json = encode_canonical_json(fallback_key).decode("utf-8")
+
+            # If the uploaded key is the same as the current fallback key,
+            # don't do anything.  This prevents marking the key as unused if it
+            # was already used.
+            if old_key_json != new_key_json:
+                self.db_pool.simple_upsert_txn(
+                    txn,
+                    table="e2e_fallback_keys_json",
+                    keyvalues={
+                        "user_id": user_id,
+                        "device_id": device_id,
+                        "algorithm": algorithm,
+                    },
+                    values={
+                        "key_id": key_id,
+                        "key_json": json_encoder.encode(fallback_key),
+                        "used": False,
+                    },
+                )
 
     @cached(max_entries=10000)
     async def get_e2e_unused_fallback_key_types(
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 0c3b86fda9..f0723892e4 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -162,6 +162,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         local_user = "@boris:" + self.hs.hostname
         device_id = "xyz"
         fallback_key = {"alg1:k1": "key1"}
+        fallback_key2 = {"alg1:k2": "key2"}
         otk = {"alg1:k2": "key2"}
 
         # we shouldn't have any unused fallback keys yet
@@ -213,6 +214,35 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
             {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
         )
 
+        # re-uploading the same fallback key should still result in no unused fallback
+        # keys
+        self.get_success(
+            self.handler.upload_keys_for_user(
+                local_user,
+                device_id,
+                {"org.matrix.msc2732.fallback_keys": fallback_key},
+            )
+        )
+
+        res = self.get_success(
+            self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
+        )
+        self.assertEqual(res, [])
+
+        # uploading a new fallback key should result in an unused fallback key
+        self.get_success(
+            self.handler.upload_keys_for_user(
+                local_user,
+                device_id,
+                {"org.matrix.msc2732.fallback_keys": fallback_key2},
+            )
+        )
+
+        res = self.get_success(
+            self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
+        )
+        self.assertEqual(res, ["alg1"])
+
         # if the user uploads a one-time key, the next claim should fetch the
         # one-time key, and then go back to the fallback
         self.get_success(
@@ -238,7 +268,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         )
         self.assertEqual(
             res,
-            {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
+            {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key2}}},
         )
 
     def test_replace_master_key(self):