summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/13885.misc1
-rw-r--r--synapse/app/admin_cmd.py2
-rw-r--r--synapse/storage/controllers/persist_events.py71
-rw-r--r--synapse/storage/databases/main/__init__.py2
-rw-r--r--synapse/storage/databases/main/devices.py64
-rw-r--r--synapse/storage/databases/main/events.py6
-rw-r--r--synapse/storage/databases/main/roommember.py46
7 files changed, 85 insertions, 107 deletions
diff --git a/changelog.d/13885.misc b/changelog.d/13885.misc
new file mode 100644
index 0000000000..bc76b862df
--- /dev/null
+++ b/changelog.d/13885.misc
@@ -0,0 +1 @@
+Correctly handle a race with device lists when a remote user leaves during a partial join.
diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py
index 8a583d3ec6..3c8c00ea5b 100644
--- a/synapse/app/admin_cmd.py
+++ b/synapse/app/admin_cmd.py
@@ -53,9 +53,9 @@ logger = logging.getLogger("synapse.app.admin_cmd")
 
 class AdminCmdSlavedStore(
     SlavedFilteringStore,
-    SlavedDeviceStore,
     SlavedPushRuleStore,
     SlavedEventStore,
+    SlavedDeviceStore,
     TagsWorkerStore,
     DeviceInboxWorkerStore,
     AccountDataWorkerStore,
diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py
index 501dbbc990..709cb792ed 100644
--- a/synapse/storage/controllers/persist_events.py
+++ b/synapse/storage/controllers/persist_events.py
@@ -598,11 +598,6 @@ class EventsPersistenceStorageController:
             # room
             state_delta_for_room: Dict[str, DeltaState] = {}
 
-            # Set of remote users which were in rooms the server has left or who may
-            # have left rooms the server is in. We should check if we still share any
-            # rooms and if not we mark their device lists as stale.
-            potentially_left_users: Set[str] = set()
-
             if not backfilled:
                 with Measure(self._clock, "_calculate_state_and_extrem"):
                     # Work out the new "current state" for each room.
@@ -716,8 +711,6 @@ class EventsPersistenceStorageController:
                                 room_id,
                                 ev_ctx_rm,
                                 delta,
-                                current_state,
-                                potentially_left_users,
                             )
                             if not is_still_joined:
                                 logger.info("Server no longer in room %s", room_id)
@@ -725,20 +718,6 @@ class EventsPersistenceStorageController:
                                 current_state = {}
                                 delta.no_longer_in_room = True
 
-                            # Add all remote users that might have left rooms.
-                            potentially_left_users.update(
-                                user_id
-                                for event_type, user_id in delta.to_delete
-                                if event_type == EventTypes.Member
-                                and not self.is_mine_id(user_id)
-                            )
-                            potentially_left_users.update(
-                                user_id
-                                for event_type, user_id in delta.to_insert.keys()
-                                if event_type == EventTypes.Member
-                                and not self.is_mine_id(user_id)
-                            )
-
                             state_delta_for_room[room_id] = delta
 
             await self.persist_events_store._persist_events_and_state_updates(
@@ -749,8 +728,6 @@ class EventsPersistenceStorageController:
                 inhibit_local_membership_updates=backfilled,
             )
 
-            await self._handle_potentially_left_users(potentially_left_users)
-
         return replaced_events
 
     async def _calculate_new_extremities(
@@ -1126,8 +1103,6 @@ class EventsPersistenceStorageController:
         room_id: str,
         ev_ctx_rm: List[Tuple[EventBase, EventContext]],
         delta: DeltaState,
-        current_state: Optional[StateMap[str]],
-        potentially_left_users: Set[str],
     ) -> bool:
         """Check if the server will still be joined after the given events have
         been persised.
@@ -1137,11 +1112,6 @@ class EventsPersistenceStorageController:
             ev_ctx_rm
             delta: The delta of current state between what is in the database
                 and what the new current state will be.
-            current_state: The new current state if it already been calculated,
-                otherwise None.
-            potentially_left_users: If the server has left the room, then joined
-                remote users will be added to this set to indicate that the
-                server may no longer be sharing a room with them.
         """
 
         if not any(
@@ -1195,45 +1165,4 @@ class EventsPersistenceStorageController:
         ):
             return True
 
-        # The server will leave the room, so we go and find out which remote
-        # users will still be joined when we leave.
-        if current_state is None:
-            current_state = await self.main_store.get_partial_current_state_ids(room_id)
-            current_state = dict(current_state)
-            for key in delta.to_delete:
-                current_state.pop(key, None)
-
-            current_state.update(delta.to_insert)
-
-        remote_event_ids = [
-            event_id
-            for (
-                typ,
-                state_key,
-            ), event_id in current_state.items()
-            if typ == EventTypes.Member and not self.is_mine_id(state_key)
-        ]
-        members = await self.main_store.get_membership_from_event_ids(remote_event_ids)
-        potentially_left_users.update(
-            member.user_id
-            for member in members.values()
-            if member and member.membership == Membership.JOIN
-        )
-
         return False
-
-    async def _handle_potentially_left_users(self, user_ids: Set[str]) -> None:
-        """Given a set of remote users check if the server still shares a room with
-        them. If not then mark those users' device cache as stale.
-        """
-
-        if not user_ids:
-            return
-
-        joined_users = await self.main_store.get_users_server_still_shares_room_with(
-            user_ids
-        )
-        left_users = user_ids - joined_users
-
-        for user_id in left_users:
-            await self.main_store.mark_remote_user_device_list_as_unsubscribed(user_id)
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 4dccbb732a..0843f10340 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -83,6 +83,7 @@ logger = logging.getLogger(__name__)
 
 class DataStore(
     EventsBackgroundUpdatesStore,
+    DeviceStore,
     RoomMemberStore,
     RoomStore,
     RoomBatchStore,
@@ -114,7 +115,6 @@ class DataStore(
     StreamWorkerStore,
     OpenIdStore,
     ClientIpWorkerStore,
-    DeviceStore,
     DeviceInboxStore,
     UserDirectoryStore,
     UserErasureStore,
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 5d700ca6c3..1151fb0cc3 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -47,6 +47,7 @@ from synapse.storage.database import (
     make_tuple_comparison_clause,
 )
 from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
+from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
 from synapse.storage.types import Cursor
 from synapse.types import JsonDict, get_verify_key_from_cross_signing_key
 from synapse.util import json_decoder, json_encoder
@@ -70,7 +71,7 @@ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
 BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
 
 
-class DeviceWorkerStore(EndToEndKeyWorkerStore):
+class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
     def __init__(
         self,
         database: DatabasePool,
@@ -985,24 +986,59 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
             desc="mark_remote_user_device_cache_as_valid",
         )
 
+    async def handle_potentially_left_users(self, user_ids: Set[str]) -> None:
+        """Given a set of remote users check if the server still shares a room with
+        them. If not then mark those users' device cache as stale.
+        """
+
+        if not user_ids:
+            return
+
+        await self.db_pool.runInteraction(
+            "_handle_potentially_left_users",
+            self.handle_potentially_left_users_txn,
+            user_ids,
+        )
+
+    def handle_potentially_left_users_txn(
+        self,
+        txn: LoggingTransaction,
+        user_ids: Set[str],
+    ) -> None:
+        """Given a set of remote users check if the server still shares a room with
+        them. If not then mark those users' device cache as stale.
+        """
+
+        if not user_ids:
+            return
+
+        joined_users = self.get_users_server_still_shares_room_with_txn(txn, user_ids)
+        left_users = user_ids - joined_users
+
+        for user_id in left_users:
+            self.mark_remote_user_device_list_as_unsubscribed_txn(txn, user_id)
+
     async def mark_remote_user_device_list_as_unsubscribed(self, user_id: str) -> None:
         """Mark that we no longer track device lists for remote user."""
 
-        def _mark_remote_user_device_list_as_unsubscribed_txn(
-            txn: LoggingTransaction,
-        ) -> None:
-            self.db_pool.simple_delete_txn(
-                txn,
-                table="device_lists_remote_extremeties",
-                keyvalues={"user_id": user_id},
-            )
-            self._invalidate_cache_and_stream(
-                txn, self.get_device_list_last_stream_id_for_remote, (user_id,)
-            )
-
         await self.db_pool.runInteraction(
             "mark_remote_user_device_list_as_unsubscribed",
-            _mark_remote_user_device_list_as_unsubscribed_txn,
+            self.mark_remote_user_device_list_as_unsubscribed_txn,
+            user_id,
+        )
+
+    def mark_remote_user_device_list_as_unsubscribed_txn(
+        self,
+        txn: LoggingTransaction,
+        user_id: str,
+    ) -> None:
+        self.db_pool.simple_delete_txn(
+            txn,
+            table="device_lists_remote_extremeties",
+            keyvalues={"user_id": user_id},
+        )
+        self._invalidate_cache_and_stream(
+            txn, self.get_device_list_last_stream_id_for_remote, (user_id,)
         )
 
     async def get_dehydrated_device(
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 2e156a4a11..b59eb7478b 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1202,6 +1202,12 @@ class PersistEventsStore:
                 txn, room_id, members_changed
             )
 
+            # Check if any of the remote membership changes requires us to
+            # unsubscribe from their device lists.
+            self.store.handle_potentially_left_users_txn(
+                txn, {m for m in members_changed if not self.hs.is_mine_id(m)}
+            )
+
     def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str) -> None:
         """Update the room version in the database based off current state
         events.
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index a8d224602a..8ada3cdac3 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -662,31 +662,37 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         if not user_ids:
             return set()
 
-        def _get_users_server_still_shares_room_with_txn(
-            txn: LoggingTransaction,
-        ) -> Set[str]:
-            sql = """
-                SELECT state_key FROM current_state_events
-                WHERE
-                    type = 'm.room.member'
-                    AND membership = 'join'
-                    AND %s
-                GROUP BY state_key
-            """
-
-            clause, args = make_in_list_sql_clause(
-                self.database_engine, "state_key", user_ids
-            )
+        return await self.db_pool.runInteraction(
+            "get_users_server_still_shares_room_with",
+            self.get_users_server_still_shares_room_with_txn,
+            user_ids,
+        )
 
-            txn.execute(sql % (clause,), args)
+    def get_users_server_still_shares_room_with_txn(
+        self,
+        txn: LoggingTransaction,
+        user_ids: Collection[str],
+    ) -> Set[str]:
+        if not user_ids:
+            return set()
 
-            return {row[0] for row in txn}
+        sql = """
+            SELECT state_key FROM current_state_events
+            WHERE
+                type = 'm.room.member'
+                AND membership = 'join'
+                AND %s
+            GROUP BY state_key
+        """
 
-        return await self.db_pool.runInteraction(
-            "get_users_server_still_shares_room_with",
-            _get_users_server_still_shares_room_with_txn,
+        clause, args = make_in_list_sql_clause(
+            self.database_engine, "state_key", user_ids
         )
 
+        txn.execute(sql % (clause,), args)
+
+        return {row[0] for row in txn}
+
     @cancellable
     async def get_rooms_for_user(
         self, user_id: str, on_invalidate: Optional[Callable[[], None]] = None