diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 6f54036d67..1d41d8d445 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -13,36 +13,38 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+import sys
+import threading
+import time
-from synapse.api.errors import StoreError
-from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
-from synapse.util.caches import CACHE_SIZE_FACTOR
-from synapse.util.caches.dictionary_cache import DictionaryCache
-from synapse.util.caches.descriptors import Cache
-from synapse.storage.engines import PostgresEngine
-import synapse.metrics
+from six import iteritems, iterkeys, itervalues
+from six.moves import intern, range
+from prometheus_client import Histogram
from twisted.internet import defer
-import sys
-import time
-import threading
-
+from synapse.api.errors import StoreError
+from synapse.storage.engines import PostgresEngine
+from synapse.util.caches.descriptors import Cache
+from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
logger = logging.getLogger(__name__)
+try:
+ MAX_TXN_ID = sys.maxint - 1
+except AttributeError:
+ # python 3 does not have a maximum int value
+ MAX_TXN_ID = 2**63 - 1
+
sql_logger = logging.getLogger("synapse.storage.SQL")
transaction_logger = logging.getLogger("synapse.storage.txn")
perf_logger = logging.getLogger("synapse.storage.TIME")
+sql_scheduling_timer = Histogram("synapse_storage_schedule_time", "sec")
-metrics = synapse.metrics.get_metrics_for("synapse.storage")
-
-sql_scheduling_timer = metrics.register_distribution("schedule_time")
-
-sql_query_timer = metrics.register_distribution("query_time", labels=["verb"])
-sql_txn_timer = metrics.register_distribution("transaction_time", labels=["desc"])
+sql_query_timer = Histogram("synapse_storage_query_time", "sec", ["verb"])
+sql_txn_timer = Histogram("synapse_storage_transaction_time", "sec", ["desc"])
class LoggingTransaction(object):
@@ -50,16 +52,16 @@ class LoggingTransaction(object):
passed to the constructor. Adds logging and metrics to the .execute()
method."""
__slots__ = [
- "txn", "name", "database_engine", "after_callbacks", "final_callbacks",
+ "txn", "name", "database_engine", "after_callbacks", "exception_callbacks",
]
def __init__(self, txn, name, database_engine, after_callbacks,
- final_callbacks):
+ exception_callbacks):
object.__setattr__(self, "txn", txn)
object.__setattr__(self, "name", name)
object.__setattr__(self, "database_engine", database_engine)
object.__setattr__(self, "after_callbacks", after_callbacks)
- object.__setattr__(self, "final_callbacks", final_callbacks)
+ object.__setattr__(self, "exception_callbacks", exception_callbacks)
def call_after(self, callback, *args, **kwargs):
"""Call the given callback on the main twisted thread after the
@@ -68,8 +70,8 @@ class LoggingTransaction(object):
"""
self.after_callbacks.append((callback, args, kwargs))
- def call_finally(self, callback, *args, **kwargs):
- self.final_callbacks.append((callback, args, kwargs))
+ def call_on_exception(self, callback, *args, **kwargs):
+ self.exception_callbacks.append((callback, args, kwargs))
def __getattr__(self, name):
return getattr(self.txn, name)
@@ -103,11 +105,11 @@ class LoggingTransaction(object):
"[SQL values] {%s} %r",
self.name, args[0]
)
- except:
+ except Exception:
# Don't let logging failures stop SQL from working
pass
- start = time.time() * 1000
+ start = time.time()
try:
return func(
@@ -117,9 +119,9 @@ class LoggingTransaction(object):
logger.debug("[SQL FAIL] {%s} %s", self.name, e)
raise
finally:
- msecs = (time.time() * 1000) - start
- sql_logger.debug("[SQL time] {%s} %f", self.name, msecs)
- sql_query_timer.inc_by(msecs, sql.split()[0])
+ secs = time.time() - start
+ sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs)
+ sql_query_timer.labels(sql.split()[0]).observe(secs)
class PerformanceCounters(object):
@@ -129,7 +131,7 @@ class PerformanceCounters(object):
def update(self, key, start_time, end_time=None):
if end_time is None:
- end_time = time.time() * 1000
+ end_time = time.time()
duration = end_time - start_time
count, cum_time = self.current_counters.get(key, (0, 0))
count += 1
@@ -139,7 +141,7 @@ class PerformanceCounters(object):
def interval(self, interval_duration, limit=3):
counters = []
- for name, (count, cum_time) in self.current_counters.iteritems():
+ for name, (count, cum_time) in iteritems(self.current_counters):
prev_count, prev_time = self.previous_counters.get(name, (0, 0))
counters.append((
(cum_time - prev_time) / interval_duration,
@@ -162,7 +164,7 @@ class PerformanceCounters(object):
class SQLBaseStore(object):
_TXN_ID = 0
- def __init__(self, hs):
+ def __init__(self, db_conn, hs):
self.hs = hs
self._clock = hs.get_clock()
self._db_pool = hs.get_db_pool()
@@ -180,10 +182,6 @@ class SQLBaseStore(object):
self._get_event_cache = Cache("*getEvent*", keylen=3,
max_entries=hs.config.event_cache_size)
- self._state_group_cache = DictionaryCache(
- "*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR
- )
-
self._event_fetch_lock = threading.Condition()
self._event_fetch_list = []
self._event_fetch_ongoing = 0
@@ -221,14 +219,14 @@ class SQLBaseStore(object):
self._clock.looping_call(loop, 10000)
- def _new_transaction(self, conn, desc, after_callbacks, final_callbacks,
- logging_context, func, *args, **kwargs):
- start = time.time() * 1000
+ def _new_transaction(self, conn, desc, after_callbacks, exception_callbacks,
+ func, *args, **kwargs):
+ start = time.time()
txn_id = self._TXN_ID
# We don't really need these to be unique, so lets stop it from
# growing really large.
- self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1)
+ self._TXN_ID = (self._TXN_ID + 1) % (MAX_TXN_ID)
name = "%s-%x" % (desc, txn_id, )
@@ -242,7 +240,7 @@ class SQLBaseStore(object):
txn = conn.cursor()
txn = LoggingTransaction(
txn, name, self.database_engine, after_callbacks,
- final_callbacks,
+ exception_callbacks,
)
r = func(txn, *args, **kwargs)
conn.commit()
@@ -283,73 +281,85 @@ class SQLBaseStore(object):
logger.debug("[TXN FAIL] {%s} %s", name, e)
raise
finally:
- end = time.time() * 1000
+ end = time.time()
duration = end - start
- if logging_context is not None:
- logging_context.add_database_transaction(duration)
+ LoggingContext.current_context().add_database_transaction(duration)
- transaction_logger.debug("[TXN END] {%s} %f", name, duration)
+ transaction_logger.debug("[TXN END] {%s} %f sec", name, duration)
self._current_txn_total_time += duration
self._txn_perf_counters.update(desc, start, end)
- sql_txn_timer.inc_by(duration, desc)
+ sql_txn_timer.labels(desc).observe(duration)
@defer.inlineCallbacks
def runInteraction(self, desc, func, *args, **kwargs):
- """Wraps the .runInteraction() method on the underlying db_pool."""
- current_context = LoggingContext.current_context()
-
- start_time = time.time() * 1000
+ """Starts a transaction on the database and runs a given function
- after_callbacks = []
- final_callbacks = []
+ Arguments:
+ desc (str): description of the transaction, for logging and metrics
+ func (func): callback function, which will be called with a
+ database transaction (twisted.enterprise.adbapi.Transaction) as
+ its first argument, followed by `args` and `kwargs`.
- def inner_func(conn, *args, **kwargs):
- with LoggingContext("runInteraction") as context:
- sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
+ args (list): positional args to pass to `func`
+ kwargs (dict): named args to pass to `func`
- if self.database_engine.is_connection_closed(conn):
- logger.debug("Reconnecting closed database connection")
- conn.reconnect()
-
- current_context.copy_to(context)
- return self._new_transaction(
- conn, desc, after_callbacks, final_callbacks, current_context,
- func, *args, **kwargs
- )
+ Returns:
+ Deferred: The result of func
+ """
+ after_callbacks = []
+ exception_callbacks = []
try:
- with PreserveLoggingContext():
- result = yield self._db_pool.runWithConnection(
- inner_func, *args, **kwargs
- )
+ result = yield self.runWithConnection(
+ self._new_transaction,
+ desc, after_callbacks, exception_callbacks, func,
+ *args, **kwargs
+ )
for after_callback, after_args, after_kwargs in after_callbacks:
after_callback(*after_args, **after_kwargs)
- finally:
- for after_callback, after_args, after_kwargs in final_callbacks:
+ except: # noqa: E722, as we reraise the exception this is fine.
+ for after_callback, after_args, after_kwargs in exception_callbacks:
after_callback(*after_args, **after_kwargs)
+ raise
defer.returnValue(result)
@defer.inlineCallbacks
def runWithConnection(self, func, *args, **kwargs):
- """Wraps the .runInteraction() method on the underlying db_pool."""
- current_context = LoggingContext.current_context()
+ """Wraps the .runWithConnection() method on the underlying db_pool.
- start_time = time.time() * 1000
+ Arguments:
+ func (func): callback function, which will be called with a
+ database connection (twisted.enterprise.adbapi.Connection) as
+ its first argument, followed by `args` and `kwargs`.
+ args (list): positional args to pass to `func`
+ kwargs (dict): named args to pass to `func`
+
+ Returns:
+ Deferred: The result of func
+ """
+ parent_context = LoggingContext.current_context()
+ if parent_context == LoggingContext.sentinel:
+ logger.warn(
+ "Starting db connection from sentinel context: metrics will be lost",
+ )
+ parent_context = None
+
+ start_time = time.time()
def inner_func(conn, *args, **kwargs):
- with LoggingContext("runWithConnection") as context:
- sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
+ with LoggingContext("runWithConnection", parent_context) as context:
+ sched_duration_sec = time.time() - start_time
+ sql_scheduling_timer.observe(sched_duration_sec)
+ context.add_database_scheduled(sched_duration_sec)
if self.database_engine.is_connection_closed(conn):
logger.debug("Reconnecting closed database connection")
conn.reconnect()
- current_context.copy_to(context)
-
return func(conn, *args, **kwargs)
with PreserveLoggingContext():
@@ -368,7 +378,7 @@ class SQLBaseStore(object):
Returns:
A list of dicts where the key is the column header.
"""
- col_headers = list(intern(column[0]) for column in cursor.description)
+ col_headers = list(intern(str(column[0])) for column in cursor.description)
results = list(
dict(zip(col_headers, row)) for row in cursor
)
@@ -475,23 +485,53 @@ class SQLBaseStore(object):
txn.executemany(sql, vals)
+ @defer.inlineCallbacks
def _simple_upsert(self, table, keyvalues, values,
insertion_values={}, desc="_simple_upsert", lock=True):
"""
+
+ `lock` should generally be set to True (the default), but can be set
+ to False if either of the following are true:
+
+ * there is a UNIQUE INDEX on the key columns. In this case a conflict
+ will cause an IntegrityError in which case this function will retry
+ the update.
+
+ * we somehow know that we are the only thread which will be updating
+ this table.
+
Args:
table (str): The table to upsert into
keyvalues (dict): The unique key tables and their new values
values (dict): The nonunique columns and their new values
- insertion_values (dict): key/values to use when inserting
+ insertion_values (dict): additional key/values to use only when
+ inserting
+ lock (bool): True to lock the table when doing the upsert.
Returns:
Deferred(bool): True if a new entry was created, False if an
existing one was updated.
"""
- return self.runInteraction(
- desc,
- self._simple_upsert_txn, table, keyvalues, values, insertion_values,
- lock
- )
+ attempts = 0
+ while True:
+ try:
+ result = yield self.runInteraction(
+ desc,
+ self._simple_upsert_txn, table, keyvalues, values, insertion_values,
+ lock=lock
+ )
+ defer.returnValue(result)
+ except self.database_engine.module.IntegrityError as e:
+ attempts += 1
+ if attempts >= 5:
+ # don't retry forever, because things other than races
+ # can cause IntegrityErrors
+ raise
+
+ # presumably we raced with another transaction: let's retry.
+ logger.warn(
+ "IntegrityError when upserting into %s; retrying: %s",
+ table, e
+ )
def _simple_upsert_txn(self, txn, table, keyvalues, values, insertion_values={},
lock=True):
@@ -499,37 +539,38 @@ class SQLBaseStore(object):
if lock:
self.database_engine.lock_table(txn, table)
- # Try to update
+ # First try to update.
sql = "UPDATE %s SET %s WHERE %s" % (
table,
", ".join("%s = ?" % (k,) for k in values),
" AND ".join("%s = ?" % (k,) for k in keyvalues)
)
- sqlargs = values.values() + keyvalues.values()
+ sqlargs = list(values.values()) + list(keyvalues.values())
txn.execute(sql, sqlargs)
- if txn.rowcount == 0:
- # We didn't update and rows so insert a new one
- allvalues = {}
- allvalues.update(keyvalues)
- allvalues.update(values)
- allvalues.update(insertion_values)
+ if txn.rowcount > 0:
+ # successfully updated at least one row.
+ return False
- sql = "INSERT INTO %s (%s) VALUES (%s)" % (
- table,
- ", ".join(k for k in allvalues),
- ", ".join("?" for _ in allvalues)
- )
- txn.execute(sql, allvalues.values())
+ # We didn't update any rows so insert a new one
+ allvalues = {}
+ allvalues.update(keyvalues)
+ allvalues.update(values)
+ allvalues.update(insertion_values)
- return True
- else:
- return False
+ sql = "INSERT INTO %s (%s) VALUES (%s)" % (
+ table,
+ ", ".join(k for k in allvalues),
+ ", ".join("?" for _ in allvalues)
+ )
+ txn.execute(sql, list(allvalues.values()))
+ # successfully inserted
+ return True
def _simple_select_one(self, table, keyvalues, retcols,
allow_none=False, desc="_simple_select_one"):
"""Executes a SELECT query on the named table, which is expected to
- return a single row, returning a single column from it.
+ return a single row, returning multiple columns from it.
Args:
table : string giving the table name
@@ -582,20 +623,18 @@ class SQLBaseStore(object):
@staticmethod
def _simple_select_onecol_txn(txn, table, keyvalues, retcol):
- if keyvalues:
- where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
- else:
- where = ""
-
sql = (
- "SELECT %(retcol)s FROM %(table)s %(where)s"
+ "SELECT %(retcol)s FROM %(table)s"
) % {
"retcol": retcol,
"table": table,
- "where": where,
}
- txn.execute(sql, keyvalues.values())
+ if keyvalues:
+ sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
+ txn.execute(sql, list(keyvalues.values()))
+ else:
+ txn.execute(sql)
return [r[0] for r in txn]
@@ -606,7 +645,7 @@ class SQLBaseStore(object):
Args:
table (str): table name
- keyvalues (dict): column names and values to select the rows with
+ keyvalues (dict|None): column names and values to select the rows with
retcol (str): column whos value we wish to retrieve.
Returns:
@@ -657,7 +696,7 @@ class SQLBaseStore(object):
table,
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
)
- txn.execute(sql, keyvalues.values())
+ txn.execute(sql, list(keyvalues.values()))
else:
sql = "SELECT %s FROM %s" % (
", ".join(retcols),
@@ -688,9 +727,12 @@ class SQLBaseStore(object):
if not iterable:
defer.returnValue(results)
+ # iterables can not be sliced, so convert it to a list first
+ it_list = list(iterable)
+
chunks = [
- iterable[i:i + batch_size]
- for i in xrange(0, len(iterable), batch_size)
+ it_list[i:i + batch_size]
+ for i in range(0, len(it_list), batch_size)
]
for chunk in chunks:
rows = yield self.runInteraction(
@@ -730,7 +772,7 @@ class SQLBaseStore(object):
)
values.extend(iterable)
- for key, value in keyvalues.iteritems():
+ for key, value in iteritems(keyvalues):
clauses.append("%s = ?" % (key,))
values.append(value)
@@ -743,6 +785,33 @@ class SQLBaseStore(object):
txn.execute(sql, values)
return cls.cursor_to_dict(txn)
+ def _simple_update(self, table, keyvalues, updatevalues, desc):
+ return self.runInteraction(
+ desc,
+ self._simple_update_txn,
+ table, keyvalues, updatevalues,
+ )
+
+ @staticmethod
+ def _simple_update_txn(txn, table, keyvalues, updatevalues):
+ if keyvalues:
+ where = "WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
+ else:
+ where = ""
+
+ update_sql = "UPDATE %s SET %s %s" % (
+ table,
+ ", ".join("%s = ?" % (k,) for k in updatevalues),
+ where,
+ )
+
+ txn.execute(
+ update_sql,
+ list(updatevalues.values()) + list(keyvalues.values())
+ )
+
+ return txn.rowcount
+
def _simple_update_one(self, table, keyvalues, updatevalues,
desc="_simple_update_one"):
"""Executes an UPDATE query on the named table, setting new values for
@@ -768,27 +837,13 @@ class SQLBaseStore(object):
table, keyvalues, updatevalues,
)
- @staticmethod
- def _simple_update_one_txn(txn, table, keyvalues, updatevalues):
- if keyvalues:
- where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
- else:
- where = ""
-
- update_sql = "UPDATE %s SET %s %s" % (
- table,
- ", ".join("%s = ?" % (k,) for k in updatevalues),
- where,
- )
-
- txn.execute(
- update_sql,
- updatevalues.values() + keyvalues.values()
- )
+ @classmethod
+ def _simple_update_one_txn(cls, txn, table, keyvalues, updatevalues):
+ rowcount = cls._simple_update_txn(txn, table, keyvalues, updatevalues)
- if txn.rowcount == 0:
+ if rowcount == 0:
raise StoreError(404, "No row found")
- if txn.rowcount > 1:
+ if rowcount > 1:
raise StoreError(500, "More than one row matched")
@staticmethod
@@ -800,7 +855,7 @@ class SQLBaseStore(object):
" AND ".join("%s = ?" % (k,) for k in keyvalues)
)
- txn.execute(select_sql, keyvalues.values())
+ txn.execute(select_sql, list(keyvalues.values()))
row = txn.fetchone()
if not row:
@@ -838,7 +893,7 @@ class SQLBaseStore(object):
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
)
- txn.execute(sql, keyvalues.values())
+ txn.execute(sql, list(keyvalues.values()))
if txn.rowcount == 0:
raise StoreError(404, "No row found")
if txn.rowcount > 1:
@@ -856,7 +911,7 @@ class SQLBaseStore(object):
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
)
- return txn.execute(sql, keyvalues.values())
+ return txn.execute(sql, list(keyvalues.values()))
def _simple_delete_many(self, table, column, iterable, keyvalues, desc):
return self.runInteraction(
@@ -888,7 +943,7 @@ class SQLBaseStore(object):
)
values.extend(iterable)
- for key, value in keyvalues.iteritems():
+ for key, value in iteritems(keyvalues):
clauses.append("%s = ?" % (key,))
values.append(value)
@@ -928,7 +983,7 @@ class SQLBaseStore(object):
txn.close()
if cache:
- min_val = min(cache.itervalues())
+ min_val = min(itervalues(cache))
else:
min_val = max_value
@@ -951,7 +1006,8 @@ class SQLBaseStore(object):
# __exit__ called after the transaction finishes.
ctx = self._cache_id_gen.get_next()
stream_id = ctx.__enter__()
- txn.call_finally(ctx.__exit__, None, None, None)
+ txn.call_on_exception(ctx.__exit__, None, None, None)
+ txn.call_after(ctx.__exit__, None, None, None)
txn.call_after(self.hs.get_notifier().on_new_replication_data)
self._simple_insert_txn(
@@ -1042,7 +1098,7 @@ class SQLBaseStore(object):
" AND ".join("%s = ?" % (k,) for k in keyvalues),
" ? ASC LIMIT ? OFFSET ?"
)
- txn.execute(sql, keyvalues.values() + pagevalues)
+ txn.execute(sql, list(keyvalues.values()) + list(pagevalues))
else:
sql = "SELECT %s FROM %s ORDER BY %s" % (
", ".join(retcols),
|