diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 90fb1a1f00..ccda9f1caa 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -17,7 +17,17 @@
import itertools
import logging
from collections import OrderedDict, namedtuple
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Dict,
+ Generator,
+ Iterable,
+ List,
+ Optional,
+ Set,
+ Tuple,
+)
import attr
from prometheus_client import Counter
@@ -35,7 +45,7 @@ from synapse.storage.databases.main.search import SearchEntry
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types import StateMap, get_domain_from_id
from synapse.util import json_encoder
-from synapse.util.iterutils import batch_iter
+from synapse.util.iterutils import batch_iter, sorted_topologically
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -366,6 +376,36 @@ class PersistEventsStore:
# Insert into event_to_state_groups.
self._store_event_state_mappings_txn(txn, events_and_contexts)
+ self._persist_event_auth_chain_txn(txn, [e for e, _ in events_and_contexts])
+
+ # _store_rejected_events_txn filters out any events which were
+ # rejected, and returns the filtered list.
+ events_and_contexts = self._store_rejected_events_txn(
+ txn, events_and_contexts=events_and_contexts
+ )
+
+ # From this point onwards the events are only ones that weren't
+ # rejected.
+
+ self._update_metadata_tables_txn(
+ txn,
+ events_and_contexts=events_and_contexts,
+ all_events_and_contexts=all_events_and_contexts,
+ backfilled=backfilled,
+ )
+
+ # We call this last as it assumes we've inserted the events into
+ # room_memberships, where applicable.
+ self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
+
+ def _persist_event_auth_chain_txn(
+ self, txn: LoggingTransaction, events: List[EventBase],
+ ) -> None:
+
+ # We only care about state events, so this if there are no state events.
+ if not any(e.is_state() for e in events):
+ return
+
# We want to store event_auth mappings for rejected events, as they're
# used in state res v2.
# This is only necessary if the rejected event appears in an accepted
@@ -381,31 +421,467 @@ class PersistEventsStore:
"room_id": event.room_id,
"auth_id": auth_id,
}
- for event, _ in events_and_contexts
+ for event in events
for auth_id in event.auth_event_ids()
if event.is_state()
],
)
- # _store_rejected_events_txn filters out any events which were
- # rejected, and returns the filtered list.
- events_and_contexts = self._store_rejected_events_txn(
- txn, events_and_contexts=events_and_contexts
+ # We now calculate chain ID/sequence numbers for any state events we're
+ # persisting. We ignore out of band memberships as we're not in the room
+ # and won't have their auth chain (we'll fix it up later if we join the
+ # room).
+ #
+ # See: docs/auth_chain_difference_algorithm.md
+
+ # We ignore legacy rooms that we aren't filling the chain cover index
+ # for.
+ rows = self.db_pool.simple_select_many_txn(
+ txn,
+ table="rooms",
+ column="room_id",
+ iterable={event.room_id for event in events if event.is_state()},
+ keyvalues={},
+ retcols=("room_id", "has_auth_chain_index"),
)
+ rooms_using_chain_index = {
+ row["room_id"] for row in rows if row["has_auth_chain_index"]
+ }
- # From this point onwards the events are only ones that weren't
- # rejected.
+ state_events = {
+ event.event_id: event
+ for event in events
+ if event.is_state() and event.room_id in rooms_using_chain_index
+ }
- self._update_metadata_tables_txn(
+ if not state_events:
+ return
+
+ # We need to know the type/state_key and auth events of the events we're
+ # calculating chain IDs for. We don't rely on having the full Event
+ # instances as we'll potentially be pulling more events from the DB and
+ # we don't need the overhead of fetching/parsing the full event JSON.
+ event_to_types = {
+ e.event_id: (e.type, e.state_key) for e in state_events.values()
+ }
+ event_to_auth_chain = {
+ e.event_id: e.auth_event_ids() for e in state_events.values()
+ }
+ event_to_room_id = {e.event_id: e.room_id for e in state_events.values()}
+
+ self._add_chain_cover_index(
+ txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain,
+ )
+
+ @classmethod
+ def _add_chain_cover_index(
+ cls,
+ txn,
+ db_pool: DatabasePool,
+ event_to_room_id: Dict[str, str],
+ event_to_types: Dict[str, Tuple[str, str]],
+ event_to_auth_chain: Dict[str, List[str]],
+ ) -> None:
+ """Calculate the chain cover index for the given events.
+
+ Args:
+ event_to_room_id: Event ID to the room ID of the event
+ event_to_types: Event ID to type and state_key of the event
+ event_to_auth_chain: Event ID to list of auth event IDs of the
+ event (events with no auth events can be excluded).
+ """
+
+ # Map from event ID to chain ID/sequence number.
+ chain_map = {} # type: Dict[str, Tuple[int, int]]
+
+ # Set of event IDs to calculate chain ID/seq numbers for.
+ events_to_calc_chain_id_for = set(event_to_room_id)
+
+ # We check if there are any events that need to be handled in the rooms
+ # we're looking at. These should just be out of band memberships, where
+ # we didn't have the auth chain when we first persisted.
+ rows = db_pool.simple_select_many_txn(
txn,
- events_and_contexts=events_and_contexts,
- all_events_and_contexts=all_events_and_contexts,
- backfilled=backfilled,
+ table="event_auth_chain_to_calculate",
+ keyvalues={},
+ column="room_id",
+ iterable=set(event_to_room_id.values()),
+ retcols=("event_id", "type", "state_key"),
)
+ for row in rows:
+ event_id = row["event_id"]
+ event_type = row["type"]
+ state_key = row["state_key"]
+
+ # (We could pull out the auth events for all rows at once using
+ # simple_select_many, but this case happens rarely and almost always
+ # with a single row.)
+ auth_events = db_pool.simple_select_onecol_txn(
+ txn, "event_auth", keyvalues={"event_id": event_id}, retcol="auth_id",
+ )
- # We call this last as it assumes we've inserted the events into
- # room_memberships, where applicable.
- self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
+ events_to_calc_chain_id_for.add(event_id)
+ event_to_types[event_id] = (event_type, state_key)
+ event_to_auth_chain[event_id] = auth_events
+
+ # First we get the chain ID and sequence numbers for the events'
+ # auth events (that aren't also currently being persisted).
+ #
+ # Note that there there is an edge case here where we might not have
+ # calculated chains and sequence numbers for events that were "out
+ # of band". We handle this case by fetching the necessary info and
+ # adding it to the set of events to calculate chain IDs for.
+
+ missing_auth_chains = {
+ a_id
+ for auth_events in event_to_auth_chain.values()
+ for a_id in auth_events
+ if a_id not in events_to_calc_chain_id_for
+ }
+
+ # We loop here in case we find an out of band membership and need to
+ # fetch their auth event info.
+ while missing_auth_chains:
+ sql = """
+ SELECT event_id, events.type, state_key, chain_id, sequence_number
+ FROM events
+ INNER JOIN state_events USING (event_id)
+ LEFT JOIN event_auth_chains USING (event_id)
+ WHERE
+ """
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "event_id", missing_auth_chains,
+ )
+ txn.execute(sql + clause, args)
+
+ missing_auth_chains.clear()
+
+ for auth_id, event_type, state_key, chain_id, sequence_number in txn:
+ event_to_types[auth_id] = (event_type, state_key)
+
+ if chain_id is None:
+ # No chain ID, so the event was persisted out of band.
+ # We add to list of events to calculate auth chains for.
+
+ events_to_calc_chain_id_for.add(auth_id)
+
+ event_to_auth_chain[auth_id] = db_pool.simple_select_onecol_txn(
+ txn,
+ "event_auth",
+ keyvalues={"event_id": auth_id},
+ retcol="auth_id",
+ )
+
+ missing_auth_chains.update(
+ e
+ for e in event_to_auth_chain[auth_id]
+ if e not in event_to_types
+ )
+ else:
+ chain_map[auth_id] = (chain_id, sequence_number)
+
+ # Now we check if we have any events where we don't have auth chain,
+ # this should only be out of band memberships.
+ for event_id in sorted_topologically(event_to_auth_chain, event_to_auth_chain):
+ for auth_id in event_to_auth_chain[event_id]:
+ if (
+ auth_id not in chain_map
+ and auth_id not in events_to_calc_chain_id_for
+ ):
+ events_to_calc_chain_id_for.discard(event_id)
+
+ # If this is an event we're trying to persist we add it to
+ # the list of events to calculate chain IDs for next time
+ # around. (Otherwise we will have already added it to the
+ # table).
+ room_id = event_to_room_id.get(event_id)
+ if room_id:
+ e_type, state_key = event_to_types[event_id]
+ db_pool.simple_insert_txn(
+ txn,
+ table="event_auth_chain_to_calculate",
+ values={
+ "event_id": event_id,
+ "room_id": room_id,
+ "type": e_type,
+ "state_key": state_key,
+ },
+ )
+
+ # We stop checking the event's auth events since we've
+ # discarded it.
+ break
+
+ if not events_to_calc_chain_id_for:
+ return
+
+ # Allocate chain ID/sequence numbers to each new event.
+ new_chain_tuples = cls._allocate_chain_ids(
+ txn,
+ db_pool,
+ event_to_room_id,
+ event_to_types,
+ event_to_auth_chain,
+ events_to_calc_chain_id_for,
+ chain_map,
+ )
+ chain_map.update(new_chain_tuples)
+
+ db_pool.simple_insert_many_txn(
+ txn,
+ table="event_auth_chains",
+ values=[
+ {"event_id": event_id, "chain_id": c_id, "sequence_number": seq}
+ for event_id, (c_id, seq) in new_chain_tuples.items()
+ ],
+ )
+
+ db_pool.simple_delete_many_txn(
+ txn,
+ table="event_auth_chain_to_calculate",
+ keyvalues={},
+ column="event_id",
+ iterable=new_chain_tuples,
+ )
+
+ # Now we need to calculate any new links between chains caused by
+ # the new events.
+ #
+ # Links are pairs of chain ID/sequence numbers such that for any
+ # event A (CA, SA) and any event B (CB, SB), B is in A's auth chain
+ # if and only if there is at least one link (CA, S1) -> (CB, S2)
+ # where SA >= S1 and S2 >= SB.
+ #
+ # We try and avoid adding redundant links to the table, e.g. if we
+ # have two links between two chains which both start/end at the
+ # sequence number event (or cross) then one can be safely dropped.
+ #
+ # To calculate new links we look at every new event and:
+ # 1. Fetch the chain ID/sequence numbers of its auth events,
+ # discarding any that are reachable by other auth events, or
+ # that have the same chain ID as the event.
+ # 2. For each retained auth event we:
+ # a. Add a link from the event's to the auth event's chain
+ # ID/sequence number; and
+ # b. Add a link from the event to every chain reachable by the
+ # auth event.
+
+ # Step 1, fetch all existing links from all the chains we've seen
+ # referenced.
+ chain_links = _LinkMap()
+ rows = db_pool.simple_select_many_txn(
+ txn,
+ table="event_auth_chain_links",
+ column="origin_chain_id",
+ iterable={chain_id for chain_id, _ in chain_map.values()},
+ keyvalues={},
+ retcols=(
+ "origin_chain_id",
+ "origin_sequence_number",
+ "target_chain_id",
+ "target_sequence_number",
+ ),
+ )
+ for row in rows:
+ chain_links.add_link(
+ (row["origin_chain_id"], row["origin_sequence_number"]),
+ (row["target_chain_id"], row["target_sequence_number"]),
+ new=False,
+ )
+
+ # We do this in toplogical order to avoid adding redundant links.
+ for event_id in sorted_topologically(
+ events_to_calc_chain_id_for, event_to_auth_chain
+ ):
+ chain_id, sequence_number = chain_map[event_id]
+
+ # Filter out auth events that are reachable by other auth
+ # events. We do this by looking at every permutation of pairs of
+ # auth events (A, B) to check if B is reachable from A.
+ reduction = {
+ a_id
+ for a_id in event_to_auth_chain.get(event_id, [])
+ if chain_map[a_id][0] != chain_id
+ }
+ for start_auth_id, end_auth_id in itertools.permutations(
+ event_to_auth_chain.get(event_id, []), r=2,
+ ):
+ if chain_links.exists_path_from(
+ chain_map[start_auth_id], chain_map[end_auth_id]
+ ):
+ reduction.discard(end_auth_id)
+
+ # Step 2, figure out what the new links are from the reduced
+ # list of auth events.
+ for auth_id in reduction:
+ auth_chain_id, auth_sequence_number = chain_map[auth_id]
+
+ # Step 2a, add link between the event and auth event
+ chain_links.add_link(
+ (chain_id, sequence_number), (auth_chain_id, auth_sequence_number)
+ )
+
+ # Step 2b, add a link to chains reachable from the auth
+ # event.
+ for target_id, target_seq in chain_links.get_links_from(
+ (auth_chain_id, auth_sequence_number)
+ ):
+ if target_id == chain_id:
+ continue
+
+ chain_links.add_link(
+ (chain_id, sequence_number), (target_id, target_seq)
+ )
+
+ db_pool.simple_insert_many_txn(
+ txn,
+ table="event_auth_chain_links",
+ values=[
+ {
+ "origin_chain_id": source_id,
+ "origin_sequence_number": source_seq,
+ "target_chain_id": target_id,
+ "target_sequence_number": target_seq,
+ }
+ for (
+ source_id,
+ source_seq,
+ target_id,
+ target_seq,
+ ) in chain_links.get_additions()
+ ],
+ )
+
+ @staticmethod
+ def _allocate_chain_ids(
+ txn,
+ db_pool: DatabasePool,
+ event_to_room_id: Dict[str, str],
+ event_to_types: Dict[str, Tuple[str, str]],
+ event_to_auth_chain: Dict[str, List[str]],
+ events_to_calc_chain_id_for: Set[str],
+ chain_map: Dict[str, Tuple[int, int]],
+ ) -> Dict[str, Tuple[int, int]]:
+ """Allocates, but does not persist, chain ID/sequence numbers for the
+ events in `events_to_calc_chain_id_for`. (c.f. _add_chain_cover_index
+ for info on args)
+ """
+
+ # We now calculate the chain IDs/sequence numbers for the events. We do
+ # this by looking at the chain ID and sequence number of any auth event
+ # with the same type/state_key and incrementing the sequence number by
+ # one. If there was no match or the chain ID/sequence number is already
+ # taken we generate a new chain.
+ #
+ # We try to reduce the number of times that we hit the database by
+ # batching up calls, to make this more efficient when persisting large
+ # numbers of state events (e.g. during joins).
+ #
+ # We do this by:
+ # 1. Calculating for each event which auth event will be used to
+ # inherit the chain ID, i.e. converting the auth chain graph to a
+ # tree that we can allocate chains on. We also keep track of which
+ # existing chain IDs have been referenced.
+ # 2. Fetching the max allocated sequence number for each referenced
+ # existing chain ID, generating a map from chain ID to the max
+ # allocated sequence number.
+ # 3. Iterating over the tree and allocating a chain ID/seq no. to the
+ # new event, by incrementing the sequence number from the
+ # referenced event's chain ID/seq no. and checking that the
+ # incremented sequence number hasn't already been allocated (by
+ # looking in the map generated in the previous step). We generate a
+ # new chain if the sequence number has already been allocated.
+ #
+
+ existing_chains = set() # type: Set[int]
+ tree = [] # type: List[Tuple[str, Optional[str]]]
+
+ # We need to do this in a topologically sorted order as we want to
+ # generate chain IDs/sequence numbers of an event's auth events before
+ # the event itself.
+ for event_id in sorted_topologically(
+ events_to_calc_chain_id_for, event_to_auth_chain
+ ):
+ for auth_id in event_to_auth_chain.get(event_id, []):
+ if event_to_types.get(event_id) == event_to_types.get(auth_id):
+ existing_chain_id = chain_map.get(auth_id)
+ if existing_chain_id:
+ existing_chains.add(existing_chain_id[0])
+
+ tree.append((event_id, auth_id))
+ break
+ else:
+ tree.append((event_id, None))
+
+ # Fetch the current max sequence number for each existing referenced chain.
+ sql = """
+ SELECT chain_id, MAX(sequence_number) FROM event_auth_chains
+ WHERE %s
+ GROUP BY chain_id
+ """
+ clause, args = make_in_list_sql_clause(
+ db_pool.engine, "chain_id", existing_chains
+ )
+ txn.execute(sql % (clause,), args)
+
+ chain_to_max_seq_no = {row[0]: row[1] for row in txn} # type: Dict[Any, int]
+
+ # Allocate the new events chain ID/sequence numbers.
+ #
+ # To reduce the number of calls to the database we don't allocate a
+ # chain ID number in the loop, instead we use a temporary `object()` for
+ # each new chain ID. Once we've done the loop we generate the necessary
+ # number of new chain IDs in one call, replacing all temporary
+ # objects with real allocated chain IDs.
+
+ unallocated_chain_ids = set() # type: Set[object]
+ new_chain_tuples = {} # type: Dict[str, Tuple[Any, int]]
+ for event_id, auth_event_id in tree:
+ # If we reference an auth_event_id we fetch the allocated chain ID,
+ # either from the existing `chain_map` or the newly generated
+ # `new_chain_tuples` map.
+ existing_chain_id = None
+ if auth_event_id:
+ existing_chain_id = new_chain_tuples.get(auth_event_id)
+ if not existing_chain_id:
+ existing_chain_id = chain_map[auth_event_id]
+
+ new_chain_tuple = None # type: Optional[Tuple[Any, int]]
+ if existing_chain_id:
+ # We found a chain ID/sequence number candidate, check its
+ # not already taken.
+ proposed_new_id = existing_chain_id[0]
+ proposed_new_seq = existing_chain_id[1] + 1
+
+ if chain_to_max_seq_no[proposed_new_id] < proposed_new_seq:
+ new_chain_tuple = (
+ proposed_new_id,
+ proposed_new_seq,
+ )
+
+ # If we need to start a new chain we allocate a temporary chain ID.
+ if not new_chain_tuple:
+ new_chain_tuple = (object(), 1)
+ unallocated_chain_ids.add(new_chain_tuple[0])
+
+ new_chain_tuples[event_id] = new_chain_tuple
+ chain_to_max_seq_no[new_chain_tuple[0]] = new_chain_tuple[1]
+
+ # Generate new chain IDs for all unallocated chain IDs.
+ newly_allocated_chain_ids = db_pool.event_chain_id_gen.get_next_mult_txn(
+ txn, len(unallocated_chain_ids)
+ )
+
+ # Map from potentially temporary chain ID to real chain ID
+ chain_id_to_allocated_map = dict(
+ zip(unallocated_chain_ids, newly_allocated_chain_ids)
+ ) # type: Dict[Any, int]
+ chain_id_to_allocated_map.update((c, c) for c in existing_chains)
+
+ return {
+ event_id: (chain_id_to_allocated_map[chain_id], seq)
+ for event_id, (chain_id, seq) in new_chain_tuples.items()
+ }
def _persist_transaction_ids_txn(
self,
@@ -489,7 +965,7 @@ class PersistEventsStore:
WHERE room_id = ? AND type = ? AND state_key = ?
)
"""
- txn.executemany(
+ txn.execute_batch(
sql,
(
(
@@ -508,7 +984,7 @@ class PersistEventsStore:
)
# Now we actually update the current_state_events table
- txn.executemany(
+ txn.execute_batch(
"DELETE FROM current_state_events"
" WHERE room_id = ? AND type = ? AND state_key = ?",
(
@@ -520,7 +996,7 @@ class PersistEventsStore:
# We include the membership in the current state table, hence we do
# a lookup when we insert. This assumes that all events have already
# been inserted into room_memberships.
- txn.executemany(
+ txn.execute_batch(
"""INSERT INTO current_state_events
(room_id, type, state_key, event_id, membership)
VALUES (?, ?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
@@ -540,7 +1016,7 @@ class PersistEventsStore:
# we have no record of the fact the user *was* a member of the
# room but got, say, state reset out of it.
if to_delete or to_insert:
- txn.executemany(
+ txn.execute_batch(
"DELETE FROM local_current_membership"
" WHERE room_id = ? AND user_id = ?",
(
@@ -551,7 +1027,7 @@ class PersistEventsStore:
)
if to_insert:
- txn.executemany(
+ txn.execute_batch(
"""INSERT INTO local_current_membership
(room_id, user_id, event_id, membership)
VALUES (?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
@@ -799,7 +1275,8 @@ class PersistEventsStore:
return [ec for ec in events_and_contexts if ec[0] not in to_remove]
def _store_event_txn(self, txn, events_and_contexts):
- """Insert new events into the event and event_json tables
+ """Insert new events into the event, event_json, redaction and
+ state_events tables.
Args:
txn (twisted.enterprise.adbapi.Connection): db connection
@@ -871,6 +1348,29 @@ class PersistEventsStore:
updatevalues={"have_censored": False},
)
+ state_events_and_contexts = [
+ ec for ec in events_and_contexts if ec[0].is_state()
+ ]
+
+ state_values = []
+ for event, context in state_events_and_contexts:
+ vals = {
+ "event_id": event.event_id,
+ "room_id": event.room_id,
+ "type": event.type,
+ "state_key": event.state_key,
+ }
+
+ # TODO: How does this work with backfilling?
+ if hasattr(event, "replaces_state"):
+ vals["prev_state"] = event.replaces_state
+
+ state_values.append(vals)
+
+ self.db_pool.simple_insert_many_txn(
+ txn, table="state_events", values=state_values
+ )
+
def _store_rejected_events_txn(self, txn, events_and_contexts):
"""Add rows to the 'rejections' table for received events which were
rejected
@@ -987,29 +1487,6 @@ class PersistEventsStore:
txn, [event for event, _ in events_and_contexts]
)
- state_events_and_contexts = [
- ec for ec in events_and_contexts if ec[0].is_state()
- ]
-
- state_values = []
- for event, context in state_events_and_contexts:
- vals = {
- "event_id": event.event_id,
- "room_id": event.room_id,
- "type": event.type,
- "state_key": event.state_key,
- }
-
- # TODO: How does this work with backfilling?
- if hasattr(event, "replaces_state"):
- vals["prev_state"] = event.replaces_state
-
- state_values.append(vals)
-
- self.db_pool.simple_insert_many_txn(
- txn, table="state_events", values=state_values
- )
-
# Prefill the event cache
self._add_to_cache(txn, events_and_contexts)
@@ -1350,7 +1827,7 @@ class PersistEventsStore:
"""
if events_and_contexts:
- txn.executemany(
+ txn.execute_batch(
sql,
(
(
@@ -1379,7 +1856,7 @@ class PersistEventsStore:
# Now we delete the staging area for *all* events that were being
# persisted.
- txn.executemany(
+ txn.execute_batch(
"DELETE FROM event_push_actions_staging WHERE event_id = ?",
((event.event_id,) for event, _ in all_events_and_contexts),
)
@@ -1498,7 +1975,7 @@ class PersistEventsStore:
" )"
)
- txn.executemany(
+ txn.execute_batch(
query,
[
(e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
@@ -1512,7 +1989,7 @@ class PersistEventsStore:
"DELETE FROM event_backward_extremities"
" WHERE event_id = ? AND room_id = ?"
)
- txn.executemany(
+ txn.execute_batch(
query,
[
(ev.event_id, ev.room_id)
@@ -1520,3 +1997,131 @@ class PersistEventsStore:
if not ev.internal_metadata.is_outlier()
],
)
+
+
+@attr.s(slots=True)
+class _LinkMap:
+ """A helper type for tracking links between chains.
+ """
+
+ # Stores the set of links as nested maps: source chain ID -> target chain ID
+ # -> source sequence number -> target sequence number.
+ maps = attr.ib(type=Dict[int, Dict[int, Dict[int, int]]], factory=dict)
+
+ # Stores the links that have been added (with new set to true), as tuples of
+ # `(source chain ID, source sequence no, target chain ID, target sequence no.)`
+ additions = attr.ib(type=Set[Tuple[int, int, int, int]], factory=set)
+
+ def add_link(
+ self,
+ src_tuple: Tuple[int, int],
+ target_tuple: Tuple[int, int],
+ new: bool = True,
+ ) -> bool:
+ """Add a new link between two chains, ensuring no redundant links are added.
+
+ New links should be added in topological order.
+
+ Args:
+ src_tuple: The chain ID/sequence number of the source of the link.
+ target_tuple: The chain ID/sequence number of the target of the link.
+ new: Whether this is a "new" link, i.e. should it be returned
+ by `get_additions`.
+
+ Returns:
+ True if a link was added, false if the given link was dropped as redundant
+ """
+ src_chain, src_seq = src_tuple
+ target_chain, target_seq = target_tuple
+
+ current_links = self.maps.setdefault(src_chain, {}).setdefault(target_chain, {})
+
+ assert src_chain != target_chain
+
+ if new:
+ # Check if the new link is redundant
+ for current_seq_src, current_seq_target in current_links.items():
+ # If a link "crosses" another link then its redundant. For example
+ # in the following link 1 (L1) is redundant, as any event reachable
+ # via L1 is *also* reachable via L2.
+ #
+ # Chain A Chain B
+ # | |
+ # L1 |------ |
+ # | | |
+ # L2 |---- | -->|
+ # | | |
+ # | |--->|
+ # | |
+ # | |
+ #
+ # So we only need to keep links which *do not* cross, i.e. links
+ # that both start and end above or below an existing link.
+ #
+ # Note, since we add links in topological ordering we should never
+ # see `src_seq` less than `current_seq_src`.
+
+ if current_seq_src <= src_seq and target_seq <= current_seq_target:
+ # This new link is redundant, nothing to do.
+ return False
+
+ self.additions.add((src_chain, src_seq, target_chain, target_seq))
+
+ current_links[src_seq] = target_seq
+ return True
+
+ def get_links_from(
+ self, src_tuple: Tuple[int, int]
+ ) -> Generator[Tuple[int, int], None, None]:
+ """Gets the chains reachable from the given chain/sequence number.
+
+ Yields:
+ The chain ID and sequence number the link points to.
+ """
+ src_chain, src_seq = src_tuple
+ for target_id, sequence_numbers in self.maps.get(src_chain, {}).items():
+ for link_src_seq, target_seq in sequence_numbers.items():
+ if link_src_seq <= src_seq:
+ yield target_id, target_seq
+
+ def get_links_between(
+ self, source_chain: int, target_chain: int
+ ) -> Generator[Tuple[int, int], None, None]:
+ """Gets the links between two chains.
+
+ Yields:
+ The source and target sequence numbers.
+ """
+
+ yield from self.maps.get(source_chain, {}).get(target_chain, {}).items()
+
+ def get_additions(self) -> Generator[Tuple[int, int, int, int], None, None]:
+ """Gets any newly added links.
+
+ Yields:
+ The source chain ID/sequence number and target chain ID/sequence number
+ """
+
+ for src_chain, src_seq, target_chain, _ in self.additions:
+ target_seq = self.maps.get(src_chain, {}).get(target_chain, {}).get(src_seq)
+ if target_seq is not None:
+ yield (src_chain, src_seq, target_chain, target_seq)
+
+ def exists_path_from(
+ self, src_tuple: Tuple[int, int], target_tuple: Tuple[int, int],
+ ) -> bool:
+ """Checks if there is a path between the source chain ID/sequence and
+ target chain ID/sequence.
+ """
+ src_chain, src_seq = src_tuple
+ target_chain, target_seq = target_tuple
+
+ if src_chain == target_chain:
+ return target_seq <= src_seq
+
+ links = self.get_links_between(src_chain, target_chain)
+ for link_start_seq, link_end_seq in links:
+ if link_start_seq <= src_seq and target_seq <= link_end_seq:
+ return True
+
+ return False
|