diff --git a/changelog.d/7095.misc b/changelog.d/7095.misc
new file mode 100644
index 0000000000..44fc9f616f
--- /dev/null
+++ b/changelog.d/7095.misc
@@ -0,0 +1 @@
+Attempt to improve performance of state res v2 algorithm.
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index df7a4f6a89..4afefc6b1d 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -662,28 +662,16 @@ class StateResolutionStore(object):
allow_rejected=allow_rejected,
)
- def get_auth_chain(self, event_ids: List[str], ignore_events: Set[str]):
- """Gets the full auth chain for a set of events (including rejected
- events).
-
- Includes the given event IDs in the result.
-
- Note that:
- 1. All events must be state events.
- 2. For v1 rooms this may not have the full auth chain in the
- presence of rejected events
-
- Args:
- event_ids: The event IDs of the events to fetch the auth chain for.
- Must be state events.
- ignore_events: Set of events to exclude from the returned auth
- chain.
+ def get_auth_chain_difference(self, state_sets: List[Set[str]]):
+ """Given sets of state events figure out the auth chain difference (as
+ per state res v2 algorithm).
+ This equivalent to fetching the full auth chain for each set of state
+ and returning the events that don't appear in each and every auth
+ chain.
Returns:
- Deferred[list[str]]: List of event IDs of the auth chain.
+ Deferred[Set[str]]: Set of event IDs.
"""
- return self.store.get_auth_chain_ids(
- event_ids, include_given=True, ignore_events=ignore_events,
- )
+ return self.store.get_auth_chain_difference(state_sets)
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index 0ffe6d8c14..18484e2fa6 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -227,36 +227,12 @@ def _get_auth_chain_difference(state_sets, event_map, state_res_store):
Returns:
Deferred[set[str]]: Set of event IDs
"""
- common = set(itervalues(state_sets[0])).intersection(
- *(itervalues(s) for s in state_sets[1:])
- )
-
- auth_sets = []
- for state_set in state_sets:
- auth_ids = {
- eid
- for key, eid in iteritems(state_set)
- if (
- key[0] in (EventTypes.Member, EventTypes.ThirdPartyInvite)
- or key
- in (
- (EventTypes.PowerLevels, ""),
- (EventTypes.Create, ""),
- (EventTypes.JoinRules, ""),
- )
- )
- and eid not in common
- }
- auth_chain = yield state_res_store.get_auth_chain(auth_ids, common)
- auth_ids.update(auth_chain)
-
- auth_sets.append(auth_ids)
-
- intersection = set(auth_sets[0]).intersection(*auth_sets[1:])
- union = set().union(*auth_sets)
+ difference = yield state_res_store.get_auth_chain_difference(
+ [set(state_set.values()) for state_set in state_sets]
+ )
- return union - intersection
+ return difference
def _seperate(state_sets):
diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py
index 49a7b8b433..62d4e9f599 100644
--- a/synapse/storage/data_stores/main/event_federation.py
+++ b/synapse/storage/data_stores/main/event_federation.py
@@ -14,7 +14,7 @@
# limitations under the License.
import itertools
import logging
-from typing import List, Optional, Set
+from typing import Dict, List, Optional, Set, Tuple
from six.moves.queue import Empty, PriorityQueue
@@ -103,6 +103,154 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
return list(results)
+ def get_auth_chain_difference(self, state_sets: List[Set[str]]):
+ """Given sets of state events figure out the auth chain difference (as
+ per state res v2 algorithm).
+
+ This equivalent to fetching the full auth chain for each set of state
+ and returning the events that don't appear in each and every auth
+ chain.
+
+ Returns:
+ Deferred[Set[str]]
+ """
+
+ return self.db.runInteraction(
+ "get_auth_chain_difference",
+ self._get_auth_chain_difference_txn,
+ state_sets,
+ )
+
+ def _get_auth_chain_difference_txn(
+ self, txn, state_sets: List[Set[str]]
+ ) -> Set[str]:
+
+ # Algorithm Description
+ # ~~~~~~~~~~~~~~~~~~~~~
+ #
+ # The idea here is to basically walk the auth graph of each state set in
+ # tandem, keeping track of which auth events are reachable by each state
+ # set. If we reach an auth event we've already visited (via a different
+ # state set) then we mark that auth event and all ancestors as reachable
+ # by the state set. This requires that we keep track of the auth chains
+ # in memory.
+ #
+ # Doing it in a such a way means that we can stop early if all auth
+ # events we're currently walking are reachable by all state sets.
+ #
+ # *Note*: We can't stop walking an event's auth chain if it is reachable
+ # by all state sets. This is because other auth chains we're walking
+ # might be reachable only via the original auth chain. For example,
+ # given the following auth chain:
+ #
+ # A -> C -> D -> E
+ # / /
+ # B -´---------´
+ #
+ # and state sets {A} and {B} then walking the auth chains of A and B
+ # would immediately show that C is reachable by both. However, if we
+ # stopped at C then we'd only reach E via the auth chain of B and so E
+ # would errornously get included in the returned difference.
+ #
+ # The other thing that we do is limit the number of auth chains we walk
+ # at once, due to practical limits (i.e. we can only query the database
+ # with a limited set of parameters). We pick the auth chains we walk
+ # each iteration based on their depth, in the hope that events with a
+ # lower depth are likely reachable by those with higher depths.
+ #
+ # We could use any ordering that we believe would give a rough
+ # topological ordering, e.g. origin server timestamp. If the ordering
+ # chosen is not topological then the algorithm still produces the right
+ # result, but perhaps a bit more inefficiently. This is why it is safe
+ # to use "depth" here.
+
+ initial_events = set(state_sets[0]).union(*state_sets[1:])
+
+ # Dict from events in auth chains to which sets *cannot* reach them.
+ # I.e. if the set is empty then all sets can reach the event.
+ event_to_missing_sets = {
+ event_id: {i for i, a in enumerate(state_sets) if event_id not in a}
+ for event_id in initial_events
+ }
+
+ # We need to get the depth of the initial events for sorting purposes.
+ sql = """
+ SELECT depth, event_id FROM events
+ WHERE %s
+ ORDER BY depth ASC
+ """
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "event_id", initial_events
+ )
+ txn.execute(sql % (clause,), args)
+
+ # The sorted list of events whose auth chains we should walk.
+ search = txn.fetchall() # type: List[Tuple[int, str]]
+
+ # Map from event to its auth events
+ event_to_auth_events = {} # type: Dict[str, Set[str]]
+
+ base_sql = """
+ SELECT a.event_id, auth_id, depth
+ FROM event_auth AS a
+ INNER JOIN events AS e ON (e.event_id = a.auth_id)
+ WHERE
+ """
+
+ while search:
+ # Check whether all our current walks are reachable by all state
+ # sets. If so we can bail.
+ if all(not event_to_missing_sets[eid] for _, eid in search):
+ break
+
+ # Fetch the auth events and their depths of the N last events we're
+ # currently walking
+ search, chunk = search[:-100], search[-100:]
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "a.event_id", [e_id for _, e_id in chunk]
+ )
+ txn.execute(base_sql + clause, args)
+
+ for event_id, auth_event_id, auth_event_depth in txn:
+ event_to_auth_events.setdefault(event_id, set()).add(auth_event_id)
+
+ sets = event_to_missing_sets.get(auth_event_id)
+ if sets is None:
+ # First time we're seeing this event, so we add it to the
+ # queue of things to fetch.
+ search.append((auth_event_depth, auth_event_id))
+
+ # Assume that this event is unreachable from any of the
+ # state sets until proven otherwise
+ sets = event_to_missing_sets[auth_event_id] = set(
+ range(len(state_sets))
+ )
+ else:
+ # We've previously seen this event, so look up its auth
+ # events and recursively mark all ancestors as reachable
+ # by the current event's state set.
+ a_ids = event_to_auth_events.get(auth_event_id)
+ while a_ids:
+ new_aids = set()
+ for a_id in a_ids:
+ event_to_missing_sets[a_id].intersection_update(
+ event_to_missing_sets[event_id]
+ )
+
+ b = event_to_auth_events.get(a_id)
+ if b:
+ new_aids.update(b)
+
+ a_ids = new_aids
+
+ # Mark that the auth event is reachable by the approriate sets.
+ sets.intersection_update(event_to_missing_sets[event_id])
+
+ search.sort()
+
+ # Return all events where not all sets can reach them.
+ return {eid for eid, n in event_to_missing_sets.items() if n}
+
def get_oldest_events_in_room(self, room_id):
return self.db.runInteraction(
"get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id
diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py
index 5059ade850..a44960203e 100644
--- a/tests/state/test_v2.py
+++ b/tests/state/test_v2.py
@@ -603,7 +603,7 @@ class TestStateResolutionStore(object):
return {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map}
- def get_auth_chain(self, event_ids, ignore_events):
+ def _get_auth_chain(self, event_ids):
"""Gets the full auth chain for a set of events (including rejected
events).
@@ -617,9 +617,6 @@ class TestStateResolutionStore(object):
Args:
event_ids (list): The event IDs of the events to fetch the auth
chain for. Must be state events.
- ignore_events: Set of events to exclude from the returned auth
- chain.
-
Returns:
Deferred[list[str]]: List of event IDs of the auth chain.
"""
@@ -629,7 +626,7 @@ class TestStateResolutionStore(object):
stack = list(event_ids)
while stack:
event_id = stack.pop()
- if event_id in result or event_id in ignore_events:
+ if event_id in result:
continue
result.add(event_id)
@@ -639,3 +636,9 @@ class TestStateResolutionStore(object):
stack.append(aid)
return list(result)
+
+ def get_auth_chain_difference(self, auth_sets):
+ chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets]
+
+ common = set(chains[0]).intersection(*chains[1:])
+ return set(chains[0]).union(*chains[1:]) - common
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index a331517f4d..3aeec0dc0f 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -13,19 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
-
import tests.unittest
import tests.utils
-class EventFederationWorkerStoreTestCase(tests.unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
+class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
+ def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
- @defer.inlineCallbacks
def test_get_prev_events_for_room(self):
room_id = "@ROOM:local"
@@ -61,15 +56,14 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase):
)
for i in range(0, 20):
- yield self.store.db.runInteraction("insert", insert_event, i)
+ self.get_success(self.store.db.runInteraction("insert", insert_event, i))
# this should get the last ten
- r = yield self.store.get_prev_events_for_room(room_id)
+ r = self.get_success(self.store.get_prev_events_for_room(room_id))
self.assertEqual(10, len(r))
for i in range(0, 10):
self.assertEqual("$event_%i:local" % (19 - i), r[i])
- @defer.inlineCallbacks
def test_get_rooms_with_many_extremities(self):
room1 = "#room1"
room2 = "#room2"
@@ -86,25 +80,154 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase):
)
for i in range(0, 20):
- yield self.store.db.runInteraction("insert", insert_event, i, room1)
- yield self.store.db.runInteraction("insert", insert_event, i, room2)
- yield self.store.db.runInteraction("insert", insert_event, i, room3)
+ self.get_success(
+ self.store.db.runInteraction("insert", insert_event, i, room1)
+ )
+ self.get_success(
+ self.store.db.runInteraction("insert", insert_event, i, room2)
+ )
+ self.get_success(
+ self.store.db.runInteraction("insert", insert_event, i, room3)
+ )
# Test simple case
- r = yield self.store.get_rooms_with_many_extremities(5, 5, [])
+ r = self.get_success(self.store.get_rooms_with_many_extremities(5, 5, []))
self.assertEqual(len(r), 3)
# Does filter work?
- r = yield self.store.get_rooms_with_many_extremities(5, 5, [room1])
+ r = self.get_success(self.store.get_rooms_with_many_extremities(5, 5, [room1]))
self.assertTrue(room2 in r)
self.assertTrue(room3 in r)
self.assertEqual(len(r), 2)
- r = yield self.store.get_rooms_with_many_extremities(5, 5, [room1, room2])
+ r = self.get_success(
+ self.store.get_rooms_with_many_extremities(5, 5, [room1, room2])
+ )
self.assertEqual(r, [room3])
# Does filter and limit work?
- r = yield self.store.get_rooms_with_many_extremities(5, 1, [room1])
+ r = self.get_success(self.store.get_rooms_with_many_extremities(5, 1, [room1]))
self.assertTrue(r == [room2] or r == [room3])
+
+ def test_auth_difference(self):
+ room_id = "@ROOM:local"
+
+ # The silly auth graph we use to test the auth difference algorithm,
+ # where the top are the most recent events.
+ #
+ # A B
+ # \ /
+ # D E
+ # \ |
+ # ` F C
+ # | /|
+ # G ´ |
+ # | \ |
+ # H I
+ # | |
+ # K J
+
+ auth_graph = {
+ "a": ["e"],
+ "b": ["e"],
+ "c": ["g", "i"],
+ "d": ["f"],
+ "e": ["f"],
+ "f": ["g"],
+ "g": ["h", "i"],
+ "h": ["k"],
+ "i": ["j"],
+ "k": [],
+ "j": [],
+ }
+
+ depth_map = {
+ "a": 7,
+ "b": 7,
+ "c": 4,
+ "d": 6,
+ "e": 6,
+ "f": 5,
+ "g": 3,
+ "h": 2,
+ "i": 2,
+ "k": 1,
+ "j": 1,
+ }
+
+ # We rudely fiddle with the appropriate tables directly, as that's much
+ # easier than constructing events properly.
+
+ def insert_event(txn, event_id, stream_ordering):
+
+ depth = depth_map[event_id]
+
+ self.store.db.simple_insert_txn(
+ txn,
+ table="events",
+ values={
+ "event_id": event_id,
+ "room_id": room_id,
+ "depth": depth,
+ "topological_ordering": depth,
+ "type": "m.test",
+ "processed": True,
+ "outlier": False,
+ "stream_ordering": stream_ordering,
+ },
+ )
+
+ self.store.db.simple_insert_many_txn(
+ txn,
+ table="event_auth",
+ values=[
+ {"event_id": event_id, "room_id": room_id, "auth_id": a}
+ for a in auth_graph[event_id]
+ ],
+ )
+
+ next_stream_ordering = 0
+ for event_id in auth_graph:
+ next_stream_ordering += 1
+ self.get_success(
+ self.store.db.runInteraction(
+ "insert", insert_event, event_id, next_stream_ordering
+ )
+ )
+
+ # Now actually test that various combinations give the right result:
+
+ difference = self.get_success(
+ self.store.get_auth_chain_difference([{"a"}, {"b"}])
+ )
+ self.assertSetEqual(difference, {"a", "b"})
+
+ difference = self.get_success(
+ self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}])
+ )
+ self.assertSetEqual(difference, {"a", "b", "c", "e", "f"})
+
+ difference = self.get_success(
+ self.store.get_auth_chain_difference([{"a", "c"}, {"b"}])
+ )
+ self.assertSetEqual(difference, {"a", "b", "c"})
+
+ difference = self.get_success(
+ self.store.get_auth_chain_difference([{"a"}, {"b"}, {"d"}])
+ )
+ self.assertSetEqual(difference, {"a", "b", "d", "e"})
+
+ difference = self.get_success(
+ self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}, {"d"}])
+ )
+ self.assertSetEqual(difference, {"a", "b", "c", "d", "e", "f"})
+
+ difference = self.get_success(
+ self.store.get_auth_chain_difference([{"a"}, {"b"}, {"e"}])
+ )
+ self.assertSetEqual(difference, {"a", "b"})
+
+ difference = self.get_success(self.store.get_auth_chain_difference([{"a"}]))
+ self.assertSetEqual(difference, set())
|