diff options
author | Erik Johnston <erik@matrix.org> | 2015-05-14 16:54:35 +0100 |
---|---|---|
committer | Erik Johnston <erik@matrix.org> | 2015-05-14 16:54:35 +0100 |
commit | 1d566edb81e1dffea026d4e603a12cee664a8eda (patch) | |
tree | cd5487b49933a00ea255dd8df15ca94e293922f1 | |
parent | Call from right thread (diff) | |
download | synapse-1d566edb81e1dffea026d4e603a12cee664a8eda.tar.xz |
Remove race condition
Diffstat (limited to '')
-rw-r--r-- | synapse/storage/_base.py | 168 | ||||
-rw-r--r-- | synapse/storage/engines/postgres.py | 2 | ||||
-rw-r--r-- | synapse/storage/engines/sqlite3.py | 2 | ||||
-rw-r--r-- | synapse/storage/events.py | 81 |
4 files changed, 157 insertions, 96 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 97bf42469a..ceff99c16d 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -26,6 +26,8 @@ from util.id_generators import IdGenerator, StreamIdGenerator from twisted.internet import defer from collections import namedtuple, OrderedDict + +import contextlib import functools import sys import time @@ -299,7 +301,7 @@ class SQLBaseStore(object): self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True, max_entries=hs.config.event_cache_size) - self._event_fetch_lock = threading.Lock() + self._event_fetch_lock = threading.Condition() self._event_fetch_list = [] self._event_fetch_ongoing = 0 @@ -342,6 +344,84 @@ class SQLBaseStore(object): self._clock.looping_call(loop, 10000) + @contextlib.contextmanager + def _new_transaction(self, conn, desc, after_callbacks): + start = time.time() * 1000 + txn_id = self._TXN_ID + + # We don't really need these to be unique, so lets stop it from + # growing really large. + self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1) + + name = "%s-%x" % (desc, txn_id, ) + + transaction_logger.debug("[TXN START] {%s}", name) + + try: + i = 0 + N = 5 + while True: + try: + txn = conn.cursor() + txn = LoggingTransaction( + txn, name, self.database_engine, after_callbacks + ) + except self.database_engine.module.OperationalError as e: + # This can happen if the database disappears mid + # transaction. + logger.warn( + "[TXN OPERROR] {%s} %s %d/%d", + name, e, i, N + ) + if i < N: + i += 1 + try: + conn.rollback() + except self.database_engine.module.Error as e1: + logger.warn( + "[TXN EROLL] {%s} %s", + name, e1, + ) + continue + raise + except self.database_engine.module.DatabaseError as e: + if self.database_engine.is_deadlock(e): + logger.warn("[TXN DEADLOCK] {%s} %d/%d", name, i, N) + if i < N: + i += 1 + try: + conn.rollback() + except self.database_engine.module.Error as e1: + logger.warn( + "[TXN EROLL] {%s} %s", + name, e1, + ) + continue + raise + + try: + yield txn + conn.commit() + return + except: + try: + conn.rollback() + except: + pass + raise + except Exception as e: + logger.debug("[TXN FAIL] {%s} %s", name, e) + raise + finally: + end = time.time() * 1000 + duration = end - start + + transaction_logger.debug("[TXN END] {%s} %f", name, duration) + + self._current_txn_total_time += duration + self._txn_perf_counters.update(desc, start, end) + sql_txn_timer.inc_by(duration, desc) + @defer.inlineCallbacks def runInteraction(self, desc, func, *args, **kwargs): """Wraps the .runInteraction() method on the underlying db_pool.""" @@ -353,83 +433,49 @@ class SQLBaseStore(object): def inner_func(conn, *args, **kwargs): with LoggingContext("runInteraction") as context: + sql_scheduling_timer.inc_by(time.time() * 1000 - start_time) + if self.database_engine.is_connection_closed(conn): logger.debug("Reconnecting closed database connection") conn.reconnect() current_context.copy_to(context) - start = time.time() * 1000 - txn_id = self._TXN_ID + with self._new_transaction(conn, desc, after_callbacks) as txn: + return func(txn, *args, **kwargs) - # We don't really need these to be unique, so lets stop it from - # growing really large. - self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1) + result = yield preserve_context_over_fn( + self._db_pool.runWithConnection, + inner_func, *args, **kwargs + ) - name = "%s-%x" % (desc, txn_id, ) + for after_callback, after_args in after_callbacks: + after_callback(*after_args) + defer.returnValue(result) + @defer.inlineCallbacks + def runWithConnection(self, func, *args, **kwargs): + """Wraps the .runInteraction() method on the underlying db_pool.""" + current_context = LoggingContext.current_context() + + start_time = time.time() * 1000 + + def inner_func(conn, *args, **kwargs): + with LoggingContext("runWithConnection") as context: sql_scheduling_timer.inc_by(time.time() * 1000 - start_time) - transaction_logger.debug("[TXN START] {%s}", name) - try: - i = 0 - N = 5 - while True: - try: - txn = conn.cursor() - txn = LoggingTransaction( - txn, name, self.database_engine, after_callbacks - ) - return func(txn, *args, **kwargs) - except self.database_engine.module.OperationalError as e: - # This can happen if the database disappears mid - # transaction. - logger.warn( - "[TXN OPERROR] {%s} %s %d/%d", - name, e, i, N - ) - if i < N: - i += 1 - try: - conn.rollback() - except self.database_engine.module.Error as e1: - logger.warn( - "[TXN EROLL] {%s} %s", - name, e1, - ) - continue - except self.database_engine.module.DatabaseError as e: - if self.database_engine.is_deadlock(e): - logger.warn("[TXN DEADLOCK] {%s} %d/%d", name, i, N) - if i < N: - i += 1 - try: - conn.rollback() - except self.database_engine.module.Error as e1: - logger.warn( - "[TXN EROLL] {%s} %s", - name, e1, - ) - continue - raise - except Exception as e: - logger.debug("[TXN FAIL] {%s} %s", name, e) - raise - finally: - end = time.time() * 1000 - duration = end - start - transaction_logger.debug("[TXN END] {%s} %f", name, duration) + if self.database_engine.is_connection_closed(conn): + logger.debug("Reconnecting closed database connection") + conn.reconnect() + + current_context.copy_to(context) - self._current_txn_total_time += duration - self._txn_perf_counters.update(desc, start, end) - sql_txn_timer.inc_by(duration, desc) + return func(conn, *args, **kwargs) result = yield preserve_context_over_fn( self._db_pool.runWithConnection, inner_func, *args, **kwargs ) - for after_callback, after_args in after_callbacks: - after_callback(*after_args) defer.returnValue(result) def cursor_to_dict(self, cursor): diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index a323028546..4a855ffd56 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -19,6 +19,8 @@ from ._base import IncorrectDatabaseSetup class PostgresEngine(object): + single_threaded = False + def __init__(self, database_module): self.module = database_module self.module.extensions.register_type(self.module.extensions.UNICODE) diff --git a/synapse/storage/engines/sqlite3.py b/synapse/storage/engines/sqlite3.py index ff13d8006a..d18e2808d1 100644 --- a/synapse/storage/engines/sqlite3.py +++ b/synapse/storage/engines/sqlite3.py @@ -17,6 +17,8 @@ from synapse.storage import prepare_database, prepare_sqlite3_database class Sqlite3Engine(object): + single_threaded = True + def __init__(self, database_module): self.module = database_module diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 59af21a2ca..b4abd83260 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -504,23 +504,26 @@ class EventsStore(SQLBaseStore): if not events: defer.returnValue({}) - def do_fetch(txn): + def do_fetch(conn): event_list = [] while True: try: with self._event_fetch_lock: - event_list = self._event_fetch_list - self._event_fetch_list = [] - - if not event_list: + i = 0 + while not self._event_fetch_list: self._event_fetch_ongoing -= 1 return + event_list = self._event_fetch_list + self._event_fetch_list = [] + event_id_lists = zip(*event_list)[0] event_ids = [ item for sublist in event_id_lists for item in sublist ] - rows = self._fetch_event_rows(txn, event_ids) + + with self._new_transaction(conn, "do_fetch", []) as txn: + rows = self._fetch_event_rows(txn, event_ids) row_dict = { r["event_id"]: r @@ -528,22 +531,44 @@ class EventsStore(SQLBaseStore): } for ids, d in event_list: - reactor.callFromThread( - d.callback, - [ - row_dict[i] for i in ids - if i in row_dict - ] - ) + def fire(): + if not d.called: + d.callback( + [ + row_dict[i] + for i in ids + if i in row_dict + ] + ) + reactor.callFromThread(fire) except Exception as e: + logger.exception("do_fetch") for _, d in event_list: - try: + if not d.called: reactor.callFromThread(d.errback, e) - except: - pass - def cb(rows): - return defer.gatherResults([ + with self._event_fetch_lock: + self._event_fetch_ongoing -= 1 + return + + events_d = defer.Deferred() + with self._event_fetch_lock: + self._event_fetch_list.append( + (events, events_d) + ) + + self._event_fetch_lock.notify_all() + + # if self._event_fetch_ongoing < 5: + self._event_fetch_ongoing += 1 + self.runWithConnection( + do_fetch + ) + + rows = yield events_d + + res = yield defer.gatherResults( + [ self._get_event_from_row( None, row["internal_metadata"], row["json"], row["redacts"], @@ -552,23 +577,9 @@ class EventsStore(SQLBaseStore): rejected_reason=row["rejects"], ) for row in rows - ]) - - d = defer.Deferred() - d.addCallback(cb) - with self._event_fetch_lock: - self._event_fetch_list.append( - (events, d) - ) - - if self._event_fetch_ongoing < 3: - self._event_fetch_ongoing += 1 - self.runInteraction( - "do_fetch", - do_fetch - ) - - res = yield d + ], + consumeErrors=True + ) defer.returnValue({ e.event_id: e |