summary refs log tree commit diff
path: root/synapse/storage/databases/main/devices.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/devices.py')
-rw-r--r--synapse/storage/databases/main/devices.py174
1 files changed, 147 insertions, 27 deletions
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index dfca34550d..d7f015c783 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -47,6 +47,7 @@ from synapse.storage.database import (
     make_tuple_comparison_clause,
 )
 from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
+from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
 from synapse.storage.types import Cursor
 from synapse.types import JsonDict, get_verify_key_from_cross_signing_key
 from synapse.util import json_decoder, json_encoder
@@ -70,7 +71,7 @@ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
 BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
 
 
-class DeviceWorkerStore(EndToEndKeyWorkerStore):
+class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
     def __init__(
         self,
         database: DatabasePool,
@@ -985,24 +986,59 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
             desc="mark_remote_user_device_cache_as_valid",
         )
 
+    async def handle_potentially_left_users(self, user_ids: Set[str]) -> None:
+        """Given a set of remote users check if the server still shares a room with
+        them. If not then mark those users' device cache as stale.
+        """
+
+        if not user_ids:
+            return
+
+        await self.db_pool.runInteraction(
+            "_handle_potentially_left_users",
+            self.handle_potentially_left_users_txn,
+            user_ids,
+        )
+
+    def handle_potentially_left_users_txn(
+        self,
+        txn: LoggingTransaction,
+        user_ids: Set[str],
+    ) -> None:
+        """Given a set of remote users check if the server still shares a room with
+        them. If not then mark those users' device cache as stale.
+        """
+
+        if not user_ids:
+            return
+
+        joined_users = self.get_users_server_still_shares_room_with_txn(txn, user_ids)
+        left_users = user_ids - joined_users
+
+        for user_id in left_users:
+            self.mark_remote_user_device_list_as_unsubscribed_txn(txn, user_id)
+
     async def mark_remote_user_device_list_as_unsubscribed(self, user_id: str) -> None:
         """Mark that we no longer track device lists for remote user."""
 
-        def _mark_remote_user_device_list_as_unsubscribed_txn(
-            txn: LoggingTransaction,
-        ) -> None:
-            self.db_pool.simple_delete_txn(
-                txn,
-                table="device_lists_remote_extremeties",
-                keyvalues={"user_id": user_id},
-            )
-            self._invalidate_cache_and_stream(
-                txn, self.get_device_list_last_stream_id_for_remote, (user_id,)
-            )
-
         await self.db_pool.runInteraction(
             "mark_remote_user_device_list_as_unsubscribed",
-            _mark_remote_user_device_list_as_unsubscribed_txn,
+            self.mark_remote_user_device_list_as_unsubscribed_txn,
+            user_id,
+        )
+
+    def mark_remote_user_device_list_as_unsubscribed_txn(
+        self,
+        txn: LoggingTransaction,
+        user_id: str,
+    ) -> None:
+        self.db_pool.simple_delete_txn(
+            txn,
+            table="device_lists_remote_extremeties",
+            keyvalues={"user_id": user_id},
+        )
+        self._invalidate_cache_and_stream(
+            txn, self.get_device_list_last_stream_id_for_remote, (user_id,)
         )
 
     async def get_dehydrated_device(
@@ -1271,6 +1307,33 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
 
         return changes
 
+    async def get_device_list_changes_in_room(
+        self, room_id: str, min_stream_id: int
+    ) -> Collection[Tuple[str, str]]:
+        """Get all device list changes that happened in the room since the given
+        stream ID.
+
+        Returns:
+            Collection of user ID/device ID tuples of all devices that have
+            changed
+        """
+
+        sql = """
+            SELECT DISTINCT user_id, device_id FROM device_lists_changes_in_room
+            WHERE room_id = ? AND stream_id > ?
+        """
+
+        def get_device_list_changes_in_room_txn(
+            txn: LoggingTransaction,
+        ) -> Collection[Tuple[str, str]]:
+            txn.execute(sql, (room_id, min_stream_id))
+            return cast(Collection[Tuple[str, str]], txn.fetchall())
+
+        return await self.db_pool.runInteraction(
+            "get_device_list_changes_in_room",
+            get_device_list_changes_in_room_txn,
+        )
+
 
 class DeviceBackgroundUpdateStore(SQLBaseStore):
     def __init__(
@@ -1910,14 +1973,15 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         user_id: str,
         device_id: str,
         room_id: str,
-        stream_id: int,
+        stream_id: Optional[int],
         hosts: Collection[str],
         context: Optional[Dict[str, str]],
     ) -> None:
         """Queue the device update to be sent to the given set of hosts,
         calculated from the room ID.
 
-        Marks the associated row in `device_lists_changes_in_room` as handled.
+        Marks the associated row in `device_lists_changes_in_room` as handled,
+        if `stream_id` is provided.
         """
 
         def add_device_list_outbound_pokes_txn(
@@ -1933,17 +1997,18 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
                     context=context,
                 )
 
-            self.db_pool.simple_update_txn(
-                txn,
-                table="device_lists_changes_in_room",
-                keyvalues={
-                    "user_id": user_id,
-                    "device_id": device_id,
-                    "stream_id": stream_id,
-                    "room_id": room_id,
-                },
-                updatevalues={"converted_to_destinations": True},
-            )
+            if stream_id:
+                self.db_pool.simple_update_txn(
+                    txn,
+                    table="device_lists_changes_in_room",
+                    keyvalues={
+                        "user_id": user_id,
+                        "device_id": device_id,
+                        "stream_id": stream_id,
+                        "room_id": room_id,
+                    },
+                    updatevalues={"converted_to_destinations": True},
+                )
 
         if not hosts:
             # If there are no hosts then we don't try and generate stream IDs.
@@ -1959,3 +2024,58 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
                 add_device_list_outbound_pokes_txn,
                 stream_ids,
             )
+
+    async def add_remote_device_list_to_pending(
+        self, user_id: str, device_id: str
+    ) -> None:
+        """Add a device list update to the table tracking remote device list
+        updates during partial joins.
+        """
+
+        async with self._device_list_id_gen.get_next() as stream_id:  # type: ignore[attr-defined]
+            await self.db_pool.simple_upsert(
+                table="device_lists_remote_pending",
+                keyvalues={
+                    "user_id": user_id,
+                    "device_id": device_id,
+                },
+                values={"stream_id": stream_id},
+                desc="add_remote_device_list_to_pending",
+            )
+
+    async def get_pending_remote_device_list_updates_for_room(
+        self, room_id: str
+    ) -> Collection[Tuple[str, str]]:
+        """Get the set of remote device list updates from the pending table for
+        the room.
+        """
+
+        min_device_stream_id = await self.db_pool.simple_select_one_onecol(
+            table="partial_state_rooms",
+            keyvalues={
+                "room_id": room_id,
+            },
+            retcol="device_lists_stream_id",
+            desc="get_pending_remote_device_list_updates_for_room_device",
+        )
+
+        sql = """
+            SELECT user_id, device_id FROM device_lists_remote_pending AS d
+            INNER JOIN current_state_events AS c ON
+                type = 'm.room.member'
+                AND state_key = user_id
+                AND membership = 'join'
+            WHERE
+                room_id = ? AND stream_id > ?
+        """
+
+        def get_pending_remote_device_list_updates_for_room_txn(
+            txn: LoggingTransaction,
+        ) -> Collection[Tuple[str, str]]:
+            txn.execute(sql, (room_id, min_device_stream_id))
+            return cast(Collection[Tuple[str, str]], txn.fetchall())
+
+        return await self.db_pool.runInteraction(
+            "get_pending_remote_device_list_updates_for_room",
+            get_pending_remote_device_list_updates_for_room_txn,
+        )