| diff --git a/synapse/state.py b/synapse/state.py
index 4c8247e7c2..cc93bbcb6b 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -58,7 +58,11 @@ class _StateCacheEntry(object):
     __slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"]
 
     def __init__(self, state, state_group, prev_group=None, delta_ids=None):
+        # dict[(str, str), str] map  from (type, state_key) to event_id
         self.state = frozendict(state)
+
+        # the ID of a state group if one and only one is involved.
+        # otherwise, None otherwise?
         self.state_group = state_group
 
         self.prev_group = prev_group
@@ -81,31 +85,19 @@ class _StateCacheEntry(object):
 
 
 class StateHandler(object):
-    """ Responsible for doing state conflict resolution.
+    """Fetches bits of state from the stores, and does state resolution
+    where necessary
     """
 
     def __init__(self, hs):
         self.clock = hs.get_clock()
         self.store = hs.get_datastore()
         self.hs = hs
-
-        # dict of set of event_ids -> _StateCacheEntry.
-        self._state_cache = None
-        self.resolve_linearizer = Linearizer(name="state_resolve_lock")
+        self._state_resolution_handler = hs.get_state_resolution_handler()
 
     def start_caching(self):
-        logger.debug("start_caching")
-
-        self._state_cache = ExpiringCache(
-            cache_name="state_cache",
-            clock=self.clock,
-            max_len=SIZE_OF_CACHE,
-            expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
-            iterable=True,
-            reset_expiry_on_get=True,
-        )
-
-        self._state_cache.start()
+        # TODO: remove this shim
+        self._state_resolution_handler.start_caching()
 
     @defer.inlineCallbacks
     def get_current_state(self, room_id, event_type=None, state_key="",
@@ -127,7 +119,7 @@ class StateHandler(object):
             latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
 
         logger.debug("calling resolve_state_groups from get_current_state")
-        ret = yield self.resolve_state_groups(room_id, latest_event_ids)
+        ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
         state = ret.state
 
         if event_type:
@@ -164,7 +156,7 @@ class StateHandler(object):
             latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
 
         logger.debug("calling resolve_state_groups from get_current_state_ids")
-        ret = yield self.resolve_state_groups(room_id, latest_event_ids)
+        ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
         state = ret.state
 
         defer.returnValue(state)
@@ -174,7 +166,7 @@ class StateHandler(object):
         if not latest_event_ids:
             latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
         logger.debug("calling resolve_state_groups from get_current_user_in_room")
-        entry = yield self.resolve_state_groups(room_id, latest_event_ids)
+        entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
         joined_users = yield self.store.get_joined_users_from_state(room_id, entry)
         defer.returnValue(joined_users)
 
@@ -183,7 +175,7 @@ class StateHandler(object):
         if not latest_event_ids:
             latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
         logger.debug("calling resolve_state_groups from get_current_hosts_in_room")
-        entry = yield self.resolve_state_groups(room_id, latest_event_ids)
+        entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
         joined_hosts = yield self.store.get_joined_hosts(room_id, entry)
         defer.returnValue(joined_hosts)
 
@@ -191,8 +183,15 @@ class StateHandler(object):
     def compute_event_context(self, event, old_state=None):
         """Build an EventContext structure for the event.
 
+        This works out what the current state should be for the event, and
+        generates a new state group if necessary.
+
         Args:
             event (synapse.events.EventBase):
