summary refs log tree commit diff
path: root/synapse/storage/_base.py
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2015-05-14 16:54:35 +0100
committerErik Johnston <erik@matrix.org>2015-05-14 16:54:35 +0100
commit1d566edb81e1dffea026d4e603a12cee664a8eda (patch)
treecd5487b49933a00ea255dd8df15ca94e293922f1 /synapse/storage/_base.py
parentCall from right thread (diff)
downloadsynapse-1d566edb81e1dffea026d4e603a12cee664a8eda.tar.xz
Remove race condition
Diffstat (limited to 'synapse/storage/_base.py')
-rw-r--r--synapse/storage/_base.py168
1 files changed, 107 insertions, 61 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 97bf42469a..ceff99c16d 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -26,6 +26,8 @@ from util.id_generators import IdGenerator, StreamIdGenerator
 from twisted.internet import defer
 
 from collections import namedtuple, OrderedDict
+
+import contextlib
 import functools
 import sys
 import time
@@ -299,7 +301,7 @@ class SQLBaseStore(object):
         self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True,
                                       max_entries=hs.config.event_cache_size)
 
-        self._event_fetch_lock = threading.Lock()
+        self._event_fetch_lock = threading.Condition()
         self._event_fetch_list = []
         self._event_fetch_ongoing = 0
 
@@ -342,6 +344,84 @@ class SQLBaseStore(object):
 
         self._clock.looping_call(loop, 10000)
 
+    @contextlib.contextmanager
+    def _new_transaction(self, conn, desc, after_callbacks):
+        start = time.time() * 1000
+        txn_id = self._TXN_ID
+
+        # We don't really need these to be unique, so lets stop it from
+        # growing really large.
+        self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1)
+
+        name = "%s-%x" % (desc, txn_id, )
+
+        transaction_logger.debug("[TXN START] {%s}", name)
+
+        try:
+            i = 0
+            N = 5
+            while True:
+                try:
+                    txn = conn.cursor()
+                    txn = LoggingTransaction(
+                        txn, name, self.database_engine, after_callbacks
+                    )
+                except self.database_engine.module.OperationalError as e:
+                    # This can happen if the database disappears mid
+                    # transaction.
+                    logger.warn(
+                        "[TXN OPERROR] {%s} %s %d/%d",
+                        name, e, i, N
+                    )
+                    if i < N:
+                        i += 1
+                        try:
+                            conn.rollback()
+                        except self.database_engine.module.Error as e1:
+                            logger.warn(
+                                "[TXN EROLL] {%s} %s",
+                                name, e1,
+                            )
+                        continue
+                    raise
+                except self.database_engine.module.DatabaseError as e:
+                    if self.database_engine.is_deadlock(e):
+                        logger.warn("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
+                        if i < N:
+                            i += 1
+                            try:
+                                conn.rollback()
+                            except self.database_engine.module.Error as e1:
+                                logger.warn(
+                                    "[TXN EROLL] {%s} %s",
+                                    name, e1,
+                                )
+                            continue
+                    raise
+
+                try:
+                    yield txn
+                    conn.commit()
+                    return
+                except:
+                    try:
+                        conn.rollback()
+                    except:
+                        pass
+                    raise
+        except Exception as e:
+            logger.debug("[TXN FAIL] {%s} %s", name, e)
+            raise
+        finally:
+            end = time.time() * 1000
+            duration = end - start
+
+            transaction_logger.debug("[TXN END] {%s} %f", name, duration)
+
+            self._current_txn_total_time += duration
+            self._txn_perf_counters.update(desc, start, end)
+            sql_txn_timer.inc_by(duration, desc)
+
     @defer.inlineCallbacks
     def runInteraction(self, desc, func, *args, **kwargs):
         """Wraps the .runInteraction() method on the underlying db_pool."""
@@ -353,83 +433,49 @@ class SQLBaseStore(object):
 
         def inner_func(conn, *args, **kwargs):
             with LoggingContext("runInteraction") as context:
+                sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
+
                 if self.database_engine.is_connection_closed(conn):
                     logger.debug("Reconnecting closed database connection")
                     conn.reconnect()
 
                 current_context.copy_to(context)
-                start = time.time() * 1000
-                txn_id = self._TXN_ID
+                with self._new_transaction(conn, desc, after_callbacks) as txn:
+                    return func(txn, *args, **kwargs)
 
-                # We don't really need these to be unique, so lets stop it from
-                # growing really large.
-                self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1)
+        result = yield preserve_context_over_fn(
+            self._db_pool.runWithConnection,
+            inner_func, *args, **kwargs
+        )
 
-                name = "%s-%x" % (desc, txn_id, )
+        for after_callback, after_args in after_callbacks:
+            after_callback(*after_args)
+        defer.returnValue(result)
 
+    @defer.inlineCallbacks
+    def runWithConnection(self, func, *args, **kwargs):
+        """Wraps the .runInteraction() method on the underlying db_pool."""
+        current_context = LoggingContext.current_context()
+
+        start_time = time.time() * 1000
+
+        def inner_func(conn, *args, **kwargs):
+            with LoggingContext("runWithConnection") as context:
                 sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
-                transaction_logger.debug("[TXN START] {%s}", name)
-                try:
-                    i = 0
-                    N = 5
-                    while True:
-                        try:
-                            txn = conn.cursor()
-                            txn = LoggingTransaction(
-                                txn, name, self.database_engine, after_callbacks
-                            )
-                            return func(txn, *args, **kwargs)
-                        except self.database_engine.module.OperationalError as e:
-                            # This can happen if the database disappears mid
-                            # transaction.
-                            logger.warn(
-                                "[TXN OPERROR] {%s} %s %d/%d",
-                                name, e, i, N
-                            )
-                            if i < N:
-                                i += 1
-                                try:
-                                    conn.rollback()
-                                except self.database_engine.module.Error as e1:
-                                    logger.warn(
-                                        "[TXN EROLL] {%s} %s",
-                                        name, e1,
-                                    )
-                                continue
-                        except self.database_engine.module.DatabaseError as e:
-                            if self.database_engine.is_deadlock(e):
-                                logger.warn("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
-                                if i < N:
-                                    i += 1
-                                    try:
-                                        conn.rollback()
-                                    except self.database_engine.module.Error as e1:
-                                        logger.warn(
-                                            "[TXN EROLL] {%s} %s",
-                                            name, e1,
-                                        )
-                                    continue
-                            raise
-                except Exception as e:
-                    logger.debug("[TXN FAIL] {%s} %s", name, e)
-                    raise
-                finally:
-                    end = time.time() * 1000
-                    duration = end - start
 
-                    transaction_logger.debug("[TXN END] {%s} %f", name, duration)
+                if self.database_engine.is_connection_closed(conn):
+                    logger.debug("Reconnecting closed database connection")
+                    conn.reconnect()
+
+                current_context.copy_to(context)
 
-                    self._current_txn_total_time += duration
-                    self._txn_perf_counters.update(desc, start, end)
-                    sql_txn_timer.inc_by(duration, desc)
+                return func(conn, *args, **kwargs)
 
         result = yield preserve_context_over_fn(
             self._db_pool.runWithConnection,
             inner_func, *args, **kwargs
         )
 
-        for after_callback, after_args in after_callbacks:
-            after_callback(*after_args)
         defer.returnValue(result)
 
     def cursor_to_dict(self, cursor):