summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/event_federation.py148
1 files changed, 145 insertions, 3 deletions
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 18ddb92fcc..332193ad1c 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -54,11 +54,12 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         )  # type: LruCache[str, List[Tuple[str, int]]]
 
     async def get_auth_chain(
-        self, event_ids: Collection[str], include_given: bool = False
+        self, room_id: str, event_ids: Collection[str], include_given: bool = False
     ) -> List[EventBase]:
         """Get auth events for given event_ids. The events *must* be state events.
 
         Args:
+            room_id: The room the event is in.
             event_ids: state events
             include_given: include the given events in result
 
@@ -66,24 +67,44 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             list of events
         """
         event_ids = await self.get_auth_chain_ids(
-            event_ids, include_given=include_given
+            room_id, event_ids, include_given=include_given
         )
         return await self.get_events_as_list(event_ids)
 
     async def get_auth_chain_ids(
         self,
+        room_id: str,
         event_ids: Collection[str],
         include_given: bool = False,
     ) -> List[str]:
         """Get auth events for given event_ids. The events *must* be state events.
 
         Args:
+            room_id: The room the event is in.
             event_ids: state events
             include_given: include the given events in result
 
         Returns:
-            An awaitable which resolve to a list of event_ids
+            list of event_ids
         """
+
+        # Check if we have indexed the room so we can use the chain cover
+        # algorithm.
+        room = await self.get_room(room_id)
+        if room["has_auth_chain_index"]:
+            try:
+                return await self.db_pool.runInteraction(
+                    "get_auth_chain_ids_chains",
+                    self._get_auth_chain_ids_using_cover_index_txn,
+                    room_id,
+                    event_ids,
+                    include_given,
+                )
+            except _NoChainCoverIndex:
+                # For whatever reason we don't actually have a chain cover index
+                # for the events in question, so we fall back to the old method.
+                pass
+
         return await self.db_pool.runInteraction(
             "get_auth_chain_ids",
             self._get_auth_chain_ids_txn,
@@ -91,9 +112,130 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             include_given,
         )
 
+    def _get_auth_chain_ids_using_cover_index_txn(
+        self, txn: Cursor, room_id: str, event_ids: Collection[str], include_given: bool
+    ) -> List[str]:
+        """Calculates the auth chain IDs using the chain index."""
+
+        # First we look up the chain ID/sequence numbers for the given events.
+
+        initial_events = set(event_ids)
+
+        # All the events that we've found that are reachable from the events.
+        seen_events = set()  # type: Set[str]
+
+        # A map from chain ID to max sequence number of the given events.
+        event_chains = {}  # type: Dict[int, int]
+
+        sql = """
+            SELECT event_id, chain_id, sequence_number
+            FROM event_auth_chains
+            WHERE %s
+        """
+        for batch in batch_iter(initial_events, 1000):
+            clause, args = make_in_list_sql_clause(
+                txn.database_engine, "event_id", batch
+            )
+            txn.execute(sql % (clause,), args)
+
+            for event_id, chain_id, sequence_number in txn:
+                seen_events.add(event_id)
+                event_chains[chain_id] = max(
+                    sequence_number, event_chains.get(chain_id, 0)
+                )
+
+        # Check that we actually have a chain ID for all the events.
+        events_missing_chain_info = initial_events.difference(seen_events)
+        if events_missing_chain_info:
+            # This can happen due to e.g. downgrade/upgrade of the server. We
+            # raise an exception and fall back to the previous algorithm.
+            logger.info(
+                "Unexpectedly found that events don't have chain IDs in room %s: %s",
+                room_id,
+                events_missing_chain_info,
+            )
+            raise _NoChainCoverIndex(room_id)
+
+        # Now we look up all links for the chains we have, adding chains that
+        # are reachable from any event.
+        sql = """
+            SELECT
+                origin_chain_id, origin_sequence_number,
+                target_chain_id, target_sequence_number
+            FROM event_auth_chain_links
+            WHERE %s
+        """
+
+        # A map from chain ID to max sequence number *reachable* from any event ID.
+        chains = {}  # type: Dict[int, int]
+
+        # Add all linked chains reachable from initial set of chains.
+        for batch in batch_iter(event_chains, 1000):
+            clause, args = make_in_list_sql_clause(
+                txn.database_engine, "origin_chain_id", batch
+            )
+            txn.execute(sql % (clause,), args)
+
+            for (
+                origin_chain_id,
+                origin_sequence_number,
+                target_chain_id,
+                target_sequence_number,
+            ) in txn:
+                # chains are only reachable if the origin sequence number of
+                # the link is less than the max sequence number in the
+                # origin chain.
+                if origin_sequence_number <= event_chains.get(origin_chain_id, 0):
+                    chains[target_chain_id] = max(
+                        target_sequence_number,
+                        chains.get(target_chain_id, 0),
+                    )
+
+        # Add the initial set of chains, excluding the sequence corresponding to
+        # initial event.
+        for chain_id, seq_no in event_chains.items():
+            chains[chain_id] = max(seq_no - 1, chains.get(chain_id, 0))
+
+        # Now for each chain we figure out the maximum sequence number reachable
+        # from *any* event ID. Events with a sequence less than that are in the
+        # auth chain.
+        if include_given:
+            results = initial_events
+        else:
+            results = set()
+
+        if isinstance(self.database_engine, PostgresEngine):
+            # We can use `execute_values` to efficiently fetch the gaps when
+            # using postgres.
+            sql = """
+                SELECT event_id
+                FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, max_seq)
+                WHERE
+                    c.chain_id = l.chain_id
+                    AND sequence_number <= max_seq
+            """
+
+            rows = txn.execute_values(sql, chains.items())
+            results.update(r for r, in rows)
+        else:
+            # For SQLite we just fall back to doing a noddy for loop.
+            sql = """
+                SELECT event_id FROM event_auth_chains
+                WHERE chain_id = ? AND sequence_number <= ?
+            """
+            for chain_id, max_no in chains.items():
+                txn.execute(sql, (chain_id, max_no))
+                results.update(r for r, in txn)
+
+        return list(results)
+
     def _get_auth_chain_ids_txn(
         self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool
     ) -> List[str]:
+        """Calculates the auth chain IDs.
+
+        This is used when we don't have a cover index for the room.
+        """
         if include_given:
             results = set(event_ids)
         else: