summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/17291.misc1
-rw-r--r--synapse/storage/controllers/persist_events.py12
-rw-r--r--synapse/storage/databases/main/events.py249
-rw-r--r--tests/storage/test_event_chain.py9
-rw-r--r--tests/storage/test_event_federation.py41
5 files changed, 82 insertions, 230 deletions
diff --git a/changelog.d/17291.misc b/changelog.d/17291.misc
deleted file mode 100644
index b1f89a324d..0000000000
--- a/changelog.d/17291.misc
+++ /dev/null
@@ -1 +0,0 @@
-Do not block event sending/receiving while calulating large event auth chains.
diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py
index d0e015bf19..84699a2ee1 100644
--- a/synapse/storage/controllers/persist_events.py
+++ b/synapse/storage/controllers/persist_events.py
@@ -617,17 +617,6 @@ class EventsPersistenceStorageController:
                         room_id, chunk
                     )
 
-            with Measure(self._clock, "calculate_chain_cover_index_for_events"):
-                # We now calculate chain ID/sequence numbers for any state events we're
-                # persisting. We ignore out of band memberships as we're not in the room
-                # and won't have their auth chain (we'll fix it up later if we join the
-                # room).
-                #
-                # See: docs/auth_chain_difference_algorithm.md
-                new_event_links = await self.persist_events_store.calculate_chain_cover_index_for_events(
-                    room_id, [e for e, _ in chunk]
-                )
-
             await self.persist_events_store._persist_events_and_state_updates(
                 room_id,
                 chunk,
@@ -635,7 +624,6 @@ class EventsPersistenceStorageController:
                 new_forward_extremities=new_forward_extremities,
                 use_negative_stream_ordering=backfilled,
                 inhibit_local_membership_updates=backfilled,
-                new_event_links=new_event_links,
             )
 
         return replaced_events
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index c6df13c064..66428e6c8e 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -34,6 +34,7 @@ from typing import (
     Optional,
     Set,
     Tuple,
+    Union,
     cast,
 )
 
@@ -99,23 +100,6 @@ class DeltaState:
         return not self.to_delete and not self.to_insert and not self.no_longer_in_room
 
 
-@attr.s(slots=True, auto_attribs=True)
-class NewEventChainLinks:
-    """Information about new auth chain links that need to be added to the DB.
-
-    Attributes:
-        chain_id, sequence_number: the IDs corresponding to the event being
-            inserted, and the starting point of the links
-        links: Lists the links that need to be added, 2-tuple of the chain
-            ID/sequence number of the end point of the link.
-    """
-
-    chain_id: int
-    sequence_number: int
-
-    links: List[Tuple[int, int]] = attr.Factory(list)
-
-
 class PersistEventsStore:
     """Contains all the functions for writing events to the database.
 
@@ -164,7 +148,6 @@ class PersistEventsStore:
         *,
         state_delta_for_room: Optional[DeltaState],
         new_forward_extremities: Optional[Set[str]],
-        new_event_links: Dict[str, NewEventChainLinks],
         use_negative_stream_ordering: bool = False,
         inhibit_local_membership_updates: bool = False,
     ) -> None:
@@ -234,7 +217,6 @@ class PersistEventsStore:
                 inhibit_local_membership_updates=inhibit_local_membership_updates,
                 state_delta_for_room=state_delta_for_room,
                 new_forward_extremities=new_forward_extremities,
-                new_event_links=new_event_links,
             )
             persist_event_counter.inc(len(events_and_contexts))
 
@@ -261,87 +243,6 @@ class PersistEventsStore:
                     (room_id,), frozenset(new_forward_extremities)
                 )
 
-    async def calculate_chain_cover_index_for_events(
-        self, room_id: str, events: Collection[EventBase]
-    ) -> Dict[str, NewEventChainLinks]:
-        # Filter to state events, and ensure there are no duplicates.
-        state_events = []
-        seen_events = set()
-        for event in events:
-            if not event.is_state() or event.event_id in seen_events:
-                continue
-
-            state_events.append(event)
-            seen_events.add(event.event_id)
-
-        if not state_events:
-            return {}
-
-        return await self.db_pool.runInteraction(
-            "_calculate_chain_cover_index_for_events",
-            self.calculate_chain_cover_index_for_events_txn,
-            room_id,
-            state_events,
-        )
-
-    def calculate_chain_cover_index_for_events_txn(
-        self, txn: LoggingTransaction, room_id: str, state_events: Collection[EventBase]
-    ) -> Dict[str, NewEventChainLinks]:
-        # We now calculate chain ID/sequence numbers for any state events we're
-        # persisting. We ignore out of band memberships as we're not in the room
-        # and won't have their auth chain (we'll fix it up later if we join the
-        # room).
-        #
-        # See: docs/auth_chain_difference_algorithm.md
-
-        # We ignore legacy rooms that we aren't filling the chain cover index
-        # for.
-        row = self.db_pool.simple_select_one_txn(
-            txn,
-            table="rooms",
-            keyvalues={"room_id": room_id},
-            retcols=("room_id", "has_auth_chain_index"),
-            allow_none=True,
-        )
-        if row is None:
-            return {}
-
-        # Filter out already persisted events.
-        rows = self.db_pool.simple_select_many_txn(
-            txn,
-            table="events",
-            column="event_id",
-            iterable=[e.event_id for e in state_events],
-            keyvalues={},
-            retcols=("event_id",),
-        )
-        already_persisted_events = {event_id for event_id, in rows}
-        state_events = [
-            event
-            for event in state_events
-            if event.event_id in already_persisted_events
-        ]
-
-        if not state_events:
-            return {}
-
-        # We need to know the type/state_key and auth events of the events we're
-        # calculating chain IDs for. We don't rely on having the full Event
-        # instances as we'll potentially be pulling more events from the DB and
-        # we don't need the overhead of fetching/parsing the full event JSON.
-        event_to_types = {e.event_id: (e.type, e.state_key) for e in state_events}
-        event_to_auth_chain = {e.event_id: e.auth_event_ids() for e in state_events}
-        event_to_room_id = {e.event_id: e.room_id for e in state_events}
-
-        return self._calculate_chain_cover_index(
-            txn,
-            self.db_pool,
-            self.store.event_chain_id_gen,
-            event_to_room_id,
-            event_to_types,
-            event_to_auth_chain,
-        )
-
     async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]:
         """Filter the supplied list of event_ids to get those which are prev_events of
         existing (non-outlier/rejected) events.
