summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/state/__init__.py28
-rw-r--r--synapse/state/v2.py32
-rw-r--r--synapse/storage/data_stores/main/event_federation.py150
3 files changed, 161 insertions, 49 deletions
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index df7a4f6a89..4afefc6b1d 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -662,28 +662,16 @@ class StateResolutionStore(object):
             allow_rejected=allow_rejected,
         )
 
-    def get_auth_chain(self, event_ids: List[str], ignore_events: Set[str]):
-        """Gets the full auth chain for a set of events (including rejected
-        events).
-
-        Includes the given event IDs in the result.
-
-        Note that:
-            1. All events must be state events.
-            2. For v1 rooms this may not have the full auth chain in the
-               presence of rejected events
-
-        Args:
-            event_ids: The event IDs of the events to fetch the auth chain for.
-                Must be state events.
-            ignore_events: Set of events to exclude from the returned auth
-                chain.
+    def get_auth_chain_difference(self, state_sets: List[Set[str]]):
+        """Given sets of state events figure out the auth chain difference (as
+        per state res v2 algorithm).
 
+        This equivalent to fetching the full auth chain for each set of state
+        and returning the events that don't appear in each and every auth
+        chain.
 
         Returns:
-            Deferred[list[str]]: List of event IDs of the auth chain.
+            Deferred[Set[str]]: Set of event IDs.
         """
 
-        return self.store.get_auth_chain_ids(
-            event_ids, include_given=True, ignore_events=ignore_events,
-        )
+        return self.store.get_auth_chain_difference(state_sets)
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index 0ffe6d8c14..18484e2fa6 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -227,36 +227,12 @@ def _get_auth_chain_difference(state_sets, event_map, state_res_store):
     Returns:
         Deferred[set[str]]: Set of event IDs
     """
-    common = set(itervalues(state_sets[0])).intersection(
-        *(itervalues(s) for s in state_sets[1:])
-    )
-
-    auth_sets = []
-    for state_set in state_sets:
-        auth_ids = {
-            eid
-            for key, eid in iteritems(state_set)
-            if (
-                key[0] in (EventTypes.Member, EventTypes.ThirdPartyInvite)
-                or key
-                in (
-                    (EventTypes.PowerLevels, ""),
-                    (EventTypes.Create, ""),
-                    (EventTypes.JoinRules, ""),
-                )
-            )
-            and eid not in common
-        }
 
-        auth_chain = yield state_res_store.get_auth_chain(auth_ids, common)
-        auth_ids.update(auth_chain)
-
-        auth_sets.append(auth_ids)
-
-    intersection = set(auth_sets[0]).intersection(*auth_sets[1:])
-    union = set().union(*auth_sets)
+    difference = yield state_res_store.get_auth_chain_difference(
+        [set(state_set.values()) for state_set in state_sets]
+    )
 
-    return union - intersection
+    return difference
 
 
 def _seperate(state_sets):
diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py
index 49a7b8b433..62d4e9f599 100644
--- a/synapse/storage/data_stores/main/event_federation.py
+++ b/synapse/storage/data_stores/main/event_federation.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 import itertools
 import logging
-from typing import List, Optional, Set
+from typing import Dict, List, Optional, Set, Tuple
 
 from six.moves.queue import Empty, PriorityQueue
 
@@ -103,6 +103,154 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
 
         return list(results)
 
