diff options
Diffstat (limited to 'synapse/storage/_base.py')
-rw-r--r-- | synapse/storage/_base.py | 438 |
1 files changed, 140 insertions, 298 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 1a2b7ebe25..0d7c7dff27 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -14,11 +14,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import itertools import logging import random import sys -import threading import time from typing import Iterable, Tuple @@ -35,8 +33,6 @@ from synapse.logging.context import LoggingContext, make_deferred_yieldable from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.types import get_domain_from_id -from synapse.util import batch_iter -from synapse.util.caches.descriptors import Cache from synapse.util.stringutils import exception_to_unicode # import a function which will return a monotonic time, in seconds @@ -79,10 +75,6 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = { "event_search": "event_search_event_id_idx", } -# This is a special cache name we use to batch multiple invalidations of caches -# based on the current state when notifying workers over replication. -_CURRENT_STATE_CACHE_NAME = "cs_cache_fake" - class LoggingTransaction(object): """An object that almost-transparently proxies for the 'txn' object @@ -237,23 +229,11 @@ class SQLBaseStore(object): # to watch it self._txn_perf_counters = PerformanceCounters() - self._get_event_cache = Cache( - "*getEvent*", keylen=3, max_entries=hs.config.event_cache_size - ) - - self._event_fetch_lock = threading.Condition() - self._event_fetch_list = [] - self._event_fetch_ongoing = 0 - - self._pending_ds = [] - self.database_engine = hs.database_engine # A set of tables that are not safe to use native upserts in. self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys()) - self._account_validity = self.hs.config.account_validity - # We add the user_directory_search table to the blacklist on SQLite # because the existing search table does not have an index, making it # unsafe to use native upserts. @@ -272,14 +252,6 @@ class SQLBaseStore(object): self.rand = random.SystemRandom() - if self._account_validity.enabled: - self._clock.call_later( - 0.0, - run_as_background_process, - "account_validity_set_expiration_dates", - self._set_expiration_date_when_missing, - ) - @defer.inlineCallbacks def _check_safe_to_upsert(self): """ @@ -290,7 +262,7 @@ class SQLBaseStore(object): If the background updates have not completed, wait 15 sec and check again. """ - updates = yield self._simple_select_list( + updates = yield self.simple_select_list( "background_updates", keyvalues=None, retcols=["update_name"], @@ -312,65 +284,6 @@ class SQLBaseStore(object): self._check_safe_to_upsert, ) - @defer.inlineCallbacks - def _set_expiration_date_when_missing(self): - """ - Retrieves the list of registered users that don't have an expiration date, and - adds an expiration date for each of them. - """ - - def select_users_with_no_expiration_date_txn(txn): - """Retrieves the list of registered users with no expiration date from the - database, filtering out deactivated users. - """ - sql = ( - "SELECT users.name FROM users" - " LEFT JOIN account_validity ON (users.name = account_validity.user_id)" - " WHERE account_validity.user_id is NULL AND users.deactivated = 0;" - ) - txn.execute(sql, []) - - res = self.cursor_to_dict(txn) - if res: - for user in res: - self.set_expiration_date_for_user_txn( - txn, user["name"], use_delta=True - ) - - yield self.runInteraction( - "get_users_with_no_expiration_date", - select_users_with_no_expiration_date_txn, - ) - - def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False): - """Sets an expiration date to the account with the given user ID. - - Args: - user_id (str): User ID to set an expiration date for. - use_delta (bool): If set to False, the expiration date for the user will be - now + validity period. If set to True, this expiration date will be a - random value in the [now + period - d ; now + period] range, d being a - delta equal to 10% of the validity period. - """ - now_ms = self._clock.time_msec() - expiration_ts = now_ms + self._account_validity.period - - if use_delta: - expiration_ts = self.rand.randrange( - expiration_ts - self._account_validity.startup_job_max_delta, - expiration_ts, - ) - - self._simple_insert_txn( - txn, - "account_validity", - values={ - "user_id": user_id, - "expiration_ts_ms": expiration_ts, - "email_sent": False, - }, - ) - def start_profiling(self): self._previous_loop_ts = monotonic_time() @@ -394,7 +307,7 @@ class SQLBaseStore(object): self._clock.looping_call(loop, 10000) - def _new_transaction( + def new_transaction( self, conn, desc, after_callbacks, exception_callbacks, func, *args, **kwargs ): start = monotonic_time() @@ -412,16 +325,15 @@ class SQLBaseStore(object): i = 0 N = 5 while True: + cursor = LoggingTransaction( + conn.cursor(), + name, + self.database_engine, + after_callbacks, + exception_callbacks, + ) try: - txn = conn.cursor() - txn = LoggingTransaction( - txn, - name, - self.database_engine, - after_callbacks, - exception_callbacks, - ) - r = func(txn, *args, **kwargs) + r = func(cursor, *args, **kwargs) conn.commit() return r except self.database_engine.module.OperationalError as e: @@ -459,6 +371,40 @@ class SQLBaseStore(object): ) continue raise + finally: + # we're either about to retry with a new cursor, or we're about to + # release the connection. Once we release the connection, it could + # get used for another query, which might do a conn.rollback(). + # + # In the latter case, even though that probably wouldn't affect the + # results of this transaction, python's sqlite will reset all + # statements on the connection [1], which will make our cursor + # invalid [2]. + # + # In any case, continuing to read rows after commit()ing seems + # dubious from the PoV of ACID transactional semantics + # (sqlite explicitly says that once you commit, you may see rows + # from subsequent updates.) + # + # In psycopg2, cursors are essentially a client-side fabrication - + # all the data is transferred to the client side when the statement + # finishes executing - so in theory we could go on streaming results + # from the cursor, but attempting to do so would make us + # incompatible with sqlite, so let's make sure we're not doing that + # by closing the cursor. + # + # (*named* cursors in psycopg2 are different and are proper server- + # side things, but (a) we don't use them and (b) they are implicitly + # closed by ending the transaction anyway.) + # + # In short, if we haven't finished with the cursor yet, that's a + # problem waiting to bite us. + # + # TL;DR: we're done with the cursor, so we can close it. + # + # [1]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/connection.c#L465 + # [2]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/cursor.c#L236 + cursor.close() except Exception as e: logger.debug("[TXN FAIL] {%s} %s", name, e) raise @@ -498,7 +444,7 @@ class SQLBaseStore(object): try: result = yield self.runWithConnection( - self._new_transaction, + self.new_transaction, desc, after_callbacks, exception_callbacks, @@ -570,7 +516,7 @@ class SQLBaseStore(object): results = list(dict(zip(col_headers, row)) for row in cursor) return results - def _execute(self, desc, decoder, query, *args): + def execute(self, desc, decoder, query, *args): """Runs a single query for a result set. Args: @@ -595,7 +541,7 @@ class SQLBaseStore(object): # no complex WHERE clauses, just a dict of values for columns. @defer.inlineCallbacks - def _simple_insert(self, table, values, or_ignore=False, desc="_simple_insert"): + def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"): """Executes an INSERT query on the named table. Args: @@ -611,7 +557,7 @@ class SQLBaseStore(object): `or_ignore` is True """ try: - yield self.runInteraction(desc, self._simple_insert_txn, table, values) + yield self.runInteraction(desc, self.simple_insert_txn, table, values) except self.database_engine.module.IntegrityError: # We have to do or_ignore flag at this layer, since we can't reuse # a cursor after we receive an error from the db. @@ -621,7 +567,7 @@ class SQLBaseStore(object): return True @staticmethod - def _simple_insert_txn(txn, table, values): + def simple_insert_txn(txn, table, values): keys, vals = zip(*values.items()) sql = "INSERT INTO %s (%s) VALUES(%s)" % ( @@ -632,11 +578,11 @@ class SQLBaseStore(object): txn.execute(sql, vals) - def _simple_insert_many(self, table, values, desc): - return self.runInteraction(desc, self._simple_insert_many_txn, table, values) + def simple_insert_many(self, table, values, desc): + return self.runInteraction(desc, self.simple_insert_many_txn, table, values) @staticmethod - def _simple_insert_many_txn(txn, table, values): + def simple_insert_many_txn(txn, table, values): if not values: return @@ -665,13 +611,13 @@ class SQLBaseStore(object): txn.executemany(sql, vals) @defer.inlineCallbacks - def _simple_upsert( + def simple_upsert( self, table, keyvalues, values, insertion_values={}, - desc="_simple_upsert", + desc="simple_upsert", lock=True, ): """ @@ -703,7 +649,7 @@ class SQLBaseStore(object): try: result = yield self.runInteraction( desc, - self._simple_upsert_txn, + self.simple_upsert_txn, table, keyvalues, values, @@ -723,7 +669,7 @@ class SQLBaseStore(object): "IntegrityError when upserting into %s; retrying: %s", table, e ) - def _simple_upsert_txn( + def simple_upsert_txn( self, txn, table, keyvalues, values, insertion_values={}, lock=True ): """ @@ -747,11 +693,11 @@ class SQLBaseStore(object): self.database_engine.can_native_upsert and table not in self._unsafe_to_upsert_tables ): - return self._simple_upsert_txn_native_upsert( + return self.simple_upsert_txn_native_upsert( txn, table, keyvalues, values, insertion_values=insertion_values ) else: - return self._simple_upsert_txn_emulated( + return self.simple_upsert_txn_emulated( txn, table, keyvalues, @@ -760,7 +706,7 @@ class SQLBaseStore(object): lock=lock, ) - def _simple_upsert_txn_emulated( + def simple_upsert_txn_emulated( self, txn, table, keyvalues, values, insertion_values={}, lock=True ): """ @@ -829,7 +775,7 @@ class SQLBaseStore(object): # successfully inserted return True - def _simple_upsert_txn_native_upsert( + def simple_upsert_txn_native_upsert( self, txn, table, keyvalues, values, insertion_values={} ): """ @@ -854,7 +800,7 @@ class SQLBaseStore(object): allvalues.update(values) latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values) - sql = ("INSERT INTO %s (%s) VALUES (%s) " "ON CONFLICT (%s) DO %s") % ( + sql = ("INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s") % ( table, ", ".join(k for k in allvalues), ", ".join("?" for _ in allvalues), @@ -863,7 +809,7 @@ class SQLBaseStore(object): ) txn.execute(sql, list(allvalues.values())) - def _simple_upsert_many_txn( + def simple_upsert_many_txn( self, txn, table, key_names, key_values, value_names, value_values ): """ @@ -883,15 +829,15 @@ class SQLBaseStore(object): self.database_engine.can_native_upsert and table not in self._unsafe_to_upsert_tables ): - return self._simple_upsert_many_txn_native_upsert( + return self.simple_upsert_many_txn_native_upsert( txn, table, key_names, key_values, value_names, value_values ) else: - return self._simple_upsert_many_txn_emulated( + return self.simple_upsert_many_txn_emulated( txn, table, key_names, key_values, value_names, value_values ) - def _simple_upsert_many_txn_emulated( + def simple_upsert_many_txn_emulated( self, txn, table, key_names, key_values, value_names, value_values ): """ @@ -916,9 +862,9 @@ class SQLBaseStore(object): _keys = {x: y for x, y in zip(key_names, keyv)} _vals = {x: y for x, y in zip(value_names, valv)} - self._simple_upsert_txn_emulated(txn, table, _keys, _vals) + self.simple_upsert_txn_emulated(txn, table, _keys, _vals) - def _simple_upsert_many_txn_native_upsert( + def simple_upsert_many_txn_native_upsert( self, txn, table, key_names, key_values, value_names, value_values ): """ @@ -963,8 +909,8 @@ class SQLBaseStore(object): return txn.execute_batch(sql, args) - def _simple_select_one( - self, table, keyvalues, retcols, allow_none=False, desc="_simple_select_one" + def simple_select_one( + self, table, keyvalues, retcols, allow_none=False, desc="simple_select_one" ): """Executes a SELECT query on the named table, which is expected to return a single row, returning multiple columns from it. @@ -978,16 +924,16 @@ class SQLBaseStore(object): statement returns no rows """ return self.runInteraction( - desc, self._simple_select_one_txn, table, keyvalues, retcols, allow_none + desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none ) - def _simple_select_one_onecol( + def simple_select_one_onecol( self, table, keyvalues, retcol, allow_none=False, - desc="_simple_select_one_onecol", + desc="simple_select_one_onecol", ): """Executes a SELECT query on the named table, which is expected to return a single row, returning a single column from it. @@ -999,7 +945,7 @@ class SQLBaseStore(object): """ return self.runInteraction( desc, - self._simple_select_one_onecol_txn, + self.simple_select_one_onecol_txn, table, keyvalues, retcol, @@ -1007,10 +953,10 @@ class SQLBaseStore(object): ) @classmethod - def _simple_select_one_onecol_txn( + def simple_select_one_onecol_txn( cls, txn, table, keyvalues, retcol, allow_none=False ): - ret = cls._simple_select_onecol_txn( + ret = cls.simple_select_onecol_txn( txn, table=table, keyvalues=keyvalues, retcol=retcol ) @@ -1023,7 +969,7 @@ class SQLBaseStore(object): raise StoreError(404, "No row found") @staticmethod - def _simple_select_onecol_txn(txn, table, keyvalues, retcol): + def simple_select_onecol_txn(txn, table, keyvalues, retcol): sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table} if keyvalues: @@ -1034,8 +980,8 @@ class SQLBaseStore(object): return [r[0] for r in txn] - def _simple_select_onecol( - self, table, keyvalues, retcol, desc="_simple_select_onecol" + def simple_select_onecol( + self, table, keyvalues, retcol, desc="simple_select_onecol" ): """Executes a SELECT query on the named table, which returns a list comprising of the values of the named column from the selected rows. @@ -1049,12 +995,10 @@ class SQLBaseStore(object): Deferred: Results in a list """ return self.runInteraction( - desc, self._simple_select_onecol_txn, table, keyvalues, retcol + desc, self.simple_select_onecol_txn, table, keyvalues, retcol ) - def _simple_select_list( - self, table, keyvalues, retcols, desc="_simple_select_list" - ): + def simple_select_list(self, table, keyvalues, retcols, desc="simple_select_list"): """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. @@ -1068,11 +1012,11 @@ class SQLBaseStore(object): defer.Deferred: resolves to list[dict[str, Any]] """ return self.runInteraction( - desc, self._simple_select_list_txn, table, keyvalues, retcols + desc, self.simple_select_list_txn, table, keyvalues, retcols ) @classmethod - def _simple_select_list_txn(cls, txn, table, keyvalues, retcols): + def simple_select_list_txn(cls, txn, table, keyvalues, retcols): """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. @@ -1098,14 +1042,14 @@ class SQLBaseStore(object): return cls.cursor_to_dict(txn) @defer.inlineCallbacks - def _simple_select_many_batch( + def simple_select_many_batch( self, table, column, iterable, retcols, keyvalues={}, - desc="_simple_select_many_batch", + desc="simple_select_many_batch", batch_size=100, ): """Executes a SELECT query on the named table, which may return zero or @@ -1134,7 +1078,7 @@ class SQLBaseStore(object): for chunk in chunks: rows = yield self.runInteraction( desc, - self._simple_select_many_txn, + self.simple_select_many_txn, table, column, chunk, @@ -1147,7 +1091,7 @@ class SQLBaseStore(object): return results @classmethod - def _simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols): + def simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols): """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. @@ -1180,13 +1124,13 @@ class SQLBaseStore(object): txn.execute(sql, values) return cls.cursor_to_dict(txn) - def _simple_update(self, table, keyvalues, updatevalues, desc): + def simple_update(self, table, keyvalues, updatevalues, desc): return self.runInteraction( - desc, self._simple_update_txn, table, keyvalues, updatevalues + desc, self.simple_update_txn, table, keyvalues, updatevalues ) @staticmethod - def _simple_update_txn(txn, table, keyvalues, updatevalues): + def simple_update_txn(txn, table, keyvalues, updatevalues): if keyvalues: where = "WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)) else: @@ -1202,8 +1146,8 @@ class SQLBaseStore(object): return txn.rowcount - def _simple_update_one( - self, table, keyvalues, updatevalues, desc="_simple_update_one" + def simple_update_one( + self, table, keyvalues, updatevalues, desc="simple_update_one" ): """Executes an UPDATE query on the named table, setting new values for columns in a row matching the key values. @@ -1223,12 +1167,12 @@ class SQLBaseStore(object): the update column in the 'keyvalues' dict as well. """ return self.runInteraction( - desc, self._simple_update_one_txn, table, keyvalues, updatevalues + desc, self.simple_update_one_txn, table, keyvalues, updatevalues ) @classmethod - def _simple_update_one_txn(cls, txn, table, keyvalues, updatevalues): - rowcount = cls._simple_update_txn(txn, table, keyvalues, updatevalues) + def simple_update_one_txn(cls, txn, table, keyvalues, updatevalues): + rowcount = cls.simple_update_txn(txn, table, keyvalues, updatevalues) if rowcount == 0: raise StoreError(404, "No row found (%s)" % (table,)) @@ -1236,7 +1180,7 @@ class SQLBaseStore(object): raise StoreError(500, "More than one row matched (%s)" % (table,)) @staticmethod - def _simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False): + def simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False): select_sql = "SELECT %s FROM %s WHERE %s" % ( ", ".join(retcols), table, @@ -1255,7 +1199,7 @@ class SQLBaseStore(object): return dict(zip(retcols, row)) - def _simple_delete_one(self, table, keyvalues, desc="_simple_delete_one"): + def simple_delete_one(self, table, keyvalues, desc="simple_delete_one"): """Executes a DELETE query on the named table, expecting to delete a single row. @@ -1263,10 +1207,10 @@ class SQLBaseStore(object): table : string giving the table name keyvalues : dict of column names and values to select the row with """ - return self.runInteraction(desc, self._simple_delete_one_txn, table, keyvalues) + return self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues) @staticmethod - def _simple_delete_one_txn(txn, table, keyvalues): + def simple_delete_one_txn(txn, table, keyvalues): """Executes a DELETE query on the named table, expecting to delete a single row. @@ -1285,11 +1229,11 @@ class SQLBaseStore(object): if txn.rowcount > 1: raise StoreError(500, "More than one row matched (%s)" % (table,)) - def _simple_delete(self, table, keyvalues, desc): - return self.runInteraction(desc, self._simple_delete_txn, table, keyvalues) + def simple_delete(self, table, keyvalues, desc): + return self.runInteraction(desc, self.simple_delete_txn, table, keyvalues) @staticmethod - def _simple_delete_txn(txn, table, keyvalues): + def simple_delete_txn(txn, table, keyvalues): sql = "DELETE FROM %s WHERE %s" % ( table, " AND ".join("%s = ?" % (k,) for k in keyvalues), @@ -1298,13 +1242,13 @@ class SQLBaseStore(object): txn.execute(sql, list(keyvalues.values())) return txn.rowcount - def _simple_delete_many(self, table, column, iterable, keyvalues, desc): + def simple_delete_many(self, table, column, iterable, keyvalues, desc): return self.runInteraction( - desc, self._simple_delete_many_txn, table, column, iterable, keyvalues + desc, self.simple_delete_many_txn, table, column, iterable, keyvalues ) @staticmethod - def _simple_delete_many_txn(txn, table, column, iterable, keyvalues): + def simple_delete_many_txn(txn, table, column, iterable, keyvalues): """Executes a DELETE query on the named table. Filters rows by if value of `column` is in `iterable`. @@ -1337,7 +1281,7 @@ class SQLBaseStore(object): return txn.rowcount - def _get_cache_dict( + def get_cache_dict( self, db_conn, table, entity_column, stream_column, max_value, limit=100000 ): # Fetch a mapping of room_id -> max stream position for "recent" rooms. @@ -1370,47 +1314,6 @@ class SQLBaseStore(object): return cache, min_val - def _invalidate_cache_and_stream(self, txn, cache_func, keys): - """Invalidates the cache and adds it to the cache stream so slaves - will know to invalidate their caches. - - This should only be used to invalidate caches where slaves won't - otherwise know from other replication streams that the cache should - be invalidated. - """ - txn.call_after(cache_func.invalidate, keys) - self._send_invalidation_to_replication(txn, cache_func.__name__, keys) - - def _invalidate_state_caches_and_stream(self, txn, room_id, members_changed): - """Special case invalidation of caches based on current state. - - We special case this so that we can batch the cache invalidations into a - single replication poke. - - Args: - txn - room_id (str): Room where state changed - members_changed (iterable[str]): The user_ids of members that have changed - """ - txn.call_after(self._invalidate_state_caches, room_id, members_changed) - - if members_changed: - # We need to be careful that the size of the `members_changed` list - # isn't so large that it causes problems sending over replication, so we - # send them in chunks. - # Max line length is 16K, and max user ID length is 255, so 50 should - # be safe. - for chunk in batch_iter(members_changed, 50): - keys = itertools.chain([room_id], chunk) - self._send_invalidation_to_replication( - txn, _CURRENT_STATE_CACHE_NAME, keys - ) - else: - # if no members changed, we still need to invalidate the other caches. - self._send_invalidation_to_replication( - txn, _CURRENT_STATE_CACHE_NAME, [room_id] - ) - def _invalidate_state_caches(self, room_id, members_changed): """Invalidates caches that are based on the current state, but does not stream invalidations down replication. @@ -1444,73 +1347,17 @@ class SQLBaseStore(object): # which is fine. pass - def _send_invalidation_to_replication(self, txn, cache_name, keys): - """Notifies replication that given cache has been invalidated. - - Note that this does *not* invalidate the cache locally. - - Args: - txn - cache_name (str) - keys (iterable[str]) - """ - - if isinstance(self.database_engine, PostgresEngine): - # get_next() returns a context manager which is designed to wrap - # the transaction. However, we want to only get an ID when we want - # to use it, here, so we need to call __enter__ manually, and have - # __exit__ called after the transaction finishes. - ctx = self._cache_id_gen.get_next() - stream_id = ctx.__enter__() - txn.call_on_exception(ctx.__exit__, None, None, None) - txn.call_after(ctx.__exit__, None, None, None) - txn.call_after(self.hs.get_notifier().on_new_replication_data) - - self._simple_insert_txn( - txn, - table="cache_invalidation_stream", - values={ - "stream_id": stream_id, - "cache_func": cache_name, - "keys": list(keys), - "invalidation_ts": self.clock.time_msec(), - }, - ) - - def get_all_updated_caches(self, last_id, current_id, limit): - if last_id == current_id: - return defer.succeed([]) - - def get_all_updated_caches_txn(txn): - # We purposefully don't bound by the current token, as we want to - # send across cache invalidations as quickly as possible. Cache - # invalidations are idempotent, so duplicates are fine. - sql = ( - "SELECT stream_id, cache_func, keys, invalidation_ts" - " FROM cache_invalidation_stream" - " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?" - ) - txn.execute(sql, (last_id, limit)) - return txn.fetchall() - - return self.runInteraction("get_all_updated_caches", get_all_updated_caches_txn) - - def get_cache_stream_token(self): - if self._cache_id_gen: - return self._cache_id_gen.get_current_token() - else: - return 0 - - def _simple_select_list_paginate( + def simple_select_list_paginate( self, table, - keyvalues, orderby, start, limit, retcols, + filters=None, + keyvalues=None, order_direction="ASC", - desc="_simple_select_list_paginate", + desc="simple_select_list_paginate", ): """ Executes a SELECT query on the named table with start and limit, @@ -1519,6 +1366,9 @@ class SQLBaseStore(object): Args: table (str): the table name + filters (dict[str, T] | None): + column names and values to filter the rows with, or None to not + apply a WHERE ? LIKE ? clause. keyvalues (dict[str, T] | None): column names and values to select the rows with, or None to not apply a WHERE clause. @@ -1532,26 +1382,28 @@ class SQLBaseStore(object): """ return self.runInteraction( desc, - self._simple_select_list_paginate_txn, + self.simple_select_list_paginate_txn, table, - keyvalues, orderby, start, limit, retcols, + filters=filters, + keyvalues=keyvalues, order_direction=order_direction, ) @classmethod - def _simple_select_list_paginate_txn( + def simple_select_list_paginate_txn( cls, txn, table, - keyvalues, orderby, start, limit, retcols, + filters=None, + keyvalues=None, order_direction="ASC", ): """ @@ -1559,16 +1411,23 @@ class SQLBaseStore(object): of row numbers, which may return zero or number of rows from start to limit, returning the result as a list of dicts. + Use `filters` to search attributes using SQL wildcards and/or `keyvalues` to + select attributes with exact matches. All constraints are joined together + using 'AND'. + Args: txn : Transaction object table (str): the table name - keyvalues (dict[str, T] | None): - column names and values to select the rows with, or None to not - apply a WHERE clause. orderby (str): Column to order the results by. start (int): Index to begin the query at. limit (int): Number of results to return. retcols (iterable[str]): the names of the columns to return + filters (dict[str, T] | None): + column names and values to filter the rows with, or None to not + apply a WHERE ? LIKE ? clause. + keyvalues (dict[str, T] | None): + column names and values to select the rows with, or None to not + apply a WHERE clause. order_direction (str): Whether the results should be ordered "ASC" or "DESC". Returns: defer.Deferred: resolves to list[dict[str, Any]] @@ -1576,10 +1435,15 @@ class SQLBaseStore(object): if order_direction not in ["ASC", "DESC"]: raise ValueError("order_direction must be one of 'ASC' or 'DESC'.") + where_clause = "WHERE " if filters or keyvalues else "" + arg_list = [] + if filters: + where_clause += " AND ".join("%s LIKE ?" % (k,) for k in filters) + arg_list += list(filters.values()) + where_clause += " AND " if filters and keyvalues else "" if keyvalues: - where_clause = "WHERE " + " AND ".join("%s = ?" % (k,) for k in keyvalues) - else: - where_clause = "" + where_clause += " AND ".join("%s = ?" % (k,) for k in keyvalues) + arg_list += list(keyvalues.values()) sql = "SELECT %s FROM %s %s ORDER BY %s %s LIMIT ? OFFSET ?" % ( ", ".join(retcols), @@ -1588,25 +1452,11 @@ class SQLBaseStore(object): orderby, order_direction, ) - txn.execute(sql, list(keyvalues.values()) + [limit, start]) + txn.execute(sql, arg_list + [limit, start]) return cls.cursor_to_dict(txn) - def get_user_count_txn(self, txn): - """Get a total number of registered users in the users list. - - Args: - txn : Transaction object - Returns: - int : number of users - """ - sql_count = "SELECT COUNT(*) FROM users WHERE is_guest = 0;" - txn.execute(sql_count) - return txn.fetchone()[0] - - def _simple_search_list( - self, table, term, col, retcols, desc="_simple_search_list" - ): + def simple_search_list(self, table, term, col, retcols, desc="simple_search_list"): """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. @@ -1621,11 +1471,11 @@ class SQLBaseStore(object): """ return self.runInteraction( - desc, self._simple_search_list_txn, table, term, col, retcols + desc, self.simple_search_list_txn, table, term, col, retcols ) @classmethod - def _simple_search_list_txn(cls, txn, table, term, col, retcols): + def simple_search_list_txn(cls, txn, table, term, col, retcols): """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. @@ -1648,14 +1498,6 @@ class SQLBaseStore(object): return cls.cursor_to_dict(txn) - @property - def database_engine_name(self): - return self.database_engine.module.__name__ - - def get_server_version(self): - """Returns a string describing the server version number""" - return self.database_engine.server_version - class _RollbackButIsFineException(Exception): """ This exception is used to rollback a transaction without implying |