summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorErik Johnston <erikj@element.io>2024-05-21 16:48:20 +0100
committerGitHub <noreply@github.com>2024-05-21 16:48:20 +0100
commitb5facbac0f2d5f6f0e83d7cac43f8de02ce6742f (patch)
treed11acdaaa92b0f0cd2f3a7b2ee4d2ff8f29d9bfb /synapse
parentMerge branch 'release-v1.108' into develop (diff)
downloadsynapse-b5facbac0f2d5f6f0e83d7cac43f8de02ce6742f.tar.xz
Improve perf of sync device lists (#17216)
Re-introduces #17191, and includes #17197 and #17214

The basic idea is to stop calling `get_rooms_for_user` everywhere, and
instead use the table `device_lists_changes_in_room`.

Commits reviewable one-by-one.
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/device.py22
-rw-r--r--synapse/handlers/sync.py38
-rw-r--r--synapse/replication/tcp/client.py15
-rw-r--r--synapse/storage/databases/main/devices.py89
4 files changed, 102 insertions, 62 deletions
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 67953a3ed9..55842e7c7b 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -159,20 +159,32 @@ class DeviceWorkerHandler:
 
     @cancellable
     async def get_device_changes_in_shared_rooms(
-        self, user_id: str, room_ids: StrCollection, from_token: StreamToken
+        self,
+        user_id: str,
+        room_ids: StrCollection,
+        from_token: StreamToken,
+        now_token: Optional[StreamToken] = None,
     ) -> Set[str]:
         """Get the set of users whose devices have changed who share a room with
         the given user.
         """
+        now_device_lists_key = self.store.get_device_stream_token()
+        if now_token:
+            now_device_lists_key = now_token.device_list_key
+
         changed_users = await self.store.get_device_list_changes_in_rooms(
-            room_ids, from_token.device_list_key
+            room_ids,
+            from_token.device_list_key,
+            now_device_lists_key,
         )
 
         if changed_users is not None:
             # We also check if the given user has changed their device. If
             # they're in no rooms then the above query won't include them.
             changed = await self.store.get_users_whose_devices_changed(
-                from_token.device_list_key, [user_id]
+                from_token.device_list_key,
+                [user_id],
+                to_key=now_device_lists_key,
             )
             changed_users.update(changed)
             return changed_users
@@ -190,7 +202,9 @@ class DeviceWorkerHandler:
         tracked_users.add(user_id)
 
         changed = await self.store.get_users_whose_devices_changed(
-            from_token.device_list_key, tracked_users
+            from_token.device_list_key,
+            tracked_users,
+            to_key=now_device_lists_key,
         )
 
         return changed
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index d3d40e8682..b7917a99d6 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -1886,38 +1886,14 @@ class SyncHandler:
 
         # Step 1a, check for changes in devices of users we share a room
         # with
-        #
-        # We do this in two different ways depending on what we have cached.
-        # If we already have a list of all the user that have changed since
-        # the last sync then it's likely more efficient to compare the rooms
-        # they're in with the rooms the syncing user is in.
-        #
-        # If we don't have that info cached then we get all the users that
-        # share a room with our user and check if those users have changed.
-        cache_result = self.store.get_cached_device_list_changes(
-            since_token.device_list_key
-        )
-        if cache_result.hit:
-            changed_users = cache_result.entities
-
-            result = await self.store.get_rooms_for_users(changed_users)
-
-            for changed_user_id, entries in result.items():
-                # Check if the changed user shares any rooms with the user,
-                # or if the changed user is the syncing user (as we always
-                # want to include device list updates of their own devices).
-                if user_id == changed_user_id or any(
-                    rid in joined_room_ids for rid in entries
-                ):
-                    users_that_have_changed.add(changed_user_id)
-        else:
-            users_that_have_changed = (
-                await self._device_handler.get_device_changes_in_shared_rooms(
-                    user_id,
-                    sync_result_builder.joined_room_ids,
-                    from_token=since_token,
-                )
+        users_that_have_changed = (
+            await self._device_handler.get_device_changes_in_shared_rooms(
+                user_id,
+                sync_result_builder.joined_room_ids,
+                from_token=since_token,
+                now_token=sync_result_builder.now_token,
             )
+        )
 
         # Step 1b, check for newly joined rooms
         for room_id in newly_joined_rooms:
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 5e5387fdcb..2d6d49eed7 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -112,6 +112,15 @@ class ReplicationDataHandler:
             token: stream token for this batch of rows
             rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
         """
+        all_room_ids: Set[str] = set()
+        if stream_name == DeviceListsStream.NAME:
+            if any(row.entity.startswith("@") and not row.is_signature for row in rows):
+                prev_token = self.store.get_device_stream_token()
+                all_room_ids = await self.store.get_all_device_list_changes(
+                    prev_token, token
+                )
+                self.store.device_lists_in_rooms_have_changed(all_room_ids, token)
+
         self.store.process_replication_rows(stream_name, instance_name, token, rows)
         # NOTE: this must be called after process_replication_rows to ensure any
         # cache invalidations are first handled before any stream ID advances.
@@ -146,12 +155,6 @@ class ReplicationDataHandler:
                     StreamKeyType.TO_DEVICE, token, users=entities
                 )
         elif stream_name == DeviceListsStream.NAME:
-            all_room_ids: Set[str] = set()
-            for row in rows:
-                if row.entity.startswith("@") and not row.is_signature:
-                    room_ids = await self.store.get_rooms_for_user(row.entity)
-                    all_room_ids.update(room_ids)
-
             # `all_room_ids` can be large, so let's wake up those streams in batches
             for batched_room_ids in batch_iter(all_room_ids, 100):
                 self.notifier.on_new_event(
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 8dbcb3f5a0..f4410b5c02 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -70,10 +70,7 @@ from synapse.types import (
 from synapse.util import json_decoder, json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches.lrucache import LruCache
-from synapse.util.caches.stream_change_cache import (
-    AllEntitiesChangedResult,
-    StreamChangeCache,
-)
+from synapse.util.caches.stream_change_cache import StreamChangeCache
 from synapse.util.cancellation import cancellable
 from synapse.util.iterutils import batch_iter
 from synapse.util.stringutils import shortstr
@@ -132,6 +129,20 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
             prefilled_cache=device_list_prefill,
         )
 
+        device_list_room_prefill, min_device_list_room_id = self.db_pool.get_cache_dict(
+            db_conn,
+            "device_lists_changes_in_room",
+            entity_column="room_id",
+            stream_column="stream_id",
+            max_value=device_list_max,
+            limit=10000,
+        )
+        self._device_list_room_stream_cache = StreamChangeCache(
+            "DeviceListRoomStreamChangeCache",
+            min_device_list_room_id,
+            prefilled_cache=device_list_room_prefill,
+        )
+
         (
             user_signature_stream_prefill,
             user_signature_stream_list_id,
@@ -209,6 +220,13 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
                     row.entity, token
                 )
 
+    def device_lists_in_rooms_have_changed(
+        self, room_ids: StrCollection, token: int
+    ) -> None:
+        "Record that device lists have changed in rooms"
+        for room_id in room_ids:
+            self._device_list_room_stream_cache.entity_has_changed(room_id, token)
+
     def get_device_stream_token(self) -> int:
         return self._device_list_id_gen.get_current_token()
 
@@ -832,16 +850,6 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
         )
         return {device[0]: db_to_json(device[1]) for device in devices}
 
-    def get_cached_device_list_changes(
-        self,
-        from_key: int,
-    ) -> AllEntitiesChangedResult:
-        """Get set of users whose devices have changed since `from_key`, or None
-        if that information is not in our cache.
-        """
-
-        return self._device_list_stream_cache.get_all_entities_changed(from_key)
-
     @cancellable
     async def get_all_devices_changed(
         self,
@@ -1457,7 +1465,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
 
     @cancellable
     async def get_device_list_changes_in_rooms(
-        self, room_ids: Collection[str], from_id: int
+        self, room_ids: Collection[str], from_id: int, to_id: int
     ) -> Optional[Set[str]]:
         """Return the set of users whose devices have changed in the given rooms
         since the given stream ID.
@@ -1473,9 +1481,15 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
         if min_stream_id > from_id:
             return None
 
+        changed_room_ids = self._device_list_room_stream_cache.get_entities_changed(
+            room_ids, from_id
+        )
+        if not changed_room_ids:
+            return set()
+
         sql = """
             SELECT DISTINCT user_id FROM device_lists_changes_in_room
-            WHERE {clause} AND stream_id >= ?
+            WHERE {clause} AND stream_id > ? AND stream_id <= ?
         """
 
         def _get_device_list_changes_in_rooms_txn(
@@ -1487,11 +1501,12 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
             return {user_id for user_id, in txn}
 
         changes = set()
-        for chunk in batch_iter(room_ids, 1000):
+        for chunk in batch_iter(changed_room_ids, 1000):
             clause, args = make_in_list_sql_clause(
                 self.database_engine, "room_id", chunk
             )
             args.append(from_id)
+            args.append(to_id)
 
             changes |= await self.db_pool.runInteraction(
                 "get_device_list_changes_in_rooms",
@@ -1502,6 +1517,34 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
 
         return changes
 
+    async def get_all_device_list_changes(self, from_id: int, to_id: int) -> Set[str]:
+        """Return the set of rooms where devices have changed since the given
+        stream ID.
+
+        Will raise an exception if the given stream ID is too old.
+        """
+
+        min_stream_id = await self._get_min_device_lists_changes_in_room()
+
+        if min_stream_id > from_id:
+            raise Exception("stream ID is too old")
+
+        sql = """
+            SELECT DISTINCT room_id FROM device_lists_changes_in_room
+            WHERE stream_id > ? AND stream_id <= ?
+        """
+
+        def _get_all_device_list_changes_txn(
+            txn: LoggingTransaction,
+        ) -> Set[str]:
+            txn.execute(sql, (from_id, to_id))
+            return {room_id for room_id, in txn}
+
+        return await self.db_pool.runInteraction(
+            "get_all_device_list_changes",
+            _get_all_device_list_changes_txn,
+        )
+
     async def get_device_list_changes_in_room(
         self, room_id: str, min_stream_id: int
     ) -> Collection[Tuple[str, str]]:
@@ -1962,8 +2005,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
     async def add_device_change_to_streams(
         self,
         user_id: str,
-        device_ids: Collection[str],
-        room_ids: Collection[str],
+        device_ids: StrCollection,
+        room_ids: StrCollection,
     ) -> Optional[int]:
         """Persist that a user's devices have been updated, and which hosts
         (if any) should be poked.
@@ -2122,8 +2165,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         self,
         txn: LoggingTransaction,
         user_id: str,
-        device_ids: Iterable[str],
-        room_ids: Collection[str],
+        device_ids: StrCollection,
+        room_ids: StrCollection,
         stream_ids: List[int],
         context: Dict[str, str],
     ) -> None:
@@ -2161,6 +2204,10 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             ],
         )
 
+        txn.call_after(
+            self.device_lists_in_rooms_have_changed, room_ids, max(stream_ids)
+        )
+
     async def get_uncoverted_outbound_room_pokes(
         self, start_stream_id: int, start_room_id: str, limit: int = 10
     ) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]: