summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/end_to_end_keys.py47
1 files changed, 27 insertions, 20 deletions
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index c96dae352d..e00f31da2b 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -14,7 +14,6 @@
 # limitations under the License.
 from twisted.internet import defer
 
-from synapse.api.errors import SynapseError
 from synapse.util.caches.descriptors import cached
 
 from canonicaljson import encode_canonical_json
@@ -124,18 +123,24 @@ class EndToEndKeyStore(SQLBaseStore):
         return result
 
     @defer.inlineCallbacks
-    def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list):
-        """Insert some new one time keys for a device.
+    def get_e2e_one_time_keys(self, user_id, device_id, key_ids):
+        """Retrieve a number of one-time keys for a user
 
-        Checks if any of the keys are already inserted, if they are then check
-        if they match. If they don't then we raise an error.
+        Args:
+            user_id(str): id of user to get keys for
+            device_id(str): id of device to get keys for
+            key_ids(list[str]): list of key ids (excluding algorithm) to
+                retrieve
+
+        Returns:
+            deferred resolving to Dict[(str, str), str]: map from (algorithm,
+            key_id) to json string for key
         """
 
-        # First we check if we have already persisted any of the keys.
         rows = yield self._simple_select_many_batch(
             table="e2e_one_time_keys_json",
             column="key_id",
-            iterable=[key_id for _, key_id, _ in key_list],
+            iterable=key_ids,
             retcols=("algorithm", "key_id", "key_json",),
             keyvalues={
                 "user_id": user_id,
@@ -144,20 +149,22 @@ class EndToEndKeyStore(SQLBaseStore):
             desc="add_e2e_one_time_keys_check",
         )
 
-        existing_key_map = {
+        defer.returnValue({
             (row["algorithm"], row["key_id"]): row["key_json"] for row in rows
-        }
-
-        new_keys = []  # Keys that we need to insert
-        for algorithm, key_id, json_bytes in key_list:
-            ex_bytes = existing_key_map.get((algorithm, key_id), None)
-            if ex_bytes:
-                if json_bytes != ex_bytes:
-                    raise SynapseError(
-                        400, "One time key with key_id %r already exists" % (key_id,)
-                    )
-            else:
-                new_keys.append((algorithm, key_id, json_bytes))
+        })
+
+    @defer.inlineCallbacks
+    def add_e2e_one_time_keys(self, user_id, device_id, time_now, new_keys):
+        """Insert some new one time keys for a device. Errors if any of the
+        keys already exist.
+
+        Args:
+            user_id(str): id of user to get keys for
+            device_id(str): id of device to get keys for
+            time_now(long): insertion time to record (ms since epoch)
+            new_keys(iterable[(str, str, str)]: keys to add - each a tuple of
+                (algorithm, key_id, key json)
+        """
 
         def _add_e2e_one_time_keys(txn):
             # We are protected from race between lookup and insertion due to