diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index e3e67d8e0d..6017c2a6e8 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -22,6 +22,8 @@ from synapse.util.logcontext import PreserveLoggingContext, LoggingContext
from synapse.util.lrucache import LruCache
import synapse.metrics
+from util.id_generators import IdGenerator, StreamIdGenerator
+
from twisted.internet import defer
from collections import namedtuple, OrderedDict
@@ -145,11 +147,12 @@ class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging and metrics to the .execute()
method."""
- __slots__ = ["txn", "name"]
+ __slots__ = ["txn", "name", "database_engine"]
- def __init__(self, txn, name):
+ def __init__(self, txn, name, database_engine):
object.__setattr__(self, "txn", txn)
object.__setattr__(self, "name", name)
+ object.__setattr__(self, "database_engine", database_engine)
def __getattr__(self, name):
return getattr(self.txn, name)
@@ -161,26 +164,32 @@ class LoggingTransaction(object):
# TODO(paul): Maybe use 'info' and 'debug' for values?
sql_logger.debug("[SQL] {%s} %s", self.name, sql)
- try:
- if args and args[0]:
- values = args[0]
+ sql = self.database_engine.convert_param_style(sql)
+
+ if args and args[0]:
+ args = list(args)
+ args[0] = [
+ self.database_engine.encode_parameter(a) for a in args[0]
+ ]
+ try:
sql_logger.debug(
- "[SQL values] {%s} " + ", ".join(("<%r>",) * len(values)),
+ "[SQL values] {%s} " + ", ".join(("<%r>",) * len(args[0])),
self.name,
- *values
+ *args[0]
)
- except:
- # Don't let logging failures stop SQL from working
- pass
+ except:
+ # Don't let logging failures stop SQL from working
+ pass
start = time.time() * 1000
+
try:
return self.txn.execute(
sql, *args, **kwargs
)
- except:
- logger.exception("[SQL FAIL] {%s}", self.name)
- raise
+ except Exception as e:
+ logger.debug("[SQL FAIL] {%s} %s", self.name, e)
+ raise
finally:
msecs = (time.time() * 1000) - start
sql_logger.debug("[SQL time] {%s} %f", self.name, msecs)
@@ -245,6 +254,14 @@ class SQLBaseStore(object):
self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True,
max_entries=hs.config.event_cache_size)
+ self.database_engine = hs.database_engine
+
+ self._stream_id_gen = StreamIdGenerator()
+ self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
+ self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
+ self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
+ self._pushers_id_gen = IdGenerator("pushers", "id", self)
+
def start_profiling(self):
self._previous_loop_ts = self._clock.time_msec()
@@ -281,8 +298,11 @@ class SQLBaseStore(object):
start_time = time.time() * 1000
- def inner_func(txn, *args, **kwargs):
+ def inner_func(conn, *args, **kwargs):
with LoggingContext("runInteraction") as context:
+ if self.database_engine.is_connection_closed(conn):
+ conn.reconnect()
+
current_context.copy_to(context)
start = time.time() * 1000
txn_id = self._TXN_ID
@@ -296,9 +316,48 @@ class SQLBaseStore(object):
sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
transaction_logger.debug("[TXN START] {%s}", name)
try:
- return func(LoggingTransaction(txn, name), *args, **kwargs)
- except:
- logger.exception("[TXN FAIL] {%s}", name)
+ i = 0
+ N = 5
+ while True:
+ try:
+ txn = conn.cursor()
+ return func(
+ LoggingTransaction(txn, name, self.database_engine),
+ *args, **kwargs
+ )
+ except self.database_engine.module.OperationalError as e:
+ # This can happen if the database disappears mid
+ # transaction.
+ logger.warn(
+ "[TXN OPERROR] {%s} %s %d/%d",
+ name, e, i, N
+ )
+ if i < N:
+ i += 1
+ try:
+ conn.rollback()
+ except self.database_engine.module.Error as e1:
+ logger.warn(
+ "[TXN EROLL] {%s} %s",
+ name, e1,
+ )
+ continue
+ except self.database_engine.module.DatabaseError as e:
+ if self.database_engine.is_deadlock(e):
+ logger.warn("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
+ if i < N:
+ i += 1
+ try:
+ conn.rollback()
+ except self.database_engine.module.Error as e1:
+ logger.warn(
+ "[TXN EROLL] {%s} %s",
+ name, e1,
+ )
+ continue
+ raise
+ except Exception as e:
+ logger.debug("[TXN FAIL] {%s} %s", name, e)
raise
finally:
end = time.time() * 1000
@@ -311,7 +370,7 @@ class SQLBaseStore(object):
sql_txn_timer.inc_by(duration, desc)
with PreserveLoggingContext():
- result = yield self._db_pool.runInteraction(
+ result = yield self._db_pool.runWithConnection(
inner_func, *args, **kwargs
)
defer.returnValue(result)
@@ -342,11 +401,11 @@ class SQLBaseStore(object):
The result of decoder(results)
"""
def interaction(txn):
- cursor = txn.execute(query, args)
+ txn.execute(query, args)
if decoder:
- return decoder(cursor)
+ return decoder(txn)
else:
- return cursor.fetchall()
+ return txn.fetchall()
return self.runInteraction(desc, interaction)
@@ -356,27 +415,29 @@ class SQLBaseStore(object):
# "Simple" SQL API methods that operate on a single table with no JOINs,
# no complex WHERE clauses, just a dict of values for columns.
- def _simple_insert(self, table, values, or_replace=False, or_ignore=False,
+ @defer.inlineCallbacks
+ def _simple_insert(self, table, values, or_ignore=False,
desc="_simple_insert"):
"""Executes an INSERT query on the named table.
Args:
table : string giving the table name
values : dict of new column names and values for them
- or_replace : bool; if True performs an INSERT OR REPLACE
"""
- return self.runInteraction(
- desc,
- self._simple_insert_txn, table, values, or_replace=or_replace,
- or_ignore=or_ignore,
- )
+ try:
+ yield self.runInteraction(
+ desc,
+ self._simple_insert_txn, table, values,
+ )
+ except self.database_engine.module.IntegrityError:
+ # We have to do or_ignore flag at this layer, since we can't reuse
+ # a cursor after we receive an error from the db.
+ if not or_ignore:
+ raise
@log_function
- def _simple_insert_txn(self, txn, table, values, or_replace=False,
- or_ignore=False):
- sql = "%s INTO %s (%s) VALUES(%s)" % (
- ("INSERT OR REPLACE" if or_replace else
- "INSERT OR IGNORE" if or_ignore else "INSERT"),
+ def _simple_insert_txn(self, txn, table, values):
+ sql = "INSERT INTO %s (%s) VALUES(%s)" % (
table,
", ".join(k for k in values),
", ".join("?" for k in values)
@@ -388,22 +449,26 @@ class SQLBaseStore(object):
)
txn.execute(sql, values.values())
- return txn.lastrowid
- def _simple_upsert(self, table, keyvalues, values, desc="_simple_upsert"):
+ def _simple_upsert(self, table, keyvalues, values,
+ insertion_values={}, desc="_simple_upsert"):
"""
Args:
table (str): The table to upsert into
keyvalues (dict): The unique key tables and their new values
values (dict): The nonunique columns and their new values
+ insertion_values (dict): key/values to use when inserting
Returns: A deferred
"""
return self.runInteraction(
desc,
- self._simple_upsert_txn, table, keyvalues, values
+ self._simple_upsert_txn, table, keyvalues, values, insertion_values,
)
- def _simple_upsert_txn(self, txn, table, keyvalues, values):
+ def _simple_upsert_txn(self, txn, table, keyvalues, values, insertion_values={}):
+ # We need to lock the table :(
+ self.database_engine.lock_table(txn, table)
+
# Try to update
sql = "UPDATE %s SET %s WHERE %s" % (
table,
@@ -422,6 +487,7 @@ class SQLBaseStore(object):
allvalues = {}
allvalues.update(keyvalues)
allvalues.update(values)
+ allvalues.update(insertion_values)
sql = "INSERT INTO %s (%s) VALUES (%s)" % (
table,
@@ -489,8 +555,7 @@ class SQLBaseStore(object):
def _simple_select_onecol_txn(self, txn, table, keyvalues, retcol):
sql = (
- "SELECT %(retcol)s FROM %(table)s WHERE %(where)s "
- "ORDER BY rowid asc"
+ "SELECT %(retcol)s FROM %(table)s WHERE %(where)s"
) % {
"retcol": retcol,
"table": table,
@@ -548,14 +613,14 @@ class SQLBaseStore(object):
retcols : list of strings giving the names of the columns to return
"""
if keyvalues:
- sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % (
+ sql = "SELECT %s FROM %s WHERE %s" % (
", ".join(retcols),
table,
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
)
txn.execute(sql, keyvalues.values())
else:
- sql = "SELECT %s FROM %s ORDER BY rowid asc" % (
+ sql = "SELECT %s FROM %s" % (
", ".join(retcols),
table
)
@@ -607,10 +672,10 @@ class SQLBaseStore(object):
def _simple_select_one_txn(self, txn, table, keyvalues, retcols,
allow_none=False):
- select_sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % (
+ select_sql = "SELECT %s FROM %s WHERE %s" % (
", ".join(retcols),
table,
- " AND ".join("%s = ?" % (k) for k in keyvalues)
+ " AND ".join("%s = ?" % (k,) for k in keyvalues)
)
txn.execute(select_sql, keyvalues.values())
@@ -648,6 +713,11 @@ class SQLBaseStore(object):
updatevalues=updatevalues,
)
+ # if txn.rowcount == 0:
+ # raise StoreError(404, "No row found")
+ if txn.rowcount > 1:
+ raise StoreError(500, "More than one row matched")
+
return ret
return self.runInteraction(desc, func)
@@ -860,6 +930,12 @@ class SQLBaseStore(object):
result = txn.fetchone()
return result[0] if result else None
+ def get_next_stream_id(self):
+ with self._next_stream_id_lock:
+ i = self._next_stream_id
+ self._next_stream_id += 1
+ return i
+
class _RollbackButIsFineException(Exception):
""" This exception is used to rollback a transaction without implying
@@ -883,7 +959,7 @@ class Table(object):
_select_where_clause = "SELECT %s FROM %s WHERE %s"
_select_clause = "SELECT %s FROM %s"
- _insert_clause = "INSERT OR REPLACE INTO %s (%s) VALUES (%s)"
+ _insert_clause = "REPLACE INTO %s (%s) VALUES (%s)"
@classmethod
def select_statement(cls, where_clause=None):
|