diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index a95ac34f09..b06c1dc45b 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -408,29 +408,58 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
fallback_keys: the keys to set. This is a map from key ID (which is
of the form "algorithm:id") to key data.
"""
+ await self.db_pool.runInteraction(
+ "set_e2e_fallback_keys_txn",
+ self._set_e2e_fallback_keys_txn,
+ user_id,
+ device_id,
+ fallback_keys,
+ )
+
+ await self.invalidate_cache_and_stream(
+ "get_e2e_unused_fallback_key_types", (user_id, device_id)
+ )
+
+ def _set_e2e_fallback_keys_txn(
+ self, txn: Connection, user_id: str, device_id: str, fallback_keys: JsonDict
+ ) -> None:
# fallback_keys will usually only have one item in it, so using a for
# loop (as opposed to calling simple_upsert_many_txn) won't be too bad
# FIXME: make sure that only one key per algorithm is uploaded
for key_id, fallback_key in fallback_keys.items():
algorithm, key_id = key_id.split(":", 1)
- await self.db_pool.simple_upsert(
- "e2e_fallback_keys_json",
+ old_key_json = self.db_pool.simple_select_one_onecol_txn(
+ txn,
+ table="e2e_fallback_keys_json",
keyvalues={
"user_id": user_id,
"device_id": device_id,
"algorithm": algorithm,
},
- values={
- "key_id": key_id,
- "key_json": json_encoder.encode(fallback_key),
- "used": False,
- },
- desc="set_e2e_fallback_key",
+ retcol="key_json",
+ allow_none=True,
)
- await self.invalidate_cache_and_stream(
- "get_e2e_unused_fallback_key_types", (user_id, device_id)
- )
+ new_key_json = encode_canonical_json(fallback_key).decode("utf-8")
+
+ # If the uploaded key is the same as the current fallback key,
+ # don't do anything. This prevents marking the key as unused if it
+ # was already used.
+ if old_key_json != new_key_json:
+ self.db_pool.simple_upsert_txn(
+ txn,
+ table="e2e_fallback_keys_json",
+ keyvalues={
+ "user_id": user_id,
+ "device_id": device_id,
+ "algorithm": algorithm,
+ },
+ values={
+ "key_id": key_id,
+ "key_json": json_encoder.encode(fallback_key),
+ "used": False,
+ },
+ )
@cached(max_entries=10000)
async def get_e2e_unused_fallback_key_types(
|