summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/state.py9
-rw-r--r--synapse/storage/events.py10
-rw-r--r--synapse/storage/state.py24
3 files changed, 24 insertions, 19 deletions
diff --git a/synapse/state.py b/synapse/state.py
index daec983dc9..147416fd81 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -160,14 +160,14 @@ class StateHandler(object):
             else:
                 context.current_state_ids = {}
             context.prev_state_events = []
-            context.state_group = None
+            context.state_group = self.store.get_next_state_group()
             defer.returnValue(context)
 
         if old_state:
             context.current_state_ids = {
                 (s.type, s.state_key): s.event_id for s in old_state
             }
-            context.state_group = None
+            context.state_group = self.store.get_next_state_group()
 
             if event.is_state():
                 key = (event.type, event.state_key)
@@ -193,7 +193,10 @@ class StateHandler(object):
         group, curr_state = ret
 
         context.current_state_ids = curr_state
-        context.state_group = group if not event.is_state() else None
+        if event.is_state() or group is None:
+            context.state_group = self.store.get_next_state_group()
+        else:
+            context.state_group = group
 
         if event.is_state():
             key = (event.type, event.state_key)
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index bc1bc97e19..1a7d4c5199 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -276,13 +276,6 @@ class EventsStore(SQLBaseStore):
                 events_and_contexts, stream_orderings
             ):
                 event.internal_metadata.stream_ordering = stream
-                # Assign a state group_id in case a new id is needed for
-                # this context. In theory we only need to assign this
-                # for contexts that have current_state and aren't outliers
-                # but that make the code more complicated. Assigning an ID
-                # per event only causes the state_group_ids to grow as fast
-                # as the stream_ordering so in practise shouldn't be a problem.
-                context.new_state_group_id = self._state_groups_id_gen.get_next()
 
             chunks = [
                 events_and_contexts[x:x + 100]
@@ -309,7 +302,6 @@ class EventsStore(SQLBaseStore):
         try:
             with self._stream_id_gen.get_next() as stream_ordering:
                 event.internal_metadata.stream_ordering = stream_ordering
-                context.new_state_group_id = self._state_groups_id_gen.get_next()
                 yield self.runInteraction(
                     "persist_event",
                     self._persist_event_txn,
@@ -523,7 +515,7 @@ class EventsStore(SQLBaseStore):
                 # Add an entry to the ex_outlier_stream table to replicate the
                 # change in outlier status to our workers.
                 stream_order = event.internal_metadata.stream_ordering
-                state_group_id = context.state_group or context.new_state_group_id
+                state_group_id = context.state_group
                 self._simple_insert_txn(
                     txn,
                     table="ex_outlier_stream",
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 0442353287..56bfdc0b55 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -83,6 +83,14 @@ class StateStore(SQLBaseStore):
             for group, event_id_map in group_to_ids.items()
         })
 
+    def _have_persisted_state_group_txn(self, txn, state_group):
+        txn.execute(
+            "SELECT count(*) FROM state_groups_state WHERE state_group = ?",
+            (state_group,)
+        )
+        row = txn.fetchone()
+        return row and row[0]
+
     def _store_mult_state_groups_txn(self, txn, events_and_contexts):
         state_groups = {}
         for event, context in events_and_contexts:
@@ -92,8 +100,10 @@ class StateStore(SQLBaseStore):
             if context.current_state_ids is None:
                 continue
 
-            if context.state_group is not None:
-                state_groups[event.event_id] = context.state_group
+            state_groups[event.event_id] = context.state_group
+
+            if self._have_persisted_state_group_txn(txn, context.state_group):
+                logger.info("Already persisted state_group: %r", context.state_group)
                 continue
 
             state_event_ids = dict(context.current_state_ids)
@@ -101,13 +111,11 @@ class StateStore(SQLBaseStore):
             if event.is_state():
                 state_event_ids[(event.type, event.state_key)] = event.event_id
 
-            state_group = context.new_state_group_id
-
             self._simple_insert_txn(
                 txn,
                 table="state_groups",
                 values={
-                    "id": state_group,
+                    "id": context.state_group,
                     "room_id": event.room_id,
                     "event_id": event.event_id,
                 },
@@ -118,7 +126,7 @@ class StateStore(SQLBaseStore):
                 table="state_groups_state",
                 values=[
                     {
-                        "state_group": state_group,
+                        "state_group": context.state_group,
                         "room_id": event.room_id,
                         "type": key[0],
                         "state_key": key[1],
@@ -127,7 +135,6 @@ class StateStore(SQLBaseStore):
                     for key, state_id in state_event_ids.items()
                 ],
             )
-            state_groups[event.event_id] = state_group
 
         self._simple_insert_many_txn(
             txn,
@@ -526,3 +533,6 @@ class StateStore(SQLBaseStore):
         return self.runInteraction(
             "get_all_new_state_groups", get_all_new_state_groups_txn
         )
+
+    def get_next_state_group(self):
+        return self._state_groups_id_gen.get_next()