diff options
Diffstat (limited to 'synapse/storage/_base.py')
-rw-r--r-- | synapse/storage/_base.py | 135 |
1 files changed, 91 insertions, 44 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index e3e67d8e0d..f5952d1fc0 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -22,6 +22,8 @@ from synapse.util.logcontext import PreserveLoggingContext, LoggingContext from synapse.util.lrucache import LruCache import synapse.metrics +from util.id_generators import IdGenerator, StreamIdGenerator + from twisted.internet import defer from collections import namedtuple, OrderedDict @@ -145,11 +147,12 @@ class LoggingTransaction(object): """An object that almost-transparently proxies for the 'txn' object passed to the constructor. Adds logging and metrics to the .execute() method.""" - __slots__ = ["txn", "name"] + __slots__ = ["txn", "name", "database_engine"] - def __init__(self, txn, name): + def __init__(self, txn, name, database_engine): object.__setattr__(self, "txn", txn) object.__setattr__(self, "name", name) + object.__setattr__(self, "database_engine", database_engine) def __getattr__(self, name): return getattr(self.txn, name) @@ -161,26 +164,32 @@ class LoggingTransaction(object): # TODO(paul): Maybe use 'info' and 'debug' for values? sql_logger.debug("[SQL] {%s} %s", self.name, sql) - try: - if args and args[0]: - values = args[0] + sql = self.database_engine.convert_param_style(sql) + + if args and args[0]: + args = list(args) + args[0] = [ + self.database_engine.encode_parameter(a) for a in args[0] + ] + try: sql_logger.debug( - "[SQL values] {%s} " + ", ".join(("<%r>",) * len(values)), + "[SQL values] {%s} " + ", ".join(("<%r>",) * len(args[0])), self.name, - *values + *args[0] ) - except: - # Don't let logging failures stop SQL from working - pass + except: + # Don't let logging failures stop SQL from working + pass start = time.time() * 1000 + try: return self.txn.execute( sql, *args, **kwargs ) - except: - logger.exception("[SQL FAIL] {%s}", self.name) - raise + except Exception as e: + logger.debug("[SQL FAIL] {%s} %s", self.name, e) + raise finally: msecs = (time.time() * 1000) - start sql_logger.debug("[SQL time] {%s} %f", self.name, msecs) @@ -245,6 +254,14 @@ class SQLBaseStore(object): self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True, max_entries=hs.config.event_cache_size) + self.database_engine = hs.database_engine + + self._stream_id_gen = StreamIdGenerator() + self._transaction_id_gen = IdGenerator("sent_transactions", "id", self) + self._state_groups_id_gen = IdGenerator("state_groups", "id", self) + self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self) + self._pushers_id_gen = IdGenerator("pushers", "id", self) + def start_profiling(self): self._previous_loop_ts = self._clock.time_msec() @@ -281,7 +298,7 @@ class SQLBaseStore(object): start_time = time.time() * 1000 - def inner_func(txn, *args, **kwargs): + def inner_func(conn, *args, **kwargs): with LoggingContext("runInteraction") as context: current_context.copy_to(context) start = time.time() * 1000 @@ -296,9 +313,25 @@ class SQLBaseStore(object): sql_scheduling_timer.inc_by(time.time() * 1000 - start_time) transaction_logger.debug("[TXN START] {%s}", name) try: - return func(LoggingTransaction(txn, name), *args, **kwargs) - except: - logger.exception("[TXN FAIL] {%s}", name) + i = 0 + N = 5 + while True: + try: + txn = conn.cursor() + return func( + LoggingTransaction(txn, name, self.database_engine), + *args, **kwargs + ) + except self.database_engine.module.DatabaseError as e: + if self.database_engine.is_deadlock(e): + logger.warn("[TXN DEADLOCK] {%s} %d/%d", name, i, N) + if i < N: + i += 1 + conn.rollback() + continue + raise + except Exception as e: + logger.debug("[TXN FAIL] {%s} %s", name, e) raise finally: end = time.time() * 1000 @@ -311,7 +344,7 @@ class SQLBaseStore(object): sql_txn_timer.inc_by(duration, desc) with PreserveLoggingContext(): - result = yield self._db_pool.runInteraction( + result = yield self._db_pool.runWithConnection( inner_func, *args, **kwargs ) defer.returnValue(result) @@ -342,11 +375,11 @@ class SQLBaseStore(object): The result of decoder(results) """ def interaction(txn): - cursor = txn.execute(query, args) + txn.execute(query, args) if decoder: - return decoder(cursor) + return decoder(txn) else: - return cursor.fetchall() + return txn.fetchall() return self.runInteraction(desc, interaction) @@ -356,27 +389,29 @@ class SQLBaseStore(object): # "Simple" SQL API methods that operate on a single table with no JOINs, # no complex WHERE clauses, just a dict of values for columns. - def _simple_insert(self, table, values, or_replace=False, or_ignore=False, + @defer.inlineCallbacks + def _simple_insert(self, table, values, or_ignore=False, desc="_simple_insert"): """Executes an INSERT query on the named table. Args: table : string giving the table name values : dict of new column names and values for them - or_replace : bool; if True performs an INSERT OR REPLACE """ - return self.runInteraction( - desc, - self._simple_insert_txn, table, values, or_replace=or_replace, - or_ignore=or_ignore, - ) + try: + yield self.runInteraction( + desc, + self._simple_insert_txn, table, values, + ) + except self.database_engine.module.IntegrityError: + # We have to do or_ignore flag at this layer, since we can't reuse + # a cursor after we receive an error from the db. + if not or_ignore: + raise @log_function - def _simple_insert_txn(self, txn, table, values, or_replace=False, - or_ignore=False): - sql = "%s INTO %s (%s) VALUES(%s)" % ( - ("INSERT OR REPLACE" if or_replace else - "INSERT OR IGNORE" if or_ignore else "INSERT"), + def _simple_insert_txn(self, txn, table, values): + sql = "INSERT INTO %s (%s) VALUES(%s)" % ( table, ", ".join(k for k in values), ", ".join("?" for k in values) @@ -388,22 +423,23 @@ class SQLBaseStore(object): ) txn.execute(sql, values.values()) - return txn.lastrowid - def _simple_upsert(self, table, keyvalues, values, desc="_simple_upsert"): + def _simple_upsert(self, table, keyvalues, values, + insertion_values={}, desc="_simple_upsert"): """ Args: table (str): The table to upsert into keyvalues (dict): The unique key tables and their new values values (dict): The nonunique columns and their new values + insertion_values (dict): key/values to use when inserting Returns: A deferred """ return self.runInteraction( desc, - self._simple_upsert_txn, table, keyvalues, values + self._simple_upsert_txn, table, keyvalues, values, insertion_values, ) - def _simple_upsert_txn(self, txn, table, keyvalues, values): + def _simple_upsert_txn(self, txn, table, keyvalues, values, insertion_values={}): # Try to update sql = "UPDATE %s SET %s WHERE %s" % ( table, @@ -422,6 +458,7 @@ class SQLBaseStore(object): allvalues = {} allvalues.update(keyvalues) allvalues.update(values) + allvalues.update(insertion_values) sql = "INSERT INTO %s (%s) VALUES (%s)" % ( table, @@ -489,8 +526,7 @@ class SQLBaseStore(object): def _simple_select_onecol_txn(self, txn, table, keyvalues, retcol): sql = ( - "SELECT %(retcol)s FROM %(table)s WHERE %(where)s " - "ORDER BY rowid asc" + "SELECT %(retcol)s FROM %(table)s WHERE %(where)s" ) % { "retcol": retcol, "table": table, @@ -548,14 +584,14 @@ class SQLBaseStore(object): retcols : list of strings giving the names of the columns to return """ if keyvalues: - sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % ( + sql = "SELECT %s FROM %s WHERE %s" % ( ", ".join(retcols), table, " AND ".join("%s = ?" % (k, ) for k in keyvalues) ) txn.execute(sql, keyvalues.values()) else: - sql = "SELECT %s FROM %s ORDER BY rowid asc" % ( + sql = "SELECT %s FROM %s" % ( ", ".join(retcols), table ) @@ -607,10 +643,10 @@ class SQLBaseStore(object): def _simple_select_one_txn(self, txn, table, keyvalues, retcols, allow_none=False): - select_sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % ( + select_sql = "SELECT %s FROM %s WHERE %s" % ( ", ".join(retcols), table, - " AND ".join("%s = ?" % (k) for k in keyvalues) + " AND ".join("%s = ?" % (k,) for k in keyvalues) ) txn.execute(select_sql, keyvalues.values()) @@ -648,6 +684,11 @@ class SQLBaseStore(object): updatevalues=updatevalues, ) + # if txn.rowcount == 0: + # raise StoreError(404, "No row found") + if txn.rowcount > 1: + raise StoreError(500, "More than one row matched") + return ret return self.runInteraction(desc, func) @@ -860,6 +901,12 @@ class SQLBaseStore(object): result = txn.fetchone() return result[0] if result else None + def get_next_stream_id(self): + with self._next_stream_id_lock: + i = self._next_stream_id + self._next_stream_id += 1 + return i + class _RollbackButIsFineException(Exception): """ This exception is used to rollback a transaction without implying @@ -883,7 +930,7 @@ class Table(object): _select_where_clause = "SELECT %s FROM %s WHERE %s" _select_clause = "SELECT %s FROM %s" - _insert_clause = "INSERT OR REPLACE INTO %s (%s) VALUES (%s)" + _insert_clause = "REPLACE INTO %s (%s) VALUES (%s)" @classmethod def select_statement(cls, where_clause=None): |