summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/device.py69
-rw-r--r--synapse/handlers/sync.py19
-rw-r--r--synapse/storage/databases/main/devices.py59
3 files changed, 116 insertions, 31 deletions
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index b79c551703..c05a170c55 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -123,23 +123,28 @@ class DeviceWorkerHandler:
 
         return device
 
-    @trace
-    @measure_func("device.get_user_ids_changed")
-    async def get_user_ids_changed(
-        self, user_id: str, from_token: StreamToken
-    ) -> JsonDict:
-        """Get list of users that have had the devices updated, or have newly
-        joined a room, that `user_id` may be interested in.
+    async def get_device_changes_in_shared_rooms(
+        self, user_id: str, room_ids: Collection[str], from_token: StreamToken
+    ) -> Collection[str]:
+        """Get the set of users whose devices have changed who share a room with
+        the given user.
         """
+        changed_users = await self.store.get_device_list_changes_in_rooms(
+            room_ids, from_token.device_list_key
+        )
 
-        set_tag("user_id", user_id)
-        set_tag("from_token", from_token)
-        now_room_key = self.store.get_room_max_token()
+        if changed_users is not None:
+            # We also check if the given user has changed their device. If
+            # they're in no rooms then the above query won't include them.
+            changed = await self.store.get_users_whose_devices_changed(
+                from_token.device_list_key, [user_id]
+            )
+            changed_users.update(changed)
+            return changed_users
 
-        room_ids = await self.store.get_rooms_for_user(user_id)
+        # If the DB returned None then the `from_token` is too old, so we fall
+        # back on looking for device updates for all users.
 
-        # First we check if any devices have changed for users that we share
-        # rooms with.
         users_who_share_room = await self.store.get_users_who_share_room_with_user(
             user_id
         )
@@ -153,6 +158,27 @@ class DeviceWorkerHandler:
             from_token.device_list_key, tracked_users
         )
 
+        return changed
+
+    @trace
+    @measure_func("device.get_user_ids_changed")
+    async def get_user_ids_changed(
+        self, user_id: str, from_token: StreamToken
+    ) -> JsonDict:
+        """Get list of users that have had the devices updated, or have newly
+        joined a room, that `user_id` may be interested in.
+        """
+
+        set_tag("user_id", user_id)
+        set_tag("from_token", from_token)
+        now_room_key = self.store.get_room_max_token()
+
+        room_ids = await self.store.get_rooms_for_user(user_id)
+
+        changed = await self.get_device_changes_in_shared_rooms(
+            user_id, room_ids, from_token
+        )
+
         # Then work out if any users have since joined
         rooms_changed = self.store.get_rooms_that_changed(room_ids, from_token.room_key)
 
@@ -237,10 +263,19 @@ class DeviceWorkerHandler:
                         break
 
         if possibly_changed or possibly_left:
-            # Take the intersection of the users whose devices may have changed
-            # and those that actually still share a room with the user
-            possibly_joined = possibly_changed & users_who_share_room
-            possibly_left = (possibly_changed | possibly_left) - users_who_share_room
+            possibly_joined = possibly_changed
+            possibly_left = possibly_changed | possibly_left
+
+            # Double check if we still share rooms with the given user.
+            users_rooms = await self.store.get_rooms_for_users_with_stream_ordering(
+                possibly_left
+            )
+            for changed_user_id, entries in users_rooms.items():
+                if any(e.room_id in room_ids for e in entries):
+                    possibly_left.discard(changed_user_id)
+                else:
+                    possibly_joined.discard(changed_user_id)
+
         else:
             possibly_joined = set()
             possibly_left = set()
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 6ad053f678..d42a414c90 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -240,6 +240,7 @@ class SyncHandler:
         self.auth_blocking = hs.get_auth_blocking()
         self._storage_controllers = hs.get_storage_controllers()
         self._state_storage_controller = self._storage_controllers.state
+        self._device_handler = hs.get_device_handler()
 
         # TODO: flush cache entries on subsequent sync request.
         #    Once we get the next /sync request (ie, one with the same access token
@@ -1268,21 +1269,11 @@ class SyncHandler:
                     ):
                         users_that_have_changed.add(changed_user_id)
             else:
-                users_who_share_room = (
-                    await self.store.get_users_who_share_room_with_user(user_id)
-                )
-
-                # Always tell the user about their own devices. We check as the user
-                # ID is almost certainly already included (unless they're not in any
-                # rooms) and taking a copy of the set is relatively expensive.
-                if user_id not in users_who_share_room:
-                    users_who_share_room = set(users_who_share_room)
-                    users_who_share_room.add(user_id)
-
-                tracked_users = users_who_share_room
                 users_that_have_changed = (
-                    await self.store.get_users_whose_devices_changed(
-                        since_token.device_list_key, tracked_users
+                    await self._device_handler.get_device_changes_in_shared_rooms(
+                        user_id,
+                        sync_result_builder.joined_room_ids,
+                        from_token=since_token,
                     )
                 )
 
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 03d1334e03..93d980786e 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -1208,6 +1208,65 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
 
         return devices
 
+    @cached()
+    async def _get_min_device_lists_changes_in_room(self) -> int:
+        """Returns the minimum stream ID that we have entries for
+        `device_lists_changes_in_room`
+        """
+
+        return await self.db_pool.simple_select_one_onecol(
+            table="device_lists_changes_in_room",
+            keyvalues={},
+            retcol="COALESCE(MIN(stream_id), 0)",
+            desc="get_min_device_lists_changes_in_room",
+        )
+
+    async def get_device_list_changes_in_rooms(
+        self, room_ids: Collection[str], from_id: int
+    ) -> Optional[Set[str]]:
+        """Return the set of users whose devices have changed in the given rooms
+        since the given stream ID.
+
+        Returns None if the given stream ID is too old.
+        """
+
+        if not room_ids:
+            return set()
+
+        min_stream_id = await self._get_min_device_lists_changes_in_room()
+
+        if min_stream_id > from_id:
+            return None
+
+        sql = """
+            SELECT DISTINCT user_id FROM device_lists_changes_in_room
+            WHERE {clause} AND stream_id >= ?
+        """
+
+        def _get_device_list_changes_in_rooms_txn(
+            txn: LoggingTransaction,
+            clause,
+            args,
+        ) -> Set[str]:
+            txn.execute(sql.format(clause=clause), args)
+            return {user_id for user_id, in txn}
+
+        changes = set()
+        for chunk in batch_iter(room_ids, 1000):
+            clause, args = make_in_list_sql_clause(
+                self.database_engine, "room_id", chunk
+            )
+            args.append(from_id)
+
+            changes |= await self.db_pool.runInteraction(
+                "get_device_list_changes_in_rooms",
+                _get_device_list_changes_in_rooms_txn,
+                clause,
+                args,
+            )
+
+        return changes
+
 
 class DeviceBackgroundUpdateStore(SQLBaseStore):
     def __init__(