summary refs log tree commit diff
path: root/synapse/storage/databases/main/devices.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/devices.py')
-rw-r--r--synapse/storage/databases/main/devices.py59
1 files changed, 59 insertions, 0 deletions
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 03d1334e03..93d980786e 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -1208,6 +1208,65 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
 
         return devices
 
+    @cached()
+    async def _get_min_device_lists_changes_in_room(self) -> int:
+        """Returns the minimum stream ID that we have entries for
+        `device_lists_changes_in_room`
+        """
+
+        return await self.db_pool.simple_select_one_onecol(
+            table="device_lists_changes_in_room",
+            keyvalues={},
+            retcol="COALESCE(MIN(stream_id), 0)",
+            desc="get_min_device_lists_changes_in_room",
+        )
+
+    async def get_device_list_changes_in_rooms(
+        self, room_ids: Collection[str], from_id: int
+    ) -> Optional[Set[str]]:
+        """Return the set of users whose devices have changed in the given rooms
+        since the given stream ID.
+
+        Returns None if the given stream ID is too old.
+        """
+
+        if not room_ids:
+            return set()
+
+        min_stream_id = await self._get_min_device_lists_changes_in_room()
+
+        if min_stream_id > from_id:
+            return None
+
+        sql = """
+            SELECT DISTINCT user_id FROM device_lists_changes_in_room
+            WHERE {clause} AND stream_id >= ?
+        """
+
+        def _get_device_list_changes_in_rooms_txn(
+            txn: LoggingTransaction,
+            clause,
+            args,
+        ) -> Set[str]:
+            txn.execute(sql.format(clause=clause), args)
+            return {user_id for user_id, in txn}
+
+        changes = set()
+        for chunk in batch_iter(room_ids, 1000):
+            clause, args = make_in_list_sql_clause(
+                self.database_engine, "room_id", chunk
+            )
+            args.append(from_id)
+
+            changes |= await self.db_pool.runInteraction(
+                "get_device_list_changes_in_rooms",
+                _get_device_list_changes_in_rooms_txn,
+                clause,
+                args,
+            )
+
+        return changes
+
 
 class DeviceBackgroundUpdateStore(SQLBaseStore):
     def __init__(