summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/events.py205
1 files changed, 118 insertions, 87 deletions
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 2fd9f4045b..599db4c9f0 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -284,71 +284,37 @@ class EventsStore(SQLBaseStore):
                 new_forward_extremeties = {}
                 current_state_for_room = {}
                 if not backfilled:
-                    # Work out the new "current state" for each room.
-                    # We do this by working out what the new extremities are and then
-                    # calculating the state from that.
-                    events_by_room = {}
-                    for event, context in chunk:
-                        events_by_room.setdefault(event.room_id, []).append(
-                            (event, context)
-                        )
-
-                    for room_id, ev_ctx_rm in events_by_room.items():
-                        # Work out new extremities by recursively adding and removing
-                        # the new events.
-                        latest_event_ids = yield self.get_latest_event_ids_in_room(
-                            room_id
-                        )
-                        new_latest_event_ids = yield self._calculate_new_extremeties(
-                            room_id, [ev for ev, _ in ev_ctx_rm]
-                        )
-
-                        if new_latest_event_ids == set(latest_event_ids):
-                            # No change in extremities, so no change in state
-                            continue
+                    with Measure(self._clock, "_calculate_state_and_extrem"):
+                        # Work out the new "current state" for each room.
+                        # We do this by working out what the new extremities are and then
+                        # calculating the state from that.
+                        events_by_room = {}
+                        for event, context in chunk:
+                            events_by_room.setdefault(event.room_id, []).append(
+                                (event, context)
+                            )
 
