diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 5fe4a0e56c..05cde96afc 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -22,7 +22,6 @@ import logging
import simplejson as json
from twisted.internet import defer
-
from synapse.storage.events_worker import EventsWorkerStore
from synapse.util.async import ObservableDeferred
from synapse.util.frozenutils import frozendict_json_encoder
@@ -425,7 +424,9 @@ class EventsStore(EventsWorkerStore):
)
current_state = yield self._get_new_state_after_events(
room_id,
- ev_ctx_rm, new_latest_event_ids,
+ ev_ctx_rm,
+ latest_event_ids,
+ new_latest_event_ids,
)
if current_state is not None:
current_state_for_room[room_id] = current_state
@@ -513,7 +514,8 @@ class EventsStore(EventsWorkerStore):
defer.returnValue(new_latest_event_ids)
@defer.inlineCallbacks
- def _get_new_state_after_events(self, room_id, events_context, new_latest_event_ids):
+ def _get_new_state_after_events(self, room_id, events_context, old_latest_event_ids,
+ new_latest_event_ids):
"""Calculate the current state dict after adding some new events to
a room
@@ -524,6 +526,9 @@ class EventsStore(EventsWorkerStore):
events_context (list[(EventBase, EventContext)]):
events and contexts which are being added to the room
+ old_latest_event_ids (iterable[str]):
+ the old forward extremities for the room.
+
new_latest_event_ids (iterable[str]):
the new forward extremities for the room.
@@ -534,64 +539,89 @@ class EventsStore(EventsWorkerStore):
"""
if not new_latest_event_ids:
- defer.returnValue({})
+ return
# map from state_group to ((type, key) -> event_id) state map
- state_groups = {}
- missing_event_ids = []
- was_updated = False
+ state_groups_map = {}
+ for ev, ctx in events_context:
+ if ctx.state_group is None:
+ # I don't think this can happen, but let's double-check
+ raise Exception(
+ "Context for new extremity event %s has no state "
+ "group" % (ev.event_id, ),
+ )
+
+ if ctx.state_group in state_groups_map:
+ continue
+
+ state_groups_map[ctx.state_group] = ctx.current_state_ids
+
+ # We need to map the event_ids to their state groups. First, let's
+ # check if the event is one we're persisting, in which case we can
+ # pull the state group from its context.
+ # Otherwise we need to pull the state group from the database.
+
+ # Set of events we need to fetch groups for. (We know none of the old
+ # extremities are going to be in events_context).
+ missing_event_ids = set(old_latest_event_ids)
+
+ event_id_to_state_group = {}
for event_id in new_latest_event_ids:
- # First search in the list of new events we're adding,
- # and then use the current state from that
+ # First search in the list of new events we're adding.
for ev, ctx in events_context:
if event_id == ev.event_id:
- if ctx.current_state_ids is None:
- raise Exception("Unknown current state")
-
- if ctx.state_group is None:
- # I don't think this can happen, but let's double-check
- raise Exception(
- "Context for new extremity event %s has no state "
- "group" % (event_id, ),
- )
-
- # If we've already seen the state group don't bother adding
- # it to the state sets again
- if ctx.state_group not in state_groups:
- state_groups[ctx.state_group] = ctx.current_state_ids
- if ctx.delta_ids or hasattr(ev, "state_key"):
- was_updated = True
+ event_id_to_state_group[event_id] = ctx.state_group
break
else:
# If we couldn't find it, then we'll need to pull
# the state from the database
- was_updated = True
- missing_event_ids.append(event_id)
-
- if not was_updated:
- return
+ missing_event_ids.add(event_id)
if missing_event_ids:
- # Now pull out the state for any missing events from DB
+ # Now pull out the state groups for any missing events from DB
event_to_groups = yield self._get_state_group_for_events(
missing_event_ids,
)
+ event_id_to_state_group.update(event_to_groups)
+
+ # State groups of old_latest_event_ids
+ old_state_groups = set(
+ event_id_to_state_group[evid] for evid in old_latest_event_ids
+ )
+
+ # State groups of new_latest_event_ids
+ new_state_groups = set(
+ event_id_to_state_group[evid] for evid in new_latest_event_ids
+ )
- groups = set(event_to_groups.itervalues()) - set(state_groups.iterkeys())
+ # If they old and new groups are the same then we don't need to do
+ # anything.
+ if old_state_groups == new_state_groups:
+ return
- if groups:
- group_to_state = yield self._get_state_for_groups(groups)
- state_groups.update(group_to_state)
+ # Now that we have calculated new_state_groups we need to get
+ # their state IDs so we can resolve to a single state set.
+ missing_state = new_state_groups - set(state_groups_map)
+ if missing_state:
+ group_to_state = yield self._get_state_for_groups(missing_state)
+ state_groups_map.update(group_to_state)
- if len(state_groups) == 1:
+ if len(new_state_groups) == 1:
# If there is only one state group, then we know what the current
# state is.
- defer.returnValue(state_groups.values()[0])
+ defer.returnValue(state_groups_map[new_state_groups.pop()])
+
+ # Ok, we need to defer to the state handler to resolve our state sets.
def get_events(ev_ids):
return self.get_events(
ev_ids, get_prev_content=False, check_redacted=False,
)
+
+ state_groups = {
+ sg: state_groups_map[sg] for sg in new_state_groups
+ }
+
events_map = {ev.event_id: ev for ev, _ in events_context}
logger.debug("calling resolve_state_groups from preserve_events")
res = yield self._state_resolution_handler.resolve_state_groups(
|