diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 18ddb92fcc..332193ad1c 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -54,11 +54,12 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
) # type: LruCache[str, List[Tuple[str, int]]]
async def get_auth_chain(
- self, event_ids: Collection[str], include_given: bool = False
+ self, room_id: str, event_ids: Collection[str], include_given: bool = False
) -> List[EventBase]:
"""Get auth events for given event_ids. The events *must* be state events.
Args:
+ room_id: The room the event is in.
event_ids: state events
include_given: include the given events in result
@@ -66,24 +67,44 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
list of events
"""
event_ids = await self.get_auth_chain_ids(
- event_ids, include_given=include_given
+ room_id, event_ids, include_given=include_given
)
return await self.get_events_as_list(event_ids)
async def get_auth_chain_ids(
self,
+ room_id: str,
event_ids: Collection[str],
include_given: bool = False,
) -> List[str]:
"""Get auth events for given event_ids. The events *must* be state events.
Args:
+ room_id: The room the event is in.
event_ids: state events
include_given: include the given events in result
Returns:
- An awaitable which resolve to a list of event_ids
+ list of event_ids
"""
+
+ # Check if we have indexed the room so we can use the chain cover
+ # algorithm.
+ room = await self.get_room(room_id)
+ if room["has_auth_chain_index"]:
+ try:
+ return await self.db_pool.runInteraction(
+ "get_auth_chain_ids_chains",
+ self._get_auth_chain_ids_using_cover_index_txn,
+ room_id,
+ event_ids,
+ include_given,
+ )
+ except _NoChainCoverIndex:
+ # For whatever reason we don't actually have a chain cover index
+ # for the events in question, so we fall back to the old method.
+ pass
+
return await self.db_pool.runInteraction(
"get_auth_chain_ids",
self._get_auth_chain_ids_txn,
@@ -91,9 +112,130 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
include_given,
)
+ def _get_auth_chain_ids_using_cover_index_txn(
+ self, txn: Cursor, room_id: str, event_ids: Collection[str], include_given: bool
+ ) -> List[str]:
+ """Calculates the auth chain IDs using the chain index."""
+
+ # First we look up the chain ID/sequence numbers for the given events.
+
+ initial_events = set(event_ids)
+
+ # All the events that we've found that are reachable from the events.
+ seen_events = set() # type: Set[str]
+
+ # A map from chain ID to max sequence number of the given events.
+ event_chains = {} # type: Dict[int, int]
+
+ sql = """
+ SELECT event_id, chain_id, sequence_number
+ FROM event_auth_chains
+ WHERE %s
+ """
+ for batch in batch_iter(initial_events, 1000):
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "event_id", batch
+ )
+ txn.execute(sql % (clause,), args)
+
+ for event_id, chain_id, sequence_number in txn:
+ seen_events.add(event_id)
+ event_chains[chain_id] = max(
+ sequence_number, event_chains.get(chain_id, 0)
+ )
+
+ # Check that we actually have a chain ID for all the events.
+ events_missing_chain_info = initial_events.difference(seen_events)
+ if events_missing_chain_info:
+ # This can happen due to e.g. downgrade/upgrade of the server. We
+ # raise an exception and fall back to the previous algorithm.
+ logger.info(
+ "Unexpectedly found that events don't have chain IDs in room %s: %s",
+ room_id,
+ events_missing_chain_info,
+ )
+ raise _NoChainCoverIndex(room_id)
+
+ # Now we look up all links for the chains we have, adding chains that
+ # are reachable from any event.
+ sql = """
+ SELECT
+ origin_chain_id, origin_sequence_number,
+ target_chain_id, target_sequence_number
+ FROM event_auth_chain_links
+ WHERE %s
+ """
+
+ # A map from chain ID to max sequence number *reachable* from any event ID.
+ chains = {} # type: Dict[int, int]
+
+ # Add all linked chains reachable from initial set of chains.
+ for batch in batch_iter(event_chains, 1000):
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "origin_chain_id", batch
+ )
+ txn.execute(sql % (clause,), args)
+
+ for (
+ origin_chain_id,
+ origin_sequence_number,
+ target_chain_id,
+ target_sequence_number,
+ ) in txn:
+ # chains are only reachable if the origin sequence number of
+ # the link is less than the max sequence number in the
+ # origin chain.
+ if origin_sequence_number <= event_chains.get(origin_chain_id, 0):
+ chains[target_chain_id] = max(
+ target_sequence_number,
+ chains.get(target_chain_id, 0),
+ )
+
+ # Add the initial set of chains, excluding the sequence corresponding to
+ # initial event.
+ for chain_id, seq_no in event_chains.items():
+ chains[chain_id] = max(seq_no - 1, chains.get(chain_id, 0))
+
+ # Now for each chain we figure out the maximum sequence number reachable
+ # from *any* event ID. Events with a sequence less than that are in the
+ # auth chain.
+ if include_given:
+ results = initial_events
+ else:
+ results = set()
+
+ if isinstance(self.database_engine, PostgresEngine):
+ # We can use `execute_values` to efficiently fetch the gaps when
+ # using postgres.
+ sql = """
+ SELECT event_id
+ FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, max_seq)
+ WHERE
+ c.chain_id = l.chain_id
+ AND sequence_number <= max_seq
+ """
+
+ rows = txn.execute_values(sql, chains.items())
+ results.update(r for r, in rows)
+ else:
+ # For SQLite we just fall back to doing a noddy for loop.
+ sql = """
+ SELECT event_id FROM event_auth_chains
+ WHERE chain_id = ? AND sequence_number <= ?
+ """
+ for chain_id, max_no in chains.items():
+ txn.execute(sql, (chain_id, max_no))
+ results.update(r for r, in txn)
+
+ return list(results)
+
def _get_auth_chain_ids_txn(
self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool
) -> List[str]:
+ """Calculates the auth chain IDs.
+
+ This is used when we don't have a cover index for the room.
+ """
if include_given:
results = set(event_ids)
else:
|