+            old_state (dict|None): The state at the event if it can't be
+                calculated from existing events. This is normally only specified
+                when receiving an event from federation where we don't have the
+                prev events for, e.g. when backfilling.
         Returns:
             synapse.events.snapshot.EventContext:
         """
@@ -216,15 +215,22 @@ class StateHandler(object):
                 context.current_state_ids = {}
                 context.prev_state_ids = {}
             context.prev_state_events = []
-            context.state_group = self.store.get_next_state_group()
+
+            # We don't store state for outliers, so we don't generate a state
+            # froup for it.
+            context.state_group = None
+
             defer.returnValue(context)
 
         if old_state:
+            # We already have the state, so we don't need to calculate it.
+            # Let's just correctly fill out the context and create a
+            # new state group for it.
+
             context = EventContext()
             context.prev_state_ids = {
                 (s.type, s.state_key): s.event_id for s in old_state
             }
-            context.state_group = self.store.get_next_state_group()
 
             if event.is_state():
                 key = (event.type, event.state_key)
@@ -237,11 +243,19 @@ class StateHandler(object):
             else:
                 context.current_state_ids = context.prev_state_ids
 
+            context.state_group = yield self.store.store_state_group(
+                event.event_id,
+                event.room_id,
+                prev_group=None,
+                delta_ids=None,
+                current_state_ids=context.current_state_ids,
+            )
+
             context.prev_state_events = []
             defer.returnValue(context)
 
         logger.debug("calling resolve_state_groups from compute_event_context")
-        entry = yield self.resolve_state_groups(
+        entry = yield self.resolve_state_groups_for_events(
             event.room_id, [e for e, _ in event.prev_events],
         )
 
@@ -250,7 +264,8 @@ class StateHandler(object):
         context = EventContext()
         context.prev_state_ids = curr_state
         if event.is_state():
-            context.state_group = self.store.get_next_state_group()
+            # If this is a state event then we need to create a new state
+            # group for the state after this event.
 
             key = (event.type, event.state_key)
             if key in context.prev_state_ids:
@@ -261,38 +276,57 @@ class StateHandler(object):
             context.current_state_ids[key] = event.event_id
 
             if entry.state_group:
+                # If the state at the event has a state group assigned then
+                # we can use that as the prev group
                 context.prev_group = entry.state_group
                 context.delta_ids = {
                     key: event.event_id
                 }
             elif entry.prev_group:
+                # If the state at the event only has a prev group, then we can
+                # use that as a prev group too.
                 context.prev_group = entry.prev_group
                 context.delta_ids = dict(entry.delta_ids)
                 context.delta_ids[key] = event.event_id
+
+            context.state_group = yield self.store.store_state_group(
+                event.event_id,
+                event.room_id,
+                prev_group=context.prev_group,
+                delta_ids=context.delta_ids,
+                current_state_ids=context.current_state_ids,
+            )
         else:
+            context.current_state_ids = context.prev_state_ids
+            context.prev_group = entry.prev_group
+            context.delta_ids = entry.delta_ids
+
             if entry.state_group is None:
-                entry.state_group = self.store.get_next_state_group()
+                entry.state_group = yield self.store.store_state_group(
+                    event.event_id,
+                    event.room_id,
+                    prev_group=entry.prev_group,
+                    delta_ids=entry.delta_ids,
+                    current_state_ids=context.current_state_ids,
+                )
                 entry.state_id = entry.state_group
 
             context.state_group = entry.state_group
-            context.current_state_ids = context.prev_state_ids
-            context.prev_group = entry.prev_group
-            context.delta_ids = entry.delta_ids
 
         context.prev_state_events = []
         defer.returnValue(context)
 
     @defer.inlineCallbacks
-    @log_function
-    def resolve_state_groups(self, room_id, event_ids):
+    def resolve_state_groups_for_events(self, room_id, event_ids):
         """ Given a list of event_ids this method fetches the state at each
         event, resolves conflicts between them and returns them.
 
+        Args:
+            room_id (str):
+            event_ids (list[str]):
+
         Returns:
