summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/device.py21
-rw-r--r--synapse/replication/http/devices.py4
-rw-r--r--synapse/storage/databases/main/devices.py26
3 files changed, 34 insertions, 17 deletions
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 9e52af5f13..9356ae998e 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -1030,7 +1030,7 @@ class DeviceListWorkerUpdater:
 
     async def multi_user_device_resync(
         self, user_ids: List[str], mark_failed_as_stale: bool = True
-    ) -> Dict[str, Optional[JsonDict]]:
+    ) -> Dict[str, Optional[JsonMapping]]:
         """
         Like `user_device_resync` but operates on multiple users **from the same origin**
         at once.
@@ -1059,6 +1059,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
         self._notifier = hs.get_notifier()
 
         self._remote_edu_linearizer = Linearizer(name="remote_device_list")
+        self._resync_linearizer = Linearizer(name="remote_device_resync")
 
         # user_id -> list of updates waiting to be handled.
         self._pending_updates: Dict[
@@ -1301,7 +1302,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
 
     async def multi_user_device_resync(
         self, user_ids: List[str], mark_failed_as_stale: bool = True
-    ) -> Dict[str, Optional[JsonDict]]:
+    ) -> Dict[str, Optional[JsonMapping]]:
         """
         Like `user_device_resync` but operates on multiple users **from the same origin**
         at once.
@@ -1321,9 +1322,11 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
         failed = set()
         # TODO(Perf): Actually batch these up
         for user_id in user_ids:
-            user_result, user_failed = await self._user_device_resync_returning_failed(
-                user_id
-            )
+            async with self._resync_linearizer.queue(user_id):
+                (
+                    user_result,
+                    user_failed,
+                ) = await self._user_device_resync_returning_failed(user_id)
             result[user_id] = user_result
             if user_failed:
                 failed.add(user_id)
@@ -1335,7 +1338,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
 
     async def _user_device_resync_returning_failed(
         self, user_id: str
-    ) -> Tuple[Optional[JsonDict], bool]:
+    ) -> Tuple[Optional[JsonMapping], bool]:
         """Fetches all devices for a user and updates the device cache with them.
 
         Args:
@@ -1348,6 +1351,12 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
               e.g. due to a connection problem.
             - True iff the resync failed and the device list should be marked as stale.
         """
+        # Check that we haven't gone and fetched the devices since we last
+        # checked if we needed to resync these device lists.
+        if await self.store.get_users_whose_devices_are_cached([user_id]):
+            cached = await self.store.get_cached_devices_for_user(user_id)
+            return cached, False
+
         logger.debug("Attempting to resync the device list for %s", user_id)
         log_kv({"message": "Doing resync to update device list."})
         # Fetch all devices for the user.
diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py
index 209833d287..b8198e059c 100644
--- a/synapse/replication/http/devices.py
+++ b/synapse/replication/http/devices.py
@@ -20,7 +20,7 @@ from twisted.web.server import Request
 from synapse.http.server import HttpServer
 from synapse.logging.opentracing import active_span
 from synapse.replication.http._base import ReplicationEndpoint
-from synapse.types import JsonDict
+from synapse.types import JsonDict, JsonMapping
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -82,7 +82,7 @@ class ReplicationMultiUserDevicesResyncRestServlet(ReplicationEndpoint):
 
     async def _handle_request(  # type: ignore[override]
         self, request: Request, content: JsonDict
-    ) -> Tuple[int, Dict[str, Optional[JsonDict]]]:
+    ) -> Tuple[int, Dict[str, Optional[JsonMapping]]]:
         user_ids: List[str] = content["user_ids"]
 
         logger.info("Resync for %r", user_ids)
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 324fdfa892..70faf4b1ec 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -759,18 +759,10 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
             mapping of user_id -> device_id -> device_info.
         """
         unique_user_ids = user_ids | {user_id for user_id, _ in user_and_device_ids}
-        user_map = await self.get_device_list_last_stream_id_for_remotes(
-            list(unique_user_ids)
-        )
 
-        # We go and check if any of the users need to have their device lists
-        # resynced. If they do then we remove them from the cached list.
-        users_needing_resync = await self.get_user_ids_requiring_device_list_resync(
+        user_ids_in_cache = await self.get_users_whose_devices_are_cached(
             unique_user_ids
         )
-        user_ids_in_cache = {
-            user_id for user_id, stream_id in user_map.items() if stream_id
-        } - users_needing_resync
         user_ids_not_in_cache = unique_user_ids - user_ids_in_cache
 
         # First fetch all the users which all devices are to be returned.
@@ -792,6 +784,22 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
 
         return user_ids_not_in_cache, results
 
+    async def get_users_whose_devices_are_cached(
+        self, user_ids: StrCollection
+    ) -> Set[str]:
+        """Checks which of the given users we have cached the devices for."""
+        user_map = await self.get_device_list_last_stream_id_for_remotes(user_ids)
+
+        # We go and check if any of the users need to have their device lists
+        # resynced. If they do then we remove them from the cached list.
+        users_needing_resync = await self.get_user_ids_requiring_device_list_resync(
+            user_ids
+        )
+        user_ids_in_cache = {
+            user_id for user_id, stream_id in user_map.items() if stream_id
+        } - users_needing_resync
+        return user_ids_in_cache
+
     @cached(num_args=2, tree=True)
     async def _get_cached_user_device(self, user_id: str, device_id: str) -> JsonDict:
         content = await self.db_pool.simple_select_one_onecol(