| diff --git a/synapse/state.py b/synapse/state.py
index 9e624b4937..1f9abf9d3d 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -341,7 +341,7 @@ class StateHandler(object):
             if conflicted_state:
                 logger.info("Resolving conflicted state for %r", room_id)
                 with Measure(self.clock, "state._resolve_events"):
-                    new_state = yield resolve_events(
+                    new_state = yield resolve_events_with_factory(
                         state_groups_ids.values(),
                         state_map_factory=lambda ev_ids: self.store.get_events(
                             ev_ids, get_prev_content=False, check_redacted=False,
@@ -404,7 +404,7 @@ class StateHandler(object):
         }
 
         with Measure(self.clock, "state._resolve_events"):
-            new_state = resolve_events(state_set_ids, state_map)
+            new_state = resolve_events_with_state_map(state_set_ids, state_map)
 
         new_state = {
             key: state_map[ev_id] for key, ev_id in new_state.items()
@@ -420,19 +420,17 @@ def _ordered_events(events):
     return sorted(events, key=key_func)
 
 
-def resolve_events(state_sets, state_map_factory):
+def resolve_events_with_state_map(state_sets, state_map):
     """
     Args:
         state_sets(list): List of dicts of (type, state_key) -> event_id,
             which are the different state groups to resolve.
-        state_map_factory(dict|callable): If callable, then will be called
-            with a list of event_ids that are needed, and should return with
-            a Deferred of dict of event_id to event. Otherwise, should be
-            a dict from event_id to event of all events in state_sets.
+        state_map(dict): a dict from event_id to event, for all events in
+            state_sets.
 
     Returns
-        dict[(str, str), synapse.events.FrozenEvent] is a map from
-        (type, state_key) to event.
+        dict[(str, str), synapse.events.FrozenEvent]:
+            a map from (type, state_key) to event.
     """
     if len(state_sets) == 1:
         return state_sets[0]
@@ -441,13 +439,6 @@ def resolve_events(state_sets, state_map_factory):
         state_sets,
     )
 
-    if callable(state_map_factory):
-        return _resolve_with_state_fac(
-            unconflicted_state, conflicted_state, state_map_factory
-        )
-
-    state_map = state_map_factory
-
     auth_events = _create_auth_events_from_maps(
         unconflicted_state, conflicted_state, state_map
     )
@@ -491,8 +482,26 @@ def _seperate(state_sets):
 
 
 @defer.inlineCallbacks
-def _resolve_with_state_fac(unconflicted_state, conflicted_state,
-                            state_map_factory):
+def resolve_events_with_factory(state_sets, state_map_factory):
+    """
+    Args:
+        state_sets(list): List of dicts of (type, state_key) -> event_id,
+            which are the different state groups to resolve.
+        state_map_factory(func): will be called
+            with a list of event_ids that are needed, and should return with
+            a Deferred of dict of event_id to event.
+
+    Returns
+        Deferred[dict[(str, str), synapse.events.FrozenEvent]]:
+            a map from (type, state_key) to event.
+    """
+    if len(state_sets) == 1:
+        defer.returnValue(state_sets[0])
+
+    unconflicted_state, conflicted_state = _seperate(
+        state_sets,
+    )
+
     needed_events = set(
         event_id
         for event_ids in conflicted_state.itervalues()
 |