summary refs log tree commit diff
path: root/synapse/state.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/state.py')
-rw-r--r--synapse/state.py41
1 files changed, 33 insertions, 8 deletions
diff --git a/synapse/state.py b/synapse/state.py
index 9e624b4937..668bbe1f16 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -208,7 +208,12 @@ class StateHandler(object):
                 context.current_state_ids = {}
                 context.prev_state_ids = {}
             context.prev_state_events = []
-            context.state_group = self.store.get_next_state_group()
+            context.state_group = yield self.store.store_state_group(
+                event.event_id, event.room_id,
+                context.prev_group,
+                context.delta_ids,
+                context.current_state_ids,
+            )
             defer.returnValue(context)
 
         if old_state:
@@ -216,7 +221,6 @@ class StateHandler(object):
             context.prev_state_ids = {
                 (s.type, s.state_key): s.event_id for s in old_state
             }
-            context.state_group = self.store.get_next_state_group()
 
             if event.is_state():
                 key = (event.type, event.state_key)
@@ -230,6 +234,14 @@ class StateHandler(object):
                 context.current_state_ids = context.prev_state_ids
 
             context.prev_state_events = []
+
+            context.state_group = yield self.store.store_state_group(
+                event.event_id, event.room_id,
+                context.prev_group,
+                context.delta_ids,
+                context.current_state_ids,
+            )
+
             defer.returnValue(context)
 
         logger.debug("calling resolve_state_groups from compute_event_context")
@@ -242,8 +254,6 @@ class StateHandler(object):
         context = EventContext()
         context.prev_state_ids = curr_state
         if event.is_state():
-            context.state_group = self.store.get_next_state_group()
-
             key = (event.type, event.state_key)
             if key in context.prev_state_ids:
                 replaces = context.prev_state_ids[key]
@@ -261,16 +271,31 @@ class StateHandler(object):
                 context.prev_group = entry.prev_group
                 context.delta_ids = dict(entry.delta_ids)
                 context.delta_ids[key] = event.event_id
-        else:
-            if entry.state_group is None:
-                entry.state_group = self.store.get_next_state_group()
-                entry.state_id = entry.state_group
 
+            context.state_group = yield self.store.store_state_group(
+                event.event_id, event.room_id,
+                context.prev_group,
+                context.delta_ids,
+                context.current_state_ids,
+            )
+        else:
             context.state_group = entry.state_group
             context.current_state_ids = context.prev_state_ids
             context.prev_group = entry.prev_group
             context.delta_ids = entry.delta_ids
 
+            if entry.state_group is None:
+                entry.state_group = yield self.store.store_state_group(
+                    event.event_id, event.room_id,
+                    context.prev_group,
+                    context.delta_ids,
+                    context.current_state_ids,
+                )
+
+                entry.state_id = entry.state_group
+
+            context.state_group = entry.state_group
+
         context.prev_state_events = []
         defer.returnValue(context)