summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/end_to_end_keys.py57
1 files changed, 48 insertions, 9 deletions
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index b9f1365f92..77bb9b0657 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -120,24 +120,63 @@ class EndToEndKeyStore(SQLBaseStore):
 
         return result
 
+    @defer.inlineCallbacks
     def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list):
+        """Insert some new one time keys for a device.
+
+        Checks if any of the keys are already inserted, if they are then check
+        if they match. If they don't then we raise an error.
+        """
+
+        # First we check if we have already persisted any of the keys.
+        rows = yield self._simple_select_many_batch(
+            table="e2e_one_time_keys_json",
+            column="key_id",
+            iterable=[key_id for _, key_id, _ in key_list],
+            retcols=("algorithm", "key_id", "key_json",),
+            keyvalues={
+                "user_id": user_id,
+                "device_id": device_id,
+            },
+            desc="add_e2e_one_time_keys_check",
+        )
+
+        existing_key_map = {
+            row["key_id"]: (row["algorithm"], row["key_json"]) for row in rows
+        }
+
+        new_keys = []  # Keys that we need to insert
+        for algorithm, key_id, json_bytes in key_list:
+            if key_id in existing_key_map:
+                ex_algo, ex_bytes = existing_key_map[key_id]
+                if algorithm != ex_algo or json_bytes != ex_bytes:
+                    raise Exception(
+                        "One time key with key_id %r already exists" % (key_id,)
+                    )
+            else:
+                new_keys.append((algorithm, key_id, json_bytes))
+
         def _add_e2e_one_time_keys(txn):
-            for (algorithm, key_id, json_bytes) in key_list:
-                self._simple_upsert_txn(
-                    txn, table="e2e_one_time_keys_json",
-                    keyvalues={
+            # We are protected from race between lookup and insertion due to
+            # a unique constraint. If there is a race of two calls to
+            # `add_e2e_one_time_keys` then they'll conflict and we will only
+            # insert one set.
+            self._simple_insert_many_txn(
+                txn, table="e2e_one_time_keys_json",
+                values=[
+                    {
                         "user_id": user_id,
                         "device_id": device_id,
                         "algorithm": algorithm,
                         "key_id": key_id,
-                    },
-                    values={
                         "ts_added_ms": time_now,
                         "key_json": json_bytes,
                     }
-                )
-        return self.runInteraction(
-            "add_e2e_one_time_keys", _add_e2e_one_time_keys
+                    for algorithm, key_id, json_bytes in new_keys
+                ],
+            )
+        yield self.runInteraction(
+            "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
         )
 
     def count_e2e_one_time_keys(self, user_id, device_id):