summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/storage/devices.py25
-rw-r--r--synapse/storage/end_to_end_keys.py9
2 files changed, 31 insertions, 3 deletions
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index c8d5f5ba8b..d9936c88bb 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -18,7 +18,7 @@ import ujson as json
 from twisted.internet import defer
 
 from synapse.api.errors import StoreError
-from ._base import SQLBaseStore
+from ._base import SQLBaseStore, Cache
 from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks
 
 
@@ -29,6 +29,14 @@ class DeviceStore(SQLBaseStore):
     def __init__(self, hs):
         super(DeviceStore, self).__init__(hs)
 
+        # Map of (user_id, device_id) -> bool. If there is an entry that implies
+        # the device exists.
+        self.device_id_exists_cache = Cache(
+            name="device_id_exists",
+            keylen=2,
+            max_entries=10000,
+        )
+
         self._clock.looping_call(
             self._prune_old_outbound_device_pokes, 60 * 60 * 1000
         )
@@ -54,6 +62,10 @@ class DeviceStore(SQLBaseStore):
             defer.Deferred: boolean whether the device was inserted or an
                 existing device existed with that ID.
         """
+        key = (user_id, device_id)
+        if self.device_id_exists_cache.get(key, None):
+            defer.returnValue(False)
+
         try:
             inserted = yield self._simple_insert(
                 "devices",
@@ -65,6 +77,7 @@ class DeviceStore(SQLBaseStore):
                 desc="store_device",
                 or_ignore=True,
             )
+            self.device_id_exists_cache.prefill(key, True)
             defer.returnValue(inserted)
         except Exception as e:
             logger.error("store_device with device_id=%s(%r) user_id=%s(%r)"
@@ -93,6 +106,7 @@ class DeviceStore(SQLBaseStore):
             desc="get_device",
         )
 
+    @defer.inlineCallbacks
     def delete_device(self, user_id, device_id):
         """Delete a device.
 
@@ -102,12 +116,15 @@ class DeviceStore(SQLBaseStore):
         Returns:
             defer.Deferred
         """
-        return self._simple_delete_one(
+        yield self._simple_delete_one(
             table="devices",
             keyvalues={"user_id": user_id, "device_id": device_id},
             desc="delete_device",
         )
 
+        self.device_id_exists_cache.invalidate((user_id, device_id))
+
+    @defer.inlineCallbacks
     def delete_devices(self, user_id, device_ids):
         """Deletes several devices.
 
@@ -117,13 +134,15 @@ class DeviceStore(SQLBaseStore):
         Returns:
             defer.Deferred
         """
-        return self._simple_delete_many(
+        yield self._simple_delete_many(
             table="devices",
             column="device_id",
             iterable=device_ids,
             keyvalues={"user_id": user_id},
             desc="delete_devices",
         )
+        for device_id in device_ids:
+            self.device_id_exists_cache.invalidate((user_id, device_id))
 
     def update_device(self, user_id, device_id, new_display_name=None):
         """Update a device.
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index 7cbc1470fd..c96dae352d 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -15,6 +15,7 @@
 from twisted.internet import defer
 
 from synapse.api.errors import SynapseError
+from synapse.util.caches.descriptors import cached
 
 from canonicaljson import encode_canonical_json
 import ujson as json
@@ -177,10 +178,14 @@ class EndToEndKeyStore(SQLBaseStore):
                     for algorithm, key_id, json_bytes in new_keys
                 ],
             )
+            txn.call_after(
+                self.count_e2e_one_time_keys.invalidate, (user_id, device_id,)
+            )
         yield self.runInteraction(
             "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
         )
 
+    @cached(max_entries=10000)
     def count_e2e_one_time_keys(self, user_id, device_id):
         """ Count the number of one time keys the server has for a device
         Returns:
@@ -225,6 +230,9 @@ class EndToEndKeyStore(SQLBaseStore):
             )
             for user_id, device_id, algorithm, key_id in delete:
                 txn.execute(sql, (user_id, device_id, algorithm, key_id))
+                txn.call_after(
+                    self.count_e2e_one_time_keys.invalidate, (user_id, device_id,)
+                )
             return result
         return self.runInteraction(
             "claim_e2e_one_time_keys", _claim_e2e_one_time_keys
@@ -242,3 +250,4 @@ class EndToEndKeyStore(SQLBaseStore):
             keyvalues={"user_id": user_id, "device_id": device_id},
             desc="delete_e2e_one_time_keys_by_device"
         )
+        self.count_e2e_one_time_keys.invalidate((user_id, device_id,))