summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erikj@jki.re>2018-10-24 11:12:12 +0100
committerGitHub <noreply@github.com>2018-10-24 11:12:12 +0100
commit3904cbf30761e7cbeea633467262837c62ef5829 (patch)
treefbd67b9770bae377e2a6427d6023523595ede1f2
parentMerge pull request #4075 from matrix-org/rav/fix_pusher_logcontexts (diff)
parentisort (diff)
downloadsynapse-3904cbf30761e7cbeea633467262837c62ef5829.tar.xz
Merge pull request #4040 from matrix-org/erikj/states_res_v2_rebase
Add v2 state resolution algorithm
Diffstat (limited to '')
-rw-r--r--changelog.d/3786.misc1
-rw-r--r--synapse/event_auth.py2
-rw-r--r--synapse/handlers/federation.py30
-rw-r--r--synapse/state/__init__.py97
-rw-r--r--synapse/state/v1.py2
-rw-r--r--synapse/state/v2.py544
-rw-r--r--synapse/storage/events.py45
-rw-r--r--tests/state/__init__.py0
-rw-r--r--tests/state/test_v2.py663
9 files changed, 1323 insertions, 61 deletions
diff --git a/changelog.d/3786.misc b/changelog.d/3786.misc
new file mode 100644
index 0000000000..a9f9a2bb27
--- /dev/null
+++ b/changelog.d/3786.misc
@@ -0,0 +1 @@
+Add initial implementation of new state resolution algorithm
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index af3eee95b9..d4d4474847 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -690,7 +690,7 @@ def auth_types_for_event(event):
     auth_types = []
 
     auth_types.append((EventTypes.PowerLevels, "", ))
-    auth_types.append((EventTypes.Member, event.user_id, ))
+    auth_types.append((EventTypes.Member, event.sender, ))
     auth_types.append((EventTypes.Create, "", ))
 
     if event.type == EventTypes.Member:
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 63e495e3f8..cd5b9bbb19 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -53,7 +53,7 @@ from synapse.replication.http.federation import (
     ReplicationFederationSendEventsRestServlet,
 )
 from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
-from synapse.state import resolve_events_with_factory
+from synapse.state import StateResolutionStore, resolve_events_with_store
 from synapse.types import UserID, get_domain_from_id
 from synapse.util import logcontext, unwrapFirstError
 from synapse.util.async_helpers import Linearizer
@@ -384,24 +384,24 @@ class FederationHandler(BaseHandler):
                             for x in remote_state:
                                 event_map[x.event_id] = x
 
-                    # Resolve any conflicting state
-                    @defer.inlineCallbacks
-                    def fetch(ev_ids):
-                        fetched = yield self.store.get_events(
-                            ev_ids, get_prev_content=False, check_redacted=False,
-                        )
-                        # add any events we fetch here to the `event_map` so that we
-                        # can use them to build the state event list below.
-                        event_map.update(fetched)
-                        defer.returnValue(fetched)
-
                     room_version = yield self.store.get_room_version(room_id)
-                    state_map = yield resolve_events_with_factory(
-                        room_version, state_maps, event_map, fetch,
+                    state_map = yield resolve_events_with_store(
+                        room_version, state_maps, event_map,
+                        state_res_store=StateResolutionStore(self.store),
                     )
 
-                    # we need to give _process_received_pdu the actual state events
+                    # We need to give _process_received_pdu the actual state events
                     # rather than event ids, so generate that now.
+
+                    # First though we need to fetch all the events that are in
+                    # state_map, so we can build up the state below.
+                    evs = yield self.store.get_events(
+                        list(state_map.values()),
+                        get_prev_content=False,
+                        check_redacted=False,
+                    )
+                    event_map.update(evs)
+
                     state = [
                         event_map[e] for e in six.itervalues(state_map)
                     ]
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index b22495c1f9..9b40b18d5b 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -19,13 +19,14 @@ from collections import namedtuple
 
 from six import iteritems, itervalues
 
+import attr
 from frozendict import frozendict
 
 from twisted.internet import defer
 
 from synapse.api.constants import EventTypes, RoomVersions
 from synapse.events.snapshot import EventContext
-from synapse.state import v1
+from synapse.state import v1, v2
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches import get_cache_factor_for
 from synapse.util.caches.expiringcache import ExpiringCache
@@ -372,15 +373,10 @@ class StateHandler(object):
 
         result = yield self._state_resolution_handler.resolve_state_groups(
             room_id, room_version, state_groups_ids, None,
-            self._state_map_factory,
+            state_res_store=StateResolutionStore(self.store),
         )
         defer.returnValue(result)
 
-    def _state_map_factory(self, ev_ids):
-        return self.store.get_events(
-            ev_ids, get_prev_content=False, check_redacted=False,
-        )
-
     @defer.inlineCallbacks
     def resolve_events(self, room_version, state_sets, event):
         logger.info(
@@ -398,10 +394,10 @@ class StateHandler(object):
         }
 
         with Measure(self.clock, "state._resolve_events"):
-            new_state = yield resolve_events_with_factory(
+            new_state = yield resolve_events_with_store(
                 room_version, state_set_ids,
                 event_map=state_map,
-                state_map_factory=self._state_map_factory
+                state_res_store=StateResolutionStore(self.store),
             )
 
         new_state = {
@@ -436,7 +432,7 @@ class StateResolutionHandler(object):
     @defer.inlineCallbacks
     @log_function
     def resolve_state_groups(
-        self, room_id, room_version, state_groups_ids, event_map, state_map_factory,
+        self, room_id, room_version, state_groups_ids, event_map, state_res_store,
     ):
         """Resolves conflicts between a set of state groups
 
@@ -454,9 +450,11 @@ class StateResolutionHandler(object):
                 a dict from event_id to event, for any events that we happen to
                 have in flight (eg, those currently being persisted). This will be
                 used as a starting point fof finding the state we need; any missing
-                events will be requested via state_map_factory.
+                events will be requested via state_res_store.
+
+                If None, all events will be fetched via state_res_store.
 
-                If None, all events will be fetched via state_map_factory.
+            state_res_store (StateResolutionStore)
 
         Returns:
             Deferred[_StateCacheEntry]: resolved state
@@ -480,10 +478,10 @@ class StateResolutionHandler(object):
 
             # start by assuming we won't have any conflicted state, and build up the new
             # state map by iterating through the state groups. If we discover a conflict,
-            # we give up and instead use `resolve_events_with_factory`.
+            # we give up and instead use `resolve_events_with_store`.
             #
             # XXX: is this actually worthwhile, or should we just let
-            # resolve_events_with_factory do it?
+            # resolve_events_with_store do it?
             new_state = {}
             conflicted_state = False
             for st in itervalues(state_groups_ids):
@@ -498,11 +496,11 @@ class StateResolutionHandler(object):
             if conflicted_state:
                 logger.info("Resolving conflicted state for %r", room_id)
                 with Measure(self.clock, "state._resolve_events"):
-                    new_state = yield resolve_events_with_factory(
+                    new_state = yield resolve_events_with_store(
                         room_version,
                         list(itervalues(state_groups_ids)),
                         event_map=event_map,
-                        state_map_factory=state_map_factory,
+                        state_res_store=state_res_store,
                     )
 
             # if the new state matches any of the input state groups, we can
@@ -583,7 +581,7 @@ def _make_state_cache_entry(
     )
 
 
-def resolve_events_with_factory(room_version, state_sets, event_map, state_map_factory):
+def resolve_events_with_store(room_version, state_sets, event_map, state_res_store):
     """
     Args:
         room_version(str): Version of the room
@@ -599,17 +597,19 @@ def resolve_events_with_factory(room_version, state_sets, event_map, state_map_f
 
             If None, all events will be fetched via state_map_factory.
 
-        state_map_factory(func): will be called
-            with a list of event_ids that are needed, and should return with
-            a Deferred of dict of event_id to event.
+        state_res_store (StateResolutionStore)
 
     Returns
         Deferred[dict[(str, str), str]]:
             a map from (type, state_key) to event_id.
     """
-    if room_version in (RoomVersions.V1, RoomVersions.VDH_TEST,):
-        return v1.resolve_events_with_factory(
-            state_sets, event_map, state_map_factory,
+    if room_version == RoomVersions.V1:
+        return v1.resolve_events_with_store(
+            state_sets, event_map, state_res_store.get_events,
+        )
+    elif room_version == RoomVersions.VDH_TEST:
+        return v2.resolve_events_with_store(
+            state_sets, event_map, state_res_store,
         )
     else:
         # This should only happen if we added a version but forgot to add it to
@@ -617,3 +617,54 @@ def resolve_events_with_factory(room_version, state_sets, event_map, state_map_f
         raise Exception(
             "No state resolution algorithm defined for version %r" % (room_version,)
         )
+
+
+@attr.s
+class StateResolutionStore(object):
+    """Interface that allows state resolution algorithms to access the database
+    in well defined way.
+
+    Args:
+        store (DataStore)
+    """
+
+    store = attr.ib()
+
+    def get_events(self, event_ids, allow_rejected=False):
+        """Get events from the database
+
+        Args:
+            event_ids (list): The event_ids of the events to fetch
+            allow_rejected (bool): If True return rejected events.
+
+        Returns:
+            Deferred[dict[str, FrozenEvent]]: Dict from event_id to event.
+        """
+
+        return self.store.get_events(
+            event_ids,
+            check_redacted=False,
+            get_prev_content=False,
+            allow_rejected=allow_rejected,
+        )
+
+    def get_auth_chain(self, event_ids):
+        """Gets the full auth chain for a set of events (including rejected
+        events).
+
+        Includes the given event IDs in the result.
+
+        Note that:
+            1. All events must be state events.
+            2. For v1 rooms this may not have the full auth chain in the
+               presence of rejected events
+
+        Args:
+            event_ids (list): The event IDs of the events to fetch the auth
+                chain for. Must be state events.
+
+        Returns:
+            Deferred[list[str]]: List of event IDs of the auth chain.
+        """
+
+        return self.store.get_auth_chain_ids(event_ids, include_given=True)
diff --git a/synapse/state/v1.py b/synapse/state/v1.py
index 7a7157b352..70a981f4a2 100644
--- a/synapse/state/v1.py
+++ b/synapse/state/v1.py
@@ -31,7 +31,7 @@ POWER_KEY = (EventTypes.PowerLevels, "")
 
 
 @defer.inlineCallbacks
-def resolve_events_with_factory(state_sets, event_map, state_map_factory):
+def resolve_events_with_store(state_sets, event_map, state_map_factory):
     """
     Args:
         state_sets(list): List of dicts of (type, state_key) -> event_id,
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
new file mode 100644
index 0000000000..5d06f7e928
--- /dev/null
+++ b/synapse/state/v2.py
@@ -0,0 +1,544 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import heapq
+import itertools
+import logging
+
+from six import iteritems, itervalues
+
+from twisted.internet import defer
+
+from synapse import event_auth
+from synapse.api.constants import EventTypes
+from synapse.api.errors import AuthError
+
+logger = logging.getLogger(__name__)
+
+
+@defer.inlineCallbacks
+def resolve_events_with_store(state_sets, event_map, state_res_store):
+    """Resolves the state using the v2 state resolution algorithm
+
+    Args:
+        state_sets(list): List of dicts of (type, state_key) -> event_id,
+            which are the different state groups to resolve.
+
+        event_map(dict[str,FrozenEvent]|None):
+            a dict from event_id to event, for any events that we happen to
+            have in flight (eg, those currently being persisted). This will be
+            used as a starting point fof finding the state we need; any missing
+            events will be requested via state_res_store.
+
+            If None, all events will be fetched via state_res_store.
+
+        state_res_store (StateResolutionStore)
+
+    Returns
+        Deferred[dict[(str, str), str]]:
+            a map from (type, state_key) to event_id.
+    """
+
+    logger.debug("Computing conflicted state")
+
+    # First split up the un/conflicted state
+    unconflicted_state, conflicted_state = _seperate(state_sets)
+
+    if not conflicted_state:
+        defer.returnValue(unconflicted_state)
+
+    logger.debug("%d conflicted state entries", len(conflicted_state))
+    logger.debug("Calculating auth chain difference")
+
+    # Also fetch all auth events that appear in only some of the state sets'
+    # auth chains.
+    auth_diff = yield _get_auth_chain_difference(
+        state_sets, event_map, state_res_store,
+    )
+
+    full_conflicted_set = set(itertools.chain(
+        itertools.chain.from_iterable(itervalues(conflicted_state)),
+        auth_diff,
+    ))
+
+    events = yield state_res_store.get_events([
+        eid for eid in full_conflicted_set
+        if eid not in event_map
+    ], allow_rejected=True)
+    event_map.update(events)
+
+    full_conflicted_set = set(eid for eid in full_conflicted_set if eid in event_map)
+
+    logger.debug("%d full_conflicted_set entries", len(full_conflicted_set))
+
+    # Get and sort all the power events (kicks/bans/etc)
+    power_events = (
+        eid for eid in full_conflicted_set
+        if _is_power_event(event_map[eid])
+    )
+
+    sorted_power_events = yield _reverse_topological_power_sort(
+        power_events,
+        event_map,
+        state_res_store,
+        full_conflicted_set,
+    )
+
+    logger.debug("sorted %d power events", len(sorted_power_events))
+
+    # Now sequentially auth each one
+    resolved_state = yield _iterative_auth_checks(
+        sorted_power_events, unconflicted_state, event_map,
+        state_res_store,
+    )
+
+    logger.debug("resolved power events")
+
+    # OK, so we've now resolved the power events. Now sort the remaining
+    # events using the mainline of the resolved power level.
+
+    leftover_events = [
+        ev_id
+        for ev_id in full_conflicted_set
+        if ev_id not in sorted_power_events
+    ]
+
+    logger.debug("sorting %d remaining events", len(leftover_events))
+
+    pl = resolved_state.get((EventTypes.PowerLevels, ""), None)
+    leftover_events = yield _mainline_sort(
+        leftover_events, pl, event_map, state_res_store,
+    )
+
+    logger.debug("resolving remaining events")
+
+    resolved_state = yield _iterative_auth_checks(
+        leftover_events, resolved_state, event_map,
+        state_res_store,
+    )
+
+    logger.debug("resolved")
+
+    # We make sure that unconflicted state always still applies.
+    resolved_state.update(unconflicted_state)
+
+    logger.debug("done")
+
+    defer.returnValue(resolved_state)
+
+
+@defer.inlineCallbacks
+def _get_power_level_for_sender(event_id, event_map, state_res_store):
+    """Return the power level of the sender of the given event according to
+    their auth events.
+
+    Args:
+        event_id (str)
+        event_map (dict[str,FrozenEvent])
+        state_res_store (StateResolutionStore)
+
+    Returns:
+        Deferred[int]
+    """
+    event = yield _get_event(event_id, event_map, state_res_store)
+
+    pl = None
+    for aid, _ in event.auth_events:
+        aev = yield _get_event(aid, event_map, state_res_store)
+        if (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""):
+            pl = aev
+            break
+
+    if pl is None:
+        # Couldn't find power level. Check if they're the creator of the room
+        for aid, _ in event.auth_events:
+            aev = yield _get_event(aid, event_map, state_res_store)
+            if (aev.type, aev.state_key) == (EventTypes.Create, ""):
+                if aev.content.get("creator") == event.sender:
+                    defer.returnValue(100)
+                break
+        defer.returnValue(0)
+
+    level = pl.content.get("users", {}).get(event.sender)
+    if level is None:
+        level = pl.content.get("users_default", 0)
+
+    if level is None:
+        defer.returnValue(0)
+    else:
+        defer.returnValue(int(level))
+
+
+@defer.inlineCallbacks
+def _get_auth_chain_difference(state_sets, event_map, state_res_store):
+    """Compare the auth chains of each state set and return the set of events
+    that only appear in some but not all of the auth chains.
+
+    Args:
+        state_sets (list)
+        event_map (dict[str,FrozenEvent])
+        state_res_store (StateResolutionStore)
+
+    Returns:
+        Deferred[set[str]]: Set of event IDs
+    """
+    common = set(itervalues(state_sets[0])).intersection(
+        *(itervalues(s) for s in state_sets[1:])
+    )
+
+    auth_sets = []
+    for state_set in state_sets:
+        auth_ids = set(
+            eid
+            for key, eid in iteritems(state_set)
+            if (key[0] in (
+                EventTypes.Member,
+                EventTypes.ThirdPartyInvite,
+            ) or key in (
+                (EventTypes.PowerLevels, ''),
+                (EventTypes.Create, ''),
+                (EventTypes.JoinRules, ''),
+            )) and eid not in common
+        )
+
+        auth_chain = yield state_res_store.get_auth_chain(auth_ids)
+        auth_ids.update(auth_chain)
+
+        auth_sets.append(auth_ids)
+
+    intersection = set(auth_sets[0]).intersection(*auth_sets[1:])
+    union = set().union(*auth_sets)
+
+    defer.returnValue(union - intersection)
+
+
+def _seperate(state_sets):
+    """Return the unconflicted and conflicted state. This is different than in
+    the original algorithm, as this defines a key to be conflicted if one of
+    the state sets doesn't have that key.
+
+    Args:
+        state_sets (list)
+
+    Returns:
+        tuple[dict, dict]: A tuple of unconflicted and conflicted state. The
+        conflicted state dict is a map from type/state_key to set of event IDs
+    """
+    unconflicted_state = {}
+    conflicted_state = {}
+
+    for key in set(itertools.chain.from_iterable(state_sets)):
+        event_ids = set(state_set.get(key) for state_set in state_sets)
+        if len(event_ids) == 1:
+            unconflicted_state[key] = event_ids.pop()
+        else:
+            event_ids.discard(None)
+            conflicted_state[key] = event_ids
+
+    return unconflicted_state, conflicted_state
+
+
+def _is_power_event(event):
+    """Return whether or not the event is a "power event", as defined by the
+    v2 state resolution algorithm
+
+    Args:
+        event (FrozenEvent)
+
+    Returns:
+        boolean
+    """
+    if (event.type, event.state_key) in (
+        (EventTypes.PowerLevels, ""),
+        (EventTypes.JoinRules, ""),
+        (EventTypes.Create, ""),
+    ):
+        return True
+
+    if event.type == EventTypes.Member:
+        if event.membership in ('leave', 'ban'):
+            return event.sender != event.state_key
+
+    return False
+
+
+@defer.inlineCallbacks
+def _add_event_and_auth_chain_to_graph(graph, event_id, event_map,
+                                       state_res_store, auth_diff):
+    """Helper function for _reverse_topological_power_sort that add the event
+    and its auth chain (that is in the auth diff) to the graph
+
+    Args:
+        graph (dict[str, set[str]]): A map from event ID to the events auth
+            event IDs
+        event_id (str): Event to add to the graph
+        event_map (dict[str,FrozenEvent])
+        state_res_store (StateResolutionStore)
+        auth_diff (set[str]): Set of event IDs that are in the auth difference.
+    """
+
+    state = [event_id]
+    while state:
+        eid = state.pop()
+        graph.setdefault(eid, set())
+
+        event = yield _get_event(eid, event_map, state_res_store)
+        for aid, _ in event.auth_events:
+            if aid in auth_diff:
+                if aid not in graph:
+                    state.append(aid)
+
+                graph.setdefault(eid, set()).add(aid)
+
+
+@defer.inlineCallbacks
+def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_diff):
+    """Returns a list of the event_ids sorted by reverse topological ordering,
+    and then by power level and origin_server_ts
+
+    Args:
+        event_ids (list[str]): The events to sort
+        event_map (dict[str,FrozenEvent])
+        state_res_store (StateResolutionStore)
+        auth_diff (set[str]): Set of event IDs that are in the auth difference.
+
+    Returns:
+        Deferred[list[str]]: The sorted list
+    """
+
+    graph = {}
+    for event_id in event_ids:
+        yield _add_event_and_auth_chain_to_graph(
+            graph, event_id, event_map, state_res_store, auth_diff,
+        )
+
+    event_to_pl = {}
+    for event_id in graph:
+        pl = yield _get_power_level_for_sender(event_id, event_map, state_res_store)
+        event_to_pl[event_id] = pl
+
+    def _get_power_order(event_id):
+        ev = event_map[event_id]
+        pl = event_to_pl[event_id]
+
+        return -pl, ev.origin_server_ts, event_id
+
+    # Note: graph is modified during the sort
+    it = lexicographical_topological_sort(
+        graph,
+        key=_get_power_order,
+    )
+    sorted_events = list(it)
+
+    defer.returnValue(sorted_events)
+
+
+@defer.inlineCallbacks
+def _iterative_auth_checks(event_ids, base_state, event_map, state_res_store):
+    """Sequentially apply auth checks to each event in given list, updating the
+    state as it goes along.
+
+    Args:
+        event_ids (list[str]): Ordered list of events to apply auth checks to
+        base_state (dict[tuple[str, str], str]): The set of state to start with
+        event_map (dict[str,FrozenEvent])
+        state_res_store (StateResolutionStore)
+
+    Returns:
+        Deferred[dict[tuple[str, str], str]]: Returns the final updated state
+    """
+    resolved_state = base_state.copy()
+
+    for event_id in event_ids:
+        event = event_map[event_id]
+
+        auth_events = {}
+        for aid, _ in event.auth_events:
+            ev = yield _get_event(aid, event_map, state_res_store)
+
+            if ev.rejected_reason is None:
+                auth_events[(ev.type, ev.state_key)] = ev
+
+        for key in event_auth.auth_types_for_event(event):
+            if key in resolved_state:
+                ev_id = resolved_state[key]
+                ev = yield _get_event(ev_id, event_map, state_res_store)
+
+                if ev.rejected_reason is None:
+                    auth_events[key] = event_map[ev_id]
+
+        try:
+            event_auth.check(
+                event, auth_events,
+                do_sig_check=False,
+                do_size_check=False
+            )
+
+            resolved_state[(event.type, event.state_key)] = event_id
+        except AuthError:
+            pass
+
+    defer.returnValue(resolved_state)
+
+
+@defer.inlineCallbacks
+def _mainline_sort(event_ids, resolved_power_event_id, event_map,
+                   state_res_store):
+    """Returns a sorted list of event_ids sorted by mainline ordering based on
+    the given event resolved_power_event_id
+
+    Args:
+        event_ids (list[str]): Events to sort
+        resolved_power_event_id (str): The final resolved power level event ID
+        event_map (dict[str,FrozenEvent])
+        state_res_store (StateResolutionStore)
+
+    Returns:
+        Deferred[list[str]]: The sorted list
+    """
+    mainline = []
+    pl = resolved_power_event_id
+    while pl:
+        mainline.append(pl)
+        pl_ev = yield _get_event(pl, event_map, state_res_store)
+        auth_events = pl_ev.auth_events
+        pl = None
+        for aid, _ in auth_events:
+            ev = yield _get_event(aid, event_map, state_res_store)
+            if (ev.type, ev.state_key) == (EventTypes.PowerLevels, ""):
+                pl = aid
+                break
+
+    mainline_map = {ev_id: i + 1 for i, ev_id in enumerate(reversed(mainline))}
+
+    event_ids = list(event_ids)
+
+    order_map = {}
+    for ev_id in event_ids:
+        depth = yield _get_mainline_depth_for_event(
+            event_map[ev_id], mainline_map,
+            event_map, state_res_store,
+        )
+        order_map[ev_id] = (depth, event_map[ev_id].origin_server_ts, ev_id)
+
+    event_ids.sort(key=lambda ev_id: order_map[ev_id])
+
+    defer.returnValue(event_ids)
+
+
+@defer.inlineCallbacks
+def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_store):
+    """Get the mainline depths for the given event based on the mainline map
+
+    Args:
+        event (FrozenEvent)
+        mainline_map (dict[str, int]): Map from event_id to mainline depth for
+            events in the mainline.
+        event_map (dict[str,FrozenEvent])
+        state_res_store (StateResolutionStore)
+
+    Returns:
+        Deferred[int]
+    """
+
+    # We do an iterative search, replacing `event with the power level in its
+    # auth events (if any)
+    while event:
+        depth = mainline_map.get(event.event_id)
+        if depth is not None:
+            defer.returnValue(depth)
+
+        auth_events = event.auth_events
+        event = None
+
+        for aid, _ in auth_events:
+            aev = yield _get_event(aid, event_map, state_res_store)
+            if (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""):
+                event = aev
+                break
+
+    # Didn't find a power level auth event, so we just return 0
+    defer.returnValue(0)
+
+
+@defer.inlineCallbacks
+def _get_event(event_id, event_map, state_res_store):
+    """Helper function to look up event in event_map, falling back to looking
+    it up in the store
+
+    Args:
+        event_id (str)
+        event_map (dict[str,FrozenEvent])
+        state_res_store (StateResolutionStore)
+
+    Returns:
+        Deferred[FrozenEvent]
+    """
+    if event_id not in event_map:
+        events = yield state_res_store.get_events([event_id], allow_rejected=True)
+        event_map.update(events)
+    defer.returnValue(event_map[event_id])
+
+
+def lexicographical_topological_sort(graph, key):
+    """Performs a lexicographic reverse topological sort on the graph.
+
+    This returns a reverse topological sort (i.e. if node A references B then B
+    appears before A in the sort), with ties broken lexicographically based on
+    return value of the `key` function.
+
+    NOTE: `graph` is modified during the sort.
+
+    Args:
+        graph (dict[str, set[str]]): A representation of the graph where each
+            node is a key in the dict and its value are the nodes edges.
+        key (func): A function that takes a node and returns a value that is
+            comparable and used to order nodes
+
+    Yields:
+        str: The next node in the topological sort
+    """
+
+    # Note, this is basically Kahn's algorithm except we look at nodes with no
+    # outgoing edges, c.f.
+    # https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm
+    outdegree_map = graph
+    reverse_graph = {}
+
+    # Lists of nodes with zero out degree. Is actually a tuple of
+    # `(key(node), node)` so that sorting does the right thing
+    zero_outdegree = []
+
+    for node, edges in iteritems(graph):
+        if len(edges) == 0:
+            zero_outdegree.append((key(node), node))
+
+        reverse_graph.setdefault(node, set())
+        for edge in edges:
+            reverse_graph.setdefault(edge, set()).add(node)
+
+    # heapq is a built in implementation of a sorted queue.
+    heapq.heapify(zero_outdegree)
+
+    while zero_outdegree:
+        _, node = heapq.heappop(zero_outdegree)
+
+        for parent in reverse_graph[node]:
+            out = outdegree_map[parent]
+            out.discard(node)
+            if len(out) == 0:
+                heapq.heappush(zero_outdegree, (key(parent), parent))
+
+        yield node
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 03cedf3a75..c780f55277 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -34,6 +34,7 @@ from synapse.api.errors import SynapseError
 from synapse.events import EventBase  # noqa: F401
 from synapse.events.snapshot import EventContext  # noqa: F401
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.state import StateResolutionStore
 from synapse.storage.background_updates import BackgroundUpdateStore
 from synapse.storage.event_federation import EventFederationStore
 from synapse.storage.events_worker import EventsWorkerStore
@@ -731,11 +732,6 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
 
         # Ok, we need to defer to the state handler to resolve our state sets.
 
-        def get_events(ev_ids):
-            return self.get_events(
-                ev_ids, get_prev_content=False, check_redacted=False,
-            )
-
         state_groups = {
             sg: state_groups_map[sg] for sg in new_state_groups
         }
@@ -745,7 +741,8 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
 
         logger.debug("calling resolve_state_groups from preserve_events")
         res = yield self._state_resolution_handler.resolve_state_groups(
-            room_id, room_version, state_groups, events_map, get_events
+            room_id, room_version, state_groups, events_map,
+            state_res_store=StateResolutionStore(self)
         )
 
         defer.returnValue((res.state, None))
@@ -854,6 +851,27 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
         # Insert into event_to_state_groups.
         self._store_event_state_mappings_txn(txn, events_and_contexts)
 
+        # We want to store event_auth mappings for rejected events, as they're
+        # used in state res v2.
+        # This is only necessary if the rejected event appears in an accepted
+        # event's auth chain, but its easier for now just to store them (and
+        # it doesn't take much storage compared to storing the entire event
+        # anyway).
+        self._simple_insert_many_txn(
+            txn,
+            table="event_auth",
+            values=[
+                {
+                    "event_id": event.event_id,
+                    "room_id": event.room_id,
+                    "auth_id": auth_id,
+                }
+                for event, _ in events_and_contexts
+                for auth_id, _ in event.auth_events
+                if event.is_state()
+            ],
+        )
+
         # _store_rejected_events_txn filters out any events which were
         # rejected, and returns the filtered list.
         events_and_contexts = self._store_rejected_events_txn(
@@ -1329,21 +1347,6 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
                     txn, event.room_id, event.redacts
                 )
 
-        self._simple_insert_many_txn(
-            txn,
-            table="event_auth",
-            values=[
-                {
-                    "event_id": event.event_id,
-                    "room_id": event.room_id,
-                    "auth_id": auth_id,
-                }
-                for event, _ in events_and_contexts
-                for auth_id, _ in event.auth_events
-                if event.is_state()
-            ],
-        )
-
         # Update the event_forward_extremities, event_backward_extremities and
         # event_edges tables.
         self._handle_mult_prev_events(
diff --git a/tests/state/__init__.py b/tests/state/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tests/state/__init__.py
diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py
new file mode 100644
index 0000000000..efd85ebe6c
--- /dev/null
+++ b/tests/state/test_v2.py
@@ -0,0 +1,663 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import itertools
+
+from six.moves import zip
+
+import attr
+
+from synapse.api.constants import EventTypes, JoinRules, Membership
+from synapse.event_auth import auth_types_for_event
+from synapse.events import FrozenEvent
+from synapse.state.v2 import lexicographical_topological_sort, resolve_events_with_store
+from synapse.types import EventID
+
+from tests import unittest
+
+ALICE = "@alice:example.com"
+BOB = "@bob:example.com"
+CHARLIE = "@charlie:example.com"
+EVELYN = "@evelyn:example.com"
+ZARA = "@zara:example.com"
+
+ROOM_ID = "!test:example.com"
+
+MEMBERSHIP_CONTENT_JOIN = {"membership": Membership.JOIN}
+MEMBERSHIP_CONTENT_BAN = {"membership": Membership.BAN}
+
+
+ORIGIN_SERVER_TS = 0
+
+
+class FakeEvent(object):
+    """A fake event we use as a convenience.
+
+    NOTE: Again as a convenience we use "node_ids" rather than event_ids to
+    refer to events. The event_id has node_id as localpart and example.com
+    as domain.
+    """
+    def __init__(self, id, sender, type, state_key, content):
+        self.node_id = id
+        self.event_id = EventID(id, "example.com").to_string()
+        self.sender = sender
+        self.type = type
+        self.state_key = state_key
+        self.content = content
+
+    def to_event(self, auth_events, prev_events):
+        """Given the auth_events and prev_events, convert to a Frozen Event
+
+        Args:
+            auth_events (list[str]): list of event_ids
+            prev_events (list[str]): list of event_ids
+
+        Returns:
+            FrozenEvent
+        """
+        global ORIGIN_SERVER_TS
+
+        ts = ORIGIN_SERVER_TS
+        ORIGIN_SERVER_TS = ORIGIN_SERVER_TS + 1
+
+        event_dict = {
+            "auth_events": [(a, {}) for a in auth_events],
+            "prev_events": [(p, {}) for p in prev_events],
+            "event_id": self.node_id,
+            "sender": self.sender,
+            "type": self.type,
+            "content": self.content,
+            "origin_server_ts": ts,
+            "room_id": ROOM_ID,
+        }
+
+        if self.state_key is not None:
+            event_dict["state_key"] = self.state_key
+
+        return FrozenEvent(event_dict)
+
+
+# All graphs start with this set of events
+INITIAL_EVENTS = [
+    FakeEvent(
+        id="CREATE",
+        sender=ALICE,
+        type=EventTypes.Create,
+        state_key="",
+        content={"creator": ALICE},
+    ),
+    FakeEvent(
+        id="IMA",
+        sender=ALICE,
+        type=EventTypes.Member,
+        state_key=ALICE,
+        content=MEMBERSHIP_CONTENT_JOIN,
+    ),
+    FakeEvent(
+        id="IPOWER",
+        sender=ALICE,
+        type=EventTypes.PowerLevels,
+        state_key="",
+        content={"users": {ALICE: 100}},
+    ),
+    FakeEvent(
+        id="IJR",
+        sender=ALICE,
+        type=EventTypes.JoinRules,
+        state_key="",
+        content={"join_rule": JoinRules.PUBLIC},
+    ),
+    FakeEvent(
+        id="IMB",
+        sender=BOB,
+        type=EventTypes.Member,
+        state_key=BOB,
+        content=MEMBERSHIP_CONTENT_JOIN,
+    ),
+    FakeEvent(
+        id="IMC",
+        sender=CHARLIE,
+        type=EventTypes.Member,
+        state_key=CHARLIE,
+        content=MEMBERSHIP_CONTENT_JOIN,
+    ),
+    FakeEvent(
+        id="IMZ",
+        sender=ZARA,
+        type=EventTypes.Member,
+        state_key=ZARA,
+        content=MEMBERSHIP_CONTENT_JOIN,
+    ),
+    FakeEvent(
+        id="START",
+        sender=ZARA,
+        type=EventTypes.Message,
+        state_key=None,
+        content={},
+    ),
+    FakeEvent(
+        id="END",
+        sender=ZARA,
+        type=EventTypes.Message,
+        state_key=None,
+        content={},
+    ),
+]
+
+INITIAL_EDGES = [
+    "START", "IMZ", "IMC", "IMB", "IJR", "IPOWER", "IMA", "CREATE",
+]
+
+
+class StateTestCase(unittest.TestCase):
+    def test_ban_vs_pl(self):
+        events = [
+            FakeEvent(
+                id="PA",
+                sender=ALICE,
+                type=EventTypes.PowerLevels,
+                state_key="",
+                content={
+                    "users": {
+                        ALICE: 100,
+                        BOB: 50,
+                    }
+                },
+            ),
+            FakeEvent(
+                id="MA",
+                sender=ALICE,
+                type=EventTypes.Member,
+                state_key=ALICE,
+                content={"membership": Membership.JOIN},
+            ),
+            FakeEvent(
+                id="MB",
+                sender=ALICE,
+                type=EventTypes.Member,
+                state_key=BOB,
+                content={"membership": Membership.BAN},
+            ),
+            FakeEvent(
+                id="PB",
+                sender=BOB,
+                type=EventTypes.PowerLevels,
+                state_key='',
+                content={
+                    "users": {
+                        ALICE: 100,
+                        BOB: 50,
+                    },
+                },
+            ),
+        ]
+
+        edges = [
+            ["END", "MB", "MA", "PA", "START"],
+            ["END", "PB", "PA"],
+        ]
+
+        expected_state_ids = ["PA", "MA", "MB"]
+
+        self.do_check(events, edges, expected_state_ids)
+
+    def test_join_rule_evasion(self):
+        events = [
+            FakeEvent(
+                id="JR",
+                sender=ALICE,
+                type=EventTypes.JoinRules,
+                state_key="",
+                content={"join_rules": JoinRules.PRIVATE},
+            ),
+            FakeEvent(
+                id="ME",
+                sender=EVELYN,
+                type=EventTypes.Member,
+                state_key=EVELYN,
+                content={"membership": Membership.JOIN},
+            ),
+        ]
+
+        edges = [
+            ["END", "JR", "START"],
+            ["END", "ME", "START"],
+        ]
+
+        expected_state_ids = ["JR"]
+
+        self.do_check(events, edges, expected_state_ids)
+
+    def test_offtopic_pl(self):
+        events = [
+            FakeEvent(
+                id="PA",
+                sender=ALICE,
+                type=EventTypes.PowerLevels,
+                state_key="",
+                content={
+                    "users": {
+                        ALICE: 100,
+                        BOB: 50,
+                    }
+                },
+            ),
+            FakeEvent(
+                id="PB",
+                sender=BOB,
+                type=EventTypes.PowerLevels,
+                state_key='',
+                content={
+                    "users": {
+                        ALICE: 100,
+                        BOB: 50,
+                        CHARLIE: 50,
+                    },
+                },
+            ),
+            FakeEvent(
+                id="PC",
+                sender=CHARLIE,
+                type=EventTypes.PowerLevels,
+                state_key='',
+                content={
+                    "users": {
+                        ALICE: 100,
+                        BOB: 50,
+                        CHARLIE: 0,
+                    },
+                },
+            ),
+        ]
+
+        edges = [
+            ["END", "PC", "PB", "PA", "START"],
+            ["END", "PA"],
+        ]
+
+        expected_state_ids = ["PC"]
+
+        self.do_check(events, edges, expected_state_ids)
+
+    def test_topic_basic(self):
+        events = [
+            FakeEvent(
+                id="T1",
+                sender=ALICE,
+                type=EventTypes.Topic,
+                state_key="",
+                content={},
+            ),
+            FakeEvent(
+                id="PA1",
+                sender=ALICE,
+                type=EventTypes.PowerLevels,
+                state_key='',
+                content={
+                    "users": {
+                        ALICE: 100,
+                        BOB: 50,
+                    },
+                },
+            ),
+            FakeEvent(
+                id="T2",
+                sender=ALICE,
+                type=EventTypes.Topic,
+                state_key="",
+                content={},
+            ),
+            FakeEvent(
+                id="PA2",
+                sender=ALICE,
+                type=EventTypes.PowerLevels,
+                state_key='',
+                content={
+                    "users": {
+                        ALICE: 100,
+                        BOB: 0,
+                    },
+                },
+            ),
+            FakeEvent(
+                id="PB",
+                sender=BOB,
+                type=EventTypes.PowerLevels,
+                state_key='',
+                content={
+                    "users": {
+                        ALICE: 100,
+                        BOB: 50,
+                    },
+                },
+            ),
+            FakeEvent(
+                id="T3",
+                sender=BOB,
+                type=EventTypes.Topic,
+                state_key="",
+                content={},
+            ),
+        ]
+
+        edges = [
+            ["END", "PA2", "T2", "PA1", "T1", "START"],
+            ["END", "T3", "PB", "PA1"],
+        ]
+
+        expected_state_ids = ["PA2", "T2"]
+
+        self.do_check(events, edges, expected_state_ids)
+
+    def test_topic_reset(self):
+        events = [
+            FakeEvent(
+                id="T1",
+                sender=ALICE,
+                type=EventTypes.Topic,
+                state_key="",
+                content={},
+            ),
+            FakeEvent(
+                id="PA",
+                sender=ALICE,
+                type=EventTypes.PowerLevels,
+                state_key='',
+                content={
+                    "users": {
+                        ALICE: 100,
+                        BOB: 50,
+                    },
+                },
+            ),
+            FakeEvent(
+                id="T2",
+                sender=BOB,
+                type=EventTypes.Topic,
+                state_key="",
+                content={},
+            ),
+            FakeEvent(
+                id="MB",
+                sender=ALICE,
+                type=EventTypes.Member,
+                state_key=BOB,
+                content={"membership": Membership.BAN},
+            ),
+        ]
+
+        edges = [
+            ["END", "MB", "T2", "PA", "T1", "START"],
+            ["END", "T1"],
+        ]
+
+        expected_state_ids = ["T1", "MB", "PA"]
+
+        self.do_check(events, edges, expected_state_ids)
+
+    def test_topic(self):
+        events = [
+            FakeEvent(
+                id="T1",
+                sender=ALICE,
+                type=EventTypes.Topic,
+                state_key="",
+                content={},
+            ),
+            FakeEvent(
+                id="PA1",
+                sender=ALICE,
+                type=EventTypes.PowerLevels,
+                state_key='',
+                content={
+                    "users": {
+                        ALICE: 100,
+                        BOB: 50,
+                    },
+                },
+            ),
+            FakeEvent(
+                id="T2",
+                sender=ALICE,
+                type=EventTypes.Topic,
+                state_key="",
+                content={},
+            ),
+            FakeEvent(
+                id="PA2",
+                sender=ALICE,
+                type=EventTypes.PowerLevels,
+                state_key='',
+                content={
+                    "users": {
+                        ALICE: 100,
+                        BOB: 0,
+                    },
+                },
+            ),
+            FakeEvent(
+                id="PB",
+                sender=BOB,
+                type=EventTypes.PowerLevels,
+                state_key='',
+                content={
+                    "users": {
+                        ALICE: 100,
+                        BOB: 50,
+                    },
+                },
+            ),
+            FakeEvent(
+                id="T3",
+                sender=BOB,
+                type=EventTypes.Topic,
+                state_key="",
+                content={},
+            ),
+            FakeEvent(
+                id="MZ1",
+                sender=ZARA,
+                type=EventTypes.Message,
+                state_key=None,
+                content={},
+            ),
+            FakeEvent(
+                id="T4",
+                sender=ALICE,
+                type=EventTypes.Topic,
+                state_key="",
+                content={},
+            ),
+        ]
+
+        edges = [
+            ["END", "T4", "MZ1", "PA2", "T2", "PA1", "T1", "START"],
+            ["END", "MZ1", "T3", "PB", "PA1"],
+        ]
+
+        expected_state_ids = ["T4", "PA2"]
+
+        self.do_check(events, edges, expected_state_ids)
+
+    def do_check(self, events, edges, expected_state_ids):
+        """Take a list of events and edges and calculate the state of the
+        graph at END, and asserts it matches `expected_state_ids`
+
+        Args:
+            events (list[FakeEvent])
+            edges (list[list[str]]): A list of chains of event edges, e.g.
+                `[[A, B, C]]` are edges A->B and B->C.
+            expected_state_ids (list[str]): The expected state at END, (excluding
+                the keys that haven't changed since START).
+        """
+        # We want to sort the events into topological order for processing.
+        graph = {}
+
+        # node_id -> FakeEvent
+        fake_event_map = {}
+
+        for ev in itertools.chain(INITIAL_EVENTS, events):
+            graph[ev.node_id] = set()
+            fake_event_map[ev.node_id] = ev
+
+        for a, b in pairwise(INITIAL_EDGES):
+            graph[a].add(b)
+
+        for edge_list in edges:
+            for a, b in pairwise(edge_list):
+                graph[a].add(b)
+
+        # event_id -> FrozenEvent
+        event_map = {}
+        # node_id -> state
+        state_at_event = {}
+
+        # We copy the map as the sort consumes the graph
+        graph_copy = {k: set(v) for k, v in graph.items()}
+
+        for node_id in lexicographical_topological_sort(graph_copy, key=lambda e: e):
+            fake_event = fake_event_map[node_id]
+            event_id = fake_event.event_id
+
+            prev_events = list(graph[node_id])
+
+            if len(prev_events) == 0:
+                state_before = {}
+            elif len(prev_events) == 1:
+                state_before = dict(state_at_event[prev_events[0]])
+            else:
+                state_d = resolve_events_with_store(
+                    [state_at_event[n] for n in prev_events],
+                    event_map=event_map,
+                    state_res_store=TestStateResolutionStore(event_map),
+                )
+
+                self.assertTrue(state_d.called)
+                state_before = state_d.result
+
+            state_after = dict(state_before)
+            if fake_event.state_key is not None:
+                state_after[(fake_event.type, fake_event.state_key)] = event_id
+
+            auth_types = set(auth_types_for_event(fake_event))
+
+            auth_events = []
+            for key in auth_types:
+                if key in state_before:
+                    auth_events.append(state_before[key])
+
+            event = fake_event.to_event(auth_events, prev_events)
+
+            state_at_event[node_id] = state_after
+            event_map[event_id] = event
+
+        expected_state = {}
+        for node_id in expected_state_ids:
+            # expected_state_ids are node IDs rather than event IDs,
+            # so we have to convert
+            event_id = EventID(node_id, "example.com").to_string()
+            event = event_map[event_id]
+
+            key = (event.type, event.state_key)
+
+            expected_state[key] = event_id
+
+        start_state = state_at_event["START"]
+        end_state = {
+            key: value
+            for key, value in state_at_event["END"].items()
+            if key in expected_state or start_state.get(key) != value
+        }
+
+        self.assertEqual(expected_state, end_state)
+
+
+class LexicographicalTestCase(unittest.TestCase):
+    def test_simple(self):
+        graph = {
+            "l": {"o"},
+            "m": {"n", "o"},
+            "n": {"o"},
+            "o": set(),
+            "p": {"o"},
+        }
+
+        res = list(lexicographical_topological_sort(graph, key=lambda x: x))
+
+        self.assertEqual(["o", "l", "n", "m", "p"], res)
+
+
+def pairwise(iterable):
+    "s -> (s0,s1), (s1,s2), (s2, s3), ..."
+    a, b = itertools.tee(iterable)
+    next(b, None)
+    return zip(a, b)
+
+
+@attr.s
+class TestStateResolutionStore(object):
+    event_map = attr.ib()
+
+    def get_events(self, event_ids, allow_rejected=False):
+        """Get events from the database
+
+        Args:
+            event_ids (list): The event_ids of the events to fetch
+            allow_rejected (bool): If True return rejected events.
+
+        Returns:
+            Deferred[dict[str, FrozenEvent]]: Dict from event_id to event.
+        """
+
+        return {
+            eid: self.event_map[eid]
+            for eid in event_ids
+            if eid in self.event_map
+        }
+
+    def get_auth_chain(self, event_ids):
+        """Gets the full auth chain for a set of events (including rejected
+        events).
+
+        Includes the given event IDs in the result.
+
+        Note that:
+            1. All events must be state events.
+            2. For v1 rooms this may not have the full auth chain in the
+               presence of rejected events
+
+        Args:
+            event_ids (list): The event IDs of the events to fetch the auth
+                chain for. Must be state events.
+
+        Returns:
+            Deferred[list[str]]: List of event IDs of the auth chain.
+        """
+
+        # Simple DFS for auth chain
+        result = set()
+        stack = list(event_ids)
+        while stack:
+            event_id = stack.pop()
+            if event_id in result:
+                continue
+
+            result.add(event_id)
+
+            event = self.event_map[event_id]
+            for aid, _ in event.auth_events:
+                stack.append(aid)
+
+        return list(result)