diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index 8e17800364..d22db0a0b9 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -19,6 +19,8 @@ from twisted.internet import defer
from synapse.api.errors import StoreError
from ._base import SQLBaseStore
+from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks
+
logger = logging.getLogger(__name__)
@@ -144,6 +146,7 @@ class DeviceStore(SQLBaseStore):
defer.returnValue({d["device_id"]: d for d in devices})
+ @cached(max_entries=10000)
def get_device_list_last_stream_id_for_remote(self, user_id):
"""Get the last stream_id we got for a user. May be None if we haven't
got any information for them.
@@ -156,16 +159,36 @@ class DeviceStore(SQLBaseStore):
allow_none=True,
)
+ @cachedList(cached_method_name="get_device_list_last_stream_id_for_remote",
+ list_name="user_ids", inlineCallbacks=True)
+ def get_device_list_last_stream_id_for_remotes(self, user_ids):
+ rows = yield self._simple_select_many_batch(
+ table="device_lists_remote_extremeties",
+ column="user_id",
+ iterable=user_ids,
+ retcols=("user_id", "stream_id",),
+ desc="get_user_devices_from_cache",
+ )
+
+ results = {user_id: None for user_id in user_ids}
+ results.update({
+ row["user_id"]: row["stream_id"] for row in rows
+ })
+
+ defer.returnValue(results)
+
+ @defer.inlineCallbacks
def mark_remote_user_device_list_as_unsubscribed(self, user_id):
"""Mark that we no longer track device lists for remote user.
"""
- return self._simple_delete(
+ yield self._simple_delete(
table="device_lists_remote_extremeties",
keyvalues={
"user_id": user_id,
},
desc="mark_remote_user_device_list_as_unsubscribed",
)
+ self.get_device_list_last_stream_id_for_remote.invalidate((user_id,))
def update_remote_device_list_cache_entry(self, user_id, device_id, content,
stream_id):
@@ -191,6 +214,12 @@ class DeviceStore(SQLBaseStore):
}
)
+ txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id,))
+ txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,))
+ txn.call_after(
+ self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
+ )
+
self._simple_upsert_txn(
txn,
table="device_lists_remote_extremeties",
@@ -234,6 +263,12 @@ class DeviceStore(SQLBaseStore):
]
)
+ txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,))
+ txn.call_after(self._get_cached_user_device.invalidate_many, (user_id,))
+ txn.call_after(
+ self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
+ )
+
self._simple_upsert_txn(
txn,
table="device_lists_remote_extremeties",
@@ -320,6 +355,7 @@ class DeviceStore(SQLBaseStore):
return (now_stream_id, results)
+ @defer.inlineCallbacks
def get_user_devices_from_cache(self, query_list):
"""Get the devices (and keys if any) for remote users from the cache.
@@ -332,27 +368,11 @@ class DeviceStore(SQLBaseStore):
a set of user_ids and results_map is a mapping of
user_id -> device_id -> device_info
"""
- return self.runInteraction(
- "get_user_devices_from_cache", self._get_user_devices_from_cache_txn,
- query_list,
+ user_ids = set(user_id for user_id, _ in query_list)
+ user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids))
+ user_ids_in_cache = set(
+ user_id for user_id, stream_id in user_map.items() if stream_id
)
-
- def _get_user_devices_from_cache_txn(self, txn, query_list):
- user_ids = {user_id for user_id, _ in query_list}
-
- user_ids_in_cache = set()
- for user_id in user_ids:
- stream_ids = self._simple_select_onecol_txn(
- txn,
- table="device_lists_remote_extremeties",
- keyvalues={
- "user_id": user_id,
- },
- retcol="stream_id",
- )
- if stream_ids:
- user_ids_in_cache.add(user_id)
-
user_ids_not_in_cache = user_ids - user_ids_in_cache
results = {}
@@ -361,32 +381,40 @@ class DeviceStore(SQLBaseStore):
continue
if device_id:
- content = self._simple_select_one_onecol_txn(
- txn,
- table="device_lists_remote_cache",
- keyvalues={
- "user_id": user_id,
- "device_id": device_id,
- },
- retcol="content",
- )
- results.setdefault(user_id, {})[device_id] = json.loads(content)
+ device = yield self._get_cached_user_device(user_id, device_id)
+ results.setdefault(user_id, {})[device_id] = device
else:
- devices = self._simple_select_list_txn(
- txn,
- table="device_lists_remote_cache",
- keyvalues={
- "user_id": user_id,
- },
- retcols=("device_id", "content"),
- )
- results[user_id] = {
- device["device_id"]: json.loads(device["content"])
- for device in devices
- }
- user_ids_in_cache.discard(user_id)
+ results[user_id] = yield self._get_cached_devices_for_user(user_id)
- return user_ids_not_in_cache, results
+ defer.returnValue((user_ids_not_in_cache, results))
+
+ @cachedInlineCallbacks(num_args=2, tree=True)
+ def _get_cached_user_device(self, user_id, device_id):
+ content = yield self._simple_select_one_onecol(
+ table="device_lists_remote_cache",
+ keyvalues={
+ "user_id": user_id,
+ "device_id": device_id,
+ },
+ retcol="content",
+ desc="_get_cached_user_device",
+ )
+ defer.returnValue(json.loads(content))
+
+ @cachedInlineCallbacks()
+ def _get_cached_devices_for_user(self, user_id):
+ devices = yield self._simple_select_list(
+ table="device_lists_remote_cache",
+ keyvalues={
+ "user_id": user_id,
+ },
+ retcols=("device_id", "content"),
+ desc="_get_cached_devices_for_user",
+ )
+ defer.returnValue({
+ device["device_id"]: json.loads(device["content"])
+ for device in devices
+ })
def get_devices_with_keys_by_user(self, user_id):
"""Get all devices (with any device keys) for a user
|