| diff --git a/synapse/state.py b/synapse/state.py
index cc93bbcb6b..932f602508 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -350,7 +350,7 @@ class StateHandler(object):
             ))
 
         result = yield self._state_resolution_handler.resolve_state_groups(
-            room_id, state_groups_ids, self._state_map_factory,
+            room_id, state_groups_ids, None, self._state_map_factory,
         )
         defer.returnValue(result)
 
@@ -413,7 +413,9 @@ class StateResolutionHandler(object):
 
     @defer.inlineCallbacks
     @log_function
-    def resolve_state_groups(self, room_id, state_groups_ids, state_map_factory):
+    def resolve_state_groups(
+        self, room_id, state_groups_ids, event_map, state_map_factory,
+    ):
         """Resolves conflicts between a set of state groups
 
         Always generates a new state group (unless we hit the cache), so should
@@ -425,6 +427,14 @@ class StateResolutionHandler(object):
                  map from state group id to the state in that state group
                 (where 'state' is a map from state key to event id)
 
+            event_map(dict[str,FrozenEvent]|None):
+                a dict from event_id to event, for any events that we happen to
+                have in flight (eg, those currently being persisted). This will be
+                used as a starting point fof finding the state we need; any missing
+                events will be requested via state_map_factory.
+
+                If None, all events will be fetched via state_map_factory.
+
         Returns:
             Deferred[_StateCacheEntry]: resolved state
         """
@@ -465,6 +475,7 @@ class StateResolutionHandler(object):
                 with Measure(self.clock, "state._resolve_events"):
                     new_state = yield resolve_events_with_factory(
                         state_groups_ids.values(),
+                        event_map=event_map,
                         state_map_factory=state_map_factory,
                     )
             else:
@@ -597,11 +608,20 @@ def _seperate(state_sets):
 
 
 @defer.inlineCallbacks
-def resolve_events_with_factory(state_sets, state_map_factory):
+def resolve_events_with_factory(state_sets, event_map, state_map_factory):
     """
     Args:
         state_sets(list): List of dicts of (type, state_key) -> event_id,
             which are the different state groups to resolve.
+
+        event_map(dict[str,FrozenEvent]|None):
+            a dict from event_id to event, for any events that we happen to
+            have in flight (eg, those currently being persisted). This will be
+            used as a starting point fof finding the state we need; any missing
+            events will be requested via state_map_factory.
+
+            If None, all events will be fetched via state_map_factory.
+
         state_map_factory(func): will be called
             with a list of event_ids that are needed, and should return with
             a Deferred of dict of event_id to event.
@@ -622,12 +642,16 @@ def resolve_events_with_factory(state_sets, state_map_factory):
         for event_ids in conflicted_state.itervalues()
         for event_id in event_ids
     )
+    if event_map is not None:
+        needed_events -= set(event_map.iterkeys())
 
     logger.info("Asking for %d conflicted events", len(needed_events))
 
     # dict[str, FrozenEvent]: a map from state event id to event. Only includes
-    # the state events which are in conflict.
+    # the state events which are in conflict (and those in event_map)
     state_map = yield state_map_factory(needed_events)
+    if event_map is not None:
+        state_map.update(event_map)
 
     # get the ids of the auth events which allow us to authenticate the
     # conflicted state, picking only from the unconflicting state.
@@ -639,6 +663,8 @@ def resolve_events_with_factory(state_sets, state_map_factory):
 
     new_needed_events = set(auth_events.itervalues())
     new_needed_events -= needed_events
+    if event_map is not None:
+        new_needed_events -= set(event_map.iterkeys())
 
     logger.info("Asking for %d auth events", len(new_needed_events))
 |