summary refs log tree commit diff
path: root/synapse/handlers/device.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/device.py')
-rw-r--r--synapse/handlers/device.py124
1 files changed, 98 insertions, 26 deletions
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index d4750a32e6..89864e1119 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -14,6 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+from http import HTTPStatus
 from typing import (
     TYPE_CHECKING,
     Any,
@@ -33,6 +34,7 @@ from synapse.api.errors import (
     Codes,
     FederationDeniedError,
     HttpResponseException,
+    InvalidAPICallError,
     RequestSendFailed,
     SynapseError,
 )
@@ -45,6 +47,7 @@ from synapse.types import (
     JsonDict,
     StreamKeyType,
     StreamToken,
+    UserID,
     get_domain_from_id,
     get_verify_key_from_cross_signing_key,
 )
@@ -893,12 +896,47 @@ class DeviceListWorkerUpdater:
 
     def __init__(self, hs: "HomeServer"):
         from synapse.replication.http.devices import (
+            ReplicationMultiUserDevicesResyncRestServlet,
             ReplicationUserDevicesResyncRestServlet,
         )
 
         self._user_device_resync_client = (
             ReplicationUserDevicesResyncRestServlet.make_client(hs)
         )
+        self._multi_user_device_resync_client = (
+            ReplicationMultiUserDevicesResyncRestServlet.make_client(hs)
+        )
+
+    async def multi_user_device_resync(
+        self, user_ids: List[str], mark_failed_as_stale: bool = True
+    ) -> Dict[str, Optional[JsonDict]]:
+        """
+        Like `user_device_resync` but operates on multiple users **from the same origin**
+        at once.
+
+        Returns:
+            Dict from User ID to the same Dict as `user_device_resync`.
+        """
+        # mark_failed_as_stale is not sent. Ensure this doesn't break expectations.
+        assert mark_failed_as_stale
+
+        if not user_ids:
+            # Shortcut empty requests
+            return {}
+
+        try:
+            return await self._multi_user_device_resync_client(user_ids=user_ids)
+        except SynapseError as err:
+            if not (
+                err.code == HTTPStatus.NOT_FOUND and err.errcode == Codes.UNRECOGNIZED
+            ):
+                raise
+
+            # Fall back to single requests
+            result: Dict[str, Optional[JsonDict]] = {}
+            for user_id in user_ids:
+                result[user_id] = await self._user_device_resync_client(user_id=user_id)
+            return result
 
     async def user_device_resync(
         self, user_id: str, mark_failed_as_stale: bool = True
@@ -913,8 +951,10 @@ class DeviceListWorkerUpdater:
             A dict with device info as under the "devices" in the result of this
             request:
             https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid
+            None when we weren't able to fetch the device info for some reason,
+            e.g. due to a connection problem.
         """
-        return await self._user_device_resync_client(user_id=user_id)
+        return (await self.multi_user_device_resync([user_id]))[user_id]
 
 
 class DeviceListUpdater(DeviceListWorkerUpdater):
@@ -1160,19 +1200,66 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
             # Allow future calls to retry resyncinc out of sync device lists.
             self._resync_retry_in_progress = False
 
+    async def multi_user_device_resync(
+        self, user_ids: List[str], mark_failed_as_stale: bool = True
+    ) -> Dict[str, Optional[JsonDict]]:
+        """
+        Like `user_device_resync` but operates on multiple users **from the same origin**
+        at once.
+
+        Returns:
+            Dict from User ID to the same Dict as `user_device_resync`.
+        """
+        if not user_ids:
+            return {}
+
+        origins = {UserID.from_string(user_id).domain for user_id in user_ids}
+
+        if len(origins) != 1:
+            raise InvalidAPICallError(f"Only one origin permitted, got {origins!r}")
+
+        result = {}
+        failed = set()
+        # TODO(Perf): Actually batch these up
+        for user_id in user_ids:
+            user_result, user_failed = await self._user_device_resync_returning_failed(
+                user_id
+            )
+            result[user_id] = user_result
+            if user_failed:
+                failed.add(user_id)
+
+        if mark_failed_as_stale:
+            await self.store.mark_remote_users_device_caches_as_stale(failed)
+
+        return result
+
     async def user_device_resync(
         self, user_id: str, mark_failed_as_stale: bool = True
     ) -> Optional[JsonDict]:
+        result, failed = await self._user_device_resync_returning_failed(user_id)
+
+        if failed and mark_failed_as_stale:
+            # Mark the remote user's device list as stale so we know we need to retry
+            # it later.
+            await self.store.mark_remote_users_device_caches_as_stale((user_id,))
+
+        return result
+
+    async def _user_device_resync_returning_failed(
+        self, user_id: str
+    ) -> Tuple[Optional[JsonDict], bool]:
         """Fetches all devices for a user and updates the device cache with them.
 
         Args:
             user_id: The user's id whose device_list will be updated.
-            mark_failed_as_stale: Whether to mark the user's device list as stale
-                if the attempt to resync failed.
         Returns:
-            A dict with device info as under the "devices" in the result of this
-            request:
-            https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid
+            - A dict with device info as under the "devices" in the result of this
+              request:
+              https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid
+              None when we weren't able to fetch the device info for some reason,
+              e.g. due to a connection problem.
+            - True iff the resync failed and the device list should be marked as stale.
         """
         logger.debug("Attempting to resync the device list for %s", user_id)
         log_kv({"message": "Doing resync to update device list."})
@@ -1181,12 +1268,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
         try:
             result = await self.federation.query_user_devices(origin, user_id)
         except NotRetryingDestination:
-            if mark_failed_as_stale:
-                # Mark the remote user's device list as stale so we know we need to retry
-                # it later.
-                await self.store.mark_remote_user_device_cache_as_stale(user_id)
-
-            return None
+            return None, True
         except (RequestSendFailed, HttpResponseException) as e:
             logger.warning(
                 "Failed to handle device list update for %s: %s",
@@ -1194,23 +1276,18 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
                 e,
             )
 
-            if mark_failed_as_stale:
-                # Mark the remote user's device list as stale so we know we need to retry
-                # it later.
-                await self.store.mark_remote_user_device_cache_as_stale(user_id)
-
             # We abort on exceptions rather than accepting the update
             # as otherwise synapse will 'forget' that its device list
             # is out of date. If we bail then we will retry the resync
             # next time we get a device list update for this user_id.
             # This makes it more likely that the device lists will
             # eventually become consistent.
-            return None
+            return None, True
         except FederationDeniedError as e:
             set_tag("error", True)
             log_kv({"reason": "FederationDeniedError"})
             logger.info(e)
-            return None
+            return None, False
         except Exception as e:
             set_tag("error", True)
             log_kv(
@@ -1218,12 +1295,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
             )
             logger.exception("Failed to handle device list update for %s", user_id)
 
-            if mark_failed_as_stale:
-                # Mark the remote user's device list as stale so we know we need to retry
-                # it later.
-                await self.store.mark_remote_user_device_cache_as_stale(user_id)
-
-            return None
+            return None, True
         log_kv({"result": result})
         stream_id = result["stream_id"]
         devices = result["devices"]
@@ -1305,7 +1377,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
         # point.
         self._seen_updates[user_id] = {stream_id}
 
-        return result
+        return result, False
 
     async def process_cross_signing_key_update(
         self,