+    def get_auth_chain_difference(self, state_sets: List[Set[str]]):
+        """Given sets of state events figure out the auth chain difference (as
+        per state res v2 algorithm).
+
+        This equivalent to fetching the full auth chain for each set of state
+        and returning the events that don't appear in each and every auth
+        chain.
+
+        Returns:
+            Deferred[Set[str]]
+        """
+
+        return self.db.runInteraction(
+            "get_auth_chain_difference",
+            self._get_auth_chain_difference_txn,
+            state_sets,
+        )
+
+    def _get_auth_chain_difference_txn(
+        self, txn, state_sets: List[Set[str]]
+    ) -> Set[str]:
+
+        # Algorithm Description
+        # ~~~~~~~~~~~~~~~~~~~~~
+        #
+        # The idea here is to basically walk the auth graph of each state set in
+        # tandem, keeping track of which auth events are reachable by each state
+        # set. If we reach an auth event we've already visited (via a different
+        # state set) then we mark that auth event and all ancestors as reachable
+        # by the state set. This requires that we keep track of the auth chains
+        # in memory.
+        #
+        # Doing it in a such a way means that we can stop early if all auth
+        # events we're currently walking are reachable by all state sets.
+        #
+        # *Note*: We can't stop walking an event's auth chain if it is reachable
+        # by all state sets. This is because other auth chains we're walking
+        # might be reachable only via the original auth chain. For example,
+        # given the following auth chain:
+        #
+        #       A -> C -> D -> E
+        #           /         /
+        #       B -´---------´
+        #
+        # and state sets {A} and {B} then walking the auth chains of A and B
+        # would immediately show that C is reachable by both. However, if we
+        # stopped at C then we'd only reach E via the auth chain of B and so E
+        # would errornously get included in the returned difference.
+        #
+        # The other thing that we do is limit the number of auth chains we walk
+        # at once, due to practical limits (i.e. we can only query the database
+        # with a limited set of parameters). We pick the auth chains we walk
+        # each iteration based on their depth, in the hope that events with a
+        # lower depth are likely reachable by those with higher depths.
+        #
+        # We could use any ordering that we believe would give a rough
+        # topological ordering, e.g. origin server timestamp. If the ordering
+        # chosen is not topological then the algorithm still produces the right
+        # result, but perhaps a bit more inefficiently. This is why it is safe
+        # to use "depth" here.
+
+        initial_events = set(state_sets[0]).union(*state_sets[1:])
+
+        # Dict from events in auth chains to which sets *cannot* reach them.
+        # I.e. if the set is empty then all sets can reach the event.
+        event_to_missing_sets = {
+            event_id: {i for i, a in enumerate(state_sets) if event_id not in a}
+            for event_id in initial_events
+        }
+
+        # We need to get the depth of the initial events for sorting purposes.
+        sql = """
+            SELECT depth, event_id FROM events
+            WHERE %s
+            ORDER BY depth ASC
+        """
+        clause, args = make_in_list_sql_clause(
+            txn.database_engine, "event_id", initial_events
+        )
+        txn.execute(sql % (clause,), args)
+
+        # The sorted list of events whose auth chains we should walk.
+        search = txn.fetchall()  # type: List[Tuple[int, str]]
+
+        # Map from event to its auth events
+        event_to_auth_events = {}  # type: Dict[str, Set[str]]
+
+        base_sql = """
+            SELECT a.event_id, auth_id, depth
+            FROM event_auth AS a
+            INNER JOIN events AS e ON (e.event_id = a.auth_id)
+            WHERE
+        """
+
+        while search:
+            # Check whether all our current walks are reachable by all state
+            # sets. If so we can bail.
+            if all(not event_to_missing_sets[eid] for _, eid in search):
+                break
+
+            # Fetch the auth events and their depths of the N last events we're
+            # currently walking
+            search, chunk = search[:-100], search[-100:]
+            clause, args = make_in_list_sql_clause(
+                txn.database_engine, "a.event_id", [e_id for _, e_id in chunk]
+            )
+            txn.execute(base_sql + clause, args)
+
+            for event_id, auth_event_id, auth_event_depth in txn:
+                event_to_auth_events.setdefault(event_id, set()).add(auth_event_id)
+
+                sets = event_to_missing_sets.get(auth_event_id)
+                if sets is None:
+                    # First time we're seeing this event, so we add it to the
+                    # queue of things to fetch.
+                    search.append((auth_event_depth, auth_event_id))
+
+                    # Assume that this event is unreachable from any of the
+                    # state sets until proven otherwise
+                    sets = event_to_missing_sets[auth_event_id] = set(
+                        range(len(state_sets))
+                    )
+                else:
+                    # We've previously seen this event, so look up its auth
+                    # events and recursively mark all ancestors as reachable
+                    # by the current event's state set.
+                    a_ids = event_to_auth_events.get(auth_event_id)
+                    while a_ids:
+                        new_aids = set()
+                        for a_id in a_ids:
+                            event_to_missing_sets[a_id].intersection_update(
+                                event_to_missing_sets[event_id]
+                            )
+
+                            b = event_to_auth_events.get(a_id)
+                            if b:
+                                new_aids.update(b)
+
+                        a_ids = new_aids
+
+                # Mark that the auth event is reachable by the approriate sets.
+                sets.intersection_update(event_to_missing_sets[event_id])
+
+            search.sort()
+
+        # Return all events where not all sets can reach them.
+        return {eid for eid, n in event_to_missing_sets.items() if n}
+
     def get_oldest_events_in_room(self, room_id):
         return self.db.runInteraction(
             "get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id