diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 9125bb1198..ee5587c721 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -22,6 +22,8 @@ from synapse.util.logcontext import PreserveLoggingContext, LoggingContext
from synapse.util.lrucache import LruCache
import synapse.metrics
+from util.id_generators import IdGenerator, StreamIdGenerator
+
from twisted.internet import defer
from collections import namedtuple, OrderedDict
@@ -29,12 +31,15 @@ import functools
import simplejson as json
import sys
import time
+import threading
+DEBUG_CACHES = False
logger = logging.getLogger(__name__)
sql_logger = logging.getLogger("synapse.storage.SQL")
transaction_logger = logging.getLogger("synapse.storage.txn")
+perf_logger = logging.getLogger("synapse.storage.TIME")
metrics = synapse.metrics.get_metrics_for("synapse.storage")
@@ -53,14 +58,78 @@ cache_counter = metrics.register_cache(
)
-# TODO(paul):
-# * more generic key management
-# * consider other eviction strategies - LRU?
-def cached(max_entries=1000):
+class Cache(object):
+
+ def __init__(self, name, max_entries=1000, keylen=1, lru=False):
+ if lru:
+ self.cache = LruCache(max_size=max_entries)
+ self.max_entries = None
+ else:
+ self.cache = OrderedDict()
+ self.max_entries = max_entries
+
+ self.name = name
+ self.keylen = keylen
+ self.sequence = 0
+ self.thread = None
+ caches_by_name[name] = self.cache
+
+ def check_thread(self):
+ expected_thread = self.thread
+ if expected_thread is None:
+ self.thread = threading.current_thread()
+ else:
+ if expected_thread is not threading.current_thread():
+ raise ValueError(
+ "Cache objects can only be accessed from the main thread"
+ )
+
+ def get(self, *keyargs):
+ if len(keyargs) != self.keylen:
+ raise ValueError("Expected a key to have %d items", self.keylen)
+
+ if keyargs in self.cache:
+ cache_counter.inc_hits(self.name)
+ return self.cache[keyargs]
+
+ cache_counter.inc_misses(self.name)
+ raise KeyError()
+
+ def update(self, sequence, *args):
+ self.check_thread()
+ if self.sequence == sequence:
+ # Only update the cache if the caches sequence number matches the
+ # number that the cache had before the SELECT was started (SYN-369)
+ self.prefill(*args)
+
+ def prefill(self, *args): # because I can't *keyargs, value
+ keyargs = args[:-1]
+ value = args[-1]
+
+ if len(keyargs) != self.keylen:
+ raise ValueError("Expected a key to have %d items", self.keylen)
+
+ if self.max_entries is not None:
+ while len(self.cache) >= self.max_entries:
+ self.cache.popitem(last=False)
+
+ self.cache[keyargs] = value
+
+ def invalidate(self, *keyargs):
+ self.check_thread()
+ if len(keyargs) != self.keylen:
+ raise ValueError("Expected a key to have %d items", self.keylen)
+ # Increment the sequence number so that any SELECT statements that
+ # raced with the INSERT don't update the cache (SYN-369)
+ self.sequence += 1
+ self.cache.pop(keyargs, None)
+
+
+def cached(max_entries=1000, num_args=1, lru=False):
""" A method decorator that applies a memoizing cache around the function.
- The function is presumed to take one additional argument, which is used as
- the key for the cache. Cache hits are served directly from the cache;
+ The function is presumed to take zero or more arguments, which are used in
+ a tuple as the key for the cache. Hits are served directly from the cache;
misses use the function body to generate the value.
The wrapped function has an additional member, a callable called
@@ -71,34 +140,42 @@ def cached(max_entries=1000):
calling the calculation function.
"""
def wrap(orig):
- cache = OrderedDict()
- name = orig.__name__
-
- caches_by_name[name] = cache
-
- def prefill(key, value):
- while len(cache) > max_entries:
- cache.popitem(last=False)
-
- cache[key] = value
+ cache = Cache(
+ name=orig.__name__,
+ max_entries=max_entries,
+ keylen=num_args,
+ lru=lru,
+ )
@functools.wraps(orig)
@defer.inlineCallbacks
- def wrapped(self, key):
- if key in cache:
- cache_counter.inc_hits(name)
- defer.returnValue(cache[key])
-
- cache_counter.inc_misses(name)
- ret = yield orig(self, key)
- prefill(key, ret)
- defer.returnValue(ret)
-
- def invalidate(key):
- cache.pop(key, None)
-
- wrapped.invalidate = invalidate
- wrapped.prefill = prefill
+ def wrapped(self, *keyargs):
+ try:
+ cached_result = cache.get(*keyargs)
+ if DEBUG_CACHES:
+ actual_result = yield orig(self, *keyargs)
+ if actual_result != cached_result:
+ logger.error(
+ "Stale cache entry %s%r: cached: %r, actual %r",
+ orig.__name__, keyargs,
+ cached_result, actual_result,
+ )
+ raise ValueError("Stale cache entry")
+ defer.returnValue(cached_result)
+ except KeyError:
+ # Get the sequence number of the cache before reading from the
+ # database so that we can tell if the cache is invalidated
+ # while the SELECT is executing (SYN-369)
+ sequence = cache.sequence
+
+ ret = yield orig(self, *keyargs)
+
+ cache.update(sequence, *keyargs + (ret,))
+
+ defer.returnValue(ret)
+
+ wrapped.invalidate = cache.invalidate
+ wrapped.prefill = cache.prefill
return wrapped
return wrap
@@ -108,11 +185,20 @@ class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging and metrics to the .execute()
method."""
- __slots__ = ["txn", "name"]
+ __slots__ = ["txn", "name", "database_engine", "after_callbacks"]
- def __init__(self, txn, name):
+ def __init__(self, txn, name, database_engine, after_callbacks):
object.__setattr__(self, "txn", txn)
object.__setattr__(self, "name", name)
+ object.__setattr__(self, "database_engine", database_engine)
+ object.__setattr__(self, "after_callbacks", after_callbacks)
+
+ def call_after(self, callback, *args):
+ """Call the given callback on the main twisted thread after the
+ transaction has finished. Used to invalidate the caches on the
+ correct thread.
+ """
+ self.after_callbacks.append((callback, args))
def __getattr__(self, name):
return getattr(self.txn, name)
@@ -120,30 +206,37 @@ class LoggingTransaction(object):
def __setattr__(self, name, value):
setattr(self.txn, name, value)
- def execute(self, sql, *args, **kwargs):
+ def execute(self, sql, *args):
+ self._do_execute(self.txn.execute, sql, *args)
+
+ def executemany(self, sql, *args):
+ self._do_execute(self.txn.executemany, sql, *args)
+
+ def _do_execute(self, func, sql, *args):
# TODO(paul): Maybe use 'info' and 'debug' for values?
sql_logger.debug("[SQL] {%s} %s", self.name, sql)
- try:
- if args and args[0]:
- values = args[0]
+ sql = self.database_engine.convert_param_style(sql)
+
+ if args:
+ try:
sql_logger.debug(
- "[SQL values] {%s} " + ", ".join(("<%r>",) * len(values)),
- self.name,
- *values
+ "[SQL values] {%s} %r",
+ self.name, args[0]
)
- except:
- # Don't let logging failures stop SQL from working
- pass
+ except:
+ # Don't let logging failures stop SQL from working
+ pass
start = time.time() * 1000
+
try:
- return self.txn.execute(
- sql, *args, **kwargs
+ return func(
+ sql, *args
)
- except:
- logger.exception("[SQL FAIL] {%s}", self.name)
- raise
+ except Exception as e:
+ logger.debug("[SQL FAIL] {%s} %s", self.name, e)
+ raise
finally:
msecs = (time.time() * 1000) - start
sql_logger.debug("[SQL time] {%s} %f", self.name, msecs)
@@ -205,10 +298,16 @@ class SQLBaseStore(object):
self._txn_perf_counters = PerformanceCounters()
self._get_event_counters = PerformanceCounters()
- self._get_event_cache = LruCache(hs.config.event_cache_size)
+ self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True,
+ max_entries=hs.config.event_cache_size)
+
+ self.database_engine = hs.database_engine
- # Pretend the getEventCache is just another named cache
- caches_by_name["*getEvent*"] = self._get_event_cache
+ self._stream_id_gen = StreamIdGenerator()
+ self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
+ self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
+ self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
+ self._pushers_id_gen = IdGenerator("pushers", "id", self)
def start_profiling(self):
self._previous_loop_ts = self._clock.time_msec()
@@ -232,7 +331,7 @@ class SQLBaseStore(object):
time_now - time_then, limit=3
)
- logger.info(
+ perf_logger.info(
"Total database time: %.3f%% {%s} {%s}",
ratio * 100, top_three_counters, top_3_event_counters
)
@@ -246,8 +345,14 @@ class SQLBaseStore(object):
start_time = time.time() * 1000
- def inner_func(txn, *args, **kwargs):
+ after_callbacks = []
+
+ def inner_func(conn, *args, **kwargs):
with LoggingContext("runInteraction") as context:
+ if self.database_engine.is_connection_closed(conn):
+ logger.debug("Reconnecting closed database connection")
+ conn.reconnect()
+
current_context.copy_to(context)
start = time.time() * 1000
txn_id = self._TXN_ID
@@ -261,9 +366,48 @@ class SQLBaseStore(object):
sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
transaction_logger.debug("[TXN START] {%s}", name)
try:
- return func(LoggingTransaction(txn, name), *args, **kwargs)
- except:
- logger.exception("[TXN FAIL] {%s}", name)
+ i = 0
+ N = 5
+ while True:
+ try:
+ txn = conn.cursor()
+ txn = LoggingTransaction(
+ txn, name, self.database_engine, after_callbacks
+ )
+ return func(txn, *args, **kwargs)
+ except self.database_engine.module.OperationalError as e:
+ # This can happen if the database disappears mid
+ # transaction.
+ logger.warn(
+ "[TXN OPERROR] {%s} %s %d/%d",
+ name, e, i, N
+ )
+ if i < N:
+ i += 1
+ try:
+ conn.rollback()
+ except self.database_engine.module.Error as e1:
+ logger.warn(
+ "[TXN EROLL] {%s} %s",
+ name, e1,
+ )
+ continue
+ except self.database_engine.module.DatabaseError as e:
+ if self.database_engine.is_deadlock(e):
+ logger.warn("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
+ if i < N:
+ i += 1
+ try:
+ conn.rollback()
+ except self.database_engine.module.Error as e1:
+ logger.warn(
+ "[TXN EROLL] {%s} %s",
+ name, e1,
+ )
+ continue
+ raise
+ except Exception as e:
+ logger.debug("[TXN FAIL] {%s} %s", name, e)
raise
finally:
end = time.time() * 1000
@@ -276,9 +420,11 @@ class SQLBaseStore(object):
sql_txn_timer.inc_by(duration, desc)
with PreserveLoggingContext():
- result = yield self._db_pool.runInteraction(
+ result = yield self._db_pool.runWithConnection(
inner_func, *args, **kwargs
)
+ for after_callback, after_args in after_callbacks:
+ after_callback(*after_args)
defer.returnValue(result)
def cursor_to_dict(self, cursor):
@@ -307,11 +453,11 @@ class SQLBaseStore(object):
The result of decoder(results)
"""
def interaction(txn):
- cursor = txn.execute(query, args)
+ txn.execute(query, args)
if decoder:
- return decoder(cursor)
+ return decoder(txn)
else:
- return cursor.fetchall()
+ return txn.fetchall()
return self.runInteraction(desc, interaction)
@@ -321,53 +467,94 @@ class SQLBaseStore(object):
# "Simple" SQL API methods that operate on a single table with no JOINs,
# no complex WHERE clauses, just a dict of values for columns.
- def _simple_insert(self, table, values, or_replace=False, or_ignore=False):
+ @defer.inlineCallbacks
+ def _simple_insert(self, table, values, or_ignore=False,
+ desc="_simple_insert"):
"""Executes an INSERT query on the named table.
Args:
table : string giving the table name
values : dict of new column names and values for them
- or_replace : bool; if True performs an INSERT OR REPLACE
"""
- return self.runInteraction(
- "_simple_insert",
- self._simple_insert_txn, table, values, or_replace=or_replace,
- or_ignore=or_ignore,
- )
+ try:
+ yield self.runInteraction(
+ desc,
+ self._simple_insert_txn, table, values,
+ )
+ except self.database_engine.module.IntegrityError:
+ # We have to do or_ignore flag at this layer, since we can't reuse
+ # a cursor after we receive an error from the db.
+ if not or_ignore:
+ raise
@log_function
- def _simple_insert_txn(self, txn, table, values, or_replace=False,
- or_ignore=False):
- sql = "%s INTO %s (%s) VALUES(%s)" % (
- ("INSERT OR REPLACE" if or_replace else
- "INSERT OR IGNORE" if or_ignore else "INSERT"),
+ def _simple_insert_txn(self, txn, table, values):
+ keys, vals = zip(*values.items())
+
+ sql = "INSERT INTO %s (%s) VALUES(%s)" % (
table,
- ", ".join(k for k in values),
- ", ".join("?" for k in values)
+ ", ".join(k for k in keys),
+ ", ".join("?" for _ in keys)
)
- logger.debug(
- "[SQL] %s Args=%s",
- sql, values.values(),
+ txn.execute(sql, vals)
+
+ def _simple_insert_many_txn(self, txn, table, values):
+ if not values:
+ return
+
+ # This is a *slight* abomination to get a list of tuples of key names
+ # and a list of tuples of value names.
+ #
+ # i.e. [{"a": 1, "b": 2}, {"c": 3, "d": 4}]
+ # => [("a", "b",), ("c", "d",)] and [(1, 2,), (3, 4,)]
+ #
+ # The sort is to ensure that we don't rely on dictionary iteration
+ # order.
+ keys, vals = zip(*[
+ zip(
+ *(sorted(i.items(), key=lambda kv: kv[0]))
+ )
+ for i in values
+ if i
+ ])
+
+ for k in keys:
+ if k != keys[0]:
+ raise RuntimeError(
+ "All items must have the same keys"
+ )
+
+ sql = "INSERT INTO %s (%s) VALUES(%s)" % (
+ table,
+ ", ".join(k for k in keys[0]),
+ ", ".join("?" for _ in keys[0])
)
- txn.execute(sql, values.values())
- return txn.lastrowid
+ txn.executemany(sql, vals)
- def _simple_upsert(self, table, keyvalues, values):
+ def _simple_upsert(self, table, keyvalues, values,
+ insertion_values={}, desc="_simple_upsert", lock=True):
"""
Args:
table (str): The table to upsert into
keyvalues (dict): The unique key tables and their new values
values (dict): The nonunique columns and their new values
+ insertion_values (dict): key/values to use when inserting
Returns: A deferred
"""
return self.runInteraction(
- "_simple_upsert",
- self._simple_upsert_txn, table, keyvalues, values
+ desc,
+ self._simple_upsert_txn, table, keyvalues, values, insertion_values,
+ lock
)
- def _simple_upsert_txn(self, txn, table, keyvalues, values):
+ def _simple_upsert_txn(self, txn, table, keyvalues, values, insertion_values={},
+ lock=True):
+ # We need to lock the table :(, unless we're *really* careful
+ if lock:
+ self.database_engine.lock_table(txn, table)
+
# Try to update
sql = "UPDATE %s SET %s WHERE %s" % (
table,
@@ -386,6 +573,7 @@ class SQLBaseStore(object):
allvalues = {}
allvalues.update(keyvalues)
allvalues.update(values)
+ allvalues.update(insertion_values)
sql = "INSERT INTO %s (%s) VALUES (%s)" % (
table,
@@ -399,7 +587,7 @@ class SQLBaseStore(object):
txn.execute(sql, allvalues.values())
def _simple_select_one(self, table, keyvalues, retcols,
- allow_none=False):
+ allow_none=False, desc="_simple_select_one"):
"""Executes a SELECT query on the named table, which is expected to
return a single row, returning a single column from it.
@@ -411,12 +599,15 @@ class SQLBaseStore(object):
allow_none : If true, return None instead of failing if the SELECT
statement returns no rows
"""
- return self._simple_selectupdate_one(
- table, keyvalues, retcols=retcols, allow_none=allow_none
+ return self.runInteraction(
+ desc,
+ self._simple_select_one_txn,
+ table, keyvalues, retcols, allow_none,
)
def _simple_select_one_onecol(self, table, keyvalues, retcol,
- allow_none=False):
+ allow_none=False,
+ desc="_simple_select_one_onecol"):
"""Executes a SELECT query on the named table, which is expected to
return a single row, returning a single column from it."
@@ -426,7 +617,7 @@ class SQLBaseStore(object):
retcol : string giving the name of the column to return
"""
return self.runInteraction(
- "_simple_select_one_onecol",
+ desc,
self._simple_select_one_onecol_txn,
table, keyvalues, retcol, allow_none=allow_none,
)
@@ -450,8 +641,7 @@ class SQLBaseStore(object):
def _simple_select_onecol_txn(self, txn, table, keyvalues, retcol):
sql = (
- "SELECT %(retcol)s FROM %(table)s WHERE %(where)s "
- "ORDER BY rowid asc"
+ "SELECT %(retcol)s FROM %(table)s WHERE %(where)s"
) % {
"retcol": retcol,
"table": table,
@@ -462,7 +652,8 @@ class SQLBaseStore(object):
return [r[0] for r in txn.fetchall()]
- def _simple_select_onecol(self, table, keyvalues, retcol):
+ def _simple_select_onecol(self, table, keyvalues, retcol,
+ desc="_simple_select_onecol"):
"""Executes a SELECT query on the named table, which returns a list
comprising of the values of the named column from the selected rows.
@@ -475,12 +666,13 @@ class SQLBaseStore(object):
Deferred: Results in a list
"""
return self.runInteraction(
- "_simple_select_onecol",
+ desc,
self._simple_select_onecol_txn,
table, keyvalues, retcol
)
- def _simple_select_list(self, table, keyvalues, retcols):
+ def _simple_select_list(self, table, keyvalues, retcols,
+ desc="_simple_select_list"):
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
@@ -491,7 +683,7 @@ class SQLBaseStore(object):
retcols : list of strings giving the names of the columns to return
"""
return self.runInteraction(
- "_simple_select_list",
+ desc,
self._simple_select_list_txn,
table, keyvalues, retcols
)
@@ -507,14 +699,14 @@ class SQLBaseStore(object):
retcols : list of strings giving the names of the columns to return
"""
if keyvalues:
- sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % (
+ sql = "SELECT %s FROM %s WHERE %s" % (
", ".join(retcols),
table,
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
)
txn.execute(sql, keyvalues.values())
else:
- sql = "SELECT %s FROM %s ORDER BY rowid asc" % (
+ sql = "SELECT %s FROM %s" % (
", ".join(retcols),
table
)
@@ -523,7 +715,7 @@ class SQLBaseStore(object):
return self.cursor_to_dict(txn)
def _simple_update_one(self, table, keyvalues, updatevalues,
- retcols=None):
+ desc="_simple_update_one"):
"""Executes an UPDATE query on the named table, setting new values for
columns in a row matching the key values.
@@ -541,56 +733,81 @@ class SQLBaseStore(object):
get-and-set. This can be used to implement compare-and-set by putting
the update column in the 'keyvalues' dict as well.
"""
- return self._simple_selectupdate_one(table, keyvalues, updatevalues,
- retcols=retcols)
+ return self.runInteraction(
+ desc,
+ self._simple_update_one_txn,
+ table, keyvalues, updatevalues,
+ )
- def _simple_selectupdate_one(self, table, keyvalues, updatevalues=None,
- retcols=None, allow_none=False):
- """ Combined SELECT then UPDATE."""
- if retcols:
- select_sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % (
- ", ".join(retcols),
- table,
- " AND ".join("%s = ?" % (k) for k in keyvalues)
- )
+ def _simple_update_one_txn(self, txn, table, keyvalues, updatevalues):
+ update_sql = "UPDATE %s SET %s WHERE %s" % (
+ table,
+ ", ".join("%s = ?" % (k,) for k in updatevalues),
+ " AND ".join("%s = ?" % (k,) for k in keyvalues)
+ )
- if updatevalues:
- update_sql = "UPDATE %s SET %s WHERE %s" % (
- table,
- ", ".join("%s = ?" % (k,) for k in updatevalues),
- " AND ".join("%s = ?" % (k,) for k in keyvalues)
- )
+ txn.execute(
+ update_sql,
+ updatevalues.values() + keyvalues.values()
+ )
+
+ if txn.rowcount == 0:
+ raise StoreError(404, "No row found")
+ if txn.rowcount > 1:
+ raise StoreError(500, "More than one row matched")
+
+ def _simple_select_one_txn(self, txn, table, keyvalues, retcols,
+ allow_none=False):
+ select_sql = "SELECT %s FROM %s WHERE %s" % (
+ ", ".join(retcols),
+ table,
+ " AND ".join("%s = ?" % (k,) for k in keyvalues)
+ )
+ txn.execute(select_sql, keyvalues.values())
+
+ row = txn.fetchone()
+ if not row:
+ if allow_none:
+ return None
+ raise StoreError(404, "No row found")
+ if txn.rowcount > 1:
+ raise StoreError(500, "More than one row matched")
+
+ return dict(zip(retcols, row))
+
+ def _simple_selectupdate_one(self, table, keyvalues, updatevalues=None,
+ retcols=None, allow_none=False,
+ desc="_simple_selectupdate_one"):
+ """ Combined SELECT then UPDATE."""
def func(txn):
ret = None
if retcols:
- txn.execute(select_sql, keyvalues.values())
-
- row = txn.fetchone()
- if not row:
- if allow_none:
- return None
- raise StoreError(404, "No row found")
- if txn.rowcount > 1:
- raise StoreError(500, "More than one row matched")
-
- ret = dict(zip(retcols, row))
+ ret = self._simple_select_one_txn(
+ txn,
+ table=table,
+ keyvalues=keyvalues,
+ retcols=retcols,
+ allow_none=allow_none,
+ )
if updatevalues:
- txn.execute(
- update_sql,
- updatevalues.values() + keyvalues.values()
+ self._simple_update_one_txn(
+ txn,
+ table=table,
+ keyvalues=keyvalues,
+ updatevalues=updatevalues,
)
- if txn.rowcount == 0:
- raise StoreError(404, "No row found")
+ # if txn.rowcount == 0:
+ # raise StoreError(404, "No row found")
if txn.rowcount > 1:
raise StoreError(500, "More than one row matched")
return ret
- return self.runInteraction("_simple_selectupdate_one", func)
+ return self.runInteraction(desc, func)
- def _simple_delete_one(self, table, keyvalues):
+ def _simple_delete_one(self, table, keyvalues, desc="_simple_delete_one"):
"""Executes a DELETE query on the named table, expecting to delete a
single row.
@@ -609,9 +826,9 @@ class SQLBaseStore(object):
raise StoreError(404, "No row found")
if txn.rowcount > 1:
raise StoreError(500, "more than one row matched")
- return self.runInteraction("_simple_delete_one", func)
+ return self.runInteraction(desc, func)
- def _simple_delete(self, table, keyvalues):
+ def _simple_delete(self, table, keyvalues, desc="_simple_delete"):
"""Executes a DELETE query on the named table.
Args:
@@ -619,7 +836,7 @@ class SQLBaseStore(object):
keyvalues : dict of column names and values to select the row with
"""
- return self.runInteraction("_simple_delete", self._simple_delete_txn)
+ return self.runInteraction(desc, self._simple_delete_txn)
def _simple_delete_txn(self, txn, table, keyvalues):
sql = "DELETE FROM %s WHERE %s" % (
@@ -670,6 +887,12 @@ class SQLBaseStore(object):
return [e for e in events if e]
+ def _invalidate_get_event_cache(self, event_id):
+ for check_redacted in (False, True):
+ for get_prev_content in (False, True):
+ self._get_event_cache.invalidate(event_id, check_redacted,
+ get_prev_content)
+
def _get_event_txn(self, txn, event_id, check_redacted=True,
get_prev_content=False, allow_rejected=False):
@@ -680,16 +903,14 @@ class SQLBaseStore(object):
sql_getevents_timer.inc_by(curr_time - last_time, desc)
return curr_time
- cache = self._get_event_cache.setdefault(event_id, {})
-
try:
- # Separate cache entries for each way to invoke _get_event_txn
- ret = cache[(check_redacted, get_prev_content, allow_rejected)]
+ ret = self._get_event_cache.get(event_id, check_redacted, get_prev_content)
- cache_counter.inc_hits("*getEvent*")
- return ret
+ if allow_rejected or not ret.rejected_reason:
+ return ret
+ else:
+ return None
except KeyError:
- cache_counter.inc_misses("*getEvent*")
pass
finally:
start_time = update_counter("event_cache", start_time)
@@ -714,19 +935,22 @@ class SQLBaseStore(object):
start_time = update_counter("select_event", start_time)
+ result = self._get_event_from_row_txn(
+ txn, internal_metadata, js, redacted,
+ check_redacted=check_redacted,
+ get_prev_content=get_prev_content,
+ rejected_reason=rejected_reason,
+ )
+ self._get_event_cache.prefill(event_id, check_redacted, get_prev_content, result)
+
if allow_rejected or not rejected_reason:
- result = self._get_event_from_row_txn(
- txn, internal_metadata, js, redacted,
- check_redacted=check_redacted,
- get_prev_content=get_prev_content,
- )
- cache[(check_redacted, get_prev_content, allow_rejected)] = result
return result
else:
return None
def _get_event_from_row_txn(self, txn, internal_metadata, js, redacted,
- check_redacted=True, get_prev_content=False):
+ check_redacted=True, get_prev_content=False,
+ rejected_reason=None):
start_time = time.time() * 1000
@@ -741,7 +965,11 @@ class SQLBaseStore(object):
internal_metadata = json.loads(internal_metadata)
start_time = update_counter("decode_internal", start_time)
- ev = FrozenEvent(d, internal_metadata_dict=internal_metadata)
+ ev = FrozenEvent(
+ d,
+ internal_metadata_dict=internal_metadata,
+ rejected_reason=rejected_reason,
+ )
start_time = update_counter("build_frozen_event", start_time)
if check_redacted and redacted:
@@ -788,6 +1016,19 @@ class SQLBaseStore(object):
result = txn.fetchone()
return result[0] if result else None
+ def get_next_stream_id(self):
+ with self._next_stream_id_lock:
+ i = self._next_stream_id
+ self._next_stream_id += 1
+ return i
+
+
+class _RollbackButIsFineException(Exception):
+ """ This exception is used to rollback a transaction without implying
+ something went wrong.
+ """
+ pass
+
class Table(object):
""" A base class used to store information about a particular table.
@@ -804,7 +1045,7 @@ class Table(object):
_select_where_clause = "SELECT %s FROM %s WHERE %s"
_select_clause = "SELECT %s FROM %s"
- _insert_clause = "INSERT OR REPLACE INTO %s (%s) VALUES (%s)"
+ _insert_clause = "REPLACE INTO %s (%s) VALUES (%s)"
@classmethod
def select_statement(cls, where_clause=None):
|