diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index e8b6cc6b80..1ca66d57d4 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -21,6 +21,7 @@ from typing import (
Dict,
Iterable,
List,
+ Mapping,
Optional,
Set,
Tuple,
@@ -100,6 +101,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
("device_lists_outbound_pokes", "stream_id"),
("device_lists_changes_in_room", "stream_id"),
("device_lists_remote_pending", "stream_id"),
+ ("device_lists_changes_converted_stream_position", "stream_id"),
],
is_writer=hs.config.worker.worker_app is None,
)
@@ -201,7 +203,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
def get_device_stream_token(self) -> int:
return self._device_list_id_gen.get_current_token()
- async def count_devices_by_users(self, user_ids: Optional[List[str]] = None) -> int:
+ async def count_devices_by_users(
+ self, user_ids: Optional[Collection[str]] = None
+ ) -> int:
"""Retrieve number of all devices of given users.
Only returns number of devices that are not marked as hidden.
@@ -212,7 +216,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
"""
def count_devices_by_users_txn(
- txn: LoggingTransaction, user_ids: List[str]
+ txn: LoggingTransaction, user_ids: Collection[str]
) -> int:
sql = """
SELECT count(*)
@@ -745,42 +749,47 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
@trace
@cancellable
async def get_user_devices_from_cache(
- self, query_list: List[Tuple[str, Optional[str]]]
- ) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]:
+ self, user_ids: Set[str], user_and_device_ids: List[Tuple[str, str]]
+ ) -> Tuple[Set[str], Dict[str, Mapping[str, JsonDict]]]:
"""Get the devices (and keys if any) for remote users from the cache.
Args:
- query_list: List of (user_id, device_ids), if device_ids is
- falsey then return all device ids for that user.
+ user_ids: users which should have all device IDs returned
+ user_and_device_ids: List of (user_id, device_ids)
Returns:
A tuple of (user_ids_not_in_cache, results_map), where
user_ids_not_in_cache is a set of user_ids and results_map is a
mapping of user_id -> device_id -> device_info.
"""
- user_ids = {user_id for user_id, _ in query_list}
- user_map = await self.get_device_list_last_stream_id_for_remotes(list(user_ids))
+ unique_user_ids = user_ids | {user_id for user_id, _ in user_and_device_ids}
+ user_map = await self.get_device_list_last_stream_id_for_remotes(
+ list(unique_user_ids)
+ )
# We go and check if any of the users need to have their device lists
# resynced. If they do then we remove them from the cached list.
users_needing_resync = await self.get_user_ids_requiring_device_list_resync(
- user_ids
+ unique_user_ids
)
user_ids_in_cache = {
user_id for user_id, stream_id in user_map.items() if stream_id
} - users_needing_resync
- user_ids_not_in_cache = user_ids - user_ids_in_cache
-
- results: Dict[str, Dict[str, JsonDict]] = {}
- for user_id, device_id in query_list:
- if user_id not in user_ids_in_cache:
- continue
+ user_ids_not_in_cache = unique_user_ids - user_ids_in_cache
- if device_id:
- device = await self._get_cached_user_device(user_id, device_id)
- results.setdefault(user_id, {})[device_id] = device
- else:
+ # First fetch all the users which all devices are to be returned.
+ results: Dict[str, Mapping[str, JsonDict]] = {}
+ for user_id in user_ids:
+ if user_id in user_ids_in_cache:
results[user_id] = await self.get_cached_devices_for_user(user_id)
+ # Then fetch all device-specific requests, but skip users we've already
+ # fetched all devices for.
+ device_specific_results: Dict[str, Dict[str, JsonDict]] = {}
+ for user_id, device_id in user_and_device_ids:
+ if user_id in user_ids_in_cache and user_id not in user_ids:
+ device = await self._get_cached_user_device(user_id, device_id)
+ device_specific_results.setdefault(user_id, {})[device_id] = device
+ results.update(device_specific_results)
set_tag("in_cache", str(results))
set_tag("not_in_cache", str(user_ids_not_in_cache))
@@ -798,7 +807,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
return db_to_json(content)
@cached()
- async def get_cached_devices_for_user(self, user_id: str) -> Dict[str, JsonDict]:
+ async def get_cached_devices_for_user(self, user_id: str) -> Mapping[str, JsonDict]:
devices = await self.db_pool.simple_select_list(
table="device_lists_remote_cache",
keyvalues={"user_id": user_id},
|