@@ -457,7 +358,6 @@ class PersistEventsStore:
         inhibit_local_membership_updates: bool,
         state_delta_for_room: Optional[DeltaState],
         new_forward_extremities: Optional[Set[str]],
-        new_event_links: Dict[str, NewEventChainLinks],
     ) -> None:
         """Insert some number of room events into the necessary database tables.
 
@@ -566,9 +466,7 @@ class PersistEventsStore:
         # Insert into event_to_state_groups.
         self._store_event_state_mappings_txn(txn, events_and_contexts)
 
-        self._persist_event_auth_chain_txn(
-            txn, [e for e, _ in events_and_contexts], new_event_links
-        )
+        self._persist_event_auth_chain_txn(txn, [e for e, _ in events_and_contexts])
 
         # _store_rejected_events_txn filters out any events which were
         # rejected, and returns the filtered list.
@@ -598,7 +496,6 @@ class PersistEventsStore:
         self,
         txn: LoggingTransaction,
         events: List[EventBase],
-        new_event_links: Dict[str, NewEventChainLinks],
     ) -> None:
         # We only care about state events, so this if there are no state events.
         if not any(e.is_state() for e in events):
@@ -622,40 +519,62 @@ class PersistEventsStore:
             ],
         )
 
-        if new_event_links:
-            self._persist_chain_cover_index(txn, self.db_pool, new_event_links)
+        # We now calculate chain ID/sequence numbers for any state events we're
+        # persisting. We ignore out of band memberships as we're not in the room
+        # and won't have their auth chain (we'll fix it up later if we join the
+        # room).
+        #
+        # See: docs/auth_chain_difference_algorithm.md
 
-    @classmethod
-    def _add_chain_cover_index(
-        cls,
-        txn: LoggingTransaction,
-        db_pool: DatabasePool,
-        event_chain_id_gen: SequenceGenerator,
-        event_to_room_id: Dict[str, str],
-        event_to_types: Dict[str, Tuple[str, str]],
-        event_to_auth_chain: Dict[str, StrCollection],
-    ) -> None:
-        """Calculate and persist the chain cover index for the given events.
+        # We ignore legacy rooms that we aren't filling the chain cover index
+        # for.
+        rows = cast(
+            List[Tuple[str, Optional[Union[int, bool]]]],
+            self.db_pool.simple_select_many_txn(
+                txn,
+                table="rooms",
+                column="room_id",
+                iterable={event.room_id for event in events if event.is_state()},
+                keyvalues={},
+                retcols=("room_id", "has_auth_chain_index"),
+            ),
+        )
+        rooms_using_chain_index = {
+            room_id for room_id, has_auth_chain_index in rows if has_auth_chain_index
+        }
 
