summary refs log tree commit diff
path: root/synapse/storage/databases/main/devices.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/devices.py')
-rw-r--r--synapse/storage/databases/main/devices.py55
1 files changed, 55 insertions, 0 deletions
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 1151fb0cc3..1e562d4a40 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -1995,3 +1995,58 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
                 add_device_list_outbound_pokes_txn,
                 stream_ids,
             )
+
+    async def add_remote_device_list_to_pending(
+        self, user_id: str, device_id: str
+    ) -> None:
+        """Add a device list update to the table tracking remote device list
+        updates during partial joins.
+        """
+
+        async with self._device_list_id_gen.get_next() as stream_id:  # type: ignore[attr-defined]
+            await self.db_pool.simple_upsert(
+                table="device_lists_remote_pending",
+                keyvalues={
+                    "user_id": user_id,
+                    "device_id": device_id,
+                },
+                values={"stream_id": stream_id},
+                desc="add_remote_device_list_to_pending",
+            )
+
+    async def get_pending_remote_device_list_updates_for_room(
+        self, room_id: str
+    ) -> Collection[Tuple[str, str]]:
+        """Get the set of remote device list updates from the pending table for
+        the room.
+        """
+
+        min_device_stream_id = await self.db_pool.simple_select_one_onecol(
+            table="partial_state_rooms",
+            keyvalues={
+                "room_id": room_id,
+            },
+            retcol="device_lists_stream_id",
+            desc="get_pending_remote_device_list_updates_for_room_device",
+        )
+
+        sql = """
+            SELECT user_id, device_id FROM device_lists_remote_pending AS d
+            INNER JOIN current_state_events AS c ON
+                type = 'm.room.member'
+                AND state_key = user_id
+                AND membership = 'join'
+            WHERE
+                room_id = ? AND stream_id > ?
+        """
+
+        def get_pending_remote_device_list_updates_for_room_txn(
+            txn: LoggingTransaction,
+        ) -> Collection[Tuple[str, str]]:
+            txn.execute(sql, (room_id, min_device_stream_id))
+            return cast(Collection[Tuple[str, str]], txn.fetchall())
+
+        return await self.db_pool.runInteraction(
+            "get_pending_remote_device_list_updates_for_room",
+            get_pending_remote_device_list_updates_for_room_txn,
+        )