summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/__init__.py2
-rw-r--r--synapse/storage/events.py90
-rw-r--r--synapse/storage/state.py16
3 files changed, 59 insertions, 49 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 250ba536ea..aaad38039e 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -116,7 +116,7 @@ class DataStore(RoomMemberStore, RoomStore,
         )
 
         self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
-        self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
+        self._state_groups_id_gen = StreamIdGenerator(db_conn, "state_groups", "id")
         self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
         self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id")
         self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 5233430028..5f675ab09b 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -26,6 +26,7 @@ from synapse.api.constants import EventTypes
 from canonicaljson import encode_canonical_json
 from contextlib import contextmanager
 
+
 import logging
 import math
 import ujson as json
@@ -79,41 +80,57 @@ class EventsStore(SQLBaseStore):
                 len(events_and_contexts)
             )
 
+        state_group_id_manager = self._state_groups_id_gen.get_next_mult(
+            len(events_and_contexts)
+        )
         with stream_ordering_manager as stream_orderings:
-            for (event, _), stream in zip(events_and_contexts, stream_orderings):
-                event.internal_metadata.stream_ordering = stream
-
-            chunks = [
-                events_and_contexts[x:x + 100]
-                for x in xrange(0, len(events_and_contexts), 100)
-            ]
+            with state_group_id_manager as state_group_ids:
+                for (event, context), stream, state_group_id in zip(
+                    events_and_contexts, stream_orderings, state_group_ids
+                ):
+                    event.internal_metadata.stream_ordering = stream
+                    # Assign a state group_id in case a new id is needed for
+                    # this context. In theory we only need to assign this
+                    # for contexts that have current_state and aren't outliers
+                    # but that make the code more complicated. Assigning an ID
+                    # per event only causes the state_group_ids to grow as fast
+                    # as the stream_ordering so in practise shouldn't be a problem.
+                    context.new_state_group_id = state_group_id
+
+                chunks = [
+                    events_and_contexts[x:x + 100]
+                    for x in xrange(0, len(events_and_contexts), 100)
+                ]
 
-            for chunk in chunks:
-                # We can't easily parallelize these since different chunks
-                # might contain the same event. :(
-                yield self.runInteraction(
-                    "persist_events",
-                    self._persist_events_txn,
-                    events_and_contexts=chunk,
-                    backfilled=backfilled,
-                    is_new_state=is_new_state,
-                )
+                for chunk in chunks:
+                    # We can't easily parallelize these since different chunks
+                    # might contain the same event. :(
+                    yield self.runInteraction(
+                        "persist_events",
+                        self._persist_events_txn,
+                        events_and_contexts=chunk,
+                        backfilled=backfilled,
+                        is_new_state=is_new_state,
+                    )
 
     @defer.inlineCallbacks
     @log_function
     def persist_event(self, event, context,
                       is_new_state=True, current_state=None):
+
         try:
             with self._stream_id_gen.get_next() as stream_ordering:
-                event.internal_metadata.stream_ordering = stream_ordering
-                yield self.runInteraction(
-                    "persist_event",
-                    self._persist_event_txn,
-                    event=event,
-                    context=context,
-                    is_new_state=is_new_state,
-                    current_state=current_state,
-                )
+                with self._state_groups_id_gen.get_next() as state_group_id:
+                    event.internal_metadata.stream_ordering = stream_ordering
+                    context.new_state_group_id = state_group_id
+                    yield self.runInteraction(
+                        "persist_event",
+                        self._persist_event_txn,
+                        event=event,
+                        context=context,
+                        is_new_state=is_new_state,
+                        current_state=current_state,
+                    )
         except _RollbackButIsFineException:
             pass
 
@@ -178,7 +195,7 @@ class EventsStore(SQLBaseStore):
 
     @log_function
     def _persist_event_txn(self, txn, event, context,
-                           is_new_state=True, current_state=None):
+                           is_new_state, current_state):
         # We purposefully do this first since if we include a `current_state`
         # key, we *want* to update the `current_state_events` table
         if current_state:
@@ -215,7 +232,7 @@ class EventsStore(SQLBaseStore):
 
     @log_function
     def _persist_events_txn(self, txn, events_and_contexts, backfilled,
-                            is_new_state=True):
+                            is_new_state):
         depth_updates = {}
         for event, context in events_and_contexts:
             # Remove the any existing cache entries for the event_ids
@@ -282,9 +299,7 @@ class EventsStore(SQLBaseStore):
 
             outlier_persisted = have_persisted[event.event_id]
             if not event.internal_metadata.is_outlier() and outlier_persisted:
-                self._store_state_groups_txn(
-                    txn, event, context,
-                )
+                self._store_mult_state_groups_txn(txn, ((event, context),))
 
                 metadata_json = encode_json(
                     event.internal_metadata.get_dict()
@@ -310,19 +325,14 @@ class EventsStore(SQLBaseStore):
 
                 self._update_extremeties(txn, [event])
 
-        events_and_contexts = filter(
-            lambda ec: ec[0] not in to_remove,
-            events_and_contexts
-        )
+        events_and_contexts = [
+            ec for ec in events_and_contexts if ec[0] not in to_remove
+        ]
 
         if not events_and_contexts:
             return
 
-        self._store_mult_state_groups_txn(txn, [
-            (event, context)
-            for event, context in events_and_contexts
-            if not event.internal_metadata.is_outlier()
-        ])
+        self._store_mult_state_groups_txn(txn, events_and_contexts)
 
         self._handle_mult_prev_events(
             txn,
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 02cefdff26..30d1060ecd 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -64,12 +64,12 @@ class StateStore(SQLBaseStore):
             for group, state_map in group_to_state.items()
         })
 
-    def _store_state_groups_txn(self, txn, event, context):
-        return self._store_mult_state_groups_txn(txn, [(event, context)])
-
     def _store_mult_state_groups_txn(self, txn, events_and_contexts):
         state_groups = {}
         for event, context in events_and_contexts:
+            if event.internal_metadata.is_outlier():
+                continue
+
             if context.current_state is None:
                 continue
 
@@ -82,7 +82,8 @@ class StateStore(SQLBaseStore):
             if event.is_state():
                 state_events[(event.type, event.state_key)] = event
 
-            state_group = self._state_groups_id_gen.get_next()
+            state_group = context.new_state_group_id
+
             self._simple_insert_txn(
                 txn,
                 table="state_groups",
@@ -114,11 +115,10 @@ class StateStore(SQLBaseStore):
             table="event_to_state_groups",
             values=[
                 {
-                    "state_group": state_groups[event.event_id],
-                    "event_id": event.event_id,
+                    "state_group": state_group_id,
+                    "event_id": event_id,
                 }
-                for event, context in events_and_contexts
-                if context.current_state is not None
+                for event_id, state_group_id in state_groups.items()
             ],
         )