summary refs log tree commit diff
diff options
context:
space:
mode:
authorEric Eastwood <eric.eastwood@beta.gouv.fr>2024-08-07 20:47:13 -0500
committerEric Eastwood <eric.eastwood@beta.gouv.fr>2024-08-07 20:47:13 -0500
commit5cf3ad3d7fde5a2a8b8f949ef5a82b68c30878dc (patch)
tree2393d5199d8965c640cee2f00ddb6b4809434499
parentHandle to_delete (diff)
downloadsynapse-5cf3ad3d7fde5a2a8b8f949ef5a82b68c30878dc.tar.xz
Handle server left room
-rw-r--r--synapse/storage/databases/main/events.py295
-rw-r--r--tests/storage/test_events.py52
2 files changed, 199 insertions, 148 deletions
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index ac77492e18..843dc22752 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -66,7 +66,13 @@ from synapse.storage.databases.main.search import SearchEntry
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import AbstractStreamIdGenerator
 from synapse.storage.util.sequence import SequenceGenerator
-from synapse.types import JsonDict, StateMap, StrCollection, get_domain_from_id
+from synapse.types import (
+    JsonDict,
+    MutableStateMap,
+    StateMap,
+    StrCollection,
+    get_domain_from_id,
+)
 from synapse.util import json_encoder
 from synapse.util.iterutils import batch_iter, sorted_topologically
 from synapse.util.stringutils import non_null_str_or_none
@@ -1178,6 +1184,168 @@ class PersistEventsStore:
             if ev_type == EventTypes.Member
         }
 
+        # We now update `sliding_sync_non_join_memberships`.
+        #
+        # This would only happen if someone was state reset out of the room
+        if to_delete:
+            txn.execute_batch(
+                "DELETE FROM sliding_sync_non_join_memberships"
+                " WHERE room_id = ? AND user_id = ?",
+                (
+                    (room_id, state_key)
+                    for event_type, state_key in to_delete
+                    if event_type == EventTypes.Member and self.is_mine_id(state_key)
+                ),
+            )
+
+        # We handle `sliding_sync_non_join_memberships` before `current_state_events` so
+        # we can gather the current state before it might be deleted if we are
+        # `no_longer_in_room`.
+        #
+        # We do this regardless of whether the server is `no_longer_in_room` or not
+        # because we still want a row if a local user was just left/kicked or got banned
+        # from the room.
+        if to_insert:
+            membership_event_id_to_user_id_map: Dict[str, str] = {}
+            for state_key, event_id in to_insert.items():
+                if state_key[0] == EventTypes.Member and self.is_mine_id(state_key[1]):
+                    membership_event_id_to_user_id_map[event_id] = state_key[1]
+
+            if len(membership_event_id_to_user_id_map) > 0:
+                # Map of values to insert/update in the `sliding_sync_non_join_memberships` table
+                sliding_sync_non_joined_rooms_insert_map: Dict[
+                    str, Optional[Union[str, bool]]
+                ] = {}
+
+                relevant_state_set = {
+                    (EventTypes.Create, ""),
+                    (EventTypes.RoomEncryption, ""),
+                    (EventTypes.Name, ""),
+                }
+
+                # Fetch the current state event IDs from the database
+                (
+                    event_type_and_state_key_in_list_clause,
+                    event_type_and_state_key_args,
+                ) = make_tuple_in_list_sql_clause(
+                    self.database_engine,
+                    ("type", "state_key"),
+                    relevant_state_set,
+                )
+                txn.execute(
+                    f"""
+                    SELECT c.event_id, c.type, c.state_key
+                    FROM current_state_events AS c
+                    WHERE
+                        c.room_id = ?
+                        AND {event_type_and_state_key_in_list_clause}
+                    """,
+                    [room_id] + event_type_and_state_key_args,
+                )
+                current_state_map: MutableStateMap[str] = {
+                    (event_type, state_key): event_id
+                    for event_id, event_type, state_key in txn
+                }
+                # Since we fetched the current state before we took `to_insert`/`to_delete`
+                # into account, we need to do a couple fixups.
+                #
+                # Update the current_state_map with what we have `to_delete`
+                for state_key in to_delete:
+                    current_state_map.pop(state_key, None)
+                # Update the current_state_map with what we have `to_insert`
+                for state_key, event_id in to_insert.items():
+                    if state_key in relevant_state_set:
+                        current_state_map[state_key] = event_id
+
+                # Fetch the raw event JSON from the database
+                (
+                    event_id_in_list_clause,
+                    event_id_args,
+                ) = make_in_list_sql_clause(
+                    self.database_engine,
+                    "event_id",
+                    current_state_map.values(),
+                )
+                txn.execute(
+                    f"""
+                    SELECT event_id, type, state_key, json FROM event_json
+                    INNER JOIN events USING (event_id)
+                    WHERE {event_id_in_list_clause}
+                    """,
+                    event_id_args,
+                )
+
+                # Parse the raw event JSON
+                for row in txn:
+                    event_id, event_type, state_key, json = row
+                    event_json = db_to_json(json)
+
+                    if event_type == EventTypes.Create:
+                        room_type = event_json.get("content", {}).get(
+                            EventContentFields.ROOM_TYPE
+                        )
+                        sliding_sync_non_joined_rooms_insert_map["room_type"] = (
+                            room_type
+                        )
+                    elif event_type == EventTypes.RoomEncryption:
+                        encryption_algorithm = event_json.get("content", {}).get(
+                            EventContentFields.ENCRYPTION_ALGORITHM
+                        )
+                        is_encrypted = encryption_algorithm is not None
+                        sliding_sync_non_joined_rooms_insert_map["is_encrypted"] = (
+                            is_encrypted
+                        )
+                    elif event_type == EventTypes.Name:
+                        room_name = event_json.get("content", {}).get(
+                            EventContentFields.ROOM_NAME
+                        )
+                        sliding_sync_non_joined_rooms_insert_map["room_name"] = (
+                            room_name
+                        )
+                    else:
+                        raise AssertionError(
+                            f"Unexpected event (we should not be fetching extra events): ({event_type}, {state_key})"
+                        )
+
+                # Update the `sliding_sync_non_join_memberships` table
+                insert_keys = sliding_sync_non_joined_rooms_insert_map.keys()
+                insert_values = sliding_sync_non_joined_rooms_insert_map.values()
+                # TODO: Only do this for non-join membership
+                txn.execute_batch(
+                    f"""
+                    WITH data_table (room_id, user_id, membership_event_id, membership, event_stream_ordering, {", ".join(insert_keys)}) AS (
+                        VALUES (
+                            ?, ?, ?,
+                            (SELECT membership FROM room_memberships WHERE event_id = ?),
+                            (SELECT stream_ordering FROM events WHERE event_id = ?),
+                            {", ".join("?" for _ in insert_values)}
+                        )
+                    )
+                    INSERT INTO sliding_sync_non_join_memberships
+                        (room_id, user_id, membership_event_id, membership, event_stream_ordering, {", ".join(insert_keys)})
+                    SELECT * FROM data_table
+                    WHERE membership != ?
+                    ON CONFLICT (room_id, user_id)
+                    DO UPDATE SET
+                        membership_event_id = EXCLUDED.membership_event_id,
+                        membership = EXCLUDED.membership,
+                        event_stream_ordering = EXCLUDED.event_stream_ordering,
+                        {", ".join(f"{key} = EXCLUDED.{key}" for key in insert_keys)}
+                    """,
+                    [
+                        [
+                            room_id,
+                            user_id,
+                            membership_event_id,
+                            membership_event_id,
+                            membership_event_id,
+                        ]
+                        + list(insert_values)
+                        + [Membership.JOIN]
+                        for membership_event_id, user_id in membership_event_id_to_user_id_map.items()
+                    ],
+                )
+
         if delta_state.no_longer_in_room:
             # Server is no longer in the room so we delete the room from
             # current_state_events, being careful we've already updated the
@@ -1403,131 +1571,6 @@ class PersistEventsStore:
                 ],
             )
 
-        # We now update `sliding_sync_non_join_memberships`. We do this regardless of
-        # whether the server is still in the room or not because we still want a row if
-        # a local user was just left/kicked or got banned from the room.
-        if to_delete:
-            txn.execute_batch(
-                "DELETE FROM sliding_sync_non_join_memberships"
-                " WHERE room_id = ? AND user_id = ?",
-                (
-                    (room_id, state_key)
-                    for event_type, state_key in to_delete
-                    if event_type == EventTypes.Member and self.is_mine_id(state_key)
-                ),
-            )
-
-        if to_insert:
-            membership_event_id_to_user_id_map: Dict[str, str] = {}
-            for state_key, event_id in to_insert.items():
-                if state_key[0] == EventTypes.Member and self.is_mine_id(state_key[1]):
-                    membership_event_id_to_user_id_map[event_id] = state_key[1]
-
-            if len(membership_event_id_to_user_id_map) > 0:
-                # Map of values to insert/update in the `sliding_sync_non_join_memberships` table
-                sliding_sync_non_joined_rooms_insert_map: Dict[
-                    str, Optional[Union[str, bool]]
-                ] = {}
-
-                # Fetch the events from the database
-                #
-                # TODO: We should gather this data before we delete the
-                # `current_state_events` in a `no_longer_in_room` situation.
-                (
-                    event_type_and_state_key_in_list_clause,
-                    event_type_and_state_key_args,
-                ) = make_tuple_in_list_sql_clause(
-                    self.database_engine,
-                    ("type", "state_key"),
-                    [
-                        (EventTypes.Create, ""),
-                        (EventTypes.RoomEncryption, ""),
-                        (EventTypes.Name, ""),
-                    ],
-                )
-                txn.execute(
-                    f"""
-                    SELECT c.event_id, c.type, c.state_key, j.json
-                    FROM current_state_events AS c
-                    INNER JOIN event_json AS j USING (event_id)
-                    WHERE
-                        c.room_id = ?
-                        AND {event_type_and_state_key_in_list_clause}
-                    """,
-                    [room_id] + event_type_and_state_key_args,
-                )
-
-                # Parse the raw event JSON
-                for row in txn:
-                    event_id, event_type, state_key, json = row
-                    event_json = db_to_json(json)
-
-                    if event_type == EventTypes.Create:
-                        room_type = event_json.get("content", {}).get(
-                            EventContentFields.ROOM_TYPE
-                        )
-                        sliding_sync_non_joined_rooms_insert_map["room_type"] = (
-                            room_type
-                        )
-                    elif event_type == EventTypes.RoomEncryption:
-                        encryption_algorithm = event_json.get("content", {}).get(
-                            EventContentFields.ENCRYPTION_ALGORITHM
-                        )
-                        is_encrypted = encryption_algorithm is not None
-                        sliding_sync_non_joined_rooms_insert_map["is_encrypted"] = (
-                            is_encrypted
-                        )
-                    elif event_type == EventTypes.Name:
-                        room_name = event_json.get("content", {}).get(
-                            EventContentFields.ROOM_NAME
-                        )
-                        sliding_sync_non_joined_rooms_insert_map["room_name"] = (
-                            room_name
-                        )
-                    else:
-                        raise AssertionError(
-                            f"Unexpected event (we should not be fetching extra events): ({event_type}, {state_key})"
-                        )
-
-                # Update the `sliding_sync_non_join_memberships` table
-                insert_keys = sliding_sync_non_joined_rooms_insert_map.keys()
-                insert_values = sliding_sync_non_joined_rooms_insert_map.values()
-                # TODO: Only do this for non-join membership
-                txn.execute_batch(
-                    f"""
-                    WITH data_table (room_id, user_id, membership_event_id, membership, event_stream_ordering, {", ".join(insert_keys)}) AS (
-                        VALUES (
-                            ?, ?, ?,
-                            (SELECT membership FROM room_memberships WHERE event_id = ?),
-                            (SELECT stream_ordering FROM events WHERE event_id = ?),
-                            {", ".join("?" for _ in insert_values)}
-                        )
-                    )
-                    INSERT INTO sliding_sync_non_join_memberships
-                        (room_id, user_id, membership_event_id, membership, event_stream_ordering, {", ".join(insert_keys)})
-                    SELECT * FROM data_table
-                    WHERE membership != ?
-                    ON CONFLICT (room_id, user_id)
-                    DO UPDATE SET
-                        membership_event_id = EXCLUDED.membership_event_id,
-                        membership = EXCLUDED.membership,
-                        event_stream_ordering = EXCLUDED.event_stream_ordering,
-                        {", ".join(f"{key} = EXCLUDED.{key}" for key in insert_keys)}
-                    """,
-                    [
-                        [
-                            room_id,
-                            user_id,
-                            membership_event_id,
-                            membership_event_id,
-                            membership_event_id,
-                        ]
-                        + list(insert_values)
-                        + [Membership.JOIN]
-                        for membership_event_id, user_id in membership_event_id_to_user_id_map.items()
-                    ],
-                )
-
         txn.call_after(
             self.store._curr_state_delta_stream_cache.entity_has_changed,
             room_id,
diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py
index ea388458b6..2fb863a8c9 100644
--- a/tests/storage/test_events.py
+++ b/tests/storage/test_events.py
@@ -863,7 +863,7 @@ class SlidingSyncPrePopulatedTablesTestCase(HomeserverTestCase):
         Test users who was invited shows up in `sliding_sync_non_join_memberships`.
         """
         user1_id = self.register_user("user1", "pass")
-        user1_tok = self.login(user1_id, "pass")
+        _user1_tok = self.login(user1_id, "pass")
         user2_id = self.register_user("user2", "pass")
         user2_tok = self.login(user2_id, "pass")
 
@@ -958,7 +958,7 @@ class SlidingSyncPrePopulatedTablesTestCase(HomeserverTestCase):
         `sliding_sync_non_join_memberships`.
         """
         user1_id = self.register_user("user1", "pass")
-        user1_tok = self.login(user1_id, "pass")
+        _user1_tok = self.login(user1_id, "pass")
         user2_id = self.register_user("user2", "pass")
         user2_tok = self.login(user2_id, "pass")
         user3_id = self.register_user("user3", "pass")
@@ -1088,28 +1088,36 @@ class SlidingSyncPrePopulatedTablesTestCase(HomeserverTestCase):
         self.assertIncludes(
             set(sliding_sync_non_join_memberships_results.keys()),
             {
-                _SlidingSyncNonJoinMembershipResult(
-                    room_id=room_id1,
-                    user_id=user1_id,
-                    membership_event_id=user1_leave_response["event_id"],
-                    membership=Membership.LEAVE,
-                    event_stream_ordering=user1_leave_event_pos.stream,
-                    room_type=None,
-                    room_name=None,
-                    is_encrypted=False,
-                ),
-                _SlidingSyncNonJoinMembershipResult(
-                    room_id=room_id1,
-                    user_id=user2_id,
-                    membership_event_id=user2_leave_response["event_id"],
-                    membership=Membership.LEAVE,
-                    event_stream_ordering=user2_leave_event_pos.stream,
-                    room_type=None,
-                    room_name=None,
-                    is_encrypted=False,
-                ),
+                (room_id1, user1_id),
+                (room_id1, user2_id),
             },
             exact=True,
         )
+        self.assertEqual(
+            sliding_sync_non_join_memberships_results.get((room_id1, user1_id)),
+            _SlidingSyncNonJoinMembershipResult(
+                room_id=room_id1,
+                user_id=user1_id,
+                membership_event_id=user1_leave_response["event_id"],
+                membership=Membership.LEAVE,
+                event_stream_ordering=user1_leave_event_pos.stream,
+                room_type=None,
+                room_name=None,
+                is_encrypted=False,
+            ),
+        )
+        self.assertEqual(
+            sliding_sync_non_join_memberships_results.get((room_id1, user2_id)),
+            _SlidingSyncNonJoinMembershipResult(
+                room_id=room_id1,
+                user_id=user2_id,
+                membership_event_id=user2_leave_response["event_id"],
+                membership=Membership.LEAVE,
+                event_stream_ordering=user2_leave_event_pos.stream,
+                room_type=None,
+                room_name=None,
+                is_encrypted=False,
+            ),
+        )
 
     # TODO: test_non_join_state_reset