summary refs log tree commit diff
path: root/synapse/storage/_base.py
diff options
context:
space:
mode:
authorErik Johnston <erikj@jki.re>2017-06-07 17:51:25 +0100
committerGitHub <noreply@github.com>2017-06-07 17:51:25 +0100
commitc62c480dc6f166f65580fad377cdbd28849978a1 (patch)
tree7a24a41a68080274770c5903cf9a4df5d989e63e /synapse/storage/_base.py
parentMerge pull request #2258 from matrix-org/erikj/user_dir (diff)
parentFix bug where state_group tables got corrupted (diff)
downloadsynapse-c62c480dc6f166f65580fad377cdbd28849978a1.tar.xz
Merge pull request #2259 from matrix-org/erikj/fix_state_woes
Fix bug where state_group tables got corrupted
Diffstat (limited to 'synapse/storage/_base.py')
-rw-r--r--synapse/storage/_base.py29
1 files changed, 21 insertions, 8 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index db816346f5..51730a88bf 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -52,13 +52,17 @@ class LoggingTransaction(object):
     """An object that almost-transparently proxies for the 'txn' object
     passed to the constructor. Adds logging and metrics to the .execute()
     method."""
-    __slots__ = ["txn", "name", "database_engine", "after_callbacks"]
+    __slots__ = [
+        "txn", "name", "database_engine", "after_callbacks", "final_callbacks",
+    ]
 
-    def __init__(self, txn, name, database_engine, after_callbacks):
+    def __init__(self, txn, name, database_engine, after_callbacks,
+                 final_callbacks):
         object.__setattr__(self, "txn", txn)
         object.__setattr__(self, "name", name)
         object.__setattr__(self, "database_engine", database_engine)
         object.__setattr__(self, "after_callbacks", after_callbacks)
+        object.__setattr__(self, "final_callbacks", final_callbacks)
 
     def call_after(self, callback, *args, **kwargs):
         """Call the given callback on the main twisted thread after the
@@ -67,6 +71,9 @@ class LoggingTransaction(object):
         """
         self.after_callbacks.append((callback, args, kwargs))
 
+    def call_finally(self, callback, *args, **kwargs):
+        self.final_callbacks.append((callback, args, kwargs))
+
     def __getattr__(self, name):
         return getattr(self.txn, name)
 
@@ -217,8 +224,8 @@ class SQLBaseStore(object):
 
         self._clock.looping_call(loop, 10000)
 
-    def _new_transaction(self, conn, desc, after_callbacks, logging_context,
-                         func, *args, **kwargs):
+    def _new_transaction(self, conn, desc, after_callbacks, final_callbacks,
+                         logging_context, func, *args, **kwargs):
         start = time.time() * 1000
         txn_id = self._TXN_ID
 
@@ -237,7 +244,8 @@ class SQLBaseStore(object):
                 try:
                     txn = conn.cursor()
                     txn = LoggingTransaction(
-                        txn, name, self.database_engine, after_callbacks
+                        txn, name, self.database_engine, after_callbacks,
+                        final_callbacks,
                     )
                     r = func(txn, *args, **kwargs)
                     conn.commit()
@@ -298,6 +306,7 @@ class SQLBaseStore(object):
         start_time = time.time() * 1000
 
         after_callbacks = []
+        final_callbacks = []
 
         def inner_func(conn, *args, **kwargs):
             with LoggingContext("runInteraction") as context:
@@ -309,7 +318,7 @@ class SQLBaseStore(object):
 
                 current_context.copy_to(context)
                 return self._new_transaction(
-                    conn, desc, after_callbacks, current_context,
+                    conn, desc, after_callbacks, final_callbacks, current_context,
                     func, *args, **kwargs
                 )
 
@@ -318,9 +327,13 @@ class SQLBaseStore(object):
                 result = yield self._db_pool.runWithConnection(
                     inner_func, *args, **kwargs
                 )
-        finally:
+
             for after_callback, after_args, after_kwargs in after_callbacks:
                 after_callback(*after_args, **after_kwargs)
+        finally:
+            for after_callback, after_args, after_kwargs in final_callbacks:
+                after_callback(*after_args, **after_kwargs)
+
         defer.returnValue(result)
 
     @defer.inlineCallbacks
@@ -941,7 +954,7 @@ class SQLBaseStore(object):
             # __exit__ called after the transaction finishes.
             ctx = self._cache_id_gen.get_next()
             stream_id = ctx.__enter__()
-            txn.call_after(ctx.__exit__, None, None, None)
+            txn.call_finally(ctx.__exit__, None, None, None)
             txn.call_after(self.hs.get_notifier().on_new_replication_data)
 
             self._simple_insert_txn(