-            a Deferred tuple of (`state_group`, `state`, `prev_state`).
-            `state_group` is the name of a state group if one and only one is
-            involved. `state` is a map from (type, state_key) to event, and
-            `prev_state` is a list of event ids.
+            Deferred[_StateCacheEntry]: resolved state
         """
         logger.debug("resolve_state_groups event_ids %s", event_ids)
 
@@ -303,13 +337,7 @@ class StateHandler(object):
             room_id, event_ids
         )
 
-        logger.debug(
-            "resolve_state_groups state_groups %s",
-            state_groups_ids.keys()
-        )
-
-        group_names = frozenset(state_groups_ids.keys())
-        if len(group_names) == 1:
+        if len(state_groups_ids) == 1:
             name, state_list = state_groups_ids.items().pop()
 
             prev_group, delta_ids = yield self.store.get_state_group_delta(name)
@@ -321,6 +349,92 @@ class StateHandler(object):
                 delta_ids=delta_ids,
             ))
 
+        result = yield self._state_resolution_handler.resolve_state_groups(
+            room_id, state_groups_ids, self._state_map_factory,
+        )
+        defer.returnValue(result)
+
+    def _state_map_factory(self, ev_ids):
+        return self.store.get_events(
+            ev_ids, get_prev_content=False, check_redacted=False,
+        )
+
+    def resolve_events(self, state_sets, event):
+        logger.info(
+            "Resolving state for %s with %d groups", event.room_id, len(state_sets)
+        )
+        state_set_ids = [{
+            (ev.type, ev.state_key): ev.event_id
+            for ev in st
+        } for st in state_sets]
+
+        state_map = {
+            ev.event_id: ev
+            for st in state_sets
+            for ev in st
+        }
+
+        with Measure(self.clock, "state._resolve_events"):
+            new_state = resolve_events_with_state_map(state_set_ids, state_map)
+
+        new_state = {
+            key: state_map[ev_id] for key, ev_id in new_state.items()
+        }
+
+        return new_state
+
+
+class StateResolutionHandler(object):
+    """Responsible for doing state conflict resolution.
+
+    Note that the storage layer depends on this handler, so all functions must
+    be storage-independent.
+    """
+    def __init__(self, hs):
+        self.clock = hs.get_clock()
+
+        # dict of set of event_ids -> _StateCacheEntry.
+        self._state_cache = None
+        self.resolve_linearizer = Linearizer(name="state_resolve_lock")
+
+    def start_caching(self):
+        logger.debug("start_caching")
+
+        self._state_cache = ExpiringCache(
+            cache_name="state_cache",
+            clock=self.clock,
+            max_len=SIZE_OF_CACHE,
+            expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
+            iterable=True,
+            reset_expiry_on_get=True,
+        )
+
+        self._state_cache.start()
+
+    @defer.inlineCallbacks
+    @log_function
+    def resolve_state_groups(self, room_id, state_groups_ids, state_map_factory):
+        """Resolves conflicts between a set of state groups
+
+        Always generates a new state group (unless we hit the cache), so should
+        not be called for a single state group
+
+        Args:
+            room_id (str): room we are resolving for (used for logging)
+            state_groups_ids (dict[int, dict[(str, str), str]]):
+                 map from state group id to the state in that state group
+                (where 'state' is a map from state key to event id)
+
+        Returns:
+            Deferred[_StateCacheEntry]: resolved state
+        """
+        logger.debug(
+            "resolve_state_groups state_groups %s",
+            state_groups_ids.keys()
+        )
+
+        group_names = frozenset(state_groups_ids.keys())
+
         with (yield self.resolve_linearizer.queue(group_names)):
             if self._state_cache is not None:
                 cache = self._state_cache.get(group_names, None)
@@ -351,15 +465,17 @@ class StateHandler(object):
                 with Measure(self.clock, "state._resolve_events"):
                     new_state = yield resolve_events_with_factory(
                         state_groups_ids.values(),
-                        state_map_factory=lambda ev_ids: self.store.get_events(
-                            ev_ids, get_prev_content=False, check_redacted=False,
-                        ),
+                        state_map_factory=state_map_factory,
                     )
             else:
                 new_state = {
                     key: e_ids.pop() for key, e_ids in state.items()
                 }
 
+            # if the new state matches any of the input state groups, we can
+            # use that state group again. Otherwise we will generate a state_id
+            # which will be used as a cache key for future resolutions, but
+            # not get persisted.
             state_group = None
             new_state_event_ids = frozenset(new_state.values())
             for sg, events in state_groups_ids.items():
@@ -396,30 +512,6 @@ class StateHandler(object):
 
             defer.returnValue(cache)
 
-    def resolve_events(self, state_sets, event):
-        logger.info(
-            "Resolving state for %s with %d groups", event.room_id, len(state_sets)
-        )
-        state_set_ids = [{
-            (ev.type, ev.state_key): ev.event_id
-            for ev in st
-        } for st in state_sets]
-
-        state_map = {
-            ev.event_id: ev
-            for st in state_sets
-            for ev in st
-        }
-
-        with Measure(self.clock, "state._resolve_events"):
-            new_state = resolve_events_with_state_map(state_set_ids, state_map)
-
-        new_state = {
-            key: state_map[ev_id] for key, ev_id in new_state.items()
-        }
-
-        return new_state
-
 
 def _ordered_events(events):
     def key_func(e):
@@ -437,8 +529,8 @@ def resolve_events_with_state_map(state_sets, state_map):
             state_sets.
 
     Returns
-        dict[(str, str), synapse.events.FrozenEvent]:
-            a map from (type, state_key) to event.
+        dict[(str, str), str]:
+            a map from (type, state_key) to event_id.
     """
     if len(state_sets) == 1:
         return state_sets[0]
@@ -460,6 +552,21 @@ def _seperate(state_sets):
     """Takes the state_sets and figures out which keys are conflicted and
     which aren't. i.e., which have multiple different event_ids associated
     with them in different state sets.
+
+    Args:
+        state_sets(list[dict[(str, str), str]]):
+            List of dicts of (type, state_key) -> event_id, which are the
+            different state groups to resolve.
+
+    Returns:
+        (dict[(str, str), str], dict[(str, str), set[str]]):
+            A tuple of (unconflicted_state, conflicted_state), where:
+
+            unconflicted_state is a dict mapping (type, state_key)->event_id
+            for unconflicted state keys.
+
+            conflicted_state is a dict mapping (type, state_key) to a set of
+            event ids for conflicted state keys.
     """
     unconflicted_state = dict(state_sets[0])
     conflicted_state = {}
@@ -500,8 +607,8 @@ def resolve_events_with_factory(state_sets, state_map_factory):
             a Deferred of dict of event_id to event.
 
     Returns
-        Deferred[dict[(str, str), synapse.events.FrozenEvent]]:
-            a map from (type, state_key) to event.
+        Deferred[dict[(str, str), str]]:
+            a map from (type, state_key) to event_id.
     """
     if len(state_sets) == 1:
         defer.returnValue(state_sets[0])
 |