diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 7e3903859b..983ce026e1 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -41,7 +41,7 @@ try:
MAX_TXN_ID = sys.maxint - 1
except AttributeError:
# python 3 does not have a maximum int value
- MAX_TXN_ID = 2**63 - 1
+ MAX_TXN_ID = 2 ** 63 - 1
sql_logger = logging.getLogger("synapse.storage.SQL")
transaction_logger = logging.getLogger("synapse.storage.txn")
@@ -76,12 +76,18 @@ class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging and metrics to the .execute()
method."""
+
__slots__ = [
- "txn", "name", "database_engine", "after_callbacks", "exception_callbacks",
+ "txn",
+ "name",
+ "database_engine",
+ "after_callbacks",
+ "exception_callbacks",
]
- def __init__(self, txn, name, database_engine, after_callbacks,
- exception_callbacks):
+ def __init__(
+ self, txn, name, database_engine, after_callbacks, exception_callbacks
+ ):
object.__setattr__(self, "txn", txn)
object.__setattr__(self, "name", name)
object.__setattr__(self, "database_engine", database_engine)
@@ -110,6 +116,7 @@ class LoggingTransaction(object):
def execute_batch(self, sql, args):
if isinstance(self.database_engine, PostgresEngine):
from psycopg2.extras import execute_batch
+
self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
else:
for val in args:
@@ -134,10 +141,7 @@ class LoggingTransaction(object):
sql = self.database_engine.convert_param_style(sql)
if args:
try:
- sql_logger.debug(
- "[SQL values] {%s} %r",
- self.name, args[0]
- )
+ sql_logger.debug("[SQL values] {%s} %r", self.name, args[0])
except Exception:
# Don't let logging failures stop SQL from working
pass
@@ -145,9 +149,7 @@ class LoggingTransaction(object):
start = time.time()
try:
- return func(
- sql, *args
- )
+ return func(sql, *args)
except Exception as e:
logger.debug("[SQL FAIL] {%s} %s", self.name, e)
raise
@@ -176,11 +178,9 @@ class PerformanceCounters(object):
counters = []
for name, (count, cum_time) in iteritems(self.current_counters):
prev_count, prev_time = self.previous_counters.get(name, (0, 0))
- counters.append((
- (cum_time - prev_time) / interval_duration,
- count - prev_count,
- name
- ))
+ counters.append(
+ ((cum_time - prev_time) / interval_duration, count - prev_count, name)
+ )
self.previous_counters = dict(self.current_counters)
@@ -212,8 +212,9 @@ class SQLBaseStore(object):
self._txn_perf_counters = PerformanceCounters()
self._get_event_counters = PerformanceCounters()
- self._get_event_cache = Cache("*getEvent*", keylen=3,
- max_entries=hs.config.event_cache_size)
+ self._get_event_cache = Cache(
+ "*getEvent*", keylen=3, max_entries=hs.config.event_cache_size
+ )
self._event_fetch_lock = threading.Condition()
self._event_fetch_list = []
@@ -239,7 +240,7 @@ class SQLBaseStore(object):
0.0,
run_as_background_process,
"upsert_safety_check",
- self._check_safe_to_upsert
+ self._check_safe_to_upsert,
)
@defer.inlineCallbacks
@@ -271,7 +272,7 @@ class SQLBaseStore(object):
15.0,
run_as_background_process,
"upsert_safety_check",
- self._check_safe_to_upsert
+ self._check_safe_to_upsert,
)
def start_profiling(self):
@@ -298,13 +299,16 @@ class SQLBaseStore(object):
perf_logger.info(
"Total database time: %.3f%% {%s} {%s}",
- ratio * 100, top_three_counters, top_3_event_counters
+ ratio * 100,
+ top_three_counters,
+ top_3_event_counters,
)
self._clock.looping_call(loop, 10000)
- def _new_transaction(self, conn, desc, after_callbacks, exception_callbacks,
- func, *args, **kwargs):
+ def _new_transaction(
+ self, conn, desc, after_callbacks, exception_callbacks, func, *args, **kwargs
+ ):
start = time.time()
txn_id = self._TXN_ID
@@ -312,7 +316,7 @@ class SQLBaseStore(object):
# growing really large.
self._TXN_ID = (self._TXN_ID + 1) % (MAX_TXN_ID)
- name = "%s-%x" % (desc, txn_id, )
+ name = "%s-%x" % (desc, txn_id)
transaction_logger.debug("[TXN START] {%s}", name)
@@ -323,7 +327,10 @@ class SQLBaseStore(object):
try:
txn = conn.cursor()
txn = LoggingTransaction(
- txn, name, self.database_engine, after_callbacks,
+ txn,
+ name,
+ self.database_engine,
+ after_callbacks,
exception_callbacks,
)
r = func(txn, *args, **kwargs)
@@ -334,7 +341,10 @@ class SQLBaseStore(object):
# transaction.
logger.warning(
"[TXN OPERROR] {%s} %s %d/%d",
- name, exception_to_unicode(e), i, N
+ name,
+ exception_to_unicode(e),
+ i,
+ N,
)
if i < N:
i += 1
@@ -342,8 +352,7 @@ class SQLBaseStore(object):
conn.rollback()
except self.database_engine.module.Error as e1:
logger.warning(
- "[TXN EROLL] {%s} %s",
- name, exception_to_unicode(e1),
+ "[TXN EROLL] {%s} %s", name, exception_to_unicode(e1)
)
continue
raise
@@ -357,7 +366,8 @@ class SQLBaseStore(object):
except self.database_engine.module.Error as e1:
logger.warning(
"[TXN EROLL] {%s} %s",
- name, exception_to_unicode(e1),
+ name,
+ exception_to_unicode(e1),
)
continue
raise
@@ -396,16 +406,17 @@ class SQLBaseStore(object):
exception_callbacks = []
if LoggingContext.current_context() == LoggingContext.sentinel:
- logger.warn(
- "Starting db txn '%s' from sentinel context",
- desc,
- )
+ logger.warn("Starting db txn '%s' from sentinel context", desc)
try:
result = yield self.runWithConnection(
self._new_transaction,
- desc, after_callbacks, exception_callbacks, func,
- *args, **kwargs
+ desc,
+ after_callbacks,
+ exception_callbacks,
+ func,
+ *args,
+ **kwargs
)
for after_callback, after_args, after_kwargs in after_callbacks:
@@ -434,7 +445,7 @@ class SQLBaseStore(object):
parent_context = LoggingContext.current_context()
if parent_context == LoggingContext.sentinel:
logger.warn(
- "Starting db connection from sentinel context: metrics will be lost",
+ "Starting db connection from sentinel context: metrics will be lost"
)
parent_context = None
@@ -453,9 +464,7 @@ class SQLBaseStore(object):
return func(conn, *args, **kwargs)
with PreserveLoggingContext():
- result = yield self._db_pool.runWithConnection(
- inner_func, *args, **kwargs
- )
+ result = yield self._db_pool.runWithConnection(inner_func, *args, **kwargs)
defer.returnValue(result)
@@ -469,9 +478,7 @@ class SQLBaseStore(object):
A list of dicts where the key is the column header.
"""
col_headers = list(intern(str(column[0])) for column in cursor.description)
- results = list(
- dict(zip(col_headers, row)) for row in cursor
- )
+ results = list(dict(zip(col_headers, row)) for row in cursor)
return results
def _execute(self, desc, decoder, query, *args):
@@ -485,6 +492,7 @@ class SQLBaseStore(object):
Returns:
The result of decoder(results)
"""
+
def interaction(txn):
txn.execute(query, args)
if decoder:
@@ -498,8 +506,7 @@ class SQLBaseStore(object):
# no complex WHERE clauses, just a dict of values for columns.
@defer.inlineCallbacks
- def _simple_insert(self, table, values, or_ignore=False,
- desc="_simple_insert"):
+ def _simple_insert(self, table, values, or_ignore=False, desc="_simple_insert"):
"""Executes an INSERT query on the named table.
Args:
@@ -511,10 +518,7 @@ class SQLBaseStore(object):
`or_ignore` is True
"""
try:
- yield self.runInteraction(
- desc,
- self._simple_insert_txn, table, values,
- )
+ yield self.runInteraction(desc, self._simple_insert_txn, table, values)
except self.database_engine.module.IntegrityError:
# We have to do or_ignore flag at this layer, since we can't reuse
# a cursor after we receive an error from the db.
@@ -530,15 +534,13 @@ class SQLBaseStore(object):
sql = "INSERT INTO %s (%s) VALUES(%s)" % (
table,
", ".join(k for k in keys),
- ", ".join("?" for _ in keys)
+ ", ".join("?" for _ in keys),
)
txn.execute(sql, vals)
def _simple_insert_many(self, table, values, desc):
- return self.runInteraction(
- desc, self._simple_insert_many_txn, table, values
- )
+ return self.runInteraction(desc, self._simple_insert_many_txn, table, values)
@staticmethod
def _simple_insert_many_txn(txn, table, values):
@@ -553,24 +555,18 @@ class SQLBaseStore(object):
#
# The sort is to ensure that we don't rely on dictionary iteration
# order.
- keys, vals = zip(*[
- zip(
- *(sorted(i.items(), key=lambda kv: kv[0]))
- )
- for i in values
- if i
- ])
+ keys, vals = zip(
+ *[zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i]
+ )
for k in keys:
if k != keys[0]:
- raise RuntimeError(
- "All items must have the same keys"
- )
+ raise RuntimeError("All items must have the same keys")
sql = "INSERT INTO %s (%s) VALUES(%s)" % (
table,
", ".join(k for k in keys[0]),
- ", ".join("?" for _ in keys[0])
+ ", ".join("?" for _ in keys[0]),
)
txn.executemany(sql, vals)
@@ -583,7 +579,7 @@ class SQLBaseStore(object):
values,
insertion_values={},
desc="_simple_upsert",
- lock=True
+ lock=True,
):
"""
@@ -599,7 +595,7 @@ class SQLBaseStore(object):
Args:
table (str): The table to upsert into
- keyvalues (dict): The unique key tables and their new values
+ keyvalues (dict): The unique key columns and their new values
values (dict): The nonunique columns and their new values
insertion_values (dict): additional key/values to use only when
inserting
@@ -631,17 +627,11 @@ class SQLBaseStore(object):
# presumably we raced with another transaction: let's retry.
logger.warn(
- "%s when upserting into %s; retrying: %s", e.__name__, table, e
+ "IntegrityError when upserting into %s; retrying: %s", table, e
)
def _simple_upsert_txn(
- self,
- txn,
- table,
- keyvalues,
- values,
- insertion_values={},
- lock=True,
+ self, txn, table, keyvalues, values, insertion_values={}, lock=True
):
"""
Pick the UPSERT method which works best on the platform. Either the
@@ -665,11 +655,7 @@ class SQLBaseStore(object):
and table not in self._unsafe_to_upsert_tables
):
return self._simple_upsert_txn_native_upsert(
- txn,
- table,
- keyvalues,
- values,
- insertion_values=insertion_values,
+ txn, table, keyvalues, values, insertion_values=insertion_values
)
else:
return self._simple_upsert_txn_emulated(
@@ -714,7 +700,7 @@ class SQLBaseStore(object):
# SELECT instead to see if it exists.
sql = "SELECT 1 FROM %s WHERE %s" % (
table,
- " AND ".join(_getwhere(k) for k in keyvalues)
+ " AND ".join(_getwhere(k) for k in keyvalues),
)
sqlargs = list(keyvalues.values())
txn.execute(sql, sqlargs)
@@ -726,7 +712,7 @@ class SQLBaseStore(object):
sql = "UPDATE %s SET %s WHERE %s" % (
table,
", ".join("%s = ?" % (k,) for k in values),
- " AND ".join(_getwhere(k) for k in keyvalues)
+ " AND ".join(_getwhere(k) for k in keyvalues),
)
sqlargs = list(values.values()) + list(keyvalues.values())
@@ -773,19 +759,14 @@ class SQLBaseStore(object):
latter = "NOTHING"
else:
allvalues.update(values)
- latter = (
- "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values)
- )
+ latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values)
- sql = (
- "INSERT INTO %s (%s) VALUES (%s) "
- "ON CONFLICT (%s) DO %s"
- ) % (
+ sql = ("INSERT INTO %s (%s) VALUES (%s) " "ON CONFLICT (%s) DO %s") % (
table,
", ".join(k for k in allvalues),
", ".join("?" for _ in allvalues),
", ".join(k for k in keyvalues),
- latter
+ latter,
)
txn.execute(sql, list(allvalues.values()))
@@ -870,8 +851,8 @@ class SQLBaseStore(object):
latter = "NOTHING"
value_values = [() for x in range(len(key_values))]
else:
- latter = (
- "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in value_names)
+ latter = "UPDATE SET " + ", ".join(
+ k + "=EXCLUDED." + k for k in value_names
)
sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s" % (
@@ -889,8 +870,9 @@ class SQLBaseStore(object):
return txn.execute_batch(sql, args)
- def _simple_select_one(self, table, keyvalues, retcols,
- allow_none=False, desc="_simple_select_one"):
+ def _simple_select_one(
+ self, table, keyvalues, retcols, allow_none=False, desc="_simple_select_one"
+ ):
"""Executes a SELECT query on the named table, which is expected to
return a single row, returning multiple columns from it.
@@ -903,14 +885,17 @@ class SQLBaseStore(object):
statement returns no rows
"""
return self.runInteraction(
- desc,
- self._simple_select_one_txn,
- table, keyvalues, retcols, allow_none,
+ desc, self._simple_select_one_txn, table, keyvalues, retcols, allow_none
)
- def _simple_select_one_onecol(self, table, keyvalues, retcol,
- allow_none=False,
- desc="_simple_select_one_onecol"):
+ def _simple_select_one_onecol(
+ self,
+ table,
+ keyvalues,
+ retcol,
+ allow_none=False,
+ desc="_simple_select_one_onecol",
+ ):
"""Executes a SELECT query on the named table, which is expected to
return a single row, returning a single column from it.
@@ -922,17 +907,18 @@ class SQLBaseStore(object):
return self.runInteraction(
desc,
self._simple_select_one_onecol_txn,
- table, keyvalues, retcol, allow_none=allow_none,
+ table,
+ keyvalues,
+ retcol,
+ allow_none=allow_none,
)
@classmethod
- def _simple_select_one_onecol_txn(cls, txn, table, keyvalues, retcol,
- allow_none=False):
+ def _simple_select_one_onecol_txn(
+ cls, txn, table, keyvalues, retcol, allow_none=False
+ ):
ret = cls._simple_select_onecol_txn(
- txn,
- table=table,
- keyvalues=keyvalues,
- retcol=retcol,
+ txn, table=table, keyvalues=keyvalues, retcol=retcol
)
if ret:
@@ -945,12 +931,7 @@ class SQLBaseStore(object):
@staticmethod
def _simple_select_onecol_txn(txn, table, keyvalues, retcol):
- sql = (
- "SELECT %(retcol)s FROM %(table)s"
- ) % {
- "retcol": retcol,
- "table": table,
- }
+ sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table}
if keyvalues:
sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
@@ -960,8 +941,9 @@ class SQLBaseStore(object):
return [r[0] for r in txn]
- def _simple_select_onecol(self, table, keyvalues, retcol,
- desc="_simple_select_onecol"):
+ def _simple_select_onecol(
+ self, table, keyvalues, retcol, desc="_simple_select_onecol"
+ ):
"""Executes a SELECT query on the named table, which returns a list
comprising of the values of the named column from the selected rows.
@@ -974,13 +956,12 @@ class SQLBaseStore(object):
Deferred: Results in a list
"""
return self.runInteraction(
- desc,
- self._simple_select_onecol_txn,
- table, keyvalues, retcol
+ desc, self._simple_select_onecol_txn, table, keyvalues, retcol
)
- def _simple_select_list(self, table, keyvalues, retcols,
- desc="_simple_select_list"):
+ def _simple_select_list(
+ self, table, keyvalues, retcols, desc="_simple_select_list"
+ ):
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
@@ -994,9 +975,7 @@ class SQLBaseStore(object):
defer.Deferred: resolves to list[dict[str, Any]]
"""
return self.runInteraction(
- desc,
- self._simple_select_list_txn,
- table, keyvalues, retcols
+ desc, self._simple_select_list_txn, table, keyvalues, retcols
)
@classmethod
@@ -1016,22 +995,26 @@ class SQLBaseStore(object):
sql = "SELECT %s FROM %s WHERE %s" % (
", ".join(retcols),
table,
- " AND ".join("%s = ?" % (k, ) for k in keyvalues)
+ " AND ".join("%s = ?" % (k,) for k in keyvalues),
)
txn.execute(sql, list(keyvalues.values()))
else:
- sql = "SELECT %s FROM %s" % (
- ", ".join(retcols),
- table
- )
+ sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
txn.execute(sql)
return cls.cursor_to_dict(txn)
@defer.inlineCallbacks
- def _simple_select_many_batch(self, table, column, iterable, retcols,
- keyvalues={}, desc="_simple_select_many_batch",
- batch_size=100):
+ def _simple_select_many_batch(
+ self,
+ table,
+ column,
+ iterable,
+ retcols,
+ keyvalues={},
+ desc="_simple_select_many_batch",
+ batch_size=100,
+ ):
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
@@ -1053,14 +1036,17 @@ class SQLBaseStore(object):
it_list = list(iterable)
chunks = [
- it_list[i:i + batch_size]
- for i in range(0, len(it_list), batch_size)
+ it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size)
]
for chunk in chunks:
rows = yield self.runInteraction(
desc,
self._simple_select_many_txn,
- table, column, chunk, keyvalues, retcols
+ table,
+ column,
+ chunk,
+ keyvalues,
+ retcols,
)
results.extend(rows)
@@ -1089,9 +1075,7 @@ class SQLBaseStore(object):
clauses = []
values = []
- clauses.append(
- "%s IN (%s)" % (column, ",".join("?" for _ in iterable))
- )
+ clauses.append("%s IN (%s)" % (column, ",".join("?" for _ in iterable)))
values.extend(iterable)
for key, value in iteritems(keyvalues):
@@ -1099,19 +1083,14 @@ class SQLBaseStore(object):
values.append(value)
if clauses:
- sql = "%s WHERE %s" % (
- sql,
- " AND ".join(clauses),
- )
+ sql = "%s WHERE %s" % (sql, " AND ".join(clauses))
txn.execute(sql, values)
return cls.cursor_to_dict(txn)
def _simple_update(self, table, keyvalues, updatevalues, desc):
return self.runInteraction(
- desc,
- self._simple_update_txn,
- table, keyvalues, updatevalues,
+ desc, self._simple_update_txn, table, keyvalues, updatevalues
)
@staticmethod
@@ -1127,15 +1106,13 @@ class SQLBaseStore(object):
where,
)
- txn.execute(
- update_sql,
- list(updatevalues.values()) + list(keyvalues.values())
- )
+ txn.execute(update_sql, list(updatevalues.values()) + list(keyvalues.values()))
return txn.rowcount
- def _simple_update_one(self, table, keyvalues, updatevalues,
- desc="_simple_update_one"):
+ def _simple_update_one(
+ self, table, keyvalues, updatevalues, desc="_simple_update_one"
+ ):
"""Executes an UPDATE query on the named table, setting new values for
columns in a row matching the key values.
@@ -1154,9 +1131,7 @@ class SQLBaseStore(object):
the update column in the 'keyvalues' dict as well.
"""
return self.runInteraction(
- desc,
- self._simple_update_one_txn,
- table, keyvalues, updatevalues,
+ desc, self._simple_update_one_txn, table, keyvalues, updatevalues
)
@classmethod
@@ -1169,12 +1144,11 @@ class SQLBaseStore(object):
raise StoreError(500, "More than one row matched (%s)" % (table,))
@staticmethod
- def _simple_select_one_txn(txn, table, keyvalues, retcols,
- allow_none=False):
+ def _simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False):
select_sql = "SELECT %s FROM %s WHERE %s" % (
", ".join(retcols),
table,
- " AND ".join("%s = ?" % (k,) for k in keyvalues)
+ " AND ".join("%s = ?" % (k,) for k in keyvalues),
)
txn.execute(select_sql, list(keyvalues.values()))
@@ -1197,9 +1171,7 @@ class SQLBaseStore(object):
table : string giving the table name
keyvalues : dict of column names and values to select the row with
"""
- return self.runInteraction(
- desc, self._simple_delete_one_txn, table, keyvalues
- )
+ return self.runInteraction(desc, self._simple_delete_one_txn, table, keyvalues)
@staticmethod
def _simple_delete_one_txn(txn, table, keyvalues):
@@ -1212,7 +1184,7 @@ class SQLBaseStore(object):
"""
sql = "DELETE FROM %s WHERE %s" % (
table,
- " AND ".join("%s = ?" % (k, ) for k in keyvalues)
+ " AND ".join("%s = ?" % (k,) for k in keyvalues),
)
txn.execute(sql, list(keyvalues.values()))
@@ -1222,15 +1194,13 @@ class SQLBaseStore(object):
raise StoreError(500, "More than one row matched (%s)" % (table,))
def _simple_delete(self, table, keyvalues, desc):
- return self.runInteraction(
- desc, self._simple_delete_txn, table, keyvalues
- )
+ return self.runInteraction(desc, self._simple_delete_txn, table, keyvalues)
@staticmethod
def _simple_delete_txn(txn, table, keyvalues):
sql = "DELETE FROM %s WHERE %s" % (
table,
- " AND ".join("%s = ?" % (k, ) for k in keyvalues)
+ " AND ".join("%s = ?" % (k,) for k in keyvalues),
)
return txn.execute(sql, list(keyvalues.values()))
@@ -1260,9 +1230,7 @@ class SQLBaseStore(object):
clauses = []
values = []
- clauses.append(
- "%s IN (%s)" % (column, ",".join("?" for _ in iterable))
- )
+ clauses.append("%s IN (%s)" % (column, ",".join("?" for _ in iterable)))
values.extend(iterable)
for key, value in iteritems(keyvalues):
@@ -1270,14 +1238,12 @@ class SQLBaseStore(object):
values.append(value)
if clauses:
- sql = "%s WHERE %s" % (
- sql,
- " AND ".join(clauses),
- )
+ sql = "%s WHERE %s" % (sql, " AND ".join(clauses))
return txn.execute(sql, values)
- def _get_cache_dict(self, db_conn, table, entity_column, stream_column,
- max_value, limit=100000):
+ def _get_cache_dict(
+ self, db_conn, table, entity_column, stream_column, max_value, limit=100000
+ ):
# Fetch a mapping of room_id -> max stream position for "recent" rooms.
# It doesn't really matter how many we get, the StreamChangeCache will
# do the right thing to ensure it respects the max size of cache.
@@ -1297,10 +1263,7 @@ class SQLBaseStore(object):
txn = db_conn.cursor()
txn.execute(sql, (int(max_value),))
- cache = {
- row[0]: int(row[1])
- for row in txn
- }
+ cache = {row[0]: int(row[1]) for row in txn}
txn.close()
@@ -1342,9 +1305,7 @@ class SQLBaseStore(object):
# be safe.
for chunk in batch_iter(members_changed, 50):
keys = itertools.chain([room_id], chunk)
- self._send_invalidation_to_replication(
- txn, _CURRENT_STATE_CACHE_NAME, keys,
- )
+ self._send_invalidation_to_replication(txn, _CURRENT_STATE_CACHE_NAME, keys)
def _invalidate_state_caches(self, room_id, members_changed):
"""Invalidates caches that are based on the current state, but does
@@ -1355,28 +1316,13 @@ class SQLBaseStore(object):
members_changed (iterable[str]): The user_ids of members that have
changed
"""
- for member in members_changed:
- self._attempt_to_invalidate_cache(
- "get_rooms_for_user_with_stream_ordering", (member,),
- )
-
for host in set(get_domain_from_id(u) for u in members_changed):
- self._attempt_to_invalidate_cache(
- "is_host_joined", (room_id, host,),
- )
- self._attempt_to_invalidate_cache(
- "was_host_joined", (room_id, host,),
- )
+ self._attempt_to_invalidate_cache("is_host_joined", (room_id, host))
+ self._attempt_to_invalidate_cache("was_host_joined", (room_id, host))
- self._attempt_to_invalidate_cache(
- "get_users_in_room", (room_id,),
- )
- self._attempt_to_invalidate_cache(
- "get_room_summary", (room_id,),
- )
- self._attempt_to_invalidate_cache(
- "get_current_state_ids", (room_id,),
- )
+ self._attempt_to_invalidate_cache("get_users_in_room", (room_id,))
+ self._attempt_to_invalidate_cache("get_room_summary", (room_id,))
+ self._attempt_to_invalidate_cache("get_current_state_ids", (room_id,))
def _attempt_to_invalidate_cache(self, cache_name, key):
"""Attempts to invalidate the cache of the given name, ignoring if the
@@ -1424,7 +1370,7 @@ class SQLBaseStore(object):
"cache_func": cache_name,
"keys": list(keys),
"invalidation_ts": self.clock.time_msec(),
- }
+ },
)
def get_all_updated_caches(self, last_id, current_id, limit):
@@ -1440,11 +1386,10 @@ class SQLBaseStore(object):
" FROM cache_invalidation_stream"
" WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
)
- txn.execute(sql, (last_id, limit,))
+ txn.execute(sql, (last_id, limit))
return txn.fetchall()
- return self.runInteraction(
- "get_all_updated_caches", get_all_updated_caches_txn
- )
+
+ return self.runInteraction("get_all_updated_caches", get_all_updated_caches_txn)
def get_cache_stream_token(self):
if self._cache_id_gen:
@@ -1452,33 +1397,61 @@ class SQLBaseStore(object):
else:
return 0
- def _simple_select_list_paginate(self, table, keyvalues, pagevalues, retcols,
- desc="_simple_select_list_paginate"):
- """Executes a SELECT query on the named table with start and limit,
+ def _simple_select_list_paginate(
+ self,
+ table,
+ keyvalues,
+ orderby,
+ start,
+ limit,
+ retcols,
+ order_direction="ASC",
+ desc="_simple_select_list_paginate",
+ ):
+ """
+ Executes a SELECT query on the named table with start and limit,
of row numbers, which may return zero or number of rows from start to limit,
returning the result as a list of dicts.
Args:
table (str): the table name
- keyvalues (dict[str, Any] | None):
+ keyvalues (dict[str, T] | None):
column names and values to select the rows with, or None to not
apply a WHERE clause.
+ orderby (str): Column to order the results by.
+ start (int): Index to begin the query at.
+ limit (int): Number of results to return.
retcols (iterable[str]): the names of the columns to return
- order (str): order the select by this column
- start (int): start number to begin the query from
- limit (int): number of rows to reterive
+ order_direction (str): Whether the results should be ordered "ASC" or "DESC".
Returns:
defer.Deferred: resolves to list[dict[str, Any]]
"""
return self.runInteraction(
desc,
self._simple_select_list_paginate_txn,
- table, keyvalues, pagevalues, retcols
+ table,
+ keyvalues,
+ orderby,
+ start,
+ limit,
+ retcols,
+ order_direction=order_direction,
)
@classmethod
- def _simple_select_list_paginate_txn(cls, txn, table, keyvalues, pagevalues, retcols):
- """Executes a SELECT query on the named table with start and limit,
+ def _simple_select_list_paginate_txn(
+ cls,
+ txn,
+ table,
+ keyvalues,
+ orderby,
+ start,
+ limit,
+ retcols,
+ order_direction="ASC",
+ ):
+ """
+ Executes a SELECT query on the named table with start and limit,
of row numbers, which may return zero or number of rows from start to limit,
returning the result as a list of dicts.
@@ -1488,66 +1461,32 @@ class SQLBaseStore(object):
keyvalues (dict[str, T] | None):
column names and values to select the rows with, or None to not
apply a WHERE clause.
- pagevalues ([]):
- order (str): order the select by this column
- start (int): start number to begin the query from
- limit (int): number of rows to reterive
+ orderby (str): Column to order the results by.
+ start (int): Index to begin the query at.
+ limit (int): Number of results to return.
retcols (iterable[str]): the names of the columns to return
+ order_direction (str): Whether the results should be ordered "ASC" or "DESC".
Returns:
defer.Deferred: resolves to list[dict[str, Any]]
-
"""
+ if order_direction not in ["ASC", "DESC"]:
+ raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
+
if keyvalues:
- sql = "SELECT %s FROM %s WHERE %s ORDER BY %s" % (
- ", ".join(retcols),
- table,
- " AND ".join("%s = ?" % (k,) for k in keyvalues),
- " ? ASC LIMIT ? OFFSET ?"
- )
- txn.execute(sql, list(keyvalues.values()) + list(pagevalues))
+ where_clause = "WHERE " + " AND ".join("%s = ?" % (k,) for k in keyvalues)
else:
- sql = "SELECT %s FROM %s ORDER BY %s" % (
- ", ".join(retcols),
- table,
- " ? ASC LIMIT ? OFFSET ?"
- )
- txn.execute(sql, pagevalues)
-
- return cls.cursor_to_dict(txn)
+ where_clause = ""
- @defer.inlineCallbacks
- def get_user_list_paginate(self, table, keyvalues, pagevalues, retcols,
- desc="get_user_list_paginate"):
- """Get a list of users from start row to a limit number of rows. This will
- return a json object with users and total number of users in users list.
-
- Args:
- table (str): the table name
- keyvalues (dict[str, Any] | None):
- column names and values to select the rows with, or None to not
- apply a WHERE clause.
- pagevalues ([]):
- order (str): order the select by this column
- start (int): start number to begin the query from
- limit (int): number of rows to reterive
- retcols (iterable[str]): the names of the columns to return
- Returns:
- defer.Deferred: resolves to json object {list[dict[str, Any]], count}
- """
- users = yield self.runInteraction(
- desc,
- self._simple_select_list_paginate_txn,
- table, keyvalues, pagevalues, retcols
- )
- count = yield self.runInteraction(
- desc,
- self.get_user_count_txn
+ sql = "SELECT %s FROM %s %s ORDER BY %s %s LIMIT ? OFFSET ?" % (
+ ", ".join(retcols),
+ table,
+ where_clause,
+ orderby,
+ order_direction,
)
- retval = {
- "users": users,
- "total": count
- }
- defer.returnValue(retval)
+ txn.execute(sql, list(keyvalues.values()) + [limit, start])
+
+ return cls.cursor_to_dict(txn)
def get_user_count_txn(self, txn):
"""Get a total number of registered users in the users list.
@@ -1561,8 +1500,9 @@ class SQLBaseStore(object):
txn.execute(sql_count)
return txn.fetchone()[0]
- def _simple_search_list(self, table, term, col, retcols,
- desc="_simple_search_list"):
+ def _simple_search_list(
+ self, table, term, col, retcols, desc="_simple_search_list"
+ ):
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
@@ -1577,9 +1517,7 @@ class SQLBaseStore(object):
"""
return self.runInteraction(
- desc,
- self._simple_search_list_txn,
- table, term, col, retcols
+ desc, self._simple_search_list_txn, table, term, col, retcols
)
@classmethod
@@ -1598,11 +1536,7 @@ class SQLBaseStore(object):
defer.Deferred: resolves to list[dict[str, Any]] or None
"""
if term:
- sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (
- ", ".join(retcols),
- table,
- col
- )
+ sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col)
termvalues = ["%%" + term + "%%"]
txn.execute(sql, termvalues)
else:
@@ -1623,6 +1557,7 @@ class _RollbackButIsFineException(Exception):
""" This exception is used to rollback a transaction without implying
something went wrong.
"""
+
pass
|