summary refs log tree commit diff
path: root/synapse/storage/databases/main
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main')
-rw-r--r--synapse/storage/databases/main/event_federation.py20
-rw-r--r--synapse/storage/databases/main/events.py251
2 files changed, 194 insertions, 77 deletions
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index fb132ef090..24abab4a23 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -148,6 +148,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             500000, "_event_auth_cache", size_callback=len
         )
 
+        # Flag used by unit tests to disable fallback when there is no chain cover
+        # index.
+        self.tests_allow_no_chain_cover_index = True
+
         self._clock.looping_call(self._get_stats_for_federation_staging, 30 * 1000)
 
         if isinstance(self.database_engine, PostgresEngine):
@@ -220,8 +224,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
                 )
             except _NoChainCoverIndex:
                 # For whatever reason we don't actually have a chain cover index
-                # for the events in question, so we fall back to the old method.
-                pass
+                # for the events in question, so we fall back to the old method
+                # (except in tests)
+                if not self.tests_allow_no_chain_cover_index:
+                    raise
 
         return await self.db_pool.runInteraction(
             "get_auth_chain_ids",
@@ -271,7 +277,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         if events_missing_chain_info:
             # This can happen due to e.g. downgrade/upgrade of the server. We
             # raise an exception and fall back to the previous algorithm.
-            logger.info(
+            logger.error(
                 "Unexpectedly found that events don't have chain IDs in room %s: %s",
                 room_id,
                 events_missing_chain_info,
@@ -482,8 +488,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
                 )
             except _NoChainCoverIndex:
                 # For whatever reason we don't actually have a chain cover index
-                # for the events in question, so we fall back to the old method.
-                pass
+                # for the events in question, so we fall back to the old method
+                # (except in tests)
+                if not self.tests_allow_no_chain_cover_index:
+                    raise
 
         return await self.db_pool.runInteraction(
             "get_auth_chain_difference",
@@ -710,7 +718,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         if events_missing_chain_info - event_to_auth_ids.keys():
             # Uh oh, we somehow haven't correctly done the chain cover index,
             # bail and fall back to the old method.
-            logger.info(
+            logger.error(
                 "Unexpectedly found that events don't have chain IDs in room %s: %s",
                 room_id,
                 events_missing_chain_info - event_to_auth_ids.keys(),
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 66428e6c8e..1f7acdb859 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -34,7 +34,6 @@ from typing import (
     Optional,
     Set,
     Tuple,
-    Union,
     cast,
 )
 
@@ -100,6 +99,23 @@ class DeltaState:
         return not self.to_delete and not self.to_insert and not self.no_longer_in_room
 
 
+@attr.s(slots=True, auto_attribs=True)
+class NewEventChainLinks:
+    """Information about new auth chain links that need to be added to the DB.
+
+    Attributes:
+        chain_id, sequence_number: the IDs corresponding to the event being
+            inserted, and the starting point of the links
+        links: Lists the links that need to be added, 2-tuple of the chain
+            ID/sequence number of the end point of the link.
+    """
+
+    chain_id: int
+    sequence_number: int
+
+    links: List[Tuple[int, int]] = attr.Factory(list)
+
+
 class PersistEventsStore:
     """Contains all the functions for writing events to the database.
 
@@ -148,6 +164,7 @@ class PersistEventsStore:
         *,
         state_delta_for_room: Optional[DeltaState],
         new_forward_extremities: Optional[Set[str]],
+        new_event_links: Dict[str, NewEventChainLinks],
         use_negative_stream_ordering: bool = False,
         inhibit_local_membership_updates: bool = False,
     ) -> None:
@@ -217,6 +234,7 @@ class PersistEventsStore:
                 inhibit_local_membership_updates=inhibit_local_membership_updates,
                 state_delta_for_room=state_delta_for_room,
                 new_forward_extremities=new_forward_extremities,
+                new_event_links=new_event_links,
             )
             persist_event_counter.inc(len(events_and_contexts))
 
@@ -243,6 +261,87 @@ class PersistEventsStore:
                     (room_id,), frozenset(new_forward_extremities)
                 )
 
+    async def calculate_chain_cover_index_for_events(
+        self, room_id: str, events: Collection[EventBase]
+    ) -> Dict[str, NewEventChainLinks]:
+        # Filter to state events, and ensure there are no duplicates.
+        state_events = []
+        seen_events = set()
+        for event in events:
+            if not event.is_state() or event.event_id in seen_events:
+                continue
+
+            state_events.append(event)
+            seen_events.add(event.event_id)
+
+        if not state_events:
+            return {}
+
+        return await self.db_pool.runInteraction(
+            "_calculate_chain_cover_index_for_events",
+            self.calculate_chain_cover_index_for_events_txn,
+            room_id,
+            state_events,
+        )
+
+    def calculate_chain_cover_index_for_events_txn(
+        self, txn: LoggingTransaction, room_id: str, state_events: Collection[EventBase]
+    ) -> Dict[str, NewEventChainLinks]:
+        # We now calculate chain ID/sequence numbers for any state events we're
+        # persisting. We ignore out of band memberships as we're not in the room
+        # and won't have their auth chain (we'll fix it up later if we join the
+        # room).
+        #
+        # See: docs/auth_chain_difference_algorithm.md
+
+        # We ignore legacy rooms that we aren't filling the chain cover index
+        # for.
+        row = self.db_pool.simple_select_one_txn(
+            txn,
+            table="rooms",
+            keyvalues={"room_id": room_id},
+            retcols=("room_id", "has_auth_chain_index"),
+            allow_none=True,
+        )
+        if row is None or row[1] is False:
+            return {}
+
+        # Filter out events that we've already calculated.
+        rows = self.db_pool.simple_select_many_txn(
+            txn,
+            table="event_auth_chains",
+            column="event_id",
+            iterable=[e.event_id for e in state_events],
+            keyvalues={},
+            retcols=("event_id",),
+        )
+        already_persisted_events = {event_id for event_id, in rows}
+        state_events = [
+            event
+            for event in state_events
+            if event.event_id not in already_persisted_events
+        ]
+
+        if not state_events:
+            return {}
+
+        # We need to know the type/state_key and auth events of the events we're
+        # calculating chain IDs for. We don't rely on having the full Event
+        # instances as we'll potentially be pulling more events from the DB and
+        # we don't need the overhead of fetching/parsing the full event JSON.
+        event_to_types = {e.event_id: (e.type, e.state_key) for e in state_events}
+        event_to_auth_chain = {e.event_id: e.auth_event_ids() for e in state_events}
+        event_to_room_id = {e.event_id: e.room_id for e in state_events}
+
+        return self._calculate_chain_cover_index(
+            txn,
+            self.db_pool,
+            self.store.event_chain_id_gen,
+            event_to_room_id,
+            event_to_types,
+            event_to_auth_chain,
+        )
+
     async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]:
         """Filter the supplied list of event_ids to get those which are prev_events of
         existing (non-outlier/rejected) events.
@@ -358,6 +457,7 @@ class PersistEventsStore:
         inhibit_local_membership_updates: bool,
         state_delta_for_room: Optional[DeltaState],
         new_forward_extremities: Optional[Set[str]],
+        new_event_links: Dict[str, NewEventChainLinks],
     ) -> None:
         """Insert some number of room events into the necessary database tables.
 
@@ -466,7 +566,9 @@ class PersistEventsStore:
         # Insert into event_to_state_groups.
         self._store_event_state_mappings_txn(txn, events_and_contexts)
 
-        self._persist_event_auth_chain_txn(txn, [e for e, _ in events_and_contexts])
+        self._persist_event_auth_chain_txn(
+            txn, [e for e, _ in events_and_contexts], new_event_links
+        )
 
         # _store_rejected_events_txn filters out any events which were
         # rejected, and returns the filtered list.
@@ -496,7 +598,11 @@ class PersistEventsStore:
         self,
         txn: LoggingTransaction,
         events: List[EventBase],
+        new_event_links: Dict[str, NewEventChainLinks],
     ) -> None:
+        if new_event_links:
+            self._persist_chain_cover_index(txn, self.db_pool, new_event_links)
+
         # We only care about state events, so this if there are no state events.
         if not any(e.is_state() for e in events):
             return
@@ -519,62 +625,37 @@ class PersistEventsStore:
             ],
         )
 
-        # We now calculate chain ID/sequence numbers for any state events we're
-        # persisting. We ignore out of band memberships as we're not in the room
-        # and won't have their auth chain (we'll fix it up later if we join the
-        # room).
-        #
-        # See: docs/auth_chain_difference_algorithm.md
-
-        # We ignore legacy rooms that we aren't filling the chain cover index
-        # for.
-        rows = cast(
-            List[Tuple[str, Optional[Union[int, bool]]]],
-            self.db_pool.simple_select_many_txn(
-                txn,
-                table="rooms",
-                column="room_id",
-                iterable={event.room_id for event in events if event.is_state()},
-                keyvalues={},
-                retcols=("room_id", "has_auth_chain_index"),
-            ),
-        )
-        rooms_using_chain_index = {
-            room_id for room_id, has_auth_chain_index in rows if has_auth_chain_index
-        }
-
-        state_events = {
-            event.event_id: event
-            for event in events
-            if event.is_state() and event.room_id in rooms_using_chain_index
-        }
-
-        if not state_events:
-            return
+    @classmethod
+    def _add_chain_cover_index(
+        cls,
+        txn: LoggingTransaction,
+        db_pool: DatabasePool,
+        event_chain_id_gen: SequenceGenerator,
+        event_to_room_id: Dict[str, str],
+        event_to_types: Dict[str, Tuple[str, str]],
+        event_to_auth_chain: Dict[str, StrCollection],
+    ) -> None:
+        """Calculate and persist the chain cover index for the given events.
 
-        # We need to know the type/state_key and auth events of the events we're
-        # calculating chain IDs for. We don't rely on having the full Event
-        # instances as we'll potentially be pulling more events from the DB and
-        # we don't need the overhead of fetching/parsing the full event JSON.
-        event_to_types = {
-            e.event_id: (e.type, e.state_key) for e in state_events.values()
-        }
-        event_to_auth_chain = {
-            e.event_id: e.auth_event_ids() for e in state_events.values()
-        }
-        event_to_room_id = {e.event_id: e.room_id for e in state_events.values()}
+        Args:
+            event_to_room_id: Event ID to the room ID of the event
+            event_to_types: Event ID to type and state_key of the event
+            event_to_auth_chain: Event ID to list of auth event IDs of the
+                event (events with no auth events can be excluded).
+        """
 
-        self._add_chain_cover_index(
+        new_event_links = cls._calculate_chain_cover_index(
             txn,
-            self.db_pool,
-            self.store.event_chain_id_gen,
+            db_pool,
+            event_chain_id_gen,
             event_to_room_id,
             event_to_types,
             event_to_auth_chain,
         )
+        cls._persist_chain_cover_index(txn, db_pool, new_event_links)
 
     @classmethod
-    def _add_chain_cover_index(
+    def _calculate_chain_cover_index(
         cls,
         txn: LoggingTransaction,
         db_pool: DatabasePool,
@@ -582,7 +663,7 @@ class PersistEventsStore:
         event_to_room_id: Dict[str, str],
         event_to_types: Dict[str, Tuple[str, str]],
         event_to_auth_chain: Dict[str, StrCollection],
-    ) -> None:
+    ) -> Dict[str, NewEventChainLinks]:
         """Calculate the chain cover index for the given events.
 
         Args:
@@ -590,6 +671,10 @@ class PersistEventsStore:
             event_to_types: Event ID to type and state_key of the event
             event_to_auth_chain: Event ID to list of auth event IDs of the
                 event (events with no auth events can be excluded).
+
+        Returns:
+            A mapping with any new auth chain links we need to add, keyed by
+            event ID.
         """
 
         # Map from event ID to chain ID/sequence number.
@@ -708,11 +793,11 @@ class PersistEventsStore:
                     room_id = event_to_room_id.get(event_id)
                     if room_id:
                         e_type, state_key = event_to_types[event_id]
-                        db_pool.simple_insert_txn(
+                        db_pool.simple_upsert_txn(
                             txn,
                             table="event_auth_chain_to_calculate",
+                            keyvalues={"event_id": event_id},
                             values={
-                                "event_id": event_id,
                                 "room_id": room_id,
                                 "type": e_type,
                                 "state_key": state_key,
@@ -724,7 +809,7 @@ class PersistEventsStore:
                     break
 
         if not events_to_calc_chain_id_for:
-            return
+            return {}
 
         # Allocate chain ID/sequence numbers to each new event.
         new_chain_tuples = cls._allocate_chain_ids(
@@ -739,23 +824,10 @@ class PersistEventsStore:
         )
         chain_map.update(new_chain_tuples)
 
-        db_pool.simple_insert_many_txn(
-            txn,
-            table="event_auth_chains",
-            keys=("event_id", "chain_id", "sequence_number"),
-            values=[
-                (event_id, c_id, seq)
-                for event_id, (c_id, seq) in new_chain_tuples.items()
-            ],
-        )
-
-        db_pool.simple_delete_many_txn(
-            txn,
-            table="event_auth_chain_to_calculate",
-            keyvalues={},
-            column="event_id",
-            values=new_chain_tuples,
-        )
+        to_return = {
+            event_id: NewEventChainLinks(chain_id, sequence_number)
+            for event_id, (chain_id, sequence_number) in new_chain_tuples.items()
+        }
 
         # Now we need to calculate any new links between chains caused by
         # the new events.
@@ -825,10 +897,38 @@ class PersistEventsStore:
                 auth_chain_id, auth_sequence_number = chain_map[auth_id]
 
                 # Step 2a, add link between the event and auth event
+                to_return[event_id].links.append((auth_chain_id, auth_sequence_number))
                 chain_links.add_link(
                     (chain_id, sequence_number), (auth_chain_id, auth_sequence_number)
                 )
 
+        return to_return
+
+    @classmethod
+    def _persist_chain_cover_index(
+        cls,
+        txn: LoggingTransaction,
+        db_pool: DatabasePool,
+        new_event_links: Dict[str, NewEventChainLinks],
+    ) -> None:
+        db_pool.simple_insert_many_txn(
+            txn,
+            table="event_auth_chains",
+            keys=("event_id", "chain_id", "sequence_number"),
+            values=[
+                (event_id, new_links.chain_id, new_links.sequence_number)
+                for event_id, new_links in new_event_links.items()
+            ],
+        )
+
+        db_pool.simple_delete_many_txn(
+            txn,
+            table="event_auth_chain_to_calculate",
+            keyvalues={},
+            column="event_id",
+            values=new_event_links,
+        )
+
         db_pool.simple_insert_many_txn(
             txn,
             table="event_auth_chain_links",
@@ -838,7 +938,16 @@ class PersistEventsStore:
                 "target_chain_id",
                 "target_sequence_number",
             ),
-            values=list(chain_links.get_additions()),
+            values=[
+                (
+                    new_links.chain_id,
+                    new_links.sequence_number,
+                    target_chain_id,
+                    target_sequence_number,
+                )
+                for new_links in new_event_links.values()
+                for (target_chain_id, target_sequence_number) in new_links.links
+            ],
         )
 
     @staticmethod