-        Args:
-            event_to_room_id: Event ID to the room ID of the event
-            event_to_types: Event ID to type and state_key of the event
-            event_to_auth_chain: Event ID to list of auth event IDs of the
-                event (events with no auth events can be excluded).
-        """
+        state_events = {
+            event.event_id: event
+            for event in events
+            if event.is_state() and event.room_id in rooms_using_chain_index
+        }
+
+        if not state_events:
+            return
+
+        # We need to know the type/state_key and auth events of the events we're
+        # calculating chain IDs for. We don't rely on having the full Event
+        # instances as we'll potentially be pulling more events from the DB and
+        # we don't need the overhead of fetching/parsing the full event JSON.
+        event_to_types = {
+            e.event_id: (e.type, e.state_key) for e in state_events.values()
+        }
+        event_to_auth_chain = {
+            e.event_id: e.auth_event_ids() for e in state_events.values()
+        }
+        event_to_room_id = {e.event_id: e.room_id for e in state_events.values()}
 
-        new_event_links = cls._calculate_chain_cover_index(
+        self._add_chain_cover_index(
             txn,
-            db_pool,
-            event_chain_id_gen,
+            self.db_pool,
+            self.store.event_chain_id_gen,
             event_to_room_id,
             event_to_types,
             event_to_auth_chain,
         )
-        cls._persist_chain_cover_index(txn, db_pool, new_event_links)
 
     @classmethod
-    def _calculate_chain_cover_index(
+    def _add_chain_cover_index(
         cls,
         txn: LoggingTransaction,
         db_pool: DatabasePool,
@@ -663,7 +582,7 @@ class PersistEventsStore:
         event_to_room_id: Dict[str, str],
         event_to_types: Dict[str, Tuple[str, str]],
         event_to_auth_chain: Dict[str, StrCollection],
-    ) -> Dict[str, NewEventChainLinks]:
+    ) -> None:
         """Calculate the chain cover index for the given events.
 
         Args:
@@ -671,10 +590,6 @@ class PersistEventsStore:
             event_to_types: Event ID to type and state_key of the event
             event_to_auth_chain: Event ID to list of auth event IDs of the
                 event (events with no auth events can be excluded).
