summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/federation.py4
-rw-r--r--synapse/state.py18
-rw-r--r--synapse/storage/state.py28
3 files changed, 30 insertions, 20 deletions
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 8a1038c44a..f7cb3c1bb2 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1585,10 +1585,12 @@ class FederationHandler(BaseHandler):
                 current_state = set(e.event_id for e in auth_events.values())
                 different_auth = event_auth_events - current_state
 
+                context.current_state_ids = dict(context.current_state_ids)
                 context.current_state_ids.update({
                     k: a.event_id for k, a in auth_events.items()
                     if k != event_key
                 })
+                context.prev_state_ids = dict(context.prev_state_ids)
                 context.prev_state_ids.update({
                     k: a.event_id for k, a in auth_events.items()
                 })
@@ -1670,10 +1672,12 @@ class FederationHandler(BaseHandler):
                 # 4. Look at rejects and their proofs.
                 # TODO.
 
+                context.current_state_ids = dict(context.current_state_ids)
                 context.current_state_ids.update({
                     k: a.event_id for k, a in auth_events.items()
                     if k != event_key
                 })
+                context.prev_state_ids = dict(context.prev_state_ids)
                 context.prev_state_ids.update({
                     k: a.event_id for k, a in auth_events.items()
                 })
diff --git a/synapse/state.py b/synapse/state.py
index 5ce23add5d..b4eca0e5d5 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -26,6 +26,7 @@ from synapse.events.snapshot import EventContext
 from synapse.util.async import Linearizer
 
 from collections import namedtuple
+from frozendict import frozendict
 
 import logging
 import hashlib
@@ -58,11 +59,11 @@ class _StateCacheEntry(object):
     __slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"]
 
     def __init__(self, state, state_group, prev_group=None, delta_ids=None):
-        self.state = state
+        self.state = frozendict(state)
         self.state_group = state_group
 
         self.prev_group = prev_group
-        self.delta_ids = delta_ids
+        self.delta_ids = frozendict(delta_ids) if delta_ids is not None else None
 
         # The `state_id` is a unique ID we generate that can be used as ID for
         # this collection of state. Usually this would be the same as the
@@ -238,13 +239,7 @@ class StateHandler(object):
         context.prev_state_ids = curr_state
         if event.is_state():
             context.state_group = self.store.get_next_state_group()
-        else:
-            if entry.state_group is None:
-                entry.state_group = self.store.get_next_state_group()
-                entry.state_id = entry.state_group
-            context.state_group = entry.state_group
 
-        if event.is_state():
             key = (event.type, event.state_key)
             if key in context.prev_state_ids:
                 replaces = context.prev_state_ids[key]
@@ -256,10 +251,15 @@ class StateHandler(object):
             context.prev_group = entry.prev_group
             context.delta_ids = entry.delta_ids
             if context.delta_ids is not None:
+                context.delta_ids = dict(context.delta_ids)
                 context.delta_ids[key] = event.event_id
         else:
-            context.current_state_ids = context.prev_state_ids
+            if entry.state_group is None:
+                entry.state_group = self.store.get_next_state_group()
+                entry.state_id = entry.state_group
 
+            context.state_group = entry.state_group
+            context.current_state_ids = context.prev_state_ids
             context.prev_group = entry.prev_group
             context.delta_ids = entry.delta_ids
 
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index fdbdade536..7eb342674c 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -817,16 +817,24 @@ class StateStore(SQLBaseStore):
 
     @defer.inlineCallbacks
     def _background_index_state(self, progress, batch_size):
-        def reindex_txn(txn):
+        def reindex_txn(conn):
+            conn.rollback()
             if isinstance(self.database_engine, PostgresEngine):
-                txn.execute(
-                    "CREATE INDEX CONCURRENTLY state_groups_state_type_idx"
-                    " ON state_groups_state(state_group, type, state_key)"
-                )
-                txn.execute(
-                    "DROP INDEX IF EXISTS state_groups_state_id"
-                )
+                # postgres insists on autocommit for the index
+                conn.set_session(autocommit=True)
+                try:
+                    txn = conn.cursor()
+                    txn.execute(
+                        "CREATE INDEX CONCURRENTLY state_groups_state_type_idx"
+                        " ON state_groups_state(state_group, type, state_key)"
+                    )
+                    txn.execute(
+                        "DROP INDEX IF EXISTS state_groups_state_id"
+                    )
+                finally:
+                    conn.set_session(autocommit=False)
             else:
+                txn = conn.cursor()
                 txn.execute(
                     "CREATE INDEX state_groups_state_type_idx"
                     " ON state_groups_state(state_group, type, state_key)"
@@ -835,9 +843,7 @@ class StateStore(SQLBaseStore):
                     "DROP INDEX IF EXISTS state_groups_state_id"
                 )
 
-        yield self.runInteraction(
-            self.STATE_GROUP_INDEX_UPDATE_NAME, reindex_txn
-        )
+        yield self.runWithConnection(reindex_txn)
 
         yield self._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME)