summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2015-05-14 16:54:35 +0100
committerErik Johnston <erik@matrix.org>2015-05-14 16:54:35 +0100
commit1d566edb81e1dffea026d4e603a12cee664a8eda (patch)
treecd5487b49933a00ea255dd8df15ca94e293922f1 /synapse
parentCall from right thread (diff)
downloadsynapse-1d566edb81e1dffea026d4e603a12cee664a8eda.tar.xz
Remove race condition
Diffstat (limited to 'synapse')
-rw-r--r--synapse/storage/_base.py168
-rw-r--r--synapse/storage/engines/postgres.py2
-rw-r--r--synapse/storage/engines/sqlite3.py2
-rw-r--r--synapse/storage/events.py81
4 files changed, 157 insertions, 96 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 97bf42469a..ceff99c16d 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -26,6 +26,8 @@ from util.id_generators import IdGenerator, StreamIdGenerator
 from twisted.internet import defer
 
 from collections import namedtuple, OrderedDict
+
+import contextlib
 import functools
 import sys
 import time
@@ -299,7 +301,7 @@ class SQLBaseStore(object):
         self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True,
                                       max_entries=hs.config.event_cache_size)
 
-        self._event_fetch_lock = threading.Lock()
+        self._event_fetch_lock = threading.Condition()
         self._event_fetch_list = []
         self._event_fetch_ongoing = 0
 
@@ -342,6 +344,84 @@ class SQLBaseStore(object):
 
         self._clock.looping_call(loop, 10000)
 
+    @contextlib.contextmanager
+    def _new_transaction(self, conn, desc, after_callbacks):
+        start = time.time() * 1000
+        txn_id = self._TXN_ID
+
+        # We don't really need these to be unique, so lets stop it from
+        # growing really large.
+        self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1)
+
+        name = "%s-%x" % (desc, txn_id, )
+
+        transaction_logger.debug("[TXN START] {%s}", name)
+
+        try:
+            i = 0
+            N = 5
+            while True:
+                try:
+                    txn = conn.cursor()
+                    txn = LoggingTransaction(
+                        txn, name, self.database_engine, after_callbacks
+                    )
+                except self.database_engine.module.OperationalError as e:
+                    # This can happen if the database disappears mid
+                    # transaction.
+                    logger.warn(
+                        "[TXN OPERROR] {%s} %s %d/%d",
+                        name, e, i, N
+                    )
+                    if i < N:
+                        i += 1
+                        try:
+                            conn.rollback()
+                        except self.database_engine.module.Error as e1:
+                            logger.warn(
+                                "[TXN EROLL] {%s} %s",
+                                name, e1,
+                            )
+                        continue
+                    raise
+                except self.database_engine.module.DatabaseError as e:
+                    if self.database_engine.is_deadlock(e):
+                        logger.warn("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
+                        if i < N:
+                            i += 1
+                            try:
+                                conn.rollback()
+                            except self.database_engine.module.Error as e1:
+                                logger.warn(
+                                    "[TXN EROLL] {%s} %s",
+                                    name, e1,
+                                )
+                            continue
+                    raise
+
+                try:
+                    yield txn
+                    conn.commit()
+                    return
+                except:
+                    try:
+                        conn.rollback()
+                    except:
+                        pass
+                    raise
+        except Exception as e:
+            logger.debug("[TXN FAIL] {%s} %s", name, e)
+            raise
+        finally:
+            end = time.time() * 1000
+            duration = end - start
+
+            transaction_logger.debug("[TXN END] {%s} %f", name, duration)
+
+            self._current_txn_total_time += duration
+            self._txn_perf_counters.update(desc, start, end)
+            sql_txn_timer.inc_by(duration, desc)
+
     @defer.inlineCallbacks
     def runInteraction(self, desc, func, *args, **kwargs):
         """Wraps the .runInteraction() method on the underlying db_pool."""
@@ -353,83 +433,49 @@ class SQLBaseStore(object):
 
         def inner_func(conn, *args, **kwargs):
             with LoggingContext("runInteraction") as context:
+                sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
+
                 if self.database_engine.is_connection_closed(conn):
                     logger.debug("Reconnecting closed database connection")
                     conn.reconnect()
 
                 current_context.copy_to(context)
-                start = time.time() * 1000
-                txn_id = self._TXN_ID
+                with self._new_transaction(conn, desc, after_callbacks) as txn:
+                    return func(txn, *args, **kwargs)
 
-                # We don't really need these to be unique, so lets stop it from
-                # growing really large.
-                self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1)
+        result = yield preserve_context_over_fn(
+            self._db_pool.runWithConnection,
+            inner_func, *args, **kwargs
+        )
 
-                name = "%s-%x" % (desc, txn_id, )
+        for after_callback, after_args in after_callbacks:
+            after_callback(*after_args)
+        defer.returnValue(result)
 
+    @defer.inlineCallbacks
+    def runWithConnection(self, func, *args, **kwargs):
+        """Wraps the .runInteraction() method on the underlying db_pool."""
+        current_context = LoggingContext.current_context()
+
+        start_time = time.time() * 1000
+
+        def inner_func(conn, *args, **kwargs):
+            with LoggingContext("runWithConnection") as context:
                 sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
