summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/devices.py118
1 files changed, 73 insertions, 45 deletions
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index 8e17800364..d22db0a0b9 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -19,6 +19,8 @@ from twisted.internet import defer
 
 from synapse.api.errors import StoreError
 from ._base import SQLBaseStore
+from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks
+
 
 logger = logging.getLogger(__name__)
 
@@ -144,6 +146,7 @@ class DeviceStore(SQLBaseStore):
 
         defer.returnValue({d["device_id"]: d for d in devices})
 
+    @cached(max_entries=10000)
     def get_device_list_last_stream_id_for_remote(self, user_id):
         """Get the last stream_id we got for a user. May be None if we haven't
         got any information for them.
@@ -156,16 +159,36 @@ class DeviceStore(SQLBaseStore):
             allow_none=True,
         )
 
+    @cachedList(cached_method_name="get_device_list_last_stream_id_for_remote",
+                list_name="user_ids", inlineCallbacks=True)
+    def get_device_list_last_stream_id_for_remotes(self, user_ids):
+        rows = yield self._simple_select_many_batch(
+            table="device_lists_remote_extremeties",
+            column="user_id",
+            iterable=user_ids,
+            retcols=("user_id", "stream_id",),
+            desc="get_user_devices_from_cache",
+        )
+
+        results = {user_id: None for user_id in user_ids}
+        results.update({
+            row["user_id"]: row["stream_id"] for row in rows
+        })
+
+        defer.returnValue(results)
+
+    @defer.inlineCallbacks
     def mark_remote_user_device_list_as_unsubscribed(self, user_id):
         """Mark that we no longer track device lists for remote user.
         """
-        return self._simple_delete(
+        yield self._simple_delete(
             table="device_lists_remote_extremeties",
             keyvalues={
                 "user_id": user_id,
             },
             desc="mark_remote_user_device_list_as_unsubscribed",
         )
+        self.get_device_list_last_stream_id_for_remote.invalidate((user_id,))
 
     def update_remote_device_list_cache_entry(self, user_id, device_id, content,
                                               stream_id):
@@ -191,6 +214,12 @@ class DeviceStore(SQLBaseStore):
             }
         )
 
+        txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id,))
+        txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,))
+        txn.call_after(
+            self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
+        )
+
         self._simple_upsert_txn(
             txn,
             table="device_lists_remote_extremeties",
@@ -234,6 +263,12 @@ class DeviceStore(SQLBaseStore):
             ]
         )
 
+        txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,))
+        txn.call_after(self._get_cached_user_device.invalidate_many, (user_id,))
+        txn.call_after(
+            self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
+        )
+
         self._simple_upsert_txn(
             txn,
             table="device_lists_remote_extremeties",
@@ -320,6 +355,7 @@ class DeviceStore(SQLBaseStore):
 
         return (now_stream_id, results)
 
+    @defer.inlineCallbacks
     def get_user_devices_from_cache(self, query_list):
         """Get the devices (and keys if any) for remote users from the cache.
 
@@ -332,27 +368,11 @@ class DeviceStore(SQLBaseStore):
             a set of user_ids and results_map is a mapping of
             user_id -> device_id -> device_info
         """
-        return self.runInteraction(
-            "get_user_devices_from_cache", self._get_user_devices_from_cache_txn,
-            query_list,
+        user_ids = set(user_id for user_id, _ in query_list)
+        user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids))
+        user_ids_in_cache = set(
+            user_id for user_id, stream_id in user_map.items() if stream_id
         )
-
-    def _get_user_devices_from_cache_txn(self, txn, query_list):
-        user_ids = {user_id for user_id, _ in query_list}
-
-        user_ids_in_cache = set()
-        for user_id in user_ids:
-            stream_ids = self._simple_select_onecol_txn(
-                txn,
-                table="device_lists_remote_extremeties",
-                keyvalues={
-                    "user_id": user_id,
-                },
-                retcol="stream_id",
-            )
-            if stream_ids:
-                user_ids_in_cache.add(user_id)
-
         user_ids_not_in_cache = user_ids - user_ids_in_cache
 
         results = {}
@@ -361,32 +381,40 @@ class DeviceStore(SQLBaseStore):
                 continue
 
             if device_id:
-                content = self._simple_select_one_onecol_txn(
-                    txn,
-                    table="device_lists_remote_cache",
-                    keyvalues={
-                        "user_id": user_id,
-                        "device_id": device_id,
-                    },
-                    retcol="content",
-                )
-                results.setdefault(user_id, {})[device_id] = json.loads(content)
+                device = yield self._get_cached_user_device(user_id, device_id)
+                results.setdefault(user_id, {})[device_id] = device
             else:
-                devices = self._simple_select_list_txn(
-                    txn,
-                    table="device_lists_remote_cache",
-                    keyvalues={
-                        "user_id": user_id,
-                    },
-                    retcols=("device_id", "content"),
-                )
-                results[user_id] = {
-                    device["device_id"]: json.loads(device["content"])
-                    for device in devices
-                }
-                user_ids_in_cache.discard(user_id)
+                results[user_id] = yield self._get_cached_devices_for_user(user_id)
 
-        return user_ids_not_in_cache, results
+        defer.returnValue((user_ids_not_in_cache, results))
+
+    @cachedInlineCallbacks(num_args=2, tree=True)
+    def _get_cached_user_device(self, user_id, device_id):
+        content = yield self._simple_select_one_onecol(
+            table="device_lists_remote_cache",
+            keyvalues={
+                "user_id": user_id,
+                "device_id": device_id,
+            },
+            retcol="content",
+            desc="_get_cached_user_device",
+        )
+        defer.returnValue(json.loads(content))
+
+    @cachedInlineCallbacks()
+    def _get_cached_devices_for_user(self, user_id):
+        devices = yield self._simple_select_list(
+            table="device_lists_remote_cache",
+            keyvalues={
+                "user_id": user_id,
+            },
+            retcols=("device_id", "content"),
+            desc="_get_cached_devices_for_user",
+        )
+        defer.returnValue({
+            device["device_id"]: json.loads(device["content"])
+            for device in devices
+        })
 
     def get_devices_with_keys_by_user(self, user_id):
         """Get all devices (with any device keys) for a user