-                        new_forward_extremeties[room_id] = new_latest_event_ids
-
-                        # Now we need to work out the different state sets for
-                        # each state extremities
-                        state_sets = []
-                        missing_event_ids = []
-                        was_updated = False
-                        for event_id in new_latest_event_ids:
-                            # First search in the list of new events we're adding,
-                            # and then use the current state from that
-                            for ev, ctx in ev_ctx_rm:
-                                if event_id == ev.event_id:
-                                    if ctx.current_state_ids is None:
-                                        raise Exception("Unknown current state")
-                                    state_sets.append(ctx.current_state_ids)
-                                    if ctx.delta_ids or hasattr(ev, "state_key"):
-                                        was_updated = True
-                                    break
-                            else:
-                                # If we couldn't find it, then we'll need to pull
-                                # the state from the database
-                                was_updated = True
-                                missing_event_ids.append(event_id)
-
-                        if missing_event_ids:
-                            # Now pull out the state for any missing events from DB
-                            event_to_groups = yield self._get_state_group_for_events(
-                                missing_event_ids,
+                        for room_id, ev_ctx_rm in events_by_room.items():
+                            # Work out new extremities by recursively adding and removing
+                            # the new events.
+                            latest_event_ids = yield self.get_latest_event_ids_in_room(
+                                room_id
+                            )
+                            new_latest_event_ids = yield self._calculate_new_extremeties(
+                                room_id, [ev for ev, _ in ev_ctx_rm]
                             )
 
-                            groups = set(event_to_groups.values())
-                            group_to_state = yield self._get_state_for_groups(groups)
+                            if new_latest_event_ids == set(latest_event_ids):
+                                # No change in extremities, so no change in state
+                                continue
 
-                            state_sets.extend(group_to_state.values())
+                            new_forward_extremeties[room_id] = new_latest_event_ids
 
-                        if not new_latest_event_ids or was_updated:
-                            current_state_for_room[room_id] = yield resolve_events(
-                                state_sets,
-                                state_map_factory=lambda ev_ids: self.get_events(
-                                    ev_ids, get_prev_content=False, check_redacted=False,
-                                ),
+                            state = yield self._calculate_state_delta(
+                                room_id, ev_ctx_rm, new_latest_event_ids
                             )
+                            if state:
+                                current_state_for_room[room_id] = state
 
                 yield self.runInteraction(
                     "persist_events",
@@ -406,6 +372,91 @@ class EventsStore(SQLBaseStore):
         defer.returnValue(new_latest_event_ids)
 
     @defer.inlineCallbacks
+    def _calculate_state_delta(self, room_id, events_context, new_latest_event_ids):
+        """Calculate the new state deltas for a room.
+
+        Assumes that we are only persisting events for one room at a time.
+
+        Returns:
+            2-tuple (to_delete, to_insert) where both are state dicts, i.e.
+            (type, state_key) -> event_id. `to_delete` are the entreis to
+            first be deleted from current_state_events, `to_insert` are entries
+            to insert.
+            May return None if there are no changes to be applied.
+        """
+        # Now we need to work out the different state sets for
+        # each state extremities
+        state_sets = []
+        missing_event_ids = []
+        was_updated = False
+        for event_id in new_latest_event_ids:
+            # First search in the list of new events we're adding,
+            # and then use the current state from that
+            for ev, ctx in events_context:
+                if event_id == ev.event_id:
+                    if ctx.current_state_ids is None:
+                        raise Exception("Unknown current state")
+                    state_sets.append(ctx.current_state_ids)
+                    if ctx.delta_ids or hasattr(ev, "state_key"):
+                        was_updated = True
+                    break
+            else:
+                # If we couldn't find it, then we'll need to pull
+                # the state from the database
+                was_updated = True
+                missing_event_ids.append(event_id)
+
+        if missing_event_ids:
+            # Now pull out the state for any missing events from DB
+            event_to_groups = yield self._get_state_group_for_events(
+                missing_event_ids,
+            )
+
+            groups = set(event_to_groups.values())
+            group_to_state = yield self._get_state_for_groups(groups)
+
+            state_sets.extend(group_to_state.values())
+
+        if not new_latest_event_ids:
+            current_state = {}
+        elif was_updated:
+            current_state = yield resolve_events(
+                state_sets,
+                state_map_factory=lambda ev_ids: self.get_events(
+                    ev_ids, get_prev_content=False, check_redacted=False,
+                ),
+            )
+        else:
+            return
+
+        existing_state_rows = yield self._simple_select_list(
+            table="current_state_events",
+            keyvalues={"room_id": room_id},
+            retcols=["event_id", "type", "state_key"],
+            desc="_calculate_state_delta",
+        )
+
+        existing_events = set(row["event_id"] for row in existing_state_rows)
+        new_events = set(ev_id for ev_id in current_state.itervalues())
+        changed_events = existing_events ^ new_events
+
+        if not changed_events:
+            return
+
+        to_delete = {
+            (row["type"], row["state_key"]): row["event_id"]
+            for row in existing_state_rows
+            if row["event_id"] in changed_events
+        }
+        events_to_insert = (new_events - existing_events)
+        to_insert = {
+            key: ev_id for key, ev_id in current_state.iteritems()
+            if ev_id in events_to_insert
+        }
+
+        defer.returnValue((to_delete, to_insert))
+
+    @defer.inlineCallbacks
     def get_event(self, event_id, check_redacted=True,
                   get_prev_content=False, allow_rejected=False,
                   allow_none=False):
@@ -475,32 +526,13 @@ class EventsStore(SQLBaseStore):
         database before insertion. This is useful when retrying due to IntegrityError.
         """
         max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
-        for room_id, current_state in current_state_for_room.iteritems():
-            existing_state_rows = self._simple_select_list_txn(
-                txn,
-                table="current_state_events",
-                keyvalues={"room_id": room_id},
-                retcols=["event_id", "type", "state_key"],
-            )
-
-            # Figure out what has changed (if anything). Then we simply delete
-            # and readd the keys that have been changed.
-            # This saves us from deleting and reinserting thousands of rows for
-            # large rooms.
-            existing_events = set(row["event_id"] for row in existing_state_rows)
-            new_events = set(ev_id for ev_id in current_state.itervalues())
-            changed_events = existing_events ^ new_events
-            if changed_events:
+        for room_id, current_state_tuple in current_state_for_room.iteritems():
+                to_delete, to_insert = current_state_tuple
                 txn.executemany(
                     "DELETE FROM current_state_events WHERE event_id = ?",
-                    [(ev_id,) for ev_id in changed_events],
+                    [(ev_id,) for ev_id in to_delete.itervalues()],
                 )
 
-                events_to_insert = (new_events - existing_events)
-                to_insert = [
-                    (key, ev_id) for key, ev_id in current_state.iteritems()
-                    if ev_id in events_to_insert
-                ]
                 self._simple_insert_many_txn(
                     txn,
                     table="current_state_events",
@@ -511,7 +543,7 @@ class EventsStore(SQLBaseStore):
                             "type": key[0],
                             "state_key": key[1],
                         }
-                        for key, ev_id in to_insert
+                        for key, ev_id in to_insert.iteritems()
                     ],
                 )
 
@@ -523,13 +555,12 @@ class EventsStore(SQLBaseStore):
                 # and which we have added, then we invlidate the caches for all
                 # those users.
                 members_changed = set(
-                    row["state_key"] for row in existing_state_rows
-                    if row["event_id"] in changed_events
-                    and row["type"] == EventTypes.Member
+                    state_key for ev_type, state_key in to_delete.iterkeys()
+                    if ev_type == EventTypes.Member
                 )
                 members_changed.update(
-                    key[1] for key, event_id in to_insert
-                    if key[0] == EventTypes.Member
+                    state_key for ev_type, state_key in to_insert.iterkeys()
+                    if ev_type == EventTypes.Member
                 )
 
                 for member in members_changed: