diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index 7cbc1470fd..e00f31da2b 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -14,7 +14,7 @@
# limitations under the License.
from twisted.internet import defer
-from synapse.api.errors import SynapseError
+from synapse.util.caches.descriptors import cached
from canonicaljson import encode_canonical_json
import ujson as json
@@ -123,18 +123,24 @@ class EndToEndKeyStore(SQLBaseStore):
return result
@defer.inlineCallbacks
- def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list):
- """Insert some new one time keys for a device.
+ def get_e2e_one_time_keys(self, user_id, device_id, key_ids):
+ """Retrieve a number of one-time keys for a user
- Checks if any of the keys are already inserted, if they are then check
- if they match. If they don't then we raise an error.
+ Args:
+ user_id(str): id of user to get keys for
+ device_id(str): id of device to get keys for
+ key_ids(list[str]): list of key ids (excluding algorithm) to
+ retrieve
+
+ Returns:
+ deferred resolving to Dict[(str, str), str]: map from (algorithm,
+ key_id) to json string for key
"""
- # First we check if we have already persisted any of the keys.
rows = yield self._simple_select_many_batch(
table="e2e_one_time_keys_json",
column="key_id",
- iterable=[key_id for _, key_id, _ in key_list],
+ iterable=key_ids,
retcols=("algorithm", "key_id", "key_json",),
keyvalues={
"user_id": user_id,
@@ -143,20 +149,22 @@ class EndToEndKeyStore(SQLBaseStore):
desc="add_e2e_one_time_keys_check",
)
- existing_key_map = {
+ defer.returnValue({
(row["algorithm"], row["key_id"]): row["key_json"] for row in rows
- }
-
- new_keys = [] # Keys that we need to insert
- for algorithm, key_id, json_bytes in key_list:
- ex_bytes = existing_key_map.get((algorithm, key_id), None)
- if ex_bytes:
- if json_bytes != ex_bytes:
- raise SynapseError(
- 400, "One time key with key_id %r already exists" % (key_id,)
- )
- else:
- new_keys.append((algorithm, key_id, json_bytes))
+ })
+
+ @defer.inlineCallbacks
+ def add_e2e_one_time_keys(self, user_id, device_id, time_now, new_keys):
+ """Insert some new one time keys for a device. Errors if any of the
+ keys already exist.
+
+ Args:
+ user_id(str): id of user to get keys for
+ device_id(str): id of device to get keys for
+ time_now(long): insertion time to record (ms since epoch)
+ new_keys(iterable[(str, str, str)]: keys to add - each a tuple of
+ (algorithm, key_id, key json)
+ """
def _add_e2e_one_time_keys(txn):
# We are protected from race between lookup and insertion due to
@@ -177,10 +185,14 @@ class EndToEndKeyStore(SQLBaseStore):
for algorithm, key_id, json_bytes in new_keys
],
)
+ txn.call_after(
+ self.count_e2e_one_time_keys.invalidate, (user_id, device_id,)
+ )
yield self.runInteraction(
"add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
)
+ @cached(max_entries=10000)
def count_e2e_one_time_keys(self, user_id, device_id):
""" Count the number of one time keys the server has for a device
Returns:
@@ -225,6 +237,9 @@ class EndToEndKeyStore(SQLBaseStore):
)
for user_id, device_id, algorithm, key_id in delete:
txn.execute(sql, (user_id, device_id, algorithm, key_id))
+ txn.call_after(
+ self.count_e2e_one_time_keys.invalidate, (user_id, device_id,)
+ )
return result
return self.runInteraction(
"claim_e2e_one_time_keys", _claim_e2e_one_time_keys
@@ -242,3 +257,4 @@ class EndToEndKeyStore(SQLBaseStore):
keyvalues={"user_id": user_id, "device_id": device_id},
desc="delete_e2e_one_time_keys_by_device"
)
+ self.count_e2e_one_time_keys.invalidate((user_id, device_id,))
|