diff --git a/changelog.d/8868.misc b/changelog.d/8868.misc
new file mode 100644
index 0000000000..1a11e30944
--- /dev/null
+++ b/changelog.d/8868.misc
@@ -0,0 +1 @@
+Improve efficiency of large state resolutions for new rooms.
diff --git a/docs/auth_chain_diff.dot b/docs/auth_chain_diff.dot
new file mode 100644
index 0000000000..978d579ada
--- /dev/null
+++ b/docs/auth_chain_diff.dot
@@ -0,0 +1,32 @@
+digraph auth {
+ nodesep=0.5;
+ rankdir="RL";
+
+ C [label="Create (1,1)"];
+
+ BJ [label="Bob's Join (2,1)", color=red];
+ BJ2 [label="Bob's Join (2,2)", color=red];
+ BJ2 -> BJ [color=red, dir=none];
+
+ subgraph cluster_foo {
+ A1 [label="Alice's invite (4,1)", color=blue];
+ A2 [label="Alice's Join (4,2)", color=blue];
+ A3 [label="Alice's Join (4,3)", color=blue];
+ A3 -> A2 -> A1 [color=blue, dir=none];
+ color=none;
+ }
+
+ PL1 [label="Power Level (3,1)", color=darkgreen];
+ PL2 [label="Power Level (3,2)", color=darkgreen];
+ PL2 -> PL1 [color=darkgreen, dir=none];
+
+ {rank = same; C; BJ; PL1; A1;}
+
+ A1 -> C [color=grey];
+ A1 -> BJ [color=grey];
+ PL1 -> C [color=grey];
+ BJ2 -> PL1 [penwidth=2];
+
+ A3 -> PL2 [penwidth=2];
+ A1 -> PL1 -> BJ -> C [penwidth=2];
+}
diff --git a/docs/auth_chain_diff.dot.png b/docs/auth_chain_diff.dot.png
new file mode 100644
index 0000000000..771c07308f
--- /dev/null
+++ b/docs/auth_chain_diff.dot.png
Binary files differdiff --git a/docs/auth_chain_difference_algorithm.md b/docs/auth_chain_difference_algorithm.md
new file mode 100644
index 0000000000..30f72a70da
--- /dev/null
+++ b/docs/auth_chain_difference_algorithm.md
@@ -0,0 +1,108 @@
+# Auth Chain Difference Algorithm
+
+The auth chain difference algorithm is used by V2 state resolution, where a
+naive implementation can be a significant source of CPU and DB usage.
+
+### Definitions
+
+A *state set* is a set of state events; e.g. the input of a state resolution
+algorithm is a collection of state sets.
+
+The *auth chain* of a set of events are all the events' auth events and *their*
+auth events, recursively (i.e. the events reachable by walking the graph induced
+by an event's auth events links).
+
+The *auth chain difference* of a collection of state sets is the union minus the
+intersection of the sets of auth chains corresponding to the state sets, i.e an
+event is in the auth chain difference if it is reachable by walking the auth
+event graph from at least one of the state sets but not from *all* of the state
+sets.
+
+## Breadth First Walk Algorithm
+
+A way of calculating the auth chain difference without calculating the full auth
+chains for each state set is to do a parallel breadth first walk (ordered by
+depth) of each state set's auth chain. By tracking which events are reachable
+from each state set we can finish early if every pending event is reachable from
+every state set.
+
+This can work well for state sets that have a small auth chain difference, but
+can be very inefficient for larger differences. However, this algorithm is still
+used if we don't have a chain cover index for the room (e.g. because we're in
+the process of indexing it).
+
+## Chain Cover Index
+
+Synapse computes auth chain differences by pre-computing a "chain cover" index
+for the auth chain in a room, allowing efficient reachability queries like "is
+event A in the auth chain of event B". This is done by assigning every event a
+*chain ID* and *sequence number* (e.g. `(5,3)`), and having a map of *links*
+between chains (e.g. `(5,3) -> (2,4)`) such that A is reachable by B (i.e. `A`
+is in the auth chain of `B`) if and only if either:
+
+1. A and B have the same chain ID and `A`'s sequence number is less than `B`'s
+ sequence number; or
+2. there is a link `L` between `B`'s chain ID and `A`'s chain ID such that
+ `L.start_seq_no` <= `B.seq_no` and `A.seq_no` <= `L.end_seq_no`.
+
+There are actually two potential implementations, one where we store links from
+each chain to every other reachable chain (the transitive closure of the links
+graph), and one where we remove redundant links (the transitive reduction of the
+links graph) e.g. if we have chains `C3 -> C2 -> C1` then the link `C3 -> C1`
+would not be stored. Synapse uses the former implementations so that it doesn't
+need to recurse to test reachability between chains.
+
+### Example
+
+An example auth graph would look like the following, where chains have been
+formed based on type/state_key and are denoted by colour and are labelled with
+`(chain ID, sequence number)`. Links are denoted by the arrows (links in grey
+are those that would be remove in the second implementation described above).
+
+
+
+Note that we don't include all links between events and their auth events, as
+most of those links would be redundant. For example, all events point to the
+create event, but each chain only needs the one link from it's base to the
+create event.
+
+## Using the Index
+
+This index can be used to calculate the auth chain difference of the state sets
+by looking at the chain ID and sequence numbers reachable from each state set:
+
+1. For every state set lookup the chain ID/sequence numbers of each state event
+2. Use the index to find all chains and the maximum sequence number reachable
+ from each state set.
+3. The auth chain difference is then all events in each chain that have sequence
+ numbers between the maximum sequence number reachable from *any* state set and
+ the minimum reachable by *all* state sets (if any).
+
+Note that steps 2 is effectively calculating the auth chain for each state set
+(in terms of chain IDs and sequence numbers), and step 3 is calculating the
+difference between the union and intersection of the auth chains.
+
+### Worked Example
+
+For example, given the above graph, we can calculate the difference between
+state sets consisting of:
+
+1. `S1`: Alice's invite `(4,1)` and Bob's second join `(2,2)`; and
+2. `S2`: Alice's second join `(4,3)` and Bob's first join `(2,1)`.
+
+Using the index we see that the following auth chains are reachable from each
+state set:
+
+1. `S1`: `(1,1)`, `(2,2)`, `(3,1)` & `(4,1)`
+2. `S2`: `(1,1)`, `(2,1)`, `(3,2)` & `(4,3)`
+
+And so, for each the ranges that are in the auth chain difference:
+1. Chain 1: None, (since everything can reach the create event).
+2. Chain 2: The range `(1, 2]` (i.e. just `2`), as `1` is reachable by all state
+ sets and the maximum reachable is `2` (corresponding to Bob's second join).
+3. Chain 3: Similarly the range `(1, 2]` (corresponding to the second power
+ level).
+4. Chain 4: The range `(1, 3]` (corresponding to both of Alice's joins).
+
+So the final result is: Bob's second join `(2,2)`, the second power level
+`(3,2)` and both of Alice's joins `(4,2)` & `(4,3)`.
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index b70ca3087b..6cfadc2b4e 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -179,6 +179,9 @@ class LoggingDatabaseConnection:
_CallbackListEntry = Tuple["Callable[..., None]", Iterable[Any], Dict[str, Any]]
+R = TypeVar("R")
+
+
class LoggingTransaction:
"""An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging and metrics to the .execute()
@@ -266,6 +269,20 @@ class LoggingTransaction:
for val in args:
self.execute(sql, val)
+ def execute_values(self, sql: str, *args: Any) -> List[Tuple]:
+ """Corresponds to psycopg2.extras.execute_values. Only available when
+ using postgres.
+
+ Always sets fetch=True when caling `execute_values`, so will return the
+ results.
+ """
+ assert isinstance(self.database_engine, PostgresEngine)
+ from psycopg2.extras import execute_values # type: ignore
+
+ return self._do_execute(
+ lambda *x: execute_values(self.txn, *x, fetch=True), sql, *args
+ )
+
def execute(self, sql: str, *args: Any) -> None:
self._do_execute(self.txn.execute, sql, *args)
@@ -276,7 +293,7 @@ class LoggingTransaction:
"Strip newlines out of SQL so that the loggers in the DB are on one line"
return " ".join(line.strip() for line in sql.splitlines() if line.strip())
- def _do_execute(self, func, sql: str, *args: Any) -> None:
+ def _do_execute(self, func: Callable[..., R], sql: str, *args: Any) -> R:
sql = self._make_sql_one_line(sql)
# TODO(paul): Maybe use 'info' and 'debug' for values?
@@ -347,9 +364,6 @@ class PerformanceCounters:
return top_n_counters
-R = TypeVar("R")
-
-
class DatabasePool:
"""Wraps a single physical database and connection pool.
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index ebffd89251..8326640d20 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -24,6 +24,8 @@ from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.signatures import SignatureWorkerStore
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.types import Cursor
from synapse.types import Collection
from synapse.util.caches.descriptors import cached
from synapse.util.caches.lrucache import LruCache
@@ -32,6 +34,11 @@ from synapse.util.iterutils import batch_iter
logger = logging.getLogger(__name__)
+class _NoChainCoverIndex(Exception):
+ def __init__(self, room_id: str):
+ super().__init__("Unexpectedly no chain cover for events in %s" % (room_id,))
+
+
class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
@@ -151,15 +158,193 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
The set of the difference in auth chains.
"""
+ # Check if we have indexed the room so we can use the chain cover
+ # algorithm.
+ room = await self.get_room(room_id)
+ if room["has_auth_chain_index"]:
+ try:
+ return await self.db_pool.runInteraction(
+ "get_auth_chain_difference_chains",
+ self._get_auth_chain_difference_using_cover_index_txn,
+ room_id,
+ state_sets,
+ )
+ except _NoChainCoverIndex:
+ # For whatever reason we don't actually have a chain cover index
+ # for the events in question, so we fall back to the old method.
+ pass
+
return await self.db_pool.runInteraction(
"get_auth_chain_difference",
self._get_auth_chain_difference_txn,
state_sets,
)
+ def _get_auth_chain_difference_using_cover_index_txn(
+ self, txn: Cursor, room_id: str, state_sets: List[Set[str]]
+ ) -> Set[str]:
+ """Calculates the auth chain difference using the chain index.
+
+ See docs/auth_chain_difference_algorithm.md for details
+ """
+
+ # First we look up the chain ID/sequence numbers for all the events, and
+ # work out the chain/sequence numbers reachable from each state set.
+
+ initial_events = set(state_sets[0]).union(*state_sets[1:])
+
+ # Map from event_id -> (chain ID, seq no)
+ chain_info = {} # type: Dict[str, Tuple[int, int]]
+
+ # Map from chain ID -> seq no -> event Id
+ chain_to_event = {} # type: Dict[int, Dict[int, str]]
+
+ # All the chains that we've found that are reachable from the state
+ # sets.
+ seen_chains = set() # type: Set[int]
+
+ sql = """
+ SELECT event_id, chain_id, sequence_number
+ FROM event_auth_chains
+ WHERE %s
+ """
+ for batch in batch_iter(initial_events, 1000):
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "event_id", batch
+ )
+ txn.execute(sql % (clause,), args)
+
+ for event_id, chain_id, sequence_number in txn:
+ chain_info[event_id] = (chain_id, sequence_number)
+ seen_chains.add(chain_id)
+ chain_to_event.setdefault(chain_id, {})[sequence_number] = event_id
+
+ # Check that we actually have a chain ID for all the events.
+ events_missing_chain_info = initial_events.difference(chain_info)
+ if events_missing_chain_info:
+ # This can happen due to e.g. downgrade/upgrade of the server. We
+ # raise an exception and fall back to the previous algorithm.
+ logger.info(
+ "Unexpectedly found that events don't have chain IDs in room %s: %s",
+ room_id,
+ events_missing_chain_info,
+ )
+ raise _NoChainCoverIndex(room_id)
+
+ # Corresponds to `state_sets`, except as a map from chain ID to max
+ # sequence number reachable from the state set.
+ set_to_chain = [] # type: List[Dict[int, int]]
+ for state_set in state_sets:
+ chains = {} # type: Dict[int, int]
+ set_to_chain.append(chains)
+
+ for event_id in state_set:
+ chain_id, seq_no = chain_info[event_id]
+
+ chains[chain_id] = max(seq_no, chains.get(chain_id, 0))
+
+ # Now we look up all links for the chains we have, adding chains to
+ # set_to_chain that are reachable from each set.
+ sql = """
+ SELECT
+ origin_chain_id, origin_sequence_number,
+ target_chain_id, target_sequence_number
+ FROM event_auth_chain_links
+ WHERE %s
+ """
+
+ # (We need to take a copy of `seen_chains` as we want to mutate it in
+ # the loop)
+ for batch in batch_iter(set(seen_chains), 1000):
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "origin_chain_id", batch
+ )
+ txn.execute(sql % (clause,), args)
+
+ for (
+ origin_chain_id,
+ origin_sequence_number,
+ target_chain_id,
+ target_sequence_number,
+ ) in txn:
+ for chains in set_to_chain:
+ # chains are only reachable if the origin sequence number of
+ # the link is less than the max sequence number in the
+ # origin chain.
+ if origin_sequence_number <= chains.get(origin_chain_id, 0):
+ chains[target_chain_id] = max(
+ target_sequence_number, chains.get(target_chain_id, 0),
+ )
+
+ seen_chains.add(target_chain_id)
+
+ # Now for each chain we figure out the maximum sequence number reachable
+ # from *any* state set and the minimum sequence number reachable from
+ # *all* state sets. Events in that range are in the auth chain
+ # difference.
+ result = set()
+
+ # Mapping from chain ID to the range of sequence numbers that should be
+ # pulled from the database.
+ chain_to_gap = {} # type: Dict[int, Tuple[int, int]]
+
+ for chain_id in seen_chains:
+ min_seq_no = min(chains.get(chain_id, 0) for chains in set_to_chain)
+ max_seq_no = max(chains.get(chain_id, 0) for chains in set_to_chain)
+
+ if min_seq_no < max_seq_no:
+ # We have a non empty gap, try and fill it from the events that
+ # we have, otherwise add them to the list of gaps to pull out
+ # from the DB.
+ for seq_no in range(min_seq_no + 1, max_seq_no + 1):
+ event_id = chain_to_event.get(chain_id, {}).get(seq_no)
+ if event_id:
+ result.add(event_id)
+ else:
+ chain_to_gap[chain_id] = (min_seq_no, max_seq_no)
+ break
+
+ if not chain_to_gap:
+ # If there are no gaps to fetch, we're done!
+ return result
+
+ if isinstance(self.database_engine, PostgresEngine):
+ # We can use `execute_values` to efficiently fetch the gaps when
+ # using postgres.
+ sql = """
+ SELECT event_id
+ FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, min_seq, max_seq)
+ WHERE
+ c.chain_id = l.chain_id
+ AND min_seq < sequence_number AND sequence_number <= max_seq
+ """
+
+ args = [
+ (chain_id, min_no, max_no)
+ for chain_id, (min_no, max_no) in chain_to_gap.items()
+ ]
+
+ rows = txn.execute_values(sql, args)
+ result.update(r for r, in rows)
+ else:
+ # For SQLite we just fall back to doing a noddy for loop.
+ sql = """
+ SELECT event_id FROM event_auth_chains
+ WHERE chain_id = ? AND ? < sequence_number AND sequence_number <= ?
+ """
+ for chain_id, (min_no, max_no) in chain_to_gap.items():
+ txn.execute(sql, (chain_id, min_no, max_no))
+ result.update(r for r, in txn)
+
+ return result
+
def _get_auth_chain_difference_txn(
self, txn, state_sets: List[Set[str]]
) -> Set[str]:
+ """Calculates the auth chain difference using a breadth first search.
+
+ This is used when we don't have a cover index for the room.
+ """
# Algorithm Description
# ~~~~~~~~~~~~~~~~~~~~~
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 5e7753e09b..186f064036 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -17,7 +17,17 @@
import itertools
import logging
from collections import OrderedDict, namedtuple
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Dict,
+ Generator,
+ Iterable,
+ List,
+ Optional,
+ Set,
+ Tuple,
+)
import attr
from prometheus_client import Counter
@@ -33,9 +43,10 @@ from synapse.storage._base import db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.search import SearchEntry
from synapse.storage.util.id_generators import MultiWriterIdGenerator
+from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import StateMap, get_domain_from_id
from synapse.util import json_encoder
-from synapse.util.iterutils import batch_iter
+from synapse.util.iterutils import batch_iter, sorted_topologically
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -89,6 +100,14 @@ class PersistEventsStore:
self._clock = hs.get_clock()
self._instance_name = hs.get_instance_name()
+ def get_chain_id_txn(txn):
+ txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains")
+ return txn.fetchone()[0]
+
+ self._event_chain_id_gen = build_sequence_generator(
+ db.engine, get_chain_id_txn, "event_auth_chain_id"
+ )
+
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
self.is_mine_id = hs.is_mine_id
@@ -366,6 +385,36 @@ class PersistEventsStore:
# Insert into event_to_state_groups.
self._store_event_state_mappings_txn(txn, events_and_contexts)
+ self._persist_event_auth_chain_txn(txn, [e for e, _ in events_and_contexts])
+
+ # _store_rejected_events_txn filters out any events which were
+ # rejected, and returns the filtered list.
+ events_and_contexts = self._store_rejected_events_txn(
+ txn, events_and_contexts=events_and_contexts
+ )
+
+ # From this point onwards the events are only ones that weren't
+ # rejected.
+
+ self._update_metadata_tables_txn(
+ txn,
+ events_and_contexts=events_and_contexts,
+ all_events_and_contexts=all_events_and_contexts,
+ backfilled=backfilled,
+ )
+
+ # We call this last as it assumes we've inserted the events into
+ # room_memberships, where applicable.
+ self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
+
+ def _persist_event_auth_chain_txn(
+ self, txn: LoggingTransaction, events: List[EventBase],
+ ) -> None:
+
+ # We only care about state events, so this if there are no state events.
+ if not any(e.is_state() for e in events):
+ return
+
# We want to store event_auth mappings for rejected events, as they're
# used in state res v2.
# This is only necessary if the rejected event appears in an accepted
@@ -381,31 +430,357 @@ class PersistEventsStore:
"room_id": event.room_id,
"auth_id": auth_id,
}
- for event, _ in events_and_contexts
+ for event in events
for auth_id in event.auth_event_ids()
if event.is_state()
],
)
- # _store_rejected_events_txn filters out any events which were
- # rejected, and returns the filtered list.
- events_and_contexts = self._store_rejected_events_txn(
- txn, events_and_contexts=events_and_contexts
+ # We now calculate chain ID/sequence numbers for any state events we're
+ # persisting. We ignore out of band memberships as we're not in the room
+ # and won't have their auth chain (we'll fix it up later if we join the
+ # room).
+ #
+ # See: docs/auth_chain_difference_algorithm.md
+
+ # We ignore legacy rooms that we aren't filling the chain cover index
+ # for.
+ rows = self.db_pool.simple_select_many_txn(
+ txn,
+ table="rooms",
+ column="room_id",
+ iterable={event.room_id for event in events if event.is_state()},
+ keyvalues={},
+ retcols=("room_id", "has_auth_chain_index"),
)
+ rooms_using_chain_index = {
+ row["room_id"] for row in rows if row["has_auth_chain_index"]
+ }
- # From this point onwards the events are only ones that weren't
- # rejected.
+ state_events = {
+ event.event_id: event
+ for event in events
+ if event.is_state() and event.room_id in rooms_using_chain_index
+ }
- self._update_metadata_tables_txn(
+ if not state_events:
+ return
+
+ # Map from event ID to chain ID/sequence number.
+ chain_map = {} # type: Dict[str, Tuple[int, int]]
+
+ # We need to know the type/state_key and auth events of the events we're
+ # calculating chain IDs for. We don't rely on having the full Event
+ # instances as we'll potentially be pulling more events from the DB and
+ # we don't need the overhead of fetching/parsing the full event JSON.
+ event_to_types = {
+ e.event_id: (e.type, e.state_key) for e in state_events.values()
+ }
+ event_to_auth_chain = {
+ e.event_id: e.auth_event_ids() for e in state_events.values()
+ }
+
+ # Set of event IDs to calculate chain ID/seq numbers for.
+ events_to_calc_chain_id_for = set(state_events)
+
+ # We check if there are any events that need to be handled in the rooms
+ # we're looking at. These should just be out of band memberships, where
+ # we didn't have the auth chain when we first persisted.
+ rows = self.db_pool.simple_select_many_txn(
txn,
- events_and_contexts=events_and_contexts,
- all_events_and_contexts=all_events_and_contexts,
- backfilled=backfilled,
+ table="event_auth_chain_to_calculate",
+ keyvalues={},
+ column="room_id",
+ iterable={e.room_id for e in state_events.values()},
+ retcols=("event_id", "type", "state_key"),
)
+ for row in rows:
+ event_id = row["event_id"]
+ event_type = row["type"]
+ state_key = row["state_key"]
+
+ # (We could pull out the auth events for all rows at once using
+ # simple_select_many, but this case happens rarely and almost always
+ # with a single row.)
+ auth_events = self.db_pool.simple_select_onecol_txn(
+ txn, "event_auth", keyvalues={"event_id": event_id}, retcol="auth_id",
+ )
- # We call this last as it assumes we've inserted the events into
- # room_memberships, where applicable.
- self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
+ events_to_calc_chain_id_for.add(event_id)
+ event_to_types[event_id] = (event_type, state_key)
+ event_to_auth_chain[event_id] = auth_events
+
+ # First we get the chain ID and sequence numbers for the events'
+ # auth events (that aren't also currently being persisted).
+ #
+ # Note that there there is an edge case here where we might not have
+ # calculated chains and sequence numbers for events that were "out
+ # of band". We handle this case by fetching the necessary info and
+ # adding it to the set of events to calculate chain IDs for.
+
+ missing_auth_chains = {
+ a_id
+ for auth_events in event_to_auth_chain.values()
+ for a_id in auth_events
+ if a_id not in events_to_calc_chain_id_for
+ }
+
+ # We loop here in case we find an out of band membership and need to
+ # fetch their auth event info.
+ while missing_auth_chains:
+ sql = """
+ SELECT event_id, events.type, state_key, chain_id, sequence_number
+ FROM events
+ INNER JOIN state_events USING (event_id)
+ LEFT JOIN event_auth_chains USING (event_id)
+ WHERE
+ """
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "event_id", missing_auth_chains,
+ )
+ txn.execute(sql + clause, args)
+
+ missing_auth_chains.clear()
+
+ for auth_id, event_type, state_key, chain_id, sequence_number in txn:
+ event_to_types[auth_id] = (event_type, state_key)
+
+ if chain_id is None:
+ # No chain ID, so the event was persisted out of band.
+ # We add to list of events to calculate auth chains for.
+
+ events_to_calc_chain_id_for.add(auth_id)
+
+ event_to_auth_chain[
+ auth_id
+ ] = self.db_pool.simple_select_onecol_txn(
+ txn,
+ "event_auth",
+ keyvalues={"event_id": auth_id},
+ retcol="auth_id",
+ )
+
+ missing_auth_chains.update(
+ e
+ for e in event_to_auth_chain[auth_id]
+ if e not in event_to_types
+ )
+ else:
+ chain_map[auth_id] = (chain_id, sequence_number)
+
+ # Now we check if we have any events where we don't have auth chain,
+ # this should only be out of band memberships.
+ for event_id in sorted_topologically(event_to_auth_chain, event_to_auth_chain):
+ for auth_id in event_to_auth_chain[event_id]:
+ if (
+ auth_id not in chain_map
+ and auth_id not in events_to_calc_chain_id_for
+ ):
+ events_to_calc_chain_id_for.discard(event_id)
+
+ # If this is an event we're trying to persist we add it to
+ # the list of events to calculate chain IDs for next time
+ # around. (Otherwise we will have already added it to the
+ # table).
+ event = state_events.get(event_id)
+ if event:
+ self.db_pool.simple_insert_txn(
+ txn,
+ table="event_auth_chain_to_calculate",
+ values={
+ "event_id": event.event_id,
+ "room_id": event.room_id,
+ "type": event.type,
+ "state_key": event.state_key,
+ },
+ )
+
+ # We stop checking the event's auth events since we've
+ # discarded it.
+ break
+
+ if not events_to_calc_chain_id_for:
+ return
+
+ # We now calculate the chain IDs/sequence numbers for the events. We
+ # do this by looking at the chain ID and sequence number of any auth
+ # event with the same type/state_key and incrementing the sequence
+ # number by one. If there was no match or the chain ID/sequence
+ # number is already taken we generate a new chain.
+ #
+ # We need to do this in a topologically sorted order as we want to
+ # generate chain IDs/sequence numbers of an event's auth events
+ # before the event itself.
+ chains_tuples_allocated = set() # type: Set[Tuple[int, int]]
+ new_chain_tuples = {} # type: Dict[str, Tuple[int, int]]
+ for event_id in sorted_topologically(
+ events_to_calc_chain_id_for, event_to_auth_chain
+ ):
+ existing_chain_id = None
+ for auth_id in event_to_auth_chain[event_id]:
+ if event_to_types.get(event_id) == event_to_types.get(auth_id):
+ existing_chain_id = chain_map[auth_id]
+ break
+
+ new_chain_tuple = None
+ if existing_chain_id:
+ # We found a chain ID/sequence number candidate, check its
+ # not already taken.
+ proposed_new_id = existing_chain_id[0]
+ proposed_new_seq = existing_chain_id[1] + 1
+ if (proposed_new_id, proposed_new_seq) not in chains_tuples_allocated:
+ already_allocated = self.db_pool.simple_select_one_onecol_txn(
+ txn,
+ table="event_auth_chains",
+ keyvalues={
+ "chain_id": proposed_new_id,
+ "sequence_number": proposed_new_seq,
+ },
+ retcol="event_id",
+ allow_none=True,
+ )
+ if already_allocated:
+ # Mark it as already allocated so we don't need to hit
+ # the DB again.
+ chains_tuples_allocated.add((proposed_new_id, proposed_new_seq))
+ else:
+ new_chain_tuple = (
+ proposed_new_id,
+ proposed_new_seq,
+ )
+
+ if not new_chain_tuple:
+ new_chain_tuple = (self._event_chain_id_gen.get_next_id_txn(txn), 1)
+
+ chains_tuples_allocated.add(new_chain_tuple)
+
+ chain_map[event_id] = new_chain_tuple
+ new_chain_tuples[event_id] = new_chain_tuple
+
+ self.db_pool.simple_insert_many_txn(
+ txn,
+ table="event_auth_chains",
+ values=[
+ {"event_id": event_id, "chain_id": c_id, "sequence_number": seq}
+ for event_id, (c_id, seq) in new_chain_tuples.items()
+ ],
+ )
+
+ self.db_pool.simple_delete_many_txn(
+ txn,
+ table="event_auth_chain_to_calculate",
+ keyvalues={},
+ column="event_id",
+ iterable=new_chain_tuples,
+ )
+
+ # Now we need to calculate any new links between chains caused by
+ # the new events.
+ #
+ # Links are pairs of chain ID/sequence numbers such that for any
+ # event A (CA, SA) and any event B (CB, SB), B is in A's auth chain
+ # if and only if there is at least one link (CA, S1) -> (CB, S2)
+ # where SA >= S1 and S2 >= SB.
+ #
+ # We try and avoid adding redundant links to the table, e.g. if we
+ # have two links between two chains which both start/end at the
+ # sequence number event (or cross) then one can be safely dropped.
+ #
+ # To calculate new links we look at every new event and:
+ # 1. Fetch the chain ID/sequence numbers of its auth events,
+ # discarding any that are reachable by other auth events, or
+ # that have the same chain ID as the event.
+ # 2. For each retained auth event we:
+ # a. Add a link from the event's to the auth event's chain
+ # ID/sequence number; and
+ # b. Add a link from the event to every chain reachable by the
+ # auth event.
+
+ # Step 1, fetch all existing links from all the chains we've seen
+ # referenced.
+ chain_links = _LinkMap()
+ rows = self.db_pool.simple_select_many_txn(
+ txn,
+ table="event_auth_chain_links",
+ column="origin_chain_id",
+ iterable={chain_id for chain_id, _ in chain_map.values()},
+ keyvalues={},
+ retcols=(
+ "origin_chain_id",
+ "origin_sequence_number",
+ "target_chain_id",
+ "target_sequence_number",
+ ),
+ )
+ for row in rows:
+ chain_links.add_link(
+ (row["origin_chain_id"], row["origin_sequence_number"]),
+ (row["target_chain_id"], row["target_sequence_number"]),
+ new=False,
+ )
+
+ # We do this in toplogical order to avoid adding redundant links.
+ for event_id in sorted_topologically(
+ events_to_calc_chain_id_for, event_to_auth_chain
+ ):
+ chain_id, sequence_number = chain_map[event_id]
+
+ # Filter out auth events that are reachable by other auth
+ # events. We do this by looking at every permutation of pairs of
+ # auth events (A, B) to check if B is reachable from A.
+ reduction = {
+ a_id
+ for a_id in event_to_auth_chain[event_id]
+ if chain_map[a_id][0] != chain_id
+ }
+ for start_auth_id, end_auth_id in itertools.permutations(
+ event_to_auth_chain[event_id], r=2,
+ ):
+ if chain_links.exists_path_from(
+ chain_map[start_auth_id], chain_map[end_auth_id]
+ ):
+ reduction.discard(end_auth_id)
+
+ # Step 2, figure out what the new links are from the reduced
+ # list of auth events.
+ for auth_id in reduction:
+ auth_chain_id, auth_sequence_number = chain_map[auth_id]
+
+ # Step 2a, add link between the event and auth event
+ chain_links.add_link(
+ (chain_id, sequence_number), (auth_chain_id, auth_sequence_number)
+ )
+
+ # Step 2b, add a link to chains reachable from the auth
+ # event.
+ for target_id, target_seq in chain_links.get_links_from(
+ (auth_chain_id, auth_sequence_number)
+ ):
+ if target_id == chain_id:
+ continue
+
+ chain_links.add_link(
+ (chain_id, sequence_number), (target_id, target_seq)
+ )
+
+ self.db_pool.simple_insert_many_txn(
+ txn,
+ table="event_auth_chain_links",
+ values=[
+ {
+ "origin_chain_id": source_id,
+ "origin_sequence_number": source_seq,
+ "target_chain_id": target_id,
+ "target_sequence_number": target_seq,
+ }
+ for (
+ source_id,
+ source_seq,
+ target_id,
+ target_seq,
+ ) in chain_links.get_additions()
+ ],
+ )
def _persist_transaction_ids_txn(
self,
@@ -1521,3 +1896,131 @@ class PersistEventsStore:
if not ev.internal_metadata.is_outlier()
],
)
+
+
+@attr.s(slots=True)
+class _LinkMap:
+ """A helper type for tracking links between chains.
+ """
+
+ # Stores the set of links as nested maps: source chain ID -> target chain ID
+ # -> source sequence number -> target sequence number.
+ maps = attr.ib(type=Dict[int, Dict[int, Dict[int, int]]], factory=dict)
+
+ # Stores the links that have been added (with new set to true), as tuples of
+ # `(source chain ID, source sequence no, target chain ID, target sequence no.)`
+ additions = attr.ib(type=Set[Tuple[int, int, int, int]], factory=set)
+
+ def add_link(
+ self,
+ src_tuple: Tuple[int, int],
+ target_tuple: Tuple[int, int],
+ new: bool = True,
+ ) -> bool:
+ """Add a new link between two chains, ensuring no redundant links are added.
+
+ New links should be added in topological order.
+
+ Args:
+ src_tuple: The chain ID/sequence number of the source of the link.
+ target_tuple: The chain ID/sequence number of the target of the link.
+ new: Whether this is a "new" link, i.e. should it be returned
+ by `get_additions`.
+
+ Returns:
+ True if a link was added, false if the given link was dropped as redundant
+ """
+ src_chain, src_seq = src_tuple
+ target_chain, target_seq = target_tuple
+
+ current_links = self.maps.setdefault(src_chain, {}).setdefault(target_chain, {})
+
+ assert src_chain != target_chain
+
+ if new:
+ # Check if the new link is redundant
+ for current_seq_src, current_seq_target in current_links.items():
+ # If a link "crosses" another link then its redundant. For example
+ # in the following link 1 (L1) is redundant, as any event reachable
+ # via L1 is *also* reachable via L2.
+ #
+ # Chain A Chain B
+ # | |
+ # L1 |------ |
+ # | | |
+ # L2 |---- | -->|
+ # | | |
+ # | |--->|
+ # | |
+ # | |
+ #
+ # So we only need to keep links which *do not* cross, i.e. links
+ # that both start and end above or below an existing link.
+ #
+ # Note, since we add links in topological ordering we should never
+ # see `src_seq` less than `current_seq_src`.
+
+ if current_seq_src <= src_seq and target_seq <= current_seq_target:
+ # This new link is redundant, nothing to do.
+ return False
+
+ self.additions.add((src_chain, src_seq, target_chain, target_seq))
+
+ current_links[src_seq] = target_seq
+ return True
+
+ def get_links_from(
+ self, src_tuple: Tuple[int, int]
+ ) -> Generator[Tuple[int, int], None, None]:
+ """Gets the chains reachable from the given chain/sequence number.
+
+ Yields:
+ The chain ID and sequence number the link points to.
+ """
+ src_chain, src_seq = src_tuple
+ for target_id, sequence_numbers in self.maps.get(src_chain, {}).items():
+ for link_src_seq, target_seq in sequence_numbers.items():
+ if link_src_seq <= src_seq:
+ yield target_id, target_seq
+
+ def get_links_between(
+ self, source_chain: int, target_chain: int
+ ) -> Generator[Tuple[int, int], None, None]:
+ """Gets the links between two chains.
+
+ Yields:
+ The source and target sequence numbers.
+ """
+
+ yield from self.maps.get(source_chain, {}).get(target_chain, {}).items()
+
+ def get_additions(self) -> Generator[Tuple[int, int, int, int], None, None]:
+ """Gets any newly added links.
+
+ Yields:
+ The source chain ID/sequence number and target chain ID/sequence number
+ """
+
+ for src_chain, src_seq, target_chain, _ in self.additions:
+ target_seq = self.maps.get(src_chain, {}).get(target_chain, {}).get(src_seq)
+ if target_seq is not None:
+ yield (src_chain, src_seq, target_chain, target_seq)
+
+ def exists_path_from(
+ self, src_tuple: Tuple[int, int], target_tuple: Tuple[int, int],
+ ) -> bool:
+ """Checks if there is a path between the source chain ID/sequence and
+ target chain ID/sequence.
+ """
+ src_chain, src_seq = src_tuple
+ target_chain, target_seq = target_tuple
+
+ if src_chain == target_chain:
+ return target_seq <= src_seq
+
+ links = self.get_links_between(src_chain, target_chain)
+ for link_start_seq, link_end_seq in links:
+ if link_start_seq <= src_seq and target_seq <= link_end_seq:
+ return True
+
+ return False
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 4650d0689b..284f2ce77c 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -84,7 +84,7 @@ class RoomWorkerStore(SQLBaseStore):
return await self.db_pool.simple_select_one(
table="rooms",
keyvalues={"room_id": room_id},
- retcols=("room_id", "is_public", "creator"),
+ retcols=("room_id", "is_public", "creator", "has_auth_chain_index"),
desc="get_room",
allow_none=True,
)
@@ -1166,6 +1166,37 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
# It's overridden by RoomStore for the synapse master.
raise NotImplementedError()
+ async def has_auth_chain_index(self, room_id: str) -> bool:
+ """Check if the room has (or can have) a chain cover index.
+
+ Defaults to True if we don't have an entry in `rooms` table nor any
+ events for the room.
+ """
+
+ has_auth_chain_index = await self.db_pool.simple_select_one_onecol(
+ table="rooms",
+ keyvalues={"room_id": room_id},
+ retcol="has_auth_chain_index",
+ desc="has_auth_chain_index",
+ allow_none=True,
+ )
+
+ if has_auth_chain_index:
+ return True
+
+ # It's possible that we already have events for the room in our DB
+ # without a corresponding room entry. If we do then we don't want to
+ # mark the room as having an auth chain cover index.
+ max_ordering = await self.db_pool.simple_select_one_onecol(
+ table="events",
+ keyvalues={"room_id": room_id},
+ retcol="MAX(stream_ordering)",
+ allow_none=True,
+ desc="upsert_room_on_join",
+ )
+
+ return max_ordering is None
+
class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
def __init__(self, database: DatabasePool, db_conn, hs):
@@ -1179,12 +1210,21 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
Called when we join a room over federation, and overwrites any room version
currently in the table.
"""
+ # It's possible that we already have events for the room in our DB
+ # without a corresponding room entry. If we do then we don't want to
+ # mark the room as having an auth chain cover index.
+ has_auth_chain_index = await self.has_auth_chain_index(room_id)
+
await self.db_pool.simple_upsert(
desc="upsert_room_on_join",
table="rooms",
keyvalues={"room_id": room_id},
values={"room_version": room_version.identifier},
- insertion_values={"is_public": False, "creator": ""},
+ insertion_values={
+ "is_public": False,
+ "creator": "",
+ "has_auth_chain_index": has_auth_chain_index,
+ },
# rooms has a unique constraint on room_id, so no need to lock when doing an
# emulated upsert.
lock=False,
@@ -1219,6 +1259,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
"creator": room_creator_user_id,
"is_public": is_public,
"room_version": room_version.identifier,
+ "has_auth_chain_index": True,
},
)
if is_public:
@@ -1247,6 +1288,11 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
When we receive an invite or any other event over federation that may relate to a room
we are not in, store the version of the room if we don't already know the room version.
"""
+ # It's possible that we already have events for the room in our DB
+ # without a corresponding room entry. If we do then we don't want to
+ # mark the room as having an auth chain cover index.
+ has_auth_chain_index = await self.has_auth_chain_index(room_id)
+
await self.db_pool.simple_upsert(
desc="maybe_store_room_on_outlier_membership",
table="rooms",
@@ -1256,6 +1302,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
"room_version": room_version.identifier,
"is_public": False,
"creator": "",
+ "has_auth_chain_index": has_auth_chain_index,
},
# rooms has a unique constraint on room_id, so no need to lock when doing an
# emulated upsert.
diff --git a/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql b/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql
new file mode 100644
index 0000000000..729196cfd5
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql
@@ -0,0 +1,52 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- See docs/auth_chain_difference_algorithm.md
+
+CREATE TABLE event_auth_chains (
+ event_id TEXT PRIMARY KEY,
+ chain_id BIGINT NOT NULL,
+ sequence_number BIGINT NOT NULL
+);
+
+CREATE UNIQUE INDEX event_auth_chains_c_seq_index ON event_auth_chains (chain_id, sequence_number);
+
+
+CREATE TABLE event_auth_chain_links (
+ origin_chain_id BIGINT NOT NULL,
+ origin_sequence_number BIGINT NOT NULL,
+
+ target_chain_id BIGINT NOT NULL,
+ target_sequence_number BIGINT NOT NULL
+);
+
+
+CREATE INDEX event_auth_chain_links_idx ON event_auth_chain_links (origin_chain_id, target_chain_id);
+
+
+-- Events that we have persisted but not calculated auth chains for,
+-- e.g. out of band memberships (where we don't have the auth chain)
+CREATE TABLE event_auth_chain_to_calculate (
+ event_id TEXT PRIMARY KEY,
+ room_id TEXT NOT NULL,
+ type TEXT NOT NULL,
+ state_key TEXT NOT NULL
+);
+
+CREATE INDEX event_auth_chain_to_calculate_rm_id ON event_auth_chain_to_calculate(room_id);
+
+
+-- Whether we've calculated the above index for a room.
+ALTER TABLE rooms ADD COLUMN has_auth_chain_index BOOLEAN;
diff --git a/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql.postgres b/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql.postgres
new file mode 100644
index 0000000000..e8a035bbeb
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/04_event_auth_chains.sql.postgres
@@ -0,0 +1,16 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE SEQUENCE IF NOT EXISTS event_auth_chain_id;
diff --git a/synapse/util/iterutils.py b/synapse/util/iterutils.py
index 06faeebe7f..f7b4857a84 100644
--- a/synapse/util/iterutils.py
+++ b/synapse/util/iterutils.py
@@ -13,8 +13,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import heapq
from itertools import islice
-from typing import Iterable, Iterator, Sequence, Tuple, TypeVar
+from typing import (
+ Dict,
+ Generator,
+ Iterable,
+ Iterator,
+ Mapping,
+ Sequence,
+ Set,
+ Tuple,
+ TypeVar,
+)
+
+from synapse.types import Collection
T = TypeVar("T")
@@ -46,3 +59,41 @@ def chunk_seq(iseq: ISeq, maxlen: int) -> Iterable[ISeq]:
If the input is empty, no chunks are returned.
"""
return (iseq[i : i + maxlen] for i in range(0, len(iseq), maxlen))
+
+
+def sorted_topologically(
+ nodes: Iterable[T], graph: Mapping[T, Collection[T]],
+) -> Generator[T, None, None]:
+ """Given a set of nodes and a graph, yield the nodes in toplogical order.
+
+ For example `sorted_topologically([1, 2], {1: [2]})` will yield `2, 1`.
+ """
+
+ # This is implemented by Kahn's algorithm.
+
+ degree_map = {node: 0 for node in nodes}
+ reverse_graph = {} # type: Dict[T, Set[T]]
+
+ for node, edges in graph.items():
+ if node not in degree_map:
+ continue
+
+ for edge in edges:
+ if edge in degree_map:
+ degree_map[node] += 1
+
+ reverse_graph.setdefault(edge, set()).add(node)
+ reverse_graph.setdefault(node, set())
+
+ zero_degree = [node for node, degree in degree_map.items() if degree == 0]
+ heapq.heapify(zero_degree)
+
+ while zero_degree:
+ node = heapq.heappop(zero_degree)
+ yield node
+
+ for edge in reverse_graph[node]:
+ if edge in degree_map:
+ degree_map[edge] -= 1
+ if degree_map[edge] == 0:
+ heapq.heappush(zero_degree, edge)
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
new file mode 100644
index 0000000000..83c377824b
--- /dev/null
+++ b/tests/storage/test_event_chain.py
@@ -0,0 +1,472 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Dict, List, Tuple
+
+from twisted.trial import unittest
+
+from synapse.api.constants import EventTypes
+from synapse.api.room_versions import RoomVersions
+from synapse.events import EventBase
+from synapse.storage.databases.main.events import _LinkMap
+
+from tests.unittest import HomeserverTestCase
+
+
+class EventChainStoreTestCase(HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self._next_stream_ordering = 1
+
+ def test_simple(self):
+ """Test that the example in `docs/auth_chain_difference_algorithm.md`
+ works.
+ """
+
+ event_factory = self.hs.get_event_builder_factory()
+ bob = "@creator:test"
+ alice = "@alice:test"
+ room_id = "!room:test"
+
+ # Ensure that we have a rooms entry so that we generate the chain index.
+ self.get_success(
+ self.store.store_room(
+ room_id=room_id,
+ room_creator_user_id="",
+ is_public=True,
+ room_version=RoomVersions.V6,
+ )
+ )
+
+ create = self.get_success(
+ event_factory.for_room_version(
+ RoomVersions.V6,
+ {
+ "type": EventTypes.Create,
+ "state_key": "",
+ "sender": bob,
+ "room_id": room_id,
+ "content": {"tag": "create"},
+ },
+ ).build(prev_event_ids=[], auth_event_ids=[])
+ )
+
+ bob_join = self.get_success(
+ event_factory.for_room_version(
+ RoomVersions.V6,
+ {
+ "type": EventTypes.Member,
+ "state_key": bob,
+ "sender": bob,
+ "room_id": room_id,
+ "content": {"tag": "bob_join"},
+ },
+ ).build(prev_event_ids=[], auth_event_ids=[create.event_id])
+ )
+
+ power = self.get_success(
+ event_factory.for_room_version(
+ RoomVersions.V6,
+ {
+ "type": EventTypes.PowerLevels,
+ "state_key": "",
+ "sender": bob,
+ "room_id": room_id,
+ "content": {"tag": "power"},
+ },
+ ).build(
+ prev_event_ids=[], auth_event_ids=[create.event_id, bob_join.event_id],
+ )
+ )
+
+ alice_invite = self.get_success(
+ event_factory.for_room_version(
+ RoomVersions.V6,
+ {
+ "type": EventTypes.Member,
+ "state_key": alice,
+ "sender": bob,
+ "room_id": room_id,
+ "content": {"tag": "alice_invite"},
+ },
+ ).build(
+ prev_event_ids=[],
+ auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
+ )
+ )
+
+ alice_join = self.get_success(
+ event_factory.for_room_version(
+ RoomVersions.V6,
+ {
+ "type": EventTypes.Member,
+ "state_key": alice,
+ "sender": alice,
+ "room_id": room_id,
+ "content": {"tag": "alice_join"},
+ },
+ ).build(
+ prev_event_ids=[],
+ auth_event_ids=[create.event_id, alice_invite.event_id, power.event_id],
+ )
+ )
+
+ power_2 = self.get_success(
+ event_factory.for_room_version(
+ RoomVersions.V6,
+ {
+ "type": EventTypes.PowerLevels,
+ "state_key": "",
+ "sender": bob,
+ "room_id": room_id,
+ "content": {"tag": "power_2"},
+ },
+ ).build(
+ prev_event_ids=[],
+ auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
+ )
+ )
+
+ bob_join_2 = self.get_success(
+ event_factory.for_room_version(
+ RoomVersions.V6,
+ {
+ "type": EventTypes.Member,
+ "state_key": bob,
+ "sender": bob,
+ "room_id": room_id,
+ "content": {"tag": "bob_join_2"},
+ },
+ ).build(
+ prev_event_ids=[],
+ auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
+ )
+ )
+
+ alice_join2 = self.get_success(
+ event_factory.for_room_version(
+ RoomVersions.V6,
+ {
+ "type": EventTypes.Member,
+ "state_key": alice,
+ "sender": alice,
+ "room_id": room_id,
+ "content": {"tag": "alice_join2"},
+ },
+ ).build(
+ prev_event_ids=[],
+ auth_event_ids=[
+ create.event_id,
+ alice_join.event_id,
+ power_2.event_id,
+ ],
+ )
+ )
+
+ events = [
+ create,
+ bob_join,
+ power,
+ alice_invite,
+ alice_join,
+ bob_join_2,
+ power_2,
+ alice_join2,
+ ]
+
+ expected_links = [
+ (bob_join, create),
+ (power, create),
+ (power, bob_join),
+ (alice_invite, create),
+ (alice_invite, power),
+ (alice_invite, bob_join),
+ (bob_join_2, power),
+ (alice_join2, power_2),
+ ]
+
+ self.persist(events)
+ chain_map, link_map = self.fetch_chains(events)
+
+ # Check that the expected links and only the expected links have been
+ # added.
+ self.assertEqual(len(expected_links), len(list(link_map.get_additions())))
+
+ for start, end in expected_links:
+ start_id, start_seq = chain_map[start.event_id]
+ end_id, end_seq = chain_map[end.event_id]
+
+ self.assertIn(
+ (start_seq, end_seq), list(link_map.get_links_between(start_id, end_id))
+ )
+
+ # Test that everything can reach the create event, but the create event
+ # can't reach anything.
+ for event in events[1:]:
+ self.assertTrue(
+ link_map.exists_path_from(
+ chain_map[event.event_id], chain_map[create.event_id]
+ ),
+ )
+
+ self.assertFalse(
+ link_map.exists_path_from(
+ chain_map[create.event_id], chain_map[event.event_id],
+ ),
+ )
+
+ def test_out_of_order_events(self):
+ """Test that we handle persisting events that we don't have the full
+ auth chain for yet (which should only happen for out of band memberships).
+ """
+ event_factory = self.hs.get_event_builder_factory()
+ bob = "@creator:test"
+ alice = "@alice:test"
+ room_id = "!room:test"
+
+ # Ensure that we have a rooms entry so that we generate the chain index.
+ self.get_success(
+ self.store.store_room(
+ room_id=room_id,
+ room_creator_user_id="",
+ is_public=True,
+ room_version=RoomVersions.V6,
+ )
+ )
+
+ # First persist the base room.
+ create = self.get_success(
+ event_factory.for_room_version(
+ RoomVersions.V6,
+ {
+ "type": EventTypes.Create,
+ "state_key": "",
+ "sender": bob,
+ "room_id": room_id,
+ "content": {"tag": "create"},
+ },
+ ).build(prev_event_ids=[], auth_event_ids=[])
+ )
+
+ bob_join = self.get_success(
+ event_factory.for_room_version(
+ RoomVersions.V6,
+ {
+ "type": EventTypes.Member,
+ "state_key": bob,
+ "sender": bob,
+ "room_id": room_id,
+ "content": {"tag": "bob_join"},
+ },
+ ).build(prev_event_ids=[], auth_event_ids=[create.event_id])
+ )
+
+ power = self.get_success(
+ event_factory.for_room_version(
+ RoomVersions.V6,
+ {
+ "type": EventTypes.PowerLevels,
+ "state_key": "",
+ "sender": bob,
+ "room_id": room_id,
+ "content": {"tag": "power"},
+ },
+ ).build(
+ prev_event_ids=[], auth_event_ids=[create.event_id, bob_join.event_id],
+ )
+ )
+
+ self.persist([create, bob_join, power])
+
+ # Now persist an invite and a couple of memberships out of order.
+ alice_invite = self.get_success(
+ event_factory.for_room_version(
+ RoomVersions.V6,
+ {
+ "type": EventTypes.Member,
+ "state_key": alice,
+ "sender": bob,
+ "room_id": room_id,
+ "content": {"tag": "alice_invite"},
+ },
+ ).build(
+ prev_event_ids=[],
+ auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
+ )
+ )
+
+ alice_join = self.get_success(
+ event_factory.for_room_version(
+ RoomVersions.V6,
+ {
+ "type": EventTypes.Member,
+ "state_key": alice,
+ "sender": alice,
+ "room_id": room_id,
+ "content": {"tag": "alice_join"},
+ },
+ ).build(
+ prev_event_ids=[],
+ auth_event_ids=[create.event_id, alice_invite.event_id, power.event_id],
+ )
+ )
+
+ alice_join2 = self.get_success(
+ event_factory.for_room_version(
+ RoomVersions.V6,
+ {
+ "type": EventTypes.Member,
+ "state_key": alice,
+ "sender": alice,
+ "room_id": room_id,
+ "content": {"tag": "alice_join2"},
+ },
+ ).build(
+ prev_event_ids=[],
+ auth_event_ids=[create.event_id, alice_join.event_id, power.event_id],
+ )
+ )
+
+ self.persist([alice_join])
+ self.persist([alice_join2])
+ self.persist([alice_invite])
+
+ # The end result should be sane.
+ events = [create, bob_join, power, alice_invite, alice_join]
+
+ chain_map, link_map = self.fetch_chains(events)
+
+ expected_links = [
+ (bob_join, create),
+ (power, create),
+ (power, bob_join),
+ (alice_invite, create),
+ (alice_invite, power),
+ (alice_invite, bob_join),
+ ]
+
+ # Check that the expected links and only the expected links have been
+ # added.
+ self.assertEqual(len(expected_links), len(list(link_map.get_additions())))
+
+ for start, end in expected_links:
+ start_id, start_seq = chain_map[start.event_id]
+ end_id, end_seq = chain_map[end.event_id]
+
+ self.assertIn(
+ (start_seq, end_seq), list(link_map.get_links_between(start_id, end_id))
+ )
+
+ def persist(
+ self, events: List[EventBase],
+ ):
+ """Persist the given events and check that the links generated match
+ those given.
+ """
+
+ persist_events_store = self.hs.get_datastores().persist_events
+
+ for e in events:
+ e.internal_metadata.stream_ordering = self._next_stream_ordering
+ self._next_stream_ordering += 1
+
+ def _persist(txn):
+ # We need to persist the events to the events and state_events
+ # tables.
+ persist_events_store._store_event_txn(txn, [(e, {}) for e in events])
+
+ # Actually call the function that calculates the auth chain stuff.
+ persist_events_store._persist_event_auth_chain_txn(txn, events)
+
+ self.get_success(
+ persist_events_store.db_pool.runInteraction("_persist", _persist,)
+ )
+
+ def fetch_chains(
+ self, events: List[EventBase]
+ ) -> Tuple[Dict[str, Tuple[int, int]], _LinkMap]:
+
+ # Fetch the map from event ID -> (chain ID, sequence number)
+ rows = self.get_success(
+ self.store.db_pool.simple_select_many_batch(
+ table="event_auth_chains",
+ column="event_id",
+ iterable=[e.event_id for e in events],
+ retcols=("event_id", "chain_id", "sequence_number"),
+ keyvalues={},
+ )
+ )
+
+ chain_map = {
+ row["event_id"]: (row["chain_id"], row["sequence_number"]) for row in rows
+ }
+
+ # Fetch all the links and pass them to the _LinkMap.
+ rows = self.get_success(
+ self.store.db_pool.simple_select_many_batch(
+ table="event_auth_chain_links",
+ column="origin_chain_id",
+ iterable=[chain_id for chain_id, _ in chain_map.values()],
+ retcols=(
+ "origin_chain_id",
+ "origin_sequence_number",
+ "target_chain_id",
+ "target_sequence_number",
+ ),
+ keyvalues={},
+ )
+ )
+
+ link_map = _LinkMap()
+ for row in rows:
+ added = link_map.add_link(
+ (row["origin_chain_id"], row["origin_sequence_number"]),
+ (row["target_chain_id"], row["target_sequence_number"]),
+ )
+
+ # We shouldn't have persisted any redundant links
+ self.assertTrue(added)
+
+ return chain_map, link_map
+
+
+class LinkMapTestCase(unittest.TestCase):
+ def test_simple(self):
+ """Basic tests for the LinkMap.
+ """
+ link_map = _LinkMap()
+
+ link_map.add_link((1, 1), (2, 1), new=False)
+ self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)])
+ self.assertCountEqual(link_map.get_links_from((1, 1)), [(2, 1)])
+ self.assertCountEqual(link_map.get_additions(), [])
+ self.assertTrue(link_map.exists_path_from((1, 5), (2, 1)))
+ self.assertFalse(link_map.exists_path_from((1, 5), (2, 2)))
+ self.assertTrue(link_map.exists_path_from((1, 5), (1, 1)))
+ self.assertFalse(link_map.exists_path_from((1, 1), (1, 5)))
+
+ # Attempting to add a redundant link is ignored.
+ self.assertFalse(link_map.add_link((1, 4), (2, 1)))
+ self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)])
+
+ # Adding new non-redundant links works
+ self.assertTrue(link_map.add_link((1, 3), (2, 3)))
+ self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)])
+
+ self.assertTrue(link_map.add_link((2, 5), (1, 3)))
+ self.assertCountEqual(link_map.get_links_between(2, 1), [(5, 3)])
+ self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)])
+
+ self.assertCountEqual(link_map.get_additions(), [(1, 3, 2, 3), (2, 5, 1, 3)])
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 482506d731..9d04a066d8 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -13,6 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import attr
+from parameterized import parameterized
+
+from synapse.events import _EventInternalMetadata
+
import tests.unittest
import tests.utils
@@ -113,7 +118,8 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
r = self.get_success(self.store.get_rooms_with_many_extremities(5, 1, [room1]))
self.assertTrue(r == [room2] or r == [room3])
- def test_auth_difference(self):
+ @parameterized.expand([(True,), (False,)])
+ def test_auth_difference(self, use_chain_cover_index: bool):
room_id = "@ROOM:local"
# The silly auth graph we use to test the auth difference algorithm,
@@ -159,46 +165,223 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
"j": 1,
}
+ # Mark the room as not having a cover index
+
+ def store_room(txn):
+ self.store.db_pool.simple_insert_txn(
+ txn,
+ "rooms",
+ {
+ "room_id": room_id,
+ "creator": "room_creator_user_id",
+ "is_public": True,
+ "room_version": "6",
+ "has_auth_chain_index": use_chain_cover_index,
+ },
+ )
+
+ self.get_success(self.store.db_pool.runInteraction("store_room", store_room))
+
# We rudely fiddle with the appropriate tables directly, as that's much
# easier than constructing events properly.
- def insert_event(txn, event_id, stream_ordering):
+ def insert_event(txn):
+ stream_ordering = 0
+
+ for event_id in auth_graph:
+ stream_ordering += 1
+ depth = depth_map[event_id]
+
+ self.store.db_pool.simple_insert_txn(
+ txn,
+ table="events",
+ values={
+ "event_id": event_id,
+ "room_id": room_id,
+ "depth": depth,
+ "topological_ordering": depth,
+ "type": "m.test",
+ "processed": True,
+ "outlier": False,
+ "stream_ordering": stream_ordering,
+ },
+ )
+
+ self.hs.datastores.persist_events._persist_event_auth_chain_txn(
+ txn,
+ [
+ FakeEvent(event_id, room_id, auth_graph[event_id])
+ for event_id in auth_graph
+ ],
+ )
+
+ self.get_success(self.store.db_pool.runInteraction("insert", insert_event,))
+
+ # Now actually test that various combinations give the right result:
+
+ difference = self.get_success(
+ self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}])
+ )
+ self.assertSetEqual(difference, {"a", "b"})
+
+ difference = self.get_success(
+ self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}])
+ )
+ self.assertSetEqual(difference, {"a", "b", "c", "e", "f"})
+
+ difference = self.get_success(
+ self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b"}])
+ )
+ self.assertSetEqual(difference, {"a", "b", "c"})
+
+ difference = self.get_success(
+ self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b", "c"}])
+ )
+ self.assertSetEqual(difference, {"a", "b"})
+
+ difference = self.get_success(
+ self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"d"}])
+ )
+ self.assertSetEqual(difference, {"a", "b", "d", "e"})
+
+ difference = self.get_success(
+ self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}, {"d"}])
+ )
+ self.assertSetEqual(difference, {"a", "b", "c", "d", "e", "f"})
+
+ difference = self.get_success(
+ self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"e"}])
+ )
+ self.assertSetEqual(difference, {"a", "b"})
+
+ difference = self.get_success(
+ self.store.get_auth_chain_difference(room_id, [{"a"}])
+ )
+ self.assertSetEqual(difference, set())
+
+ def test_auth_difference_partial_cover(self):
+ """Test that we correctly handle rooms where not all events have a chain
+ cover calculated. This can happen in some obscure edge cases, including
+ during the background update that calculates the chain cover for old
+ rooms.
+ """
+
+ room_id = "@ROOM:local"
+
+ # The silly auth graph we use to test the auth difference algorithm,
+ # where the top are the most recent events.
+ #
+ # A B
+ # \ /
+ # D E
+ # \ |
+ # ` F C
+ # | /|
+ # G ยด |
+ # | \ |
+ # H I
+ # | |
+ # K J
+
+ auth_graph = {
+ "a": ["e"],
+ "b": ["e"],
+ "c": ["g", "i"],
+ "d": ["f"],
+ "e": ["f"],
+ "f": ["g"],
+ "g": ["h", "i"],
+ "h": ["k"],
+ "i": ["j"],
+ "k": [],
+ "j": [],
+ }
+
+ depth_map = {
+ "a": 7,
+ "b": 7,
+ "c": 4,
+ "d": 6,
+ "e": 6,
+ "f": 5,
+ "g": 3,
+ "h": 2,
+ "i": 2,
+ "k": 1,
+ "j": 1,
+ }
- depth = depth_map[event_id]
+ # We rudely fiddle with the appropriate tables directly, as that's much
+ # easier than constructing events properly.
+ def insert_event(txn):
+ # First insert the room and mark it as having a chain cover.
self.store.db_pool.simple_insert_txn(
txn,
- table="events",
- values={
- "event_id": event_id,
+ "rooms",
+ {
"room_id": room_id,
- "depth": depth,
- "topological_ordering": depth,
- "type": "m.test",
- "processed": True,
- "outlier": False,
- "stream_ordering": stream_ordering,
+ "creator": "room_creator_user_id",
+ "is_public": True,
+ "room_version": "6",
+ "has_auth_chain_index": True,
},
)
- self.store.db_pool.simple_insert_many_txn(
+ stream_ordering = 0
+
+ for event_id in auth_graph:
+ stream_ordering += 1
+ depth = depth_map[event_id]
+
+ self.store.db_pool.simple_insert_txn(
+ txn,
+ table="events",
+ values={
+ "event_id": event_id,
+ "room_id": room_id,
+ "depth": depth,
+ "topological_ordering": depth,
+ "type": "m.test",
+ "processed": True,
+ "outlier": False,
+ "stream_ordering": stream_ordering,
+ },
+ )
+
+ # Insert all events apart from 'B'
+ self.hs.datastores.persist_events._persist_event_auth_chain_txn(
txn,
- table="event_auth",
- values=[
- {"event_id": event_id, "room_id": room_id, "auth_id": a}
- for a in auth_graph[event_id]
+ [
+ FakeEvent(event_id, room_id, auth_graph[event_id])
+ for event_id in auth_graph
+ if event_id != "b"
],
)
- next_stream_ordering = 0
- for event_id in auth_graph:
- next_stream_ordering += 1
- self.get_success(
- self.store.db_pool.runInteraction(
- "insert", insert_event, event_id, next_stream_ordering
- )
+ # Now we insert the event 'B' without a chain cover, by temporarily
+ # pretending the room doesn't have a chain cover.
+
+ self.store.db_pool.simple_update_txn(
+ txn,
+ table="rooms",
+ keyvalues={"room_id": room_id},
+ updatevalues={"has_auth_chain_index": False},
+ )
+
+ self.hs.datastores.persist_events._persist_event_auth_chain_txn(
+ txn, [FakeEvent("b", room_id, auth_graph["b"])],
+ )
+
+ self.store.db_pool.simple_update_txn(
+ txn,
+ table="rooms",
+ keyvalues={"room_id": room_id},
+ updatevalues={"has_auth_chain_index": True},
)
+ self.get_success(self.store.db_pool.runInteraction("insert", insert_event,))
+
# Now actually test that various combinations give the right result:
difference = self.get_success(
@@ -240,3 +423,21 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
self.store.get_auth_chain_difference(room_id, [{"a"}])
)
self.assertSetEqual(difference, set())
+
+
+@attr.s
+class FakeEvent:
+ event_id = attr.ib()
+ room_id = attr.ib()
+ auth_events = attr.ib()
+
+ type = "foo"
+ state_key = "foo"
+
+ internal_metadata = _EventInternalMetadata({})
+
+ def auth_event_ids(self):
+ return self.auth_events
+
+ def is_state(self):
+ return True
diff --git a/tests/util/test_itertools.py b/tests/util/test_itertools.py
index 0ab0a91483..1184cea5a3 100644
--- a/tests/util/test_itertools.py
+++ b/tests/util/test_itertools.py
@@ -12,7 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.util.iterutils import chunk_seq
+from typing import Dict, List
+
+from synapse.util.iterutils import chunk_seq, sorted_topologically
from tests.unittest import TestCase
@@ -45,3 +47,40 @@ class ChunkSeqTests(TestCase):
self.assertEqual(
list(parts), [],
)
+
+
+class SortTopologically(TestCase):
+ def test_empty(self):
+ "Test that an empty graph works correctly"
+
+ graph = {} # type: Dict[int, List[int]]
+ self.assertEqual(list(sorted_topologically([], graph)), [])
+
+ def test_disconnected(self):
+ "Test that a graph with no edges work"
+
+ graph = {1: [], 2: []} # type: Dict[int, List[int]]
+
+ # For disconnected nodes the output is simply sorted.
+ self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2])
+
+ def test_linear(self):
+ "Test that a simple `4 -> 3 -> 2 -> 1` graph works"
+
+ graph = {1: [], 2: [1], 3: [2], 4: [3]} # type: Dict[int, List[int]]
+
+ self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
+
+ def test_subset(self):
+ "Test that only sorting a subset of the graph works"
+ graph = {1: [], 2: [1], 3: [2], 4: [3]} # type: Dict[int, List[int]]
+
+ self.assertEqual(list(sorted_topologically([4, 3], graph)), [3, 4])
+
+ def test_fork(self):
+ "Test that a forked graph works"
+ graph = {1: [], 2: [1], 3: [1], 4: [2, 3]} # type: Dict[int, List[int]]
+
+ # Valid orderings are `[1, 3, 2, 4]` or `[1, 2, 3, 4]`, but we should
+ # always get the same one.
+ self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
|