summary refs log tree commit diff
path: root/synapse/handlers/device.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/handlers/device.py100
1 files changed, 81 insertions, 19 deletions
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index f9cc5bddbc..d4750a32e6 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -65,6 +65,8 @@ DELETE_STALE_DEVICES_INTERVAL_MS = 24 * 60 * 60 * 1000
 
 
 class DeviceWorkerHandler:
+    device_list_updater: "DeviceListWorkerUpdater"
+
     def __init__(self, hs: "HomeServer"):
         self.clock = hs.get_clock()
         self.hs = hs
@@ -76,6 +78,8 @@ class DeviceWorkerHandler:
         self.server_name = hs.hostname
         self._msc3852_enabled = hs.config.experimental.msc3852_enabled
 
+        self.device_list_updater = DeviceListWorkerUpdater(hs)
+
     @trace
     async def get_devices_by_user(self, user_id: str) -> List[JsonDict]:
         """
@@ -99,6 +103,19 @@ class DeviceWorkerHandler:
         log_kv(device_map)
         return devices
 
+    async def get_dehydrated_device(
+        self, user_id: str
+    ) -> Optional[Tuple[str, JsonDict]]:
+        """Retrieve the information for a dehydrated device.
+
+        Args:
+            user_id: the user whose dehydrated device we are looking for
+        Returns:
+            a tuple whose first item is the device ID, and the second item is
+            the dehydrated device information
+        """
+        return await self.store.get_dehydrated_device(user_id)
+
     @trace
     async def get_device(self, user_id: str, device_id: str) -> JsonDict:
         """Retrieve the given device
@@ -127,7 +144,7 @@ class DeviceWorkerHandler:
     @cancellable
     async def get_device_changes_in_shared_rooms(
         self, user_id: str, room_ids: Collection[str], from_token: StreamToken
-    ) -> Collection[str]:
+    ) -> Set[str]:
         """Get the set of users whose devices have changed who share a room with
         the given user.
         """
@@ -320,6 +337,8 @@ class DeviceWorkerHandler:
 
 
 class DeviceHandler(DeviceWorkerHandler):
+    device_list_updater: "DeviceListUpdater"
+
     def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
@@ -606,19 +625,6 @@ class DeviceHandler(DeviceWorkerHandler):
             await self.delete_devices(user_id, [old_device_id])
         return device_id
 
-    async def get_dehydrated_device(
-        self, user_id: str
-    ) -> Optional[Tuple[str, JsonDict]]:
-        """Retrieve the information for a dehydrated device.
-
-        Args:
-            user_id: the user whose dehydrated device we are looking for
-        Returns:
-            a tuple whose first item is the device ID, and the second item is
-            the dehydrated device information
-        """
-        return await self.store.get_dehydrated_device(user_id)
-
     async def rehydrate_device(
         self, user_id: str, access_token: str, device_id: str
     ) -> dict:
@@ -682,13 +688,33 @@ class DeviceHandler(DeviceWorkerHandler):
         hosts_already_sent_to: Set[str] = set()
 
         try:
+            stream_id, room_id = await self.store.get_device_change_last_converted_pos()
+
             while True:
                 self._handle_new_device_update_new_data = False
-                rows = await self.store.get_uncoverted_outbound_room_pokes()
+                max_stream_id = self.store.get_device_stream_token()
+                rows = await self.store.get_uncoverted_outbound_room_pokes(
+                    stream_id, room_id
+                )
                 if not rows:
                     # If the DB returned nothing then there is nothing left to
                     # do, *unless* a new device list update happened during the
                     # DB query.
+
+                    # Advance `(stream_id, room_id)`.
+                    # `max_stream_id` comes from *before* the query for unconverted
+                    # rows, which means that any unconverted rows must have a larger
+                    # stream ID.
+                    if max_stream_id > stream_id:
+                        stream_id, room_id = max_stream_id, ""
+                        await self.store.set_device_change_last_converted_pos(
+                            stream_id, room_id
+                        )
+                    else:
+                        assert max_stream_id == stream_id
+                        # Avoid moving `room_id` backwards.
+                        pass
+
                     if self._handle_new_device_update_new_data:
                         continue
                     else:
@@ -718,7 +744,6 @@ class DeviceHandler(DeviceWorkerHandler):
                         user_id=user_id,
                         device_id=device_id,
                         room_id=room_id,
-                        stream_id=stream_id,
                         hosts=hosts,
                         context=opentracing_context,
                     )
@@ -752,6 +777,12 @@ class DeviceHandler(DeviceWorkerHandler):
                     hosts_already_sent_to.update(hosts)
                     current_stream_id = stream_id
 
+                # Advance `(stream_id, room_id)`.
+                _, _, room_id, stream_id, _ = rows[-1]
+                await self.store.set_device_change_last_converted_pos(
+                    stream_id, room_id
+                )
+
         finally:
             self._handle_new_device_update_is_processing = False
 
@@ -834,7 +865,6 @@ class DeviceHandler(DeviceWorkerHandler):
                 user_id=user_id,
                 device_id=device_id,
                 room_id=room_id,
-                stream_id=None,
                 hosts=potentially_changed_hosts,
                 context=None,
             )
@@ -858,7 +888,36 @@ def _update_device_from_client_ips(
     )
 
 
-class DeviceListUpdater:
+class DeviceListWorkerUpdater:
+    "Handles incoming device list updates from federation and contacts the main process over replication"
+
+    def __init__(self, hs: "HomeServer"):
+        from synapse.replication.http.devices import (
+            ReplicationUserDevicesResyncRestServlet,
+        )
+
+        self._user_device_resync_client = (
+            ReplicationUserDevicesResyncRestServlet.make_client(hs)
+        )
+
+    async def user_device_resync(
+        self, user_id: str, mark_failed_as_stale: bool = True
+    ) -> Optional[JsonDict]:
+        """Fetches all devices for a user and updates the device cache with them.
+
+        Args:
+            user_id: The user's id whose device_list will be updated.
+            mark_failed_as_stale: Whether to mark the user's device list as stale
+                if the attempt to resync failed.
+        Returns:
+            A dict with device info as under the "devices" in the result of this
+            request:
+            https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid
+        """
+        return await self._user_device_resync_client(user_id=user_id)
+
+
+class DeviceListUpdater(DeviceListWorkerUpdater):
     "Handles incoming device list updates from federation and updates the DB"
 
     def __init__(self, hs: "HomeServer", device_handler: DeviceHandler):
@@ -937,7 +996,10 @@ class DeviceListUpdater:
         # Check if we are partially joining any rooms. If so we need to store
         # all device list updates so that we can handle them correctly once we
         # know who is in the room.
-        partial_rooms = await self.store.get_partial_state_rooms_and_servers()
+        # TODO(faster_joins): this fetches and processes a bunch of data that we don't
+        # use. Could be replaced by a tighter query e.g.
+        #   SELECT EXISTS(SELECT 1 FROM partial_state_rooms)
+        partial_rooms = await self.store.get_partial_state_room_resync_info()
         if partial_rooms:
             await self.store.add_remote_device_list_to_pending(
                 user_id,