summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/storage/events.py104
1 files changed, 67 insertions, 37 deletions
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 5fe4a0e56c..05cde96afc 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -22,7 +22,6 @@ import logging
 import simplejson as json
 from twisted.internet import defer
 
-
 from synapse.storage.events_worker import EventsWorkerStore
 from synapse.util.async import ObservableDeferred
 from synapse.util.frozenutils import frozendict_json_encoder
@@ -425,7 +424,9 @@ class EventsStore(EventsWorkerStore):
                             )
                             current_state = yield self._get_new_state_after_events(
                                 room_id,
-                                ev_ctx_rm, new_latest_event_ids,
+                                ev_ctx_rm,
+                                latest_event_ids,
+                                new_latest_event_ids,
                             )
                             if current_state is not None:
                                 current_state_for_room[room_id] = current_state
@@ -513,7 +514,8 @@ class EventsStore(EventsWorkerStore):
         defer.returnValue(new_latest_event_ids)
 
     @defer.inlineCallbacks
-    def _get_new_state_after_events(self, room_id, events_context, new_latest_event_ids):
+    def _get_new_state_after_events(self, room_id, events_context, old_latest_event_ids,
+                                    new_latest_event_ids):
         """Calculate the current state dict after adding some new events to
         a room
 
@@ -524,6 +526,9 @@ class EventsStore(EventsWorkerStore):
             events_context (list[(EventBase, EventContext)]):
                 events and contexts which are being added to the room
 
+            old_latest_event_ids (iterable[str]):
+                the old forward extremities for the room.
+
             new_latest_event_ids (iterable[str]):
                 the new forward extremities for the room.
 
@@ -534,64 +539,89 @@ class EventsStore(EventsWorkerStore):
         """
 
         if not new_latest_event_ids:
-            defer.returnValue({})
+            return
 
         # map from state_group to ((type, key) -> event_id) state map
-        state_groups = {}
-        missing_event_ids = []
-        was_updated = False
+        state_groups_map = {}
+        for ev, ctx in events_context:
+            if ctx.state_group is None:
+                # I don't think this can happen, but let's double-check
+                raise Exception(
+                    "Context for new extremity event %s has no state "
+                    "group" % (ev.event_id, ),
+                )
+
+            if ctx.state_group in state_groups_map:
+                continue
+
+            state_groups_map[ctx.state_group] = ctx.current_state_ids
+
+        # We need to map the event_ids to their state groups. First, let's
+        # check if the event is one we're persisting, in which case we can
+        # pull the state group from its context.
+        # Otherwise we need to pull the state group from the database.
+
+        # Set of events we need to fetch groups for. (We know none of the old
+        # extremities are going to be in events_context).
+        missing_event_ids = set(old_latest_event_ids)
+
+        event_id_to_state_group = {}
         for event_id in new_latest_event_ids:
-            # First search in the list of new events we're adding,
-            # and then use the current state from that
+            # First search in the list of new events we're adding.
             for ev, ctx in events_context:
                 if event_id == ev.event_id:
-                    if ctx.current_state_ids is None:
-                        raise Exception("Unknown current state")
-
-                    if ctx.state_group is None:
-                        # I don't think this can happen, but let's double-check
-                        raise Exception(
-                            "Context for new extremity event %s has no state "
-                            "group" % (event_id, ),
-                        )
-
-                    # If we've already seen the state group don't bother adding
-                    # it to the state sets again
-                    if ctx.state_group not in state_groups:
-                        state_groups[ctx.state_group] = ctx.current_state_ids
-                        if ctx.delta_ids or hasattr(ev, "state_key"):
-                            was_updated = True
+                    event_id_to_state_group[event_id] = ctx.state_group
                     break
             else:
                 # If we couldn't find it, then we'll need to pull
                 # the state from the database
-                was_updated = True
-                missing_event_ids.append(event_id)
-
-        if not was_updated:
-            return
+                missing_event_ids.add(event_id)
 
         if missing_event_ids:
-            # Now pull out the state for any missing events from DB
+            # Now pull out the state groups for any missing events from DB
             event_to_groups = yield self._get_state_group_for_events(
                 missing_event_ids,
             )
+            event_id_to_state_group.update(event_to_groups)
+
+        # State groups of old_latest_event_ids
+        old_state_groups = set(
+            event_id_to_state_group[evid] for evid in old_latest_event_ids
+        )
+
+        # State groups of new_latest_event_ids
+        new_state_groups = set(
+            event_id_to_state_group[evid] for evid in new_latest_event_ids
+        )
 
-            groups = set(event_to_groups.itervalues()) - set(state_groups.iterkeys())
+        # If they old and new groups are the same then we don't need to do
+        # anything.
+        if old_state_groups == new_state_groups:
+            return
 
-            if groups:
-                group_to_state = yield self._get_state_for_groups(groups)
-                state_groups.update(group_to_state)
+        # Now that we have calculated new_state_groups we need to get
+        # their state IDs so we can resolve to a single state set.
+        missing_state = new_state_groups - set(state_groups_map)
+        if missing_state:
+            group_to_state = yield self._get_state_for_groups(missing_state)
+            state_groups_map.update(group_to_state)
 
-        if len(state_groups) == 1:
+        if len(new_state_groups) == 1:
             # If there is only one state group, then we know what the current
             # state is.
-            defer.returnValue(state_groups.values()[0])
+            defer.returnValue(state_groups_map[new_state_groups.pop()])
+
+        # Ok, we need to defer to the state handler to resolve our state sets.
 
         def get_events(ev_ids):
             return self.get_events(
                 ev_ids, get_prev_content=False, check_redacted=False,
             )
+
+        state_groups = {
+            sg: state_groups_map[sg] for sg in new_state_groups
+        }
+
         events_map = {ev.event_id: ev for ev, _ in events_context}
         logger.debug("calling resolve_state_groups from preserve_events")
         res = yield self._state_resolution_handler.resolve_state_groups(