diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index ebffd89251..8326640d20 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -24,6 +24,8 @@ from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.signatures import SignatureWorkerStore
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.types import Cursor
from synapse.types import Collection
from synapse.util.caches.descriptors import cached
from synapse.util.caches.lrucache import LruCache
@@ -32,6 +34,11 @@ from synapse.util.iterutils import batch_iter
logger = logging.getLogger(__name__)
+class _NoChainCoverIndex(Exception):
+ def __init__(self, room_id: str):
+ super().__init__("Unexpectedly no chain cover for events in %s" % (room_id,))
+
+
class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
@@ -151,15 +158,193 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
The set of the difference in auth chains.
"""
+ # Check if we have indexed the room so we can use the chain cover
+ # algorithm.
+ room = await self.get_room(room_id)
+ if room["has_auth_chain_index"]:
+ try:
+ return await self.db_pool.runInteraction(
+ "get_auth_chain_difference_chains",
+ self._get_auth_chain_difference_using_cover_index_txn,
+ room_id,
+ state_sets,
+ )
+ except _NoChainCoverIndex:
+ # For whatever reason we don't actually have a chain cover index
+ # for the events in question, so we fall back to the old method.
+ pass
+
return await self.db_pool.runInteraction(
"get_auth_chain_difference",
self._get_auth_chain_difference_txn,
state_sets,
)
+ def _get_auth_chain_difference_using_cover_index_txn(
+ self, txn: Cursor, room_id: str, state_sets: List[Set[str]]
+ ) -> Set[str]:
+ """Calculates the auth chain difference using the chain index.
+
+ See docs/auth_chain_difference_algorithm.md for details
+ """
+
+ # First we look up the chain ID/sequence numbers for all the events, and
+ # work out the chain/sequence numbers reachable from each state set.
+
+ initial_events = set(state_sets[0]).union(*state_sets[1:])
+
+ # Map from event_id -> (chain ID, seq no)
+ chain_info = {} # type: Dict[str, Tuple[int, int]]
+
+ # Map from chain ID -> seq no -> event Id
+ chain_to_event = {} # type: Dict[int, Dict[int, str]]
+
+ # All the chains that we've found that are reachable from the state
+ # sets.
+ seen_chains = set() # type: Set[int]
+
+ sql = """
+ SELECT event_id, chain_id, sequence_number
+ FROM event_auth_chains
+ WHERE %s
+ """
+ for batch in batch_iter(initial_events, 1000):
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "event_id", batch
+ )
+ txn.execute(sql % (clause,), args)
+
+ for event_id, chain_id, sequence_number in txn:
+ chain_info[event_id] = (chain_id, sequence_number)
+ seen_chains.add(chain_id)
+ chain_to_event.setdefault(chain_id, {})[sequence_number] = event_id
+
+ # Check that we actually have a chain ID for all the events.
+ events_missing_chain_info = initial_events.difference(chain_info)
+ if events_missing_chain_info:
+ # This can happen due to e.g. downgrade/upgrade of the server. We
+ # raise an exception and fall back to the previous algorithm.
+ logger.info(
+ "Unexpectedly found that events don't have chain IDs in room %s: %s",
+ room_id,
+ events_missing_chain_info,
+ )
+ raise _NoChainCoverIndex(room_id)
+
+ # Corresponds to `state_sets`, except as a map from chain ID to max
+ # sequence number reachable from the state set.
+ set_to_chain = [] # type: List[Dict[int, int]]
+ for state_set in state_sets:
+ chains = {} # type: Dict[int, int]
+ set_to_chain.append(chains)
+
+ for event_id in state_set:
+ chain_id, seq_no = chain_info[event_id]
+
+ chains[chain_id] = max(seq_no, chains.get(chain_id, 0))
+
+ # Now we look up all links for the chains we have, adding chains to
+ # set_to_chain that are reachable from each set.
+ sql = """
+ SELECT
+ origin_chain_id, origin_sequence_number,
+ target_chain_id, target_sequence_number
+ FROM event_auth_chain_links
+ WHERE %s
+ """
+
+ # (We need to take a copy of `seen_chains` as we want to mutate it in
+ # the loop)
+ for batch in batch_iter(set(seen_chains), 1000):
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "origin_chain_id", batch
+ )
+ txn.execute(sql % (clause,), args)
+
+ for (
+ origin_chain_id,
+ origin_sequence_number,
+ target_chain_id,
+ target_sequence_number,
+ ) in txn:
+ for chains in set_to_chain:
+ # chains are only reachable if the origin sequence number of
+ # the link is less than the max sequence number in the
+ # origin chain.
+ if origin_sequence_number <= chains.get(origin_chain_id, 0):
+ chains[target_chain_id] = max(
+ target_sequence_number, chains.get(target_chain_id, 0),
+ )
+
+ seen_chains.add(target_chain_id)
+
+ # Now for each chain we figure out the maximum sequence number reachable
+ # from *any* state set and the minimum sequence number reachable from
+ # *all* state sets. Events in that range are in the auth chain
+ # difference.
+ result = set()
+
+ # Mapping from chain ID to the range of sequence numbers that should be
+ # pulled from the database.
+ chain_to_gap = {} # type: Dict[int, Tuple[int, int]]
+
+ for chain_id in seen_chains:
+ min_seq_no = min(chains.get(chain_id, 0) for chains in set_to_chain)
+ max_seq_no = max(chains.get(chain_id, 0) for chains in set_to_chain)
+
+ if min_seq_no < max_seq_no:
+ # We have a non empty gap, try and fill it from the events that
+ # we have, otherwise add them to the list of gaps to pull out
+ # from the DB.
+ for seq_no in range(min_seq_no + 1, max_seq_no + 1):
+ event_id = chain_to_event.get(chain_id, {}).get(seq_no)
+ if event_id:
+ result.add(event_id)
+ else:
+ chain_to_gap[chain_id] = (min_seq_no, max_seq_no)
+ break
+
+ if not chain_to_gap:
+ # If there are no gaps to fetch, we're done!
+ return result
+
+ if isinstance(self.database_engine, PostgresEngine):
+ # We can use `execute_values` to efficiently fetch the gaps when
+ # using postgres.
+ sql = """
+ SELECT event_id
+ FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, min_seq, max_seq)
+ WHERE
+ c.chain_id = l.chain_id
+ AND min_seq < sequence_number AND sequence_number <= max_seq
+ """
+
+ args = [
+ (chain_id, min_no, max_no)
+ for chain_id, (min_no, max_no) in chain_to_gap.items()
+ ]
+
+ rows = txn.execute_values(sql, args)
+ result.update(r for r, in rows)
+ else:
+ # For SQLite we just fall back to doing a noddy for loop.
+ sql = """
+ SELECT event_id FROM event_auth_chains
+ WHERE chain_id = ? AND ? < sequence_number AND sequence_number <= ?
+ """
+ for chain_id, (min_no, max_no) in chain_to_gap.items():
+ txn.execute(sql, (chain_id, min_no, max_no))
+ result.update(r for r, in txn)
+
+ return result
+
def _get_auth_chain_difference_txn(
self, txn, state_sets: List[Set[str]]
) -> Set[str]:
+ """Calculates the auth chain difference using a breadth first search.
+
+ This is used when we don't have a cover index for the room.
+ """
# Algorithm Description
# ~~~~~~~~~~~~~~~~~~~~~
|