summary refs log tree commit diff
path: root/synapse/storage/databases/main/devices.py
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2022-09-27 13:01:08 +0100
committerGitHub <noreply@github.com>2022-09-27 13:01:08 +0100
commite8318a433356413648bd180dcfc69c29ca319fc6 (patch)
tree340aabcb04f4fa3663a5c3b189131b7cee37b62e /synapse/storage/databases/main/devices.py
parentFaster room joins: Fix spurious error when joining a room (#13872) (diff)
downloadsynapse-e8318a433356413648bd180dcfc69c29ca319fc6.tar.xz
Handle the case of remote users leaving a partial join room for device lists (#13885)
Diffstat (limited to 'synapse/storage/databases/main/devices.py')
-rw-r--r--synapse/storage/databases/main/devices.py64
1 files changed, 50 insertions, 14 deletions
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 5d700ca6c3..1151fb0cc3 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -47,6 +47,7 @@ from synapse.storage.database import (
     make_tuple_comparison_clause,
 )
 from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
+from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
 from synapse.storage.types import Cursor
 from synapse.types import JsonDict, get_verify_key_from_cross_signing_key
 from synapse.util import json_decoder, json_encoder
@@ -70,7 +71,7 @@ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
 BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
 
 
-class DeviceWorkerStore(EndToEndKeyWorkerStore):
+class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
     def __init__(
         self,
         database: DatabasePool,
@@ -985,24 +986,59 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
             desc="mark_remote_user_device_cache_as_valid",
         )
 
+    async def handle_potentially_left_users(self, user_ids: Set[str]) -> None:
+        """Given a set of remote users check if the server still shares a room with
+        them. If not then mark those users' device cache as stale.
+        """
+
+        if not user_ids:
+            return
+
+        await self.db_pool.runInteraction(
+            "_handle_potentially_left_users",
+            self.handle_potentially_left_users_txn,
+            user_ids,
+        )
+
+    def handle_potentially_left_users_txn(
+        self,
+        txn: LoggingTransaction,
+        user_ids: Set[str],
+    ) -> None:
+        """Given a set of remote users check if the server still shares a room with
+        them. If not then mark those users' device cache as stale.
+        """
+
+        if not user_ids:
+            return
+
+        joined_users = self.get_users_server_still_shares_room_with_txn(txn, user_ids)
+        left_users = user_ids - joined_users
+
+        for user_id in left_users:
+            self.mark_remote_user_device_list_as_unsubscribed_txn(txn, user_id)
+
     async def mark_remote_user_device_list_as_unsubscribed(self, user_id: str) -> None:
         """Mark that we no longer track device lists for remote user."""
 
-        def _mark_remote_user_device_list_as_unsubscribed_txn(
-            txn: LoggingTransaction,
-        ) -> None:
-            self.db_pool.simple_delete_txn(
-                txn,
-                table="device_lists_remote_extremeties",
-                keyvalues={"user_id": user_id},
-            )
-            self._invalidate_cache_and_stream(
-                txn, self.get_device_list_last_stream_id_for_remote, (user_id,)
-            )
-
         await self.db_pool.runInteraction(
             "mark_remote_user_device_list_as_unsubscribed",
-            _mark_remote_user_device_list_as_unsubscribed_txn,
+            self.mark_remote_user_device_list_as_unsubscribed_txn,
+            user_id,
+        )
+
+    def mark_remote_user_device_list_as_unsubscribed_txn(
+        self,
+        txn: LoggingTransaction,
+        user_id: str,
+    ) -> None:
+        self.db_pool.simple_delete_txn(
+            txn,
+            table="device_lists_remote_extremeties",
+            keyvalues={"user_id": user_id},
+        )
+        self._invalidate_cache_and_stream(
+            txn, self.get_device_list_last_stream_id_for_remote, (user_id,)
         )
 
     async def get_dehydrated_device(