summary refs log tree commit diff
path: root/synapse/storage/state.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/state.py')
-rw-r--r--synapse/storage/state.py28
1 files changed, 16 insertions, 12 deletions
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index b1d461fef5..ec551b0b4f 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -83,6 +83,14 @@ class StateStore(SQLBaseStore):
             for group, event_id_map in group_to_ids.items()
         })
 
+    def _have_persisted_state_group_txn(self, txn, state_group):
+        txn.execute(
+            "SELECT count(*) FROM state_groups WHERE id = ?",
+            (state_group,)
+        )
+        row = txn.fetchone()
+        return row and row[0]
+
     def _store_mult_state_groups_txn(self, txn, events_and_contexts):
         state_groups = {}
         for event, context in events_and_contexts:
@@ -92,22 +100,19 @@ class StateStore(SQLBaseStore):
             if context.current_state_ids is None:
                 continue
 
-            if context.state_group is not None:
-                state_groups[event.event_id] = context.state_group
+            state_groups[event.event_id] = context.state_group
+
+            if self._have_persisted_state_group_txn(txn, context.state_group):
+                logger.info("Already persisted state_group: %r", context.state_group)
                 continue
 
             state_event_ids = dict(context.current_state_ids)
 
-            if event.is_state():
-                state_event_ids[(event.type, event.state_key)] = event.event_id
-
-            state_group = context.new_state_group_id
-
             self._simple_insert_txn(
                 txn,
                 table="state_groups",
                 values={
-                    "id": state_group,
+                    "id": context.state_group,
                     "room_id": event.room_id,
                     "event_id": event.event_id,
                 },
@@ -118,7 +123,7 @@ class StateStore(SQLBaseStore):
                 table="state_groups_state",
                 values=[
                     {
-                        "state_group": state_group,
+                        "state_group": context.state_group,
                         "room_id": event.room_id,
                         "type": key[0],
                         "state_key": key[1],
@@ -127,7 +132,6 @@ class StateStore(SQLBaseStore):
                     for key, state_id in state_event_ids.items()
                 ],
             )
-            state_groups[event.event_id] = state_group
 
         self._simple_insert_many_txn(
             txn,
@@ -527,5 +531,5 @@ class StateStore(SQLBaseStore):
             "get_all_new_state_groups", get_all_new_state_groups_txn
         )
 
-    def get_state_stream_token(self):
-        return self._state_groups_id_gen.get_current_token()
+    def get_next_state_group(self):
+        return self._state_groups_id_gen.get_next()