summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorErik Johnston <erikj@element.io>2024-06-24 15:40:28 +0100
committerGitHub <noreply@github.com>2024-06-24 14:40:28 +0000
commit930a64b6c1a4fe096d541bf9c5f0279fb636ed16 (patch)
tree083aad3888cdf294f68103b95d8070f31e1cffbf /synapse
parentAdd support for MSC3823 - Account Suspension Part 2 (#17255) (diff)
downloadsynapse-930a64b6c1a4fe096d541bf9c5f0279fb636ed16.tar.xz
Reintroduce #17291. (#17338)
This is #17291 (which got reverted), with some added fixups, and change
so that tests actually pick up the error.

The problem was that we were not calculating any new chain IDs due to a
missing `not` in a condition.
Diffstat (limited to 'synapse')
-rw-r--r--synapse/storage/controllers/persist_events.py12
-rw-r--r--synapse/storage/databases/main/event_federation.py20
-rw-r--r--synapse/storage/databases/main/events.py251
3 files changed, 206 insertions, 77 deletions
diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py
index 84699a2ee1..d0e015bf19 100644
--- a/synapse/storage/controllers/persist_events.py
+++ b/synapse/storage/controllers/persist_events.py
@@ -617,6 +617,17 @@ class EventsPersistenceStorageController:
                         room_id, chunk
                     )
 
+            with Measure(self._clock, "calculate_chain_cover_index_for_events"):
+                # We now calculate chain ID/sequence numbers for any state events we're
+                # persisting. We ignore out of band memberships as we're not in the room
+                # and won't have their auth chain (we'll fix it up later if we join the
+                # room).
+                #
+                # See: docs/auth_chain_difference_algorithm.md
+                new_event_links = await self.persist_events_store.calculate_chain_cover_index_for_events(
+                    room_id, [e for e, _ in chunk]
+                )
+
             await self.persist_events_store._persist_events_and_state_updates(
                 room_id,
                 chunk,
@@ -624,6 +635,7 @@ class EventsPersistenceStorageController:
                 new_forward_extremities=new_forward_extremities,
                 use_negative_stream_ordering=backfilled,
                 inhibit_local_membership_updates=backfilled,
+                new_event_links=new_event_links,
             )
 
         return replaced_events
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index fb132ef090..24abab4a23 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -148,6 +148,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             500000, "_event_auth_cache", size_callback=len
         )
 
+        # Flag used by unit tests to disable fallback when there is no chain cover
+        # index.
+        self.tests_allow_no_chain_cover_index = True
+
         self._clock.looping_call(self._get_stats_for_federation_staging, 30 * 1000)
 
         if isinstance(self.database_engine, PostgresEngine):
@@ -220,8 +224,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
                 )
             except _NoChainCoverIndex:
                 # For whatever reason we don't actually have a chain cover index
-                # for the events in question, so we fall back to the old method.
-                pass
+                # for the events in question, so we fall back to the old method
+                # (except in tests)
+                if not self.tests_allow_no_chain_cover_index:
+                    raise
 
         return await self.db_pool.runInteraction(
             "get_auth_chain_ids",
@@ -271,7 +277,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         if events_missing_chain_info:
             # This can happen due to e.g. downgrade/upgrade of the server. We
             # raise an exception and fall back to the previous algorithm.
-            logger.info(
+            logger.error(
                 "Unexpectedly found that events don't have chain IDs in room %s: %s",
                 room_id,
                 events_missing_chain_info,
@@ -482,8 +488,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
                 )
             except _NoChainCoverIndex:
                 # For whatever reason we don't actually have a chain cover index
-                # for the events in question, so we fall back to the old method.
-                pass
+                # for the events in question, so we fall back to the old method
+                # (except in tests)
+                if not self.tests_allow_no_chain_cover_index:
+                    raise
 
         return await self.db_pool.runInteraction(
             "get_auth_chain_difference",
@@ -710,7 +718,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         if events_missing_chain_info - event_to_auth_ids.keys():
             # Uh oh, we somehow haven't correctly done the chain cover index,
             # bail and fall back to the old method.
-            logger.info(
+            logger.error(
                 "Unexpectedly found that events don't have chain IDs in room %s: %s",
                 room_id,
                 events_missing_chain_info - event_to_auth_ids.keys(),
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 66428e6c8e..1f7acdb859 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -34,7 +34,6 @@ from typing import (
     Optional,
     Set,
     Tuple,
-    Union,
     cast,
 )
 
@@ -100,6 +99,23 @@ class DeltaState:
         return not self.to_delete and not self.to_insert and not self.no_longer_in_room
 
 
+@attr.s(slots=True, auto_attribs=True)
+class NewEventChainLinks:
+    """Information about new auth chain links that need to be added to the DB.
+
+    Attributes:
+        chain_id, sequence_number: the IDs corresponding to the event being
+            inserted, and the starting point of the links
+        links: Lists the links that need to be added, 2-tuple of the chain
+            ID/sequence number of the end point of the link.
+    """
+
+    chain_id: int
+    sequence_number: int
+
+    links: List[Tuple[int, int]] = attr.Factory(list)
+
+
 class PersistEventsStore:
     """Contains all the functions for writing events to the database.
 
@@ -148,6 +164,7 @@ class PersistEventsStore:
         *,
         state_delta_for_room: Optional[DeltaState],
         new_forward_extremities: Optional[Set[str]],
+        new_event_links: Dict[str, NewEventChainLinks],
         use_negative_stream_ordering: bool = False,
         inhibit_local_membership_updates: bool = False,
     ) -> None:
@@ -217,6 +234,7 @@ class PersistEventsStore:
                 inhibit_local_membership_updates=inhibit_local_membership_updates,
                 state_delta_for_room=state_delta_for_room,
                 new_forward_extremities=new_forward_extremities,
+                new_event_links=new_event_links,
             )
             persist_event_counter.inc(len(events_and_contexts))
 
@@ -243,6 +261,87 @@ class PersistEventsStore:
                     (room_id,), frozenset(new_forward_extremities)
                 )
 
+    async def calculate_chain_cover_index_for_events(
+        self, room_id: str, events: Collection[EventBase]
+    ) -> Dict[str, NewEventChainLinks]:
+        # Filter to state events, and ensure there are no duplicates.
+        state_events = []
+        seen_events = set()
+        for event in events:
+            if not event.is_state() or event.event_id in seen_events:
+                continue
+
+            state_events.append(event)
+            seen_events.add(event.event_id)
+
+        if not state_events:
+            return {}
+
+        return await self.db_pool.runInteraction(
+            "_calculate_chain_cover_index_for_events",
+            self.calculate_chain_cover_index_for_events_txn,
+            room_id,
+            state_events,
+        )
+
+    def calculate_chain_cover_index_for_events_txn(
+        self, txn: LoggingTransaction, room_id: str, state_events: Collection[EventBase]
+    ) -> Dict[str, NewEventChainLinks]:
+        # We now calculate chain ID/sequence numbers for any state events we're
+        # persisting. We ignore out of band memberships as we're not in the room
+        # and won't have their auth chain (we'll fix it up later if we join the
+        # room).
+        #
+        # See: docs/auth_chain_difference_algorithm.md
+
+        # We ignore legacy rooms that we aren't filling the chain cover index
+        # for.
+        row = self.db_pool.simple_select_one_txn(
+            txn,
+            table="rooms",
+            keyvalues={"room_id": room_id},
+            retcols=("room_id", "has_auth_chain_index"),
+            allow_none=True,
+        )
+        if row is None or row[1] is False:
+            return {}
+
+        # Filter out events that we've already calculated.
+        rows = self.db_pool.simple_select_many_txn(
+            txn,
+            table="event_auth_chains",
+            column="event_id",
+            iterable=[e.event_id for e in state_events],
+            keyvalues={},
+            retcols=("event_id",),
+        )
+        already_persisted_events = {event_id for event_id, in rows}
+        state_events = [
+            event
+            for event in state_events
+            if event.event_id not in already_persisted_events
+        ]
+
+        if not state_events:
+            return {}
+
+        # We need to know the type/state_key and auth events of the events we're
+        # calculating chain IDs for. We don't rely on having the full Event
+        # instances as we'll potentially be pulling more events from the DB and
+        # we don't need the overhead of fetching/parsing the full event JSON.
+        event_to_types = {e.event_id: (e.type, e.state_key) for e in state_events}
+        event_to_auth_chain = {e.event_id: e.auth_event_ids() for e in state_events}
+        event_to_room_id = {e.event_id: e.room_id for e in state_events}
+
+        return self._calculate_chain_cover_index(
+            txn,
+            self.db_pool,
+            self.store.event_chain_id_gen,
+            event_to_room_id,
+            event_to_types,
+            event_to_auth_chain,
+        )
+
     async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]:
         """Filter the supplied list of event_ids to get those which are prev_events of
         existing (non-outlier/rejected) events.
@@ -358,6 +457,7 @@ class PersistEventsStore:
         inhibit_local_membership_updates: bool,
         state_delta_for_room: Optional[DeltaState],
         new_forward_extremities: Optional[Set[str]],
+        new_event_links: Dict[str, NewEventChainLinks],
     ) -> None:
         """Insert some number of room events into the necessary database tables.
 
@@ -466,7 +566,9 @@ class PersistEventsStore:
         # Insert into event_to_state_groups.
         self._store_event_state_mappings_txn(txn, events_and_contexts)
 
-        self._persist_event_auth_chain_txn(txn, [e for e, _ in events_and_contexts])
+        self._persist_event_auth_chain_txn(
+            txn, [e for e, _ in events_and_contexts], new_event_links
+        )
 
         # _store_rejected_events_txn filters out any events which were
         # rejected, and returns the filtered list.
@@ -496,7 +598,11 @@ class PersistEventsStore:
         self,
         txn: LoggingTransaction,
         events: List[EventBase],
+        new_event_links: Dict[str, NewEventChainLinks],
     ) -> None:
+        if new_event_links:
+            self._persist_chain_cover_index(txn, self.db_pool, new_event_links)
+
         # We only care about state events, so this if there are no state events.
         if not any(e.is_state() for e in events):
             return
@@ -519,62 +625,37 @@ class PersistEventsStore:
             ],
         )
 
-        # We now calculate chain ID/sequence numbers for any state events we're
-        # persisting. We ignore out of band memberships as we're not in the room
-        # and won't have their auth chain (we'll fix it up later if we join the
-        # room).
-        #
-        # See: docs/auth_chain_difference_algorithm.md
-
-        # We ignore legacy rooms that we aren't filling the chain cover index
-        # for.
-        rows = cast(
-            List[Tuple[str, Optional[Union[int, bool]]]],
-            self.db_pool.simple_select_many_txn(
-                txn,
-                table="rooms",
-                column="room_id",
-                iterable={event.room_id for event in events if event.is_state()},
-                keyvalues={},
-                retcols=("room_id", "has_auth_chain_index"),
-            ),
-        )
-        rooms_using_chain_index = {
-            room_id for room_id, has_auth_chain_index in rows if has_auth_chain_index
-        }
-
-        state_events = {
-            event.event_id: event
-            for event in events
-            if event.is_state() and event.room_id in rooms_using_chain_index
-        }
-
-        if not state_events:
-            return
+    @classmethod
+    def _add_chain_cover_index(
+        cls,
+        txn: LoggingTransaction,
+        db_pool: DatabasePool,
+        event_chain_id_gen: SequenceGenerator,
+        event_to_room_id: Dict[str, str],
+        event_to_types: Dict[str, Tuple[str, str]],
+        event_to_auth_chain: Dict[str, StrCollection],
+    ) -> None:
+        """Calculate and persist the chain cover index for the given events.
 
-        # We need to know the type/state_key and auth events of the events we're
-        # calculating chain IDs for. We don't rely on having the full Event
-        # instances as we'll potentially be pulling more events from the DB and
-        # we don't need the overhead of fetching/parsing the full event JSON.
-        event_to_types = {
-            e.event_id: (e.type, e.state_key) for e in state_events.values()
-        }
-        event_to_auth_chain = {
-            e.event_id: e.auth_event_ids() for e in state_events.values()
-        }
-        event_to_room_id = {e.event_id: e.room_id for e in state_events.values()}
+        Args:
+            event_to_room_id: Event ID to the room ID of the event
+            event_to_types: Event ID to type and state_key of the event
+            event_to_auth_chain: Event ID to list of auth event IDs of the
+                event (events with no auth events can be excluded).
+        """
 
-        self._add_chain_cover_index(
+        new_event_links = cls._calculate_chain_cover_index(
             txn,
-            self.db_pool,
-            self.store.event_chain_id_gen,
+            db_pool,
+            event_chain_id_gen,
             event_to_room_id,
             event_to_types,
             event_to_auth_chain,
         )
+        cls._persist_chain_cover_index(txn, db_pool, new_event_links)
 
     @classmethod
-    def _add_chain_cover_index(
+    def _calculate_chain_cover_index(
         cls,
         txn: LoggingTransaction,
         db_pool: DatabasePool,
@@ -582,7 +663,7 @@ class PersistEventsStore:
         event_to_room_id: Dict[str, str],
         event_to_types: Dict[str, Tuple[str, str]],
         event_to_auth_chain: Dict[str, StrCollection],
-    ) -> None:
+    ) -> Dict[str, NewEventChainLinks]:
         """Calculate the chain cover index for the given events.
 
         Args:
@@ -590,6 +671,10 @@ class PersistEventsStore:
             event_to_types: Event ID to type and state_key of the event
             event_to_auth_chain: Event ID to list of auth event IDs of the
                 event (events with no auth events can be excluded).
+
+        Returns:
+            A mapping with any new auth chain links we need to add, keyed by
+            event ID.
         """
 
         # Map from event ID to chain ID/sequence number.
@@ -708,11 +793,11 @@ class PersistEventsStore:
                     room_id = event_to_room_id.get(event_id)
                     if room_id:
                         e_type, state_key = event_to_types[event_id]
-                        db_pool.simple_insert_txn(
+                        db_pool.simple_upsert_txn(
                             txn,
                             table="event_auth_chain_to_calculate",
+                            keyvalues={"event_id": event_id},
                             values={
-                                "event_id": event_id,
                                 "room_id": room_id,
                                 "type": e_type,
                                 "state_key": state_key,
@@ -724,7 +809,7 @@ class PersistEventsStore:
                     break
 
         if not events_to_calc_chain_id_for:
-            return
+            return {}
 
         # Allocate chain ID/sequence numbers to each new event.
         new_chain_tuples = cls._allocate_chain_ids(
@@ -739,23 +824,10 @@ class PersistEventsStore:
         )
         chain_map.update(new_chain_tuples)
 
-        db_pool.simple_insert_many_txn(
-            txn,
-            table="event_auth_chains",
-            keys=("event_id", "chain_id", "sequence_number"),
-            values=[
-                (event_id, c_id, seq)
-                for event_id, (c_id, seq) in new_chain_tuples.items()
-            ],
-        )
-
-        db_pool.simple_delete_many_txn(
-            txn,
-            table="event_auth_chain_to_calculate",
-            keyvalues={},
-            column="event_id",
-            values=new_chain_tuples,
-        )
+        to_return = {
+            event_id: NewEventChainLinks(chain_id, sequence_number)
+            for event_id, (chain_id, sequence_number) in new_chain_tuples.items()
+        }
 
         # Now we need to calculate any new links between chains caused by
         # the new events.
@@ -825,10 +897,38 @@ class PersistEventsStore:
                 auth_chain_id, auth_sequence_number = chain_map[auth_id]
 
                 # Step 2a, add link between the event and auth event
+                to_return[event_id].links.append((auth_chain_id, auth_sequence_number))
                 chain_links.add_link(
                     (chain_id, sequence_number), (auth_chain_id, auth_sequence_number)
                 )
 
+        return to_return
+
+    @classmethod
+    def _persist_chain_cover_index(
+        cls,
+        txn: LoggingTransaction,
+        db_pool: DatabasePool,
+        new_event_links: Dict[str, NewEventChainLinks],
+    ) -> None:
+        db_pool.simple_insert_many_txn(
+            txn,
+            table="event_auth_chains",
+            keys=("event_id", "chain_id", "sequence_number"),
+            values=[
+                (event_id, new_links.chain_id, new_links.sequence_number)
+                for event_id, new_links in new_event_links.items()
+            ],
+        )
+
+        db_pool.simple_delete_many_txn(
+            txn,
+            table="event_auth_chain_to_calculate",
+            keyvalues={},
+            column="event_id",
+            values=new_event_links,
+        )
+
         db_pool.simple_insert_many_txn(
             txn,
             table="event_auth_chain_links",
@@ -838,7 +938,16 @@ class PersistEventsStore:
                 "target_chain_id",
                 "target_sequence_number",
             ),
-            values=list(chain_links.get_additions()),
+            values=[
+                (
+                    new_links.chain_id,
+                    new_links.sequence_number,
+                    target_chain_id,
+                    target_sequence_number,
+                )
+                for new_links in new_event_links.values()
+                for (target_chain_id, target_sequence_number) in new_links.links
+            ],
         )
 
     @staticmethod