summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/devices.py55
-rw-r--r--synapse/storage/databases/main/room.py16
2 files changed, 58 insertions, 13 deletions
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 1e562d4a40..18358eca46 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -1307,6 +1307,33 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
 
         return changes
 
+    async def get_device_list_changes_in_room(
+        self, room_id: str, min_stream_id: int
+    ) -> Collection[Tuple[str, str]]:
+        """Get all device list changes that happened in the room since the given
+        stream ID.
+
+        Returns:
+            Collection of user ID/device ID tuples of all devices that have
+            changed
+        """
+
+        sql = """
+            SELECT DISTINCT user_id, device_id FROM device_lists_changes_in_room
+            WHERE room_id = ? AND stream_id > ?
+        """
+
+        def get_device_list_changes_in_room_txn(
+            txn: LoggingTransaction,
+        ) -> Collection[Tuple[str, str]]:
+            txn.execute(sql, (room_id, min_stream_id))
+            return cast(Collection[Tuple[str, str]], txn.fetchall())
+
+        return await self.db_pool.runInteraction(
+            "get_device_list_changes_in_room",
+            get_device_list_changes_in_room_txn,
+        )
+
 
 class DeviceBackgroundUpdateStore(SQLBaseStore):
     def __init__(
@@ -1946,14 +1973,15 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         user_id: str,
         device_id: str,
         room_id: str,
-        stream_id: int,
+        stream_id: Optional[int],
         hosts: Collection[str],
         context: Optional[Dict[str, str]],
     ) -> None:
         """Queue the device update to be sent to the given set of hosts,
         calculated from the room ID.
 
-        Marks the associated row in `device_lists_changes_in_room` as handled.
+        Marks the associated row in `device_lists_changes_in_room` as handled,
+        if `stream_id` is provided.
         """
 
         def add_device_list_outbound_pokes_txn(
@@ -1969,17 +1997,18 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
                     context=context,
                 )
 
-            self.db_pool.simple_update_txn(
-                txn,
-                table="device_lists_changes_in_room",
-                keyvalues={
-                    "user_id": user_id,
-                    "device_id": device_id,
-                    "stream_id": stream_id,
-                    "room_id": room_id,
-                },
-                updatevalues={"converted_to_destinations": True},
-            )
+            if stream_id:
+                self.db_pool.simple_update_txn(
+                    txn,
+                    table="device_lists_changes_in_room",
+                    keyvalues={
+                        "user_id": user_id,
+                        "device_id": device_id,
+                        "stream_id": stream_id,
+                        "room_id": room_id,
+                    },
+                    updatevalues={"converted_to_destinations": True},
+                )
 
         if not hosts:
             # If there are no hosts then we don't try and generate stream IDs.
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 672c9a03fc..059eef5c22 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -1256,6 +1256,22 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
 
         return entry is not None
 
+    async def get_join_event_id_and_device_lists_stream_id_for_partial_state(
+        self, room_id: str
+    ) -> Tuple[str, int]:
+        """Get the event ID of the initial join that started the partial
+        join, and the device list stream ID at the point we started the partial
+        join.
+        """
+
+        result = await self.db_pool.simple_select_one(
+            table="partial_state_rooms",
+            keyvalues={"room_id": room_id},
+            retcols=("join_event_id", "device_lists_stream_id"),
+            desc="get_join_event_id_for_partial_state",
+        )
+        return result["join_event_id"], result["device_lists_stream_id"]
+
 
 class _BackgroundUpdates:
     REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"