-                transaction_logger.debug("[TXN START] {%s}", name)
-                try:
-                    i = 0
-                    N = 5
-                    while True:
-                        try:
-                            txn = conn.cursor()
-                            txn = LoggingTransaction(
-                                txn, name, self.database_engine, after_callbacks
-                            )
-                            return func(txn, *args, **kwargs)
-                        except self.database_engine.module.OperationalError as e:
-                            # This can happen if the database disappears mid
-                            # transaction.
-                            logger.warn(
-                                "[TXN OPERROR] {%s} %s %d/%d",
-                                name, e, i, N
-                            )
-                            if i < N:
-                                i += 1
-                                try:
-                                    conn.rollback()
-                                except self.database_engine.module.Error as e1:
-                                    logger.warn(
-                                        "[TXN EROLL] {%s} %s",
-                                        name, e1,
-                                    )
-                                continue
-                        except self.database_engine.module.DatabaseError as e:
-                            if self.database_engine.is_deadlock(e):
-                                logger.warn("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
-                                if i < N:
-                                    i += 1
-                                    try:
-                                        conn.rollback()
-                                    except self.database_engine.module.Error as e1:
-                                        logger.warn(
-                                            "[TXN EROLL] {%s} %s",
-                                            name, e1,
-                                        )
-                                    continue
-                            raise
-                except Exception as e:
-                    logger.debug("[TXN FAIL] {%s} %s", name, e)
-                    raise
-                finally:
-                    end = time.time() * 1000
-                    duration = end - start
 
-                    transaction_logger.debug("[TXN END] {%s} %f", name, duration)
+                if self.database_engine.is_connection_closed(conn):
+                    logger.debug("Reconnecting closed database connection")
+                    conn.reconnect()
+
+                current_context.copy_to(context)
 
-                    self._current_txn_total_time += duration
-                    self._txn_perf_counters.update(desc, start, end)
-                    sql_txn_timer.inc_by(duration, desc)
+                return func(conn, *args, **kwargs)
 
         result = yield preserve_context_over_fn(
             self._db_pool.runWithConnection,
             inner_func, *args, **kwargs
         )
 
-        for after_callback, after_args in after_callbacks:
-            after_callback(*after_args)
         defer.returnValue(result)
 
     def cursor_to_dict(self, cursor):
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index a323028546..4a855ffd56 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -19,6 +19,8 @@ from ._base import IncorrectDatabaseSetup
 
 
 class PostgresEngine(object):
+    single_threaded = False
+
     def __init__(self, database_module):
         self.module = database_module
         self.module.extensions.register_type(self.module.extensions.UNICODE)
diff --git a/synapse/storage/engines/sqlite3.py b/synapse/storage/engines/sqlite3.py
index ff13d8006a..d18e2808d1 100644
--- a/synapse/storage/engines/sqlite3.py
+++ b/synapse/storage/engines/sqlite3.py
@@ -17,6 +17,8 @@ from synapse.storage import prepare_database, prepare_sqlite3_database
 
 
 class Sqlite3Engine(object):
+    single_threaded = True
+
     def __init__(self, database_module):
         self.module = database_module
 
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 59af21a2ca..b4abd83260 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -504,23 +504,26 @@ class EventsStore(SQLBaseStore):
         if not events:
             defer.returnValue({})
 
-        def do_fetch(txn):
+        def do_fetch(conn):
             event_list = []
             while True:
                 try:
                     with self._event_fetch_lock:
-                        event_list = self._event_fetch_list
-                        self._event_fetch_list = []
-
-                        if not event_list:
+                        i = 0
+                        while not self._event_fetch_list:
                             self._event_fetch_ongoing -= 1
                             return
 
+                        event_list = self._event_fetch_list
+                        self._event_fetch_list = []
+
                     event_id_lists = zip(*event_list)[0]
                     event_ids = [
                         item for sublist in event_id_lists for item in sublist
                     ]
-                    rows = self._fetch_event_rows(txn, event_ids)
+
+                    with self._new_transaction(conn, "do_fetch", []) as txn:
+                        rows = self._fetch_event_rows(txn, event_ids)
 
                     row_dict = {
                         r["event_id"]: r
@@ -528,22 +531,44 @@ class EventsStore(SQLBaseStore):
                     }
 
                     for ids, d in event_list:
-                        reactor.callFromThread(
-                            d.callback,
-                            [
-                                row_dict[i] for i in ids
-                                if i in row_dict
-                            ]
-                        )
+                        def fire():
+                            if not d.called:
+                                d.callback(
+                                    [
+                                        row_dict[i]
+                                        for i in ids
+                                        if i in row_dict
+                                    ]
+                                )
+                        reactor.callFromThread(fire)
                 except Exception as e:
+                    logger.exception("do_fetch")
                     for _, d in event_list:
-                        try:
+                        if not d.called:
                             reactor.callFromThread(d.errback, e)
-                        except:
-                            pass
 
-        def cb(rows):
-            return defer.gatherResults([
+                    with self._event_fetch_lock:
+                        self._event_fetch_ongoing -= 1
+                        return
+
+        events_d = defer.Deferred()
+        with self._event_fetch_lock:
+            self._event_fetch_list.append(
+                (events, events_d)
+            )
+
+            self._event_fetch_lock.notify_all()
+
+            # if self._event_fetch_ongoing < 5:
+            self._event_fetch_ongoing += 1
+            self.runWithConnection(
+                do_fetch
+            )
+
+        rows = yield events_d
+
+        res = yield defer.gatherResults(
+            [
                 self._get_event_from_row(
                     None,
                     row["internal_metadata"], row["json"], row["redacts"],
@@ -552,23 +577,9 @@ class EventsStore(SQLBaseStore):
                     rejected_reason=row["rejects"],
                 )
                 for row in rows
-            ])
-
-        d = defer.Deferred()
-        d.addCallback(cb)
-        with self._event_fetch_lock:
-            self._event_fetch_list.append(
-                (events, d)
-            )
-
-            if self._event_fetch_ongoing < 3:
-                self._event_fetch_ongoing += 1
-                self.runInteraction(
-                    "do_fetch",
-                    do_fetch
-                )
-
-        res = yield d
+            ],
+            consumeErrors=True
+        )
 
         defer.returnValue({
             e.event_id: e