diff options
Diffstat (limited to 'synapse/storage/_base.py')
-rw-r--r-- | synapse/storage/_base.py | 485 |
1 files changed, 210 insertions, 275 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 7e3903859b..983ce026e1 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -41,7 +41,7 @@ try: MAX_TXN_ID = sys.maxint - 1 except AttributeError: # python 3 does not have a maximum int value - MAX_TXN_ID = 2**63 - 1 + MAX_TXN_ID = 2 ** 63 - 1 sql_logger = logging.getLogger("synapse.storage.SQL") transaction_logger = logging.getLogger("synapse.storage.txn") @@ -76,12 +76,18 @@ class LoggingTransaction(object): """An object that almost-transparently proxies for the 'txn' object passed to the constructor. Adds logging and metrics to the .execute() method.""" + __slots__ = [ - "txn", "name", "database_engine", "after_callbacks", "exception_callbacks", + "txn", + "name", + "database_engine", + "after_callbacks", + "exception_callbacks", ] - def __init__(self, txn, name, database_engine, after_callbacks, - exception_callbacks): + def __init__( + self, txn, name, database_engine, after_callbacks, exception_callbacks + ): object.__setattr__(self, "txn", txn) object.__setattr__(self, "name", name) object.__setattr__(self, "database_engine", database_engine) @@ -110,6 +116,7 @@ class LoggingTransaction(object): def execute_batch(self, sql, args): if isinstance(self.database_engine, PostgresEngine): from psycopg2.extras import execute_batch + self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args) else: for val in args: @@ -134,10 +141,7 @@ class LoggingTransaction(object): sql = self.database_engine.convert_param_style(sql) if args: try: - sql_logger.debug( - "[SQL values] {%s} %r", - self.name, args[0] - ) + sql_logger.debug("[SQL values] {%s} %r", self.name, args[0]) except Exception: # Don't let logging failures stop SQL from working pass @@ -145,9 +149,7 @@ class LoggingTransaction(object): start = time.time() try: - return func( - sql, *args - ) + return func(sql, *args) except Exception as e: logger.debug("[SQL FAIL] {%s} %s", self.name, e) raise @@ -176,11 +178,9 @@ class PerformanceCounters(object): counters = [] for name, (count, cum_time) in iteritems(self.current_counters): prev_count, prev_time = self.previous_counters.get(name, (0, 0)) - counters.append(( - (cum_time - prev_time) / interval_duration, - count - prev_count, - name - )) + counters.append( + ((cum_time - prev_time) / interval_duration, count - prev_count, name) + ) self.previous_counters = dict(self.current_counters) @@ -212,8 +212,9 @@ class SQLBaseStore(object): self._txn_perf_counters = PerformanceCounters() self._get_event_counters = PerformanceCounters() - self._get_event_cache = Cache("*getEvent*", keylen=3, - max_entries=hs.config.event_cache_size) + self._get_event_cache = Cache( + "*getEvent*", keylen=3, max_entries=hs.config.event_cache_size + ) self._event_fetch_lock = threading.Condition() self._event_fetch_list = [] @@ -239,7 +240,7 @@ class SQLBaseStore(object): 0.0, run_as_background_process, "upsert_safety_check", - self._check_safe_to_upsert + self._check_safe_to_upsert, ) @defer.inlineCallbacks @@ -271,7 +272,7 @@ class SQLBaseStore(object): 15.0, run_as_background_process, "upsert_safety_check", - self._check_safe_to_upsert + self._check_safe_to_upsert, ) def start_profiling(self): @@ -298,13 +299,16 @@ class SQLBaseStore(object): perf_logger.info( "Total database time: %.3f%% {%s} {%s}", - ratio * 100, top_three_counters, top_3_event_counters + ratio * 100, + top_three_counters, + top_3_event_counters, ) self._clock.looping_call(loop, 10000) - def _new_transaction(self, conn, desc, after_callbacks, exception_callbacks, - func, *args, **kwargs): + def _new_transaction( + self, conn, desc, after_callbacks, exception_callbacks, func, *args, **kwargs + ): start = time.time() txn_id = self._TXN_ID @@ -312,7 +316,7 @@ class SQLBaseStore(object): # growing really large. self._TXN_ID = (self._TXN_ID + 1) % (MAX_TXN_ID) - name = "%s-%x" % (desc, txn_id, ) + name = "%s-%x" % (desc, txn_id) transaction_logger.debug("[TXN START] {%s}", name) @@ -323,7 +327,10 @@ class SQLBaseStore(object): try: txn = conn.cursor() txn = LoggingTransaction( - txn, name, self.database_engine, after_callbacks, + txn, + name, + self.database_engine, + after_callbacks, exception_callbacks, ) r = func(txn, *args, **kwargs) @@ -334,7 +341,10 @@ class SQLBaseStore(object): # transaction. logger.warning( "[TXN OPERROR] {%s} %s %d/%d", - name, exception_to_unicode(e), i, N + name, + exception_to_unicode(e), + i, + N, ) if i < N: i += 1 @@ -342,8 +352,7 @@ class SQLBaseStore(object): conn.rollback() except self.database_engine.module.Error as e1: logger.warning( - "[TXN EROLL] {%s} %s", - name, exception_to_unicode(e1), + "[TXN EROLL] {%s} %s", name, exception_to_unicode(e1) ) continue raise @@ -357,7 +366,8 @@ class SQLBaseStore(object): except self.database_engine.module.Error as e1: logger.warning( "[TXN EROLL] {%s} %s", - name, exception_to_unicode(e1), + name, + exception_to_unicode(e1), ) continue raise @@ -396,16 +406,17 @@ class SQLBaseStore(object): exception_callbacks = [] if LoggingContext.current_context() == LoggingContext.sentinel: - logger.warn( - "Starting db txn '%s' from sentinel context", - desc, - ) + logger.warn("Starting db txn '%s' from sentinel context", desc) try: result = yield self.runWithConnection( self._new_transaction, - desc, after_callbacks, exception_callbacks, func, - *args, **kwargs + desc, + after_callbacks, + exception_callbacks, + func, + *args, + **kwargs ) for after_callback, after_args, after_kwargs in after_callbacks: @@ -434,7 +445,7 @@ class SQLBaseStore(object): parent_context = LoggingContext.current_context() if parent_context == LoggingContext.sentinel: logger.warn( - "Starting db connection from sentinel context: metrics will be lost", + "Starting db connection from sentinel context: metrics will be lost" ) parent_context = None @@ -453,9 +464,7 @@ class SQLBaseStore(object): return func(conn, *args, **kwargs) with PreserveLoggingContext(): - result = yield self._db_pool.runWithConnection( - inner_func, *args, **kwargs - ) + result = yield self._db_pool.runWithConnection(inner_func, *args, **kwargs) defer.returnValue(result) @@ -469,9 +478,7 @@ class SQLBaseStore(object): A list of dicts where the key is the column header. """ col_headers = list(intern(str(column[0])) for column in cursor.description) - results = list( - dict(zip(col_headers, row)) for row in cursor - ) + results = list(dict(zip(col_headers, row)) for row in cursor) return results def _execute(self, desc, decoder, query, *args): @@ -485,6 +492,7 @@ class SQLBaseStore(object): Returns: The result of decoder(results) """ + def interaction(txn): txn.execute(query, args) if decoder: @@ -498,8 +506,7 @@ class SQLBaseStore(object): # no complex WHERE clauses, just a dict of values for columns. @defer.inlineCallbacks - def _simple_insert(self, table, values, or_ignore=False, - desc="_simple_insert"): + def _simple_insert(self, table, values, or_ignore=False, desc="_simple_insert"): """Executes an INSERT query on the named table. Args: @@ -511,10 +518,7 @@ class SQLBaseStore(object): `or_ignore` is True """ try: - yield self.runInteraction( - desc, - self._simple_insert_txn, table, values, - ) + yield self.runInteraction(desc, self._simple_insert_txn, table, values) except self.database_engine.module.IntegrityError: # We have to do or_ignore flag at this layer, since we can't reuse # a cursor after we receive an error from the db. @@ -530,15 +534,13 @@ class SQLBaseStore(object): sql = "INSERT INTO %s (%s) VALUES(%s)" % ( table, ", ".join(k for k in keys), - ", ".join("?" for _ in keys) + ", ".join("?" for _ in keys), ) txn.execute(sql, vals) def _simple_insert_many(self, table, values, desc): - return self.runInteraction( - desc, self._simple_insert_many_txn, table, values - ) + return self.runInteraction(desc, self._simple_insert_many_txn, table, values) @staticmethod def _simple_insert_many_txn(txn, table, values): @@ -553,24 +555,18 @@ class SQLBaseStore(object): # # The sort is to ensure that we don't rely on dictionary iteration # order. - keys, vals = zip(*[ - zip( - *(sorted(i.items(), key=lambda kv: kv[0])) - ) - for i in values - if i - ]) + keys, vals = zip( + *[zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i] + ) for k in keys: if k != keys[0]: - raise RuntimeError( - "All items must have the same keys" - ) + raise RuntimeError("All items must have the same keys") sql = "INSERT INTO %s (%s) VALUES(%s)" % ( table, ", ".join(k for k in keys[0]), - ", ".join("?" for _ in keys[0]) + ", ".join("?" for _ in keys[0]), ) txn.executemany(sql, vals) @@ -583,7 +579,7 @@ class SQLBaseStore(object): values, insertion_values={}, desc="_simple_upsert", - lock=True + lock=True, ): """ @@ -599,7 +595,7 @@ class SQLBaseStore(object): Args: table (str): The table to upsert into - keyvalues (dict): The unique key tables and their new values + keyvalues (dict): The unique key columns and their new values values (dict): The nonunique columns and their new values insertion_values (dict): additional key/values to use only when inserting @@ -631,17 +627,11 @@ class SQLBaseStore(object): # presumably we raced with another transaction: let's retry. logger.warn( - "%s when upserting into %s; retrying: %s", e.__name__, table, e + "IntegrityError when upserting into %s; retrying: %s", table, e ) def _simple_upsert_txn( - self, - txn, - table, - keyvalues, - values, - insertion_values={}, - lock=True, + self, txn, table, keyvalues, values, insertion_values={}, lock=True ): """ Pick the UPSERT method which works best on the platform. Either the @@ -665,11 +655,7 @@ class SQLBaseStore(object): and table not in self._unsafe_to_upsert_tables ): return self._simple_upsert_txn_native_upsert( - txn, - table, - keyvalues, - values, - insertion_values=insertion_values, + txn, table, keyvalues, values, insertion_values=insertion_values ) else: return self._simple_upsert_txn_emulated( @@ -714,7 +700,7 @@ class SQLBaseStore(object): # SELECT instead to see if it exists. sql = "SELECT 1 FROM %s WHERE %s" % ( table, - " AND ".join(_getwhere(k) for k in keyvalues) + " AND ".join(_getwhere(k) for k in keyvalues), ) sqlargs = list(keyvalues.values()) txn.execute(sql, sqlargs) @@ -726,7 +712,7 @@ class SQLBaseStore(object): sql = "UPDATE %s SET %s WHERE %s" % ( table, ", ".join("%s = ?" % (k,) for k in values), - " AND ".join(_getwhere(k) for k in keyvalues) + " AND ".join(_getwhere(k) for k in keyvalues), ) sqlargs = list(values.values()) + list(keyvalues.values()) @@ -773,19 +759,14 @@ class SQLBaseStore(object): latter = "NOTHING" else: allvalues.update(values) - latter = ( - "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values) - ) + latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values) - sql = ( - "INSERT INTO %s (%s) VALUES (%s) " - "ON CONFLICT (%s) DO %s" - ) % ( + sql = ("INSERT INTO %s (%s) VALUES (%s) " "ON CONFLICT (%s) DO %s") % ( table, ", ".join(k for k in allvalues), ", ".join("?" for _ in allvalues), ", ".join(k for k in keyvalues), - latter + latter, ) txn.execute(sql, list(allvalues.values())) @@ -870,8 +851,8 @@ class SQLBaseStore(object): latter = "NOTHING" value_values = [() for x in range(len(key_values))] else: - latter = ( - "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in value_names) + latter = "UPDATE SET " + ", ".join( + k + "=EXCLUDED." + k for k in value_names ) sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s" % ( @@ -889,8 +870,9 @@ class SQLBaseStore(object): return txn.execute_batch(sql, args) - def _simple_select_one(self, table, keyvalues, retcols, - allow_none=False, desc="_simple_select_one"): + def _simple_select_one( + self, table, keyvalues, retcols, allow_none=False, desc="_simple_select_one" + ): """Executes a SELECT query on the named table, which is expected to return a single row, returning multiple columns from it. @@ -903,14 +885,17 @@ class SQLBaseStore(object): statement returns no rows """ return self.runInteraction( - desc, - self._simple_select_one_txn, - table, keyvalues, retcols, allow_none, + desc, self._simple_select_one_txn, table, keyvalues, retcols, allow_none ) - def _simple_select_one_onecol(self, table, keyvalues, retcol, - allow_none=False, - desc="_simple_select_one_onecol"): + def _simple_select_one_onecol( + self, + table, + keyvalues, + retcol, + allow_none=False, + desc="_simple_select_one_onecol", + ): """Executes a SELECT query on the named table, which is expected to return a single row, returning a single column from it. @@ -922,17 +907,18 @@ class SQLBaseStore(object): return self.runInteraction( desc, self._simple_select_one_onecol_txn, - table, keyvalues, retcol, allow_none=allow_none, + table, + keyvalues, + retcol, + allow_none=allow_none, ) @classmethod - def _simple_select_one_onecol_txn(cls, txn, table, keyvalues, retcol, - allow_none=False): + def _simple_select_one_onecol_txn( + cls, txn, table, keyvalues, retcol, allow_none=False + ): ret = cls._simple_select_onecol_txn( - txn, - table=table, - keyvalues=keyvalues, - retcol=retcol, + txn, table=table, keyvalues=keyvalues, retcol=retcol ) if ret: @@ -945,12 +931,7 @@ class SQLBaseStore(object): @staticmethod def _simple_select_onecol_txn(txn, table, keyvalues, retcol): - sql = ( - "SELECT %(retcol)s FROM %(table)s" - ) % { - "retcol": retcol, - "table": table, - } + sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table} if keyvalues: sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)) @@ -960,8 +941,9 @@ class SQLBaseStore(object): return [r[0] for r in txn] - def _simple_select_onecol(self, table, keyvalues, retcol, - desc="_simple_select_onecol"): + def _simple_select_onecol( + self, table, keyvalues, retcol, desc="_simple_select_onecol" + ): """Executes a SELECT query on the named table, which returns a list comprising of the values of the named column from the selected rows. @@ -974,13 +956,12 @@ class SQLBaseStore(object): Deferred: Results in a list """ return self.runInteraction( - desc, - self._simple_select_onecol_txn, - table, keyvalues, retcol + desc, self._simple_select_onecol_txn, table, keyvalues, retcol ) - def _simple_select_list(self, table, keyvalues, retcols, - desc="_simple_select_list"): + def _simple_select_list( + self, table, keyvalues, retcols, desc="_simple_select_list" + ): """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. @@ -994,9 +975,7 @@ class SQLBaseStore(object): defer.Deferred: resolves to list[dict[str, Any]] """ return self.runInteraction( - desc, - self._simple_select_list_txn, - table, keyvalues, retcols + desc, self._simple_select_list_txn, table, keyvalues, retcols ) @classmethod @@ -1016,22 +995,26 @@ class SQLBaseStore(object): sql = "SELECT %s FROM %s WHERE %s" % ( ", ".join(retcols), table, - " AND ".join("%s = ?" % (k, ) for k in keyvalues) + " AND ".join("%s = ?" % (k,) for k in keyvalues), ) txn.execute(sql, list(keyvalues.values())) else: - sql = "SELECT %s FROM %s" % ( - ", ".join(retcols), - table - ) + sql = "SELECT %s FROM %s" % (", ".join(retcols), table) txn.execute(sql) return cls.cursor_to_dict(txn) @defer.inlineCallbacks - def _simple_select_many_batch(self, table, column, iterable, retcols, - keyvalues={}, desc="_simple_select_many_batch", - batch_size=100): + def _simple_select_many_batch( + self, + table, + column, + iterable, + retcols, + keyvalues={}, + desc="_simple_select_many_batch", + batch_size=100, + ): """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. @@ -1053,14 +1036,17 @@ class SQLBaseStore(object): it_list = list(iterable) chunks = [ - it_list[i:i + batch_size] - for i in range(0, len(it_list), batch_size) + it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size) ] for chunk in chunks: rows = yield self.runInteraction( desc, self._simple_select_many_txn, - table, column, chunk, keyvalues, retcols + table, + column, + chunk, + keyvalues, + retcols, ) results.extend(rows) @@ -1089,9 +1075,7 @@ class SQLBaseStore(object): clauses = [] values = [] - clauses.append( - "%s IN (%s)" % (column, ",".join("?" for _ in iterable)) - ) + clauses.append("%s IN (%s)" % (column, ",".join("?" for _ in iterable))) values.extend(iterable) for key, value in iteritems(keyvalues): @@ -1099,19 +1083,14 @@ class SQLBaseStore(object): values.append(value) if clauses: - sql = "%s WHERE %s" % ( - sql, - " AND ".join(clauses), - ) + sql = "%s WHERE %s" % (sql, " AND ".join(clauses)) txn.execute(sql, values) return cls.cursor_to_dict(txn) def _simple_update(self, table, keyvalues, updatevalues, desc): return self.runInteraction( - desc, - self._simple_update_txn, - table, keyvalues, updatevalues, + desc, self._simple_update_txn, table, keyvalues, updatevalues ) @staticmethod @@ -1127,15 +1106,13 @@ class SQLBaseStore(object): where, ) - txn.execute( - update_sql, - list(updatevalues.values()) + list(keyvalues.values()) - ) + txn.execute(update_sql, list(updatevalues.values()) + list(keyvalues.values())) return txn.rowcount - def _simple_update_one(self, table, keyvalues, updatevalues, - desc="_simple_update_one"): + def _simple_update_one( + self, table, keyvalues, updatevalues, desc="_simple_update_one" + ): """Executes an UPDATE query on the named table, setting new values for columns in a row matching the key values. @@ -1154,9 +1131,7 @@ class SQLBaseStore(object): the update column in the 'keyvalues' dict as well. """ return self.runInteraction( - desc, - self._simple_update_one_txn, - table, keyvalues, updatevalues, + desc, self._simple_update_one_txn, table, keyvalues, updatevalues ) @classmethod @@ -1169,12 +1144,11 @@ class SQLBaseStore(object): raise StoreError(500, "More than one row matched (%s)" % (table,)) @staticmethod - def _simple_select_one_txn(txn, table, keyvalues, retcols, - allow_none=False): + def _simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False): select_sql = "SELECT %s FROM %s WHERE %s" % ( ", ".join(retcols), table, - " AND ".join("%s = ?" % (k,) for k in keyvalues) + " AND ".join("%s = ?" % (k,) for k in keyvalues), ) txn.execute(select_sql, list(keyvalues.values())) @@ -1197,9 +1171,7 @@ class SQLBaseStore(object): table : string giving the table name keyvalues : dict of column names and values to select the row with """ - return self.runInteraction( - desc, self._simple_delete_one_txn, table, keyvalues - ) + return self.runInteraction(desc, self._simple_delete_one_txn, table, keyvalues) @staticmethod def _simple_delete_one_txn(txn, table, keyvalues): @@ -1212,7 +1184,7 @@ class SQLBaseStore(object): """ sql = "DELETE FROM %s WHERE %s" % ( table, - " AND ".join("%s = ?" % (k, ) for k in keyvalues) + " AND ".join("%s = ?" % (k,) for k in keyvalues), ) txn.execute(sql, list(keyvalues.values())) @@ -1222,15 +1194,13 @@ class SQLBaseStore(object): raise StoreError(500, "More than one row matched (%s)" % (table,)) def _simple_delete(self, table, keyvalues, desc): - return self.runInteraction( - desc, self._simple_delete_txn, table, keyvalues - ) + return self.runInteraction(desc, self._simple_delete_txn, table, keyvalues) @staticmethod def _simple_delete_txn(txn, table, keyvalues): sql = "DELETE FROM %s WHERE %s" % ( table, - " AND ".join("%s = ?" % (k, ) for k in keyvalues) + " AND ".join("%s = ?" % (k,) for k in keyvalues), ) return txn.execute(sql, list(keyvalues.values())) @@ -1260,9 +1230,7 @@ class SQLBaseStore(object): clauses = [] values = [] - clauses.append( - "%s IN (%s)" % (column, ",".join("?" for _ in iterable)) - ) + clauses.append("%s IN (%s)" % (column, ",".join("?" for _ in iterable))) values.extend(iterable) for key, value in iteritems(keyvalues): @@ -1270,14 +1238,12 @@ class SQLBaseStore(object): values.append(value) if clauses: - sql = "%s WHERE %s" % ( - sql, - " AND ".join(clauses), - ) + sql = "%s WHERE %s" % (sql, " AND ".join(clauses)) return txn.execute(sql, values) - def _get_cache_dict(self, db_conn, table, entity_column, stream_column, - max_value, limit=100000): + def _get_cache_dict( + self, db_conn, table, entity_column, stream_column, max_value, limit=100000 + ): # Fetch a mapping of room_id -> max stream position for "recent" rooms. # It doesn't really matter how many we get, the StreamChangeCache will # do the right thing to ensure it respects the max size of cache. @@ -1297,10 +1263,7 @@ class SQLBaseStore(object): txn = db_conn.cursor() txn.execute(sql, (int(max_value),)) - cache = { - row[0]: int(row[1]) - for row in txn - } + cache = {row[0]: int(row[1]) for row in txn} txn.close() @@ -1342,9 +1305,7 @@ class SQLBaseStore(object): # be safe. for chunk in batch_iter(members_changed, 50): keys = itertools.chain([room_id], chunk) - self._send_invalidation_to_replication( - txn, _CURRENT_STATE_CACHE_NAME, keys, - ) + self._send_invalidation_to_replication(txn, _CURRENT_STATE_CACHE_NAME, keys) def _invalidate_state_caches(self, room_id, members_changed): """Invalidates caches that are based on the current state, but does @@ -1355,28 +1316,13 @@ class SQLBaseStore(object): members_changed (iterable[str]): The user_ids of members that have changed """ - for member in members_changed: - self._attempt_to_invalidate_cache( - "get_rooms_for_user_with_stream_ordering", (member,), - ) - for host in set(get_domain_from_id(u) for u in members_changed): - self._attempt_to_invalidate_cache( - "is_host_joined", (room_id, host,), - ) - self._attempt_to_invalidate_cache( - "was_host_joined", (room_id, host,), - ) + self._attempt_to_invalidate_cache("is_host_joined", (room_id, host)) + self._attempt_to_invalidate_cache("was_host_joined", (room_id, host)) - self._attempt_to_invalidate_cache( - "get_users_in_room", (room_id,), - ) - self._attempt_to_invalidate_cache( - "get_room_summary", (room_id,), - ) - self._attempt_to_invalidate_cache( - "get_current_state_ids", (room_id,), - ) + self._attempt_to_invalidate_cache("get_users_in_room", (room_id,)) + self._attempt_to_invalidate_cache("get_room_summary", (room_id,)) + self._attempt_to_invalidate_cache("get_current_state_ids", (room_id,)) def _attempt_to_invalidate_cache(self, cache_name, key): """Attempts to invalidate the cache of the given name, ignoring if the @@ -1424,7 +1370,7 @@ class SQLBaseStore(object): "cache_func": cache_name, "keys": list(keys), "invalidation_ts": self.clock.time_msec(), - } + }, ) def get_all_updated_caches(self, last_id, current_id, limit): @@ -1440,11 +1386,10 @@ class SQLBaseStore(object): " FROM cache_invalidation_stream" " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?" ) - txn.execute(sql, (last_id, limit,)) + txn.execute(sql, (last_id, limit)) return txn.fetchall() - return self.runInteraction( - "get_all_updated_caches", get_all_updated_caches_txn - ) + + return self.runInteraction("get_all_updated_caches", get_all_updated_caches_txn) def get_cache_stream_token(self): if self._cache_id_gen: @@ -1452,33 +1397,61 @@ class SQLBaseStore(object): else: return 0 - def _simple_select_list_paginate(self, table, keyvalues, pagevalues, retcols, - desc="_simple_select_list_paginate"): - """Executes a SELECT query on the named table with start and limit, + def _simple_select_list_paginate( + self, + table, + keyvalues, + orderby, + start, + limit, + retcols, + order_direction="ASC", + desc="_simple_select_list_paginate", + ): + """ + Executes a SELECT query on the named table with start and limit, of row numbers, which may return zero or number of rows from start to limit, returning the result as a list of dicts. Args: table (str): the table name - keyvalues (dict[str, Any] | None): + keyvalues (dict[str, T] | None): column names and values to select the rows with, or None to not apply a WHERE clause. + orderby (str): Column to order the results by. + start (int): Index to begin the query at. + limit (int): Number of results to return. retcols (iterable[str]): the names of the columns to return - order (str): order the select by this column - start (int): start number to begin the query from - limit (int): number of rows to reterive + order_direction (str): Whether the results should be ordered "ASC" or "DESC". Returns: defer.Deferred: resolves to list[dict[str, Any]] """ return self.runInteraction( desc, self._simple_select_list_paginate_txn, - table, keyvalues, pagevalues, retcols + table, + keyvalues, + orderby, + start, + limit, + retcols, + order_direction=order_direction, ) @classmethod - def _simple_select_list_paginate_txn(cls, txn, table, keyvalues, pagevalues, retcols): - """Executes a SELECT query on the named table with start and limit, + def _simple_select_list_paginate_txn( + cls, + txn, + table, + keyvalues, + orderby, + start, + limit, + retcols, + order_direction="ASC", + ): + """ + Executes a SELECT query on the named table with start and limit, of row numbers, which may return zero or number of rows from start to limit, returning the result as a list of dicts. @@ -1488,66 +1461,32 @@ class SQLBaseStore(object): keyvalues (dict[str, T] | None): column names and values to select the rows with, or None to not apply a WHERE clause. - pagevalues ([]): - order (str): order the select by this column - start (int): start number to begin the query from - limit (int): number of rows to reterive + orderby (str): Column to order the results by. + start (int): Index to begin the query at. + limit (int): Number of results to return. retcols (iterable[str]): the names of the columns to return + order_direction (str): Whether the results should be ordered "ASC" or "DESC". Returns: defer.Deferred: resolves to list[dict[str, Any]] - """ + if order_direction not in ["ASC", "DESC"]: + raise ValueError("order_direction must be one of 'ASC' or 'DESC'.") + if keyvalues: - sql = "SELECT %s FROM %s WHERE %s ORDER BY %s" % ( - ", ".join(retcols), - table, - " AND ".join("%s = ?" % (k,) for k in keyvalues), - " ? ASC LIMIT ? OFFSET ?" - ) - txn.execute(sql, list(keyvalues.values()) + list(pagevalues)) + where_clause = "WHERE " + " AND ".join("%s = ?" % (k,) for k in keyvalues) else: - sql = "SELECT %s FROM %s ORDER BY %s" % ( - ", ".join(retcols), - table, - " ? ASC LIMIT ? OFFSET ?" - ) - txn.execute(sql, pagevalues) - - return cls.cursor_to_dict(txn) + where_clause = "" - @defer.inlineCallbacks - def get_user_list_paginate(self, table, keyvalues, pagevalues, retcols, - desc="get_user_list_paginate"): - """Get a list of users from start row to a limit number of rows. This will - return a json object with users and total number of users in users list. - - Args: - table (str): the table name - keyvalues (dict[str, Any] | None): - column names and values to select the rows with, or None to not - apply a WHERE clause. - pagevalues ([]): - order (str): order the select by this column - start (int): start number to begin the query from - limit (int): number of rows to reterive - retcols (iterable[str]): the names of the columns to return - Returns: - defer.Deferred: resolves to json object {list[dict[str, Any]], count} - """ - users = yield self.runInteraction( - desc, - self._simple_select_list_paginate_txn, - table, keyvalues, pagevalues, retcols - ) - count = yield self.runInteraction( - desc, - self.get_user_count_txn + sql = "SELECT %s FROM %s %s ORDER BY %s %s LIMIT ? OFFSET ?" % ( + ", ".join(retcols), + table, + where_clause, + orderby, + order_direction, ) - retval = { - "users": users, - "total": count - } - defer.returnValue(retval) + txn.execute(sql, list(keyvalues.values()) + [limit, start]) + + return cls.cursor_to_dict(txn) def get_user_count_txn(self, txn): """Get a total number of registered users in the users list. @@ -1561,8 +1500,9 @@ class SQLBaseStore(object): txn.execute(sql_count) return txn.fetchone()[0] - def _simple_search_list(self, table, term, col, retcols, - desc="_simple_search_list"): + def _simple_search_list( + self, table, term, col, retcols, desc="_simple_search_list" + ): """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. @@ -1577,9 +1517,7 @@ class SQLBaseStore(object): """ return self.runInteraction( - desc, - self._simple_search_list_txn, - table, term, col, retcols + desc, self._simple_search_list_txn, table, term, col, retcols ) @classmethod @@ -1598,11 +1536,7 @@ class SQLBaseStore(object): defer.Deferred: resolves to list[dict[str, Any]] or None """ if term: - sql = "SELECT %s FROM %s WHERE %s LIKE ?" % ( - ", ".join(retcols), - table, - col - ) + sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col) termvalues = ["%%" + term + "%%"] txn.execute(sql, termvalues) else: @@ -1623,6 +1557,7 @@ class _RollbackButIsFineException(Exception): """ This exception is used to rollback a transaction without implying something went wrong. """ + pass |