diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index f9568638a1..b090751bf1 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -228,6 +228,25 @@ class EventContext(object):
else:
self._prev_state_ids = self._current_state_ids
+ @defer.inlineCallbacks
+ def update_state(self, state_group, prev_state_ids, current_state_ids,
+ delta_ids):
+ """Replace the state in the context
+ """
+
+ # We need to make sure we wait for any ongoing fetching of state
+ # to complete so that the updated state doesn't get clobbered
+ if self._fetching_state_deferred:
+ yield make_deferred_yieldable(self._fetching_state_deferred)
+
+ self.state_group = state_group
+ self._prev_state_ids = prev_state_ids
+ self._current_state_ids = current_state_ids
+ self.delta_ids = delta_ids
+
+ # We need to ensure that that we've marked as having fetched the state
+ self._fetching_state_deferred = defer.succeed(None)
+
def _encode_state_dict(state_dict):
"""Since dicts of (type, state_key) -> event_id cannot be serialized in
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 98dd4a7fd1..14654d59f1 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1975,21 +1975,35 @@ class FederationHandler(BaseHandler):
k: a.event_id for k, a in iteritems(auth_events)
if k != event_key
}
- context.current_state_ids = dict(context.current_state_ids)
- context.current_state_ids.update(state_updates)
+ current_state_ids = yield context.get_current_state_ids(self.store)
+ current_state_ids = dict(current_state_ids)
+
+ current_state_ids.update(state_updates)
+
if context.delta_ids is not None:
- context.delta_ids = dict(context.delta_ids)
- context.delta_ids.update(state_updates)
- context.prev_state_ids = dict(context.prev_state_ids)
- context.prev_state_ids.update({
+ delta_ids = dict(context.delta_ids)
+ delta_ids.update(state_updates)
+
+ prev_state_ids = yield context.get_prev_state_ids(self.store)
+ prev_state_ids = dict(prev_state_ids)
+
+ prev_state_ids.update({
k: a.event_id for k, a in iteritems(auth_events)
})
- context.state_group = yield self.store.store_state_group(
+
+ state_group = yield self.store.store_state_group(
event.event_id,
event.room_id,
prev_group=context.prev_group,
- delta_ids=context.delta_ids,
- current_state_ids=context.current_state_ids,
+ delta_ids=delta_ids,
+ current_state_ids=current_state_ids,
+ )
+
+ yield context.update_state(
+ state_group=state_group,
+ current_state_ids=current_state_ids,
+ prev_state_ids=prev_state_ids,
+ delta_ids=delta_ids,
)
@defer.inlineCallbacks
|