summary refs log tree commit diff
path: root/synapse/storage/events.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/events.py')
-rw-r--r--synapse/storage/events.py90
1 files changed, 50 insertions, 40 deletions
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 5233430028..5f675ab09b 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -26,6 +26,7 @@ from synapse.api.constants import EventTypes
 from canonicaljson import encode_canonical_json
 from contextlib import contextmanager
 
+
 import logging
 import math
 import ujson as json
@@ -79,41 +80,57 @@ class EventsStore(SQLBaseStore):
                 len(events_and_contexts)
             )
 
+        state_group_id_manager = self._state_groups_id_gen.get_next_mult(
+            len(events_and_contexts)
+        )
         with stream_ordering_manager as stream_orderings:
-            for (event, _), stream in zip(events_and_contexts, stream_orderings):
-                event.internal_metadata.stream_ordering = stream
-
-            chunks = [
-                events_and_contexts[x:x + 100]
-                for x in xrange(0, len(events_and_contexts), 100)
-            ]
+            with state_group_id_manager as state_group_ids:
+                for (event, context), stream, state_group_id in zip(
+                    events_and_contexts, stream_orderings, state_group_ids
+                ):
+                    event.internal_metadata.stream_ordering = stream
+                    # Assign a state group_id in case a new id is needed for
+                    # this context. In theory we only need to assign this
+                    # for contexts that have current_state and aren't outliers
+                    # but that make the code more complicated. Assigning an ID
+                    # per event only causes the state_group_ids to grow as fast
+                    # as the stream_ordering so in practise shouldn't be a problem.
+                    context.new_state_group_id = state_group_id
+
+                chunks = [
+                    events_and_contexts[x:x + 100]
+                    for x in xrange(0, len(events_and_contexts), 100)
+                ]
 
-            for chunk in chunks:
-                # We can't easily parallelize these since different chunks
-                # might contain the same event. :(
-                yield self.runInteraction(
-                    "persist_events",
-                    self._persist_events_txn,
-                    events_and_contexts=chunk,
-                    backfilled=backfilled,
-                    is_new_state=is_new_state,
-                )
+                for chunk in chunks:
+                    # We can't easily parallelize these since different chunks
+                    # might contain the same event. :(
+                    yield self.runInteraction(
+                        "persist_events",
+                        self._persist_events_txn,
+                        events_and_contexts=chunk,
+                        backfilled=backfilled,
+                        is_new_state=is_new_state,
+                    )
 
     @defer.inlineCallbacks
     @log_function
     def persist_event(self, event, context,
                       is_new_state=True, current_state=None):
+
         try:
             with self._stream_id_gen.get_next() as stream_ordering:
-                event.internal_metadata.stream_ordering = stream_ordering
-                yield self.runInteraction(
-                    "persist_event",
-                    self._persist_event_txn,
-                    event=event,
-                    context=context,
-                    is_new_state=is_new_state,
-                    current_state=current_state,
-                )
+                with self._state_groups_id_gen.get_next() as state_group_id:
+                    event.internal_metadata.stream_ordering = stream_ordering
+                    context.new_state_group_id = state_group_id
+                    yield self.runInteraction(
+                        "persist_event",
+                        self._persist_event_txn,
+                        event=event,
+                        context=context,
+                        is_new_state=is_new_state,
+                        current_state=current_state,
+                    )
         except _RollbackButIsFineException:
             pass
 
@@ -178,7 +195,7 @@ class EventsStore(SQLBaseStore):
 
     @log_function
     def _persist_event_txn(self, txn, event, context,
-                           is_new_state=True, current_state=None):
+                           is_new_state, current_state):
         # We purposefully do this first since if we include a `current_state`
         # key, we *want* to update the `current_state_events` table
         if current_state:
@@ -215,7 +232,7 @@ class EventsStore(SQLBaseStore):
 
     @log_function
     def _persist_events_txn(self, txn, events_and_contexts, backfilled,
-                            is_new_state=True):
+                            is_new_state):
         depth_updates = {}
         for event, context in events_and_contexts:
             # Remove the any existing cache entries for the event_ids
@@ -282,9 +299,7 @@ class EventsStore(SQLBaseStore):
 
             outlier_persisted = have_persisted[event.event_id]
             if not event.internal_metadata.is_outlier() and outlier_persisted:
-                self._store_state_groups_txn(
-                    txn, event, context,
-                )
+                self._store_mult_state_groups_txn(txn, ((event, context),))
 
                 metadata_json = encode_json(
                     event.internal_metadata.get_dict()
@@ -310,19 +325,14 @@ class EventsStore(SQLBaseStore):
 
                 self._update_extremeties(txn, [event])
 
-        events_and_contexts = filter(
-            lambda ec: ec[0] not in to_remove,
-            events_and_contexts
-        )
+        events_and_contexts = [
+            ec for ec in events_and_contexts if ec[0] not in to_remove
+        ]
 
         if not events_and_contexts:
             return
 
-        self._store_mult_state_groups_txn(txn, [
-            (event, context)
-            for event, context in events_and_contexts
-            if not event.internal_metadata.is_outlier()
-        ])
+        self._store_mult_state_groups_txn(txn, events_and_contexts)
 
         self._handle_mult_prev_events(
             txn,