diff --git a/synapse/state.py b/synapse/state.py
index 273f9911ca..6c2aaa5e7a 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -308,7 +308,7 @@ class StateHandler(object):
))
result = yield self._state_resolution_handler.resolve_state_groups(
- room_id, state_groups_ids, self._state_map_factory,
+ room_id, state_groups_ids, None, self._state_map_factory,
)
defer.returnValue(result)
@@ -371,7 +371,9 @@ class StateResolutionHandler(object):
@defer.inlineCallbacks
@log_function
- def resolve_state_groups(self, room_id, state_groups_ids, state_map_factory):
+ def resolve_state_groups(
+ self, room_id, state_groups_ids, event_map, state_map_factory,
+ ):
"""Resolves conflicts between a set of state groups
Always generates a new state group (unless we hit the cache), so should
@@ -383,6 +385,14 @@ class StateResolutionHandler(object):
map from state group id to the state in that state group
(where 'state' is a map from state key to event id)
+ event_map(dict[str,FrozenEvent]|None):
+ a dict from event_id to event, for any events that we happen to
+ have in flight (eg, those currently being persisted). This will be
+ used as a starting point fof finding the state we need; any missing
+ events will be requested via state_map_factory.
+
+ If None, all events will be fetched via state_map_factory.
+
Returns:
Deferred[_StateCacheEntry]: resolved state
"""
@@ -423,6 +433,7 @@ class StateResolutionHandler(object):
with Measure(self.clock, "state._resolve_events"):
new_state = yield resolve_events_with_factory(
state_groups_ids.values(),
+ event_map=event_map,
state_map_factory=state_map_factory,
)
else:
@@ -555,11 +566,20 @@ def _seperate(state_sets):
@defer.inlineCallbacks
-def resolve_events_with_factory(state_sets, state_map_factory):
+def resolve_events_with_factory(state_sets, event_map, state_map_factory):
"""
Args:
state_sets(list): List of dicts of (type, state_key) -> event_id,
which are the different state groups to resolve.
+
+ event_map(dict[str,FrozenEvent]|None):
+ a dict from event_id to event, for any events that we happen to
+ have in flight (eg, those currently being persisted). This will be
+ used as a starting point fof finding the state we need; any missing
+ events will be requested via state_map_factory.
+
+ If None, all events will be fetched via state_map_factory.
+
state_map_factory(func): will be called
with a list of event_ids that are needed, and should return with
a Deferred of dict of event_id to event.
@@ -580,12 +600,16 @@ def resolve_events_with_factory(state_sets, state_map_factory):
for event_ids in conflicted_state.itervalues()
for event_id in event_ids
)
+ if event_map is not None:
+ needed_events -= set(event_map.iterkeys())
logger.info("Asking for %d conflicted events", len(needed_events))
# dict[str, FrozenEvent]: a map from state event id to event. Only includes
- # the state events which are in conflict.
+ # the state events which are in conflict (and those in event_map)
state_map = yield state_map_factory(needed_events)
+ if event_map is not None:
+ state_map.update(event_map)
# get the ids of the auth events which allow us to authenticate the
# conflicted state, picking only from the unconflicting state.
@@ -597,6 +621,8 @@ def resolve_events_with_factory(state_sets, state_map_factory):
new_needed_events = set(auth_events.itervalues())
new_needed_events -= needed_events
+ if event_map is not None:
+ new_needed_events -= set(event_map.iterkeys())
logger.info("Asking for %d auth events", len(new_needed_events))
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 2fead9eb0f..7b912ad413 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -586,6 +586,7 @@ class EventsStore(SQLBaseStore):
current_state = yield resolve_events_with_factory(
state_sets,
+ event_map={},
state_map_factory=get_events,
)
defer.returnValue(current_state)
|