-
-        Returns:
-            A mapping with any new auth chain links we need to add, keyed by
-            event ID.
         """
 
         # Map from event ID to chain ID/sequence number.
@@ -793,11 +708,11 @@ class PersistEventsStore:
                     room_id = event_to_room_id.get(event_id)
                     if room_id:
                         e_type, state_key = event_to_types[event_id]
-                        db_pool.simple_upsert_txn(
+                        db_pool.simple_insert_txn(
                             txn,
                             table="event_auth_chain_to_calculate",
-                            keyvalues={"event_id": event_id},
                             values={
+                                "event_id": event_id,
                                 "room_id": room_id,
                                 "type": e_type,
                                 "state_key": state_key,
@@ -809,7 +724,7 @@ class PersistEventsStore:
                     break
 
         if not events_to_calc_chain_id_for:
-            return {}
+            return
 
         # Allocate chain ID/sequence numbers to each new event.
         new_chain_tuples = cls._allocate_chain_ids(
@@ -824,10 +739,23 @@ class PersistEventsStore:
         )
         chain_map.update(new_chain_tuples)
 
-        to_return = {
-            event_id: NewEventChainLinks(chain_id, sequence_number)
-            for event_id, (chain_id, sequence_number) in new_chain_tuples.items()
-        }
+        db_pool.simple_insert_many_txn(
+            txn,
+            table="event_auth_chains",
+            keys=("event_id", "chain_id", "sequence_number"),
+            values=[
+                (event_id, c_id, seq)
+                for event_id, (c_id, seq) in new_chain_tuples.items()
+            ],
+        )
+
+        db_pool.simple_delete_many_txn(
+            txn,
+            table="event_auth_chain_to_calculate",
+            keyvalues={},
+            column="event_id",
+            values=new_chain_tuples,
+        )
 
         # Now we need to calculate any new links between chains caused by
         # the new events.
@@ -897,38 +825,10 @@ class PersistEventsStore:
                 auth_chain_id, auth_sequence_number = chain_map[auth_id]
 
                 # Step 2a, add link between the event and auth event
-                to_return[event_id].links.append((auth_chain_id, auth_sequence_number))
                 chain_links.add_link(
                     (chain_id, sequence_number), (auth_chain_id, auth_sequence_number)
                 )
 
-        return to_return
-
-    @classmethod
-    def _persist_chain_cover_index(
-        cls,
-        txn: LoggingTransaction,
-        db_pool: DatabasePool,
-        new_event_links: Dict[str, NewEventChainLinks],
-    ) -> None:
-        db_pool.simple_insert_many_txn(
-            txn,
-            table="event_auth_chains",
-            keys=("event_id", "chain_id", "sequence_number"),
-            values=[
-                (event_id, new_links.chain_id, new_links.sequence_number)
-                for event_id, new_links in new_event_links.items()
-            ],
-        )
-
-        db_pool.simple_delete_many_txn(
-            txn,
-            table="event_auth_chain_to_calculate",
-            keyvalues={},
-            column="event_id",
-            values=new_event_links,
-        )
-
         db_pool.simple_insert_many_txn(
             txn,
             table="event_auth_chain_links",
@@ -938,16 +838,7 @@ class PersistEventsStore:
                 "target_chain_id",
                 "target_sequence_number",
             ),
-            values=[
-                (
-                    new_links.chain_id,
-                    new_links.sequence_number,
-                    target_chain_id,
-                    target_sequence_number,
-                )
-                for new_links in new_event_links.values()
-                for (target_chain_id, target_sequence_number) in new_links.links
-            ],
+            values=list(chain_links.get_additions()),
         )
 
     @staticmethod
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index c4e216c308..81feb3ec29 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -447,14 +447,7 @@ class EventChainStoreTestCase(HomeserverTestCase):
             )
 
             # Actually call the function that calculates the auth chain stuff.
-            new_event_links = (
-                persist_events_store.calculate_chain_cover_index_for_events_txn(
-                    txn, events[0].room_id, [e for e in events if e.is_state()]
-                )
-            )
-            persist_events_store._persist_event_auth_chain_txn(
-                txn, events, new_event_links
-            )
+            persist_events_store._persist_event_auth_chain_txn(txn, events)
 
         self.get_success(
             persist_events_store.db_pool.runInteraction(
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 1832a23714..0a6253e22c 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -365,19 +365,12 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
                     },
                 )
 
-            events = [
-                cast(EventBase, FakeEvent(event_id, room_id, AUTH_GRAPH[event_id]))
-                for event_id in AUTH_GRAPH
-            ]
-            new_event_links = (
-                self.persist_events.calculate_chain_cover_index_for_events_txn(
-                    txn, room_id, [e for e in events if e.is_state()]
-                )
-            )
             self.persist_events._persist_event_auth_chain_txn(
                 txn,
-                events,
-                new_event_links,
+                [
+                    cast(EventBase, FakeEvent(event_id, room_id, AUTH_GRAPH[event_id]))
+                    for event_id in AUTH_GRAPH
+                ],
             )
 
         self.get_success(
@@ -635,20 +628,13 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
                 )
 
             # Insert all events apart from 'B'
-            events = [
-                cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
-                for event_id in auth_graph
-                if event_id != "b"
-            ]
-            new_event_links = (
-                self.persist_events.calculate_chain_cover_index_for_events_txn(
-                    txn, room_id, [e for e in events if e.is_state()]
-                )
-            )
             self.persist_events._persist_event_auth_chain_txn(
                 txn,
-                events,
-                new_event_links,
+                [
+                    cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
+                    for event_id in auth_graph
+                    if event_id != "b"
+                ],
             )
 
             # Now we insert the event 'B' without a chain cover, by temporarily
@@ -661,14 +647,9 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
                 updatevalues={"has_auth_chain_index": False},
             )
 
-            events = [cast(EventBase, FakeEvent("b", room_id, auth_graph["b"]))]
-            new_event_links = (
-                self.persist_events.calculate_chain_cover_index_for_events_txn(
-                    txn, room_id, [e for e in events if e.is_state()]
-                )
-            )
             self.persist_events._persist_event_auth_chain_txn(
-                txn, events, new_event_links
+                txn,
+                [cast(EventBase, FakeEvent("b", room_id, auth_graph["b"]))],
             )
 
             self.store.db_pool.simple_update_txn(