diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 1a2b7ebe25..0d7c7dff27 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -14,11 +14,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import itertools
import logging
import random
import sys
-import threading
import time
from typing import Iterable, Tuple
@@ -35,8 +33,6 @@ from synapse.logging.context import LoggingContext, make_deferred_yieldable
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.types import get_domain_from_id
-from synapse.util import batch_iter
-from synapse.util.caches.descriptors import Cache
from synapse.util.stringutils import exception_to_unicode
# import a function which will return a monotonic time, in seconds
@@ -79,10 +75,6 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = {
"event_search": "event_search_event_id_idx",
}
-# This is a special cache name we use to batch multiple invalidations of caches
-# based on the current state when notifying workers over replication.
-_CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
-
class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object
@@ -237,23 +229,11 @@ class SQLBaseStore(object):
# to watch it
self._txn_perf_counters = PerformanceCounters()
- self._get_event_cache = Cache(
- "*getEvent*", keylen=3, max_entries=hs.config.event_cache_size
- )
-
- self._event_fetch_lock = threading.Condition()
- self._event_fetch_list = []
- self._event_fetch_ongoing = 0
-
- self._pending_ds = []
-
self.database_engine = hs.database_engine
# A set of tables that are not safe to use native upserts in.
self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys())
- self._account_validity = self.hs.config.account_validity
-
# We add the user_directory_search table to the blacklist on SQLite
# because the existing search table does not have an index, making it
# unsafe to use native upserts.
@@ -272,14 +252,6 @@ class SQLBaseStore(object):
self.rand = random.SystemRandom()
- if self._account_validity.enabled:
- self._clock.call_later(
- 0.0,
- run_as_background_process,
- "account_validity_set_expiration_dates",
- self._set_expiration_date_when_missing,
- )
-
@defer.inlineCallbacks
def _check_safe_to_upsert(self):
"""
@@ -290,7 +262,7 @@ class SQLBaseStore(object):
If the background updates have not completed, wait 15 sec and check again.
"""
- updates = yield self._simple_select_list(
+ updates = yield self.simple_select_list(
"background_updates",
keyvalues=None,
retcols=["update_name"],
@@ -312,65 +284,6 @@ class SQLBaseStore(object):
self._check_safe_to_upsert,
)
- @defer.inlineCallbacks
- def _set_expiration_date_when_missing(self):
- """
- Retrieves the list of registered users that don't have an expiration date, and
- adds an expiration date for each of them.
- """
-
- def select_users_with_no_expiration_date_txn(txn):
- """Retrieves the list of registered users with no expiration date from the
- database, filtering out deactivated users.
- """
- sql = (
- "SELECT users.name FROM users"
- " LEFT JOIN account_validity ON (users.name = account_validity.user_id)"
- " WHERE account_validity.user_id is NULL AND users.deactivated = 0;"
- )
- txn.execute(sql, [])
-
- res = self.cursor_to_dict(txn)
- if res:
- for user in res:
- self.set_expiration_date_for_user_txn(
- txn, user["name"], use_delta=True
- )
-
- yield self.runInteraction(
- "get_users_with_no_expiration_date",
- select_users_with_no_expiration_date_txn,
- )
-
- def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False):
- """Sets an expiration date to the account with the given user ID.
-
- Args:
- user_id (str): User ID to set an expiration date for.
- use_delta (bool): If set to False, the expiration date for the user will be
- now + validity period. If set to True, this expiration date will be a
- random value in the [now + period - d ; now + period] range, d being a
- delta equal to 10% of the validity period.
- """
- now_ms = self._clock.time_msec()
- expiration_ts = now_ms + self._account_validity.period
-
- if use_delta:
- expiration_ts = self.rand.randrange(
- expiration_ts - self._account_validity.startup_job_max_delta,
- expiration_ts,
- )
-
- self._simple_insert_txn(
- txn,
- "account_validity",
- values={
- "user_id": user_id,
- "expiration_ts_ms": expiration_ts,
- "email_sent": False,
- },
- )
-
def start_profiling(self):
self._previous_loop_ts = monotonic_time()
@@ -394,7 +307,7 @@ class SQLBaseStore(object):
self._clock.looping_call(loop, 10000)
- def _new_transaction(
+ def new_transaction(
self, conn, desc, after_callbacks, exception_callbacks, func, *args, **kwargs
):
start = monotonic_time()
@@ -412,16 +325,15 @@ class SQLBaseStore(object):
i = 0
N = 5
while True:
+ cursor = LoggingTransaction(
+ conn.cursor(),
+ name,
+ self.database_engine,
+ after_callbacks,
+ exception_callbacks,
+ )
try:
- txn = conn.cursor()
- txn = LoggingTransaction(
- txn,
- name,
- self.database_engine,
- after_callbacks,
- exception_callbacks,
- )
- r = func(txn, *args, **kwargs)
+ r = func(cursor, *args, **kwargs)
conn.commit()
return r
except self.database_engine.module.OperationalError as e:
@@ -459,6 +371,40 @@ class SQLBaseStore(object):
)
continue
raise
+ finally:
+ # we're either about to retry with a new cursor, or we're about to
+ # release the connection. Once we release the connection, it could
+ # get used for another query, which might do a conn.rollback().
+ #
+ # In the latter case, even though that probably wouldn't affect the
+ # results of this transaction, python's sqlite will reset all
+ # statements on the connection [1], which will make our cursor
+ # invalid [2].
+ #
+ # In any case, continuing to read rows after commit()ing seems
+ # dubious from the PoV of ACID transactional semantics
+ # (sqlite explicitly says that once you commit, you may see rows
+ # from subsequent updates.)
+ #
+ # In psycopg2, cursors are essentially a client-side fabrication -
+ # all the data is transferred to the client side when the statement
+ # finishes executing - so in theory we could go on streaming results
+ # from the cursor, but attempting to do so would make us
+ # incompatible with sqlite, so let's make sure we're not doing that
+ # by closing the cursor.
+ #
+ # (*named* cursors in psycopg2 are different and are proper server-
+ # side things, but (a) we don't use them and (b) they are implicitly
+ # closed by ending the transaction anyway.)
+ #
+ # In short, if we haven't finished with the cursor yet, that's a
+ # problem waiting to bite us.
+ #
+ # TL;DR: we're done with the cursor, so we can close it.
+ #
+ # [1]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/connection.c#L465
+ # [2]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/cursor.c#L236
+ cursor.close()
except Exception as e:
logger.debug("[TXN FAIL] {%s} %s", name, e)
raise
@@ -498,7 +444,7 @@ class SQLBaseStore(object):
try:
result = yield self.runWithConnection(
- self._new_transaction,
+ self.new_transaction,
desc,
after_callbacks,
exception_callbacks,
@@ -570,7 +516,7 @@ class SQLBaseStore(object):
results = list(dict(zip(col_headers, row)) for row in cursor)
return results
- def _execute(self, desc, decoder, query, *args):
+ def execute(self, desc, decoder, query, *args):
"""Runs a single query for a result set.
Args:
@@ -595,7 +541,7 @@ class SQLBaseStore(object):
# no complex WHERE clauses, just a dict of values for columns.
@defer.inlineCallbacks
- def _simple_insert(self, table, values, or_ignore=False, desc="_simple_insert"):
+ def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"):
"""Executes an INSERT query on the named table.
Args:
@@ -611,7 +557,7 @@ class SQLBaseStore(object):
`or_ignore` is True
"""
try:
- yield self.runInteraction(desc, self._simple_insert_txn, table, values)
+ yield self.runInteraction(desc, self.simple_insert_txn, table, values)
except self.database_engine.module.IntegrityError:
# We have to do or_ignore flag at this layer, since we can't reuse
# a cursor after we receive an error from the db.
@@ -621,7 +567,7 @@ class SQLBaseStore(object):
return True
@staticmethod
- def _simple_insert_txn(txn, table, values):
+ def simple_insert_txn(txn, table, values):
keys, vals = zip(*values.items())
sql = "INSERT INTO %s (%s) VALUES(%s)" % (
@@ -632,11 +578,11 @@ class SQLBaseStore(object):
txn.execute(sql, vals)
- def _simple_insert_many(self, table, values, desc):
- return self.runInteraction(desc, self._simple_insert_many_txn, table, values)
+ def simple_insert_many(self, table, values, desc):
+ return self.runInteraction(desc, self.simple_insert_many_txn, table, values)
@staticmethod
- def _simple_insert_many_txn(txn, table, values):
+ def simple_insert_many_txn(txn, table, values):
if not values:
return
@@ -665,13 +611,13 @@ class SQLBaseStore(object):
txn.executemany(sql, vals)
@defer.inlineCallbacks
- def _simple_upsert(
+ def simple_upsert(
self,
table,
keyvalues,
values,
insertion_values={},
- desc="_simple_upsert",
+ desc="simple_upsert",
lock=True,
):
"""
@@ -703,7 +649,7 @@ class SQLBaseStore(object):
try:
result = yield self.runInteraction(
desc,
- self._simple_upsert_txn,
+ self.simple_upsert_txn,
table,
keyvalues,
values,
@@ -723,7 +669,7 @@ class SQLBaseStore(object):
"IntegrityError when upserting into %s; retrying: %s", table, e
)
- def _simple_upsert_txn(
+ def simple_upsert_txn(
self, txn, table, keyvalues, values, insertion_values={}, lock=True
):
"""
@@ -747,11 +693,11 @@ class SQLBaseStore(object):
self.database_engine.can_native_upsert
and table not in self._unsafe_to_upsert_tables
):
- return self._simple_upsert_txn_native_upsert(
+ return self.simple_upsert_txn_native_upsert(
txn, table, keyvalues, values, insertion_values=insertion_values
)
else:
- return self._simple_upsert_txn_emulated(
+ return self.simple_upsert_txn_emulated(
txn,
table,
keyvalues,
@@ -760,7 +706,7 @@ class SQLBaseStore(object):
lock=lock,
)
- def _simple_upsert_txn_emulated(
+ def simple_upsert_txn_emulated(
self, txn, table, keyvalues, values, insertion_values={}, lock=True
):
"""
@@ -829,7 +775,7 @@ class SQLBaseStore(object):
# successfully inserted
return True
- def _simple_upsert_txn_native_upsert(
+ def simple_upsert_txn_native_upsert(
self, txn, table, keyvalues, values, insertion_values={}
):
"""
@@ -854,7 +800,7 @@ class SQLBaseStore(object):
allvalues.update(values)
latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values)
- sql = ("INSERT INTO %s (%s) VALUES (%s) " "ON CONFLICT (%s) DO %s") % (
+ sql = ("INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s") % (
table,
", ".join(k for k in allvalues),
", ".join("?" for _ in allvalues),
@@ -863,7 +809,7 @@ class SQLBaseStore(object):
)
txn.execute(sql, list(allvalues.values()))
- def _simple_upsert_many_txn(
+ def simple_upsert_many_txn(
self, txn, table, key_names, key_values, value_names, value_values
):
"""
@@ -883,15 +829,15 @@ class SQLBaseStore(object):
self.database_engine.can_native_upsert
and table not in self._unsafe_to_upsert_tables
):
- return self._simple_upsert_many_txn_native_upsert(
+ return self.simple_upsert_many_txn_native_upsert(
txn, table, key_names, key_values, value_names, value_values
)
else:
- return self._simple_upsert_many_txn_emulated(
+ return self.simple_upsert_many_txn_emulated(
txn, table, key_names, key_values, value_names, value_values
)
- def _simple_upsert_many_txn_emulated(
+ def simple_upsert_many_txn_emulated(
self, txn, table, key_names, key_values, value_names, value_values
):
"""
@@ -916,9 +862,9 @@ class SQLBaseStore(object):
_keys = {x: y for x, y in zip(key_names, keyv)}
_vals = {x: y for x, y in zip(value_names, valv)}
- self._simple_upsert_txn_emulated(txn, table, _keys, _vals)
+ self.simple_upsert_txn_emulated(txn, table, _keys, _vals)
- def _simple_upsert_many_txn_native_upsert(
+ def simple_upsert_many_txn_native_upsert(
self, txn, table, key_names, key_values, value_names, value_values
):
"""
@@ -963,8 +909,8 @@ class SQLBaseStore(object):
return txn.execute_batch(sql, args)
- def _simple_select_one(
- self, table, keyvalues, retcols, allow_none=False, desc="_simple_select_one"
+ def simple_select_one(
+ self, table, keyvalues, retcols, allow_none=False, desc="simple_select_one"
):
"""Executes a SELECT query on the named table, which is expected to
return a single row, returning multiple columns from it.
@@ -978,16 +924,16 @@ class SQLBaseStore(object):
statement returns no rows
"""
return self.runInteraction(
- desc, self._simple_select_one_txn, table, keyvalues, retcols, allow_none
+ desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none
)
- def _simple_select_one_onecol(
+ def simple_select_one_onecol(
self,
table,
keyvalues,
retcol,
allow_none=False,
- desc="_simple_select_one_onecol",
+ desc="simple_select_one_onecol",
):
"""Executes a SELECT query on the named table, which is expected to
return a single row, returning a single column from it.
@@ -999,7 +945,7 @@ class SQLBaseStore(object):
"""
return self.runInteraction(
desc,
- self._simple_select_one_onecol_txn,
+ self.simple_select_one_onecol_txn,
table,
keyvalues,
retcol,
@@ -1007,10 +953,10 @@ class SQLBaseStore(object):
)
@classmethod
- def _simple_select_one_onecol_txn(
+ def simple_select_one_onecol_txn(
cls, txn, table, keyvalues, retcol, allow_none=False
):
- ret = cls._simple_select_onecol_txn(
+ ret = cls.simple_select_onecol_txn(
txn, table=table, keyvalues=keyvalues, retcol=retcol
)
@@ -1023,7 +969,7 @@ class SQLBaseStore(object):
raise StoreError(404, "No row found")
@staticmethod
- def _simple_select_onecol_txn(txn, table, keyvalues, retcol):
+ def simple_select_onecol_txn(txn, table, keyvalues, retcol):
sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table}
if keyvalues:
@@ -1034,8 +980,8 @@ class SQLBaseStore(object):
return [r[0] for r in txn]
- def _simple_select_onecol(
- self, table, keyvalues, retcol, desc="_simple_select_onecol"
+ def simple_select_onecol(
+ self, table, keyvalues, retcol, desc="simple_select_onecol"
):
"""Executes a SELECT query on the named table, which returns a list
comprising of the values of the named column from the selected rows.
@@ -1049,12 +995,10 @@ class SQLBaseStore(object):
Deferred: Results in a list
"""
return self.runInteraction(
- desc, self._simple_select_onecol_txn, table, keyvalues, retcol
+ desc, self.simple_select_onecol_txn, table, keyvalues, retcol
)
- def _simple_select_list(
- self, table, keyvalues, retcols, desc="_simple_select_list"
- ):
+ def simple_select_list(self, table, keyvalues, retcols, desc="simple_select_list"):
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
@@ -1068,11 +1012,11 @@ class SQLBaseStore(object):
defer.Deferred: resolves to list[dict[str, Any]]
"""
return self.runInteraction(
- desc, self._simple_select_list_txn, table, keyvalues, retcols
+ desc, self.simple_select_list_txn, table, keyvalues, retcols
)
@classmethod
- def _simple_select_list_txn(cls, txn, table, keyvalues, retcols):
+ def simple_select_list_txn(cls, txn, table, keyvalues, retcols):
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
@@ -1098,14 +1042,14 @@ class SQLBaseStore(object):
return cls.cursor_to_dict(txn)
@defer.inlineCallbacks
- def _simple_select_many_batch(
+ def simple_select_many_batch(
self,
table,
column,
iterable,
retcols,
keyvalues={},
- desc="_simple_select_many_batch",
+ desc="simple_select_many_batch",
batch_size=100,
):
"""Executes a SELECT query on the named table, which may return zero or
@@ -1134,7 +1078,7 @@ class SQLBaseStore(object):
for chunk in chunks:
rows = yield self.runInteraction(
desc,
- self._simple_select_many_txn,
+ self.simple_select_many_txn,
table,
column,
chunk,
@@ -1147,7 +1091,7 @@ class SQLBaseStore(object):
return results
@classmethod
- def _simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols):
+ def simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols):
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
@@ -1180,13 +1124,13 @@ class SQLBaseStore(object):
txn.execute(sql, values)
return cls.cursor_to_dict(txn)
- def _simple_update(self, table, keyvalues, updatevalues, desc):
+ def simple_update(self, table, keyvalues, updatevalues, desc):
return self.runInteraction(
- desc, self._simple_update_txn, table, keyvalues, updatevalues
+ desc, self.simple_update_txn, table, keyvalues, updatevalues
)
@staticmethod
- def _simple_update_txn(txn, table, keyvalues, updatevalues):
+ def simple_update_txn(txn, table, keyvalues, updatevalues):
if keyvalues:
where = "WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
else:
@@ -1202,8 +1146,8 @@ class SQLBaseStore(object):
return txn.rowcount
- def _simple_update_one(
- self, table, keyvalues, updatevalues, desc="_simple_update_one"
+ def simple_update_one(
+ self, table, keyvalues, updatevalues, desc="simple_update_one"
):
"""Executes an UPDATE query on the named table, setting new values for
columns in a row matching the key values.
@@ -1223,12 +1167,12 @@ class SQLBaseStore(object):
the update column in the 'keyvalues' dict as well.
"""
return self.runInteraction(
- desc, self._simple_update_one_txn, table, keyvalues, updatevalues
+ desc, self.simple_update_one_txn, table, keyvalues, updatevalues
)
@classmethod
- def _simple_update_one_txn(cls, txn, table, keyvalues, updatevalues):
- rowcount = cls._simple_update_txn(txn, table, keyvalues, updatevalues)
+ def simple_update_one_txn(cls, txn, table, keyvalues, updatevalues):
+ rowcount = cls.simple_update_txn(txn, table, keyvalues, updatevalues)
if rowcount == 0:
raise StoreError(404, "No row found (%s)" % (table,))
@@ -1236,7 +1180,7 @@ class SQLBaseStore(object):
raise StoreError(500, "More than one row matched (%s)" % (table,))
@staticmethod
- def _simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False):
+ def simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False):
select_sql = "SELECT %s FROM %s WHERE %s" % (
", ".join(retcols),
table,
@@ -1255,7 +1199,7 @@ class SQLBaseStore(object):
return dict(zip(retcols, row))
- def _simple_delete_one(self, table, keyvalues, desc="_simple_delete_one"):
+ def simple_delete_one(self, table, keyvalues, desc="simple_delete_one"):
"""Executes a DELETE query on the named table, expecting to delete a
single row.
@@ -1263,10 +1207,10 @@ class SQLBaseStore(object):
table : string giving the table name
keyvalues : dict of column names and values to select the row with
"""
- return self.runInteraction(desc, self._simple_delete_one_txn, table, keyvalues)
+ return self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues)
@staticmethod
- def _simple_delete_one_txn(txn, table, keyvalues):
+ def simple_delete_one_txn(txn, table, keyvalues):
"""Executes a DELETE query on the named table, expecting to delete a
single row.
@@ -1285,11 +1229,11 @@ class SQLBaseStore(object):
if txn.rowcount > 1:
raise StoreError(500, "More than one row matched (%s)" % (table,))
- def _simple_delete(self, table, keyvalues, desc):
- return self.runInteraction(desc, self._simple_delete_txn, table, keyvalues)
+ def simple_delete(self, table, keyvalues, desc):
+ return self.runInteraction(desc, self.simple_delete_txn, table, keyvalues)
@staticmethod
- def _simple_delete_txn(txn, table, keyvalues):
+ def simple_delete_txn(txn, table, keyvalues):
sql = "DELETE FROM %s WHERE %s" % (
table,
" AND ".join("%s = ?" % (k,) for k in keyvalues),
@@ -1298,13 +1242,13 @@ class SQLBaseStore(object):
txn.execute(sql, list(keyvalues.values()))
return txn.rowcount
- def _simple_delete_many(self, table, column, iterable, keyvalues, desc):
+ def simple_delete_many(self, table, column, iterable, keyvalues, desc):
return self.runInteraction(
- desc, self._simple_delete_many_txn, table, column, iterable, keyvalues
+ desc, self.simple_delete_many_txn, table, column, iterable, keyvalues
)
@staticmethod
- def _simple_delete_many_txn(txn, table, column, iterable, keyvalues):
+ def simple_delete_many_txn(txn, table, column, iterable, keyvalues):
"""Executes a DELETE query on the named table.
Filters rows by if value of `column` is in `iterable`.
@@ -1337,7 +1281,7 @@ class SQLBaseStore(object):
return txn.rowcount
- def _get_cache_dict(
+ def get_cache_dict(
self, db_conn, table, entity_column, stream_column, max_value, limit=100000
):
# Fetch a mapping of room_id -> max stream position for "recent" rooms.
@@ -1370,47 +1314,6 @@ class SQLBaseStore(object):
return cache, min_val
- def _invalidate_cache_and_stream(self, txn, cache_func, keys):
- """Invalidates the cache and adds it to the cache stream so slaves
- will know to invalidate their caches.
-
- This should only be used to invalidate caches where slaves won't
- otherwise know from other replication streams that the cache should
- be invalidated.
- """
- txn.call_after(cache_func.invalidate, keys)
- self._send_invalidation_to_replication(txn, cache_func.__name__, keys)
-
- def _invalidate_state_caches_and_stream(self, txn, room_id, members_changed):
- """Special case invalidation of caches based on current state.
-
- We special case this so that we can batch the cache invalidations into a
- single replication poke.
-
- Args:
- txn
- room_id (str): Room where state changed
- members_changed (iterable[str]): The user_ids of members that have changed
- """
- txn.call_after(self._invalidate_state_caches, room_id, members_changed)
-
- if members_changed:
- # We need to be careful that the size of the `members_changed` list
- # isn't so large that it causes problems sending over replication, so we
- # send them in chunks.
- # Max line length is 16K, and max user ID length is 255, so 50 should
- # be safe.
- for chunk in batch_iter(members_changed, 50):
- keys = itertools.chain([room_id], chunk)
- self._send_invalidation_to_replication(
- txn, _CURRENT_STATE_CACHE_NAME, keys
- )
- else:
- # if no members changed, we still need to invalidate the other caches.
- self._send_invalidation_to_replication(
- txn, _CURRENT_STATE_CACHE_NAME, [room_id]
- )
-
def _invalidate_state_caches(self, room_id, members_changed):
"""Invalidates caches that are based on the current state, but does
not stream invalidations down replication.
@@ -1444,73 +1347,17 @@ class SQLBaseStore(object):
# which is fine.
pass
- def _send_invalidation_to_replication(self, txn, cache_name, keys):
- """Notifies replication that given cache has been invalidated.
-
- Note that this does *not* invalidate the cache locally.
-
- Args:
- txn
- cache_name (str)
- keys (iterable[str])
- """
-
- if isinstance(self.database_engine, PostgresEngine):
- # get_next() returns a context manager which is designed to wrap
- # the transaction. However, we want to only get an ID when we want
- # to use it, here, so we need to call __enter__ manually, and have
- # __exit__ called after the transaction finishes.
- ctx = self._cache_id_gen.get_next()
- stream_id = ctx.__enter__()
- txn.call_on_exception(ctx.__exit__, None, None, None)
- txn.call_after(ctx.__exit__, None, None, None)
- txn.call_after(self.hs.get_notifier().on_new_replication_data)
-
- self._simple_insert_txn(
- txn,
- table="cache_invalidation_stream",
- values={
- "stream_id": stream_id,
- "cache_func": cache_name,
- "keys": list(keys),
- "invalidation_ts": self.clock.time_msec(),
- },
- )
-
- def get_all_updated_caches(self, last_id, current_id, limit):
- if last_id == current_id:
- return defer.succeed([])
-
- def get_all_updated_caches_txn(txn):
- # We purposefully don't bound by the current token, as we want to
- # send across cache invalidations as quickly as possible. Cache
- # invalidations are idempotent, so duplicates are fine.
- sql = (
- "SELECT stream_id, cache_func, keys, invalidation_ts"
- " FROM cache_invalidation_stream"
- " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
- )
- txn.execute(sql, (last_id, limit))
- return txn.fetchall()
-
- return self.runInteraction("get_all_updated_caches", get_all_updated_caches_txn)
-
- def get_cache_stream_token(self):
- if self._cache_id_gen:
- return self._cache_id_gen.get_current_token()
- else:
- return 0
-
- def _simple_select_list_paginate(
+ def simple_select_list_paginate(
self,
table,
- keyvalues,
orderby,
start,
limit,
retcols,
+ filters=None,
+ keyvalues=None,
order_direction="ASC",
- desc="_simple_select_list_paginate",
+ desc="simple_select_list_paginate",
):
"""
Executes a SELECT query on the named table with start and limit,
@@ -1519,6 +1366,9 @@ class SQLBaseStore(object):
Args:
table (str): the table name
+ filters (dict[str, T] | None):
+ column names and values to filter the rows with, or None to not
+ apply a WHERE ? LIKE ? clause.
keyvalues (dict[str, T] | None):
column names and values to select the rows with, or None to not
apply a WHERE clause.
@@ -1532,26 +1382,28 @@ class SQLBaseStore(object):
"""
return self.runInteraction(
desc,
- self._simple_select_list_paginate_txn,
+ self.simple_select_list_paginate_txn,
table,
- keyvalues,
orderby,
start,
limit,
retcols,
+ filters=filters,
+ keyvalues=keyvalues,
order_direction=order_direction,
)
@classmethod
- def _simple_select_list_paginate_txn(
+ def simple_select_list_paginate_txn(
cls,
txn,
table,
- keyvalues,
orderby,
start,
limit,
retcols,
+ filters=None,
+ keyvalues=None,
order_direction="ASC",
):
"""
@@ -1559,16 +1411,23 @@ class SQLBaseStore(object):
of row numbers, which may return zero or number of rows from start to limit,
returning the result as a list of dicts.
+ Use `filters` to search attributes using SQL wildcards and/or `keyvalues` to
+ select attributes with exact matches. All constraints are joined together
+ using 'AND'.
+
Args:
txn : Transaction object
table (str): the table name
- keyvalues (dict[str, T] | None):
- column names and values to select the rows with, or None to not
- apply a WHERE clause.
orderby (str): Column to order the results by.
start (int): Index to begin the query at.
limit (int): Number of results to return.
retcols (iterable[str]): the names of the columns to return
+ filters (dict[str, T] | None):
+ column names and values to filter the rows with, or None to not
+ apply a WHERE ? LIKE ? clause.
+ keyvalues (dict[str, T] | None):
+ column names and values to select the rows with, or None to not
+ apply a WHERE clause.
order_direction (str): Whether the results should be ordered "ASC" or "DESC".
Returns:
defer.Deferred: resolves to list[dict[str, Any]]
@@ -1576,10 +1435,15 @@ class SQLBaseStore(object):
if order_direction not in ["ASC", "DESC"]:
raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
+ where_clause = "WHERE " if filters or keyvalues else ""
+ arg_list = []
+ if filters:
+ where_clause += " AND ".join("%s LIKE ?" % (k,) for k in filters)
+ arg_list += list(filters.values())
+ where_clause += " AND " if filters and keyvalues else ""
if keyvalues:
- where_clause = "WHERE " + " AND ".join("%s = ?" % (k,) for k in keyvalues)
- else:
- where_clause = ""
+ where_clause += " AND ".join("%s = ?" % (k,) for k in keyvalues)
+ arg_list += list(keyvalues.values())
sql = "SELECT %s FROM %s %s ORDER BY %s %s LIMIT ? OFFSET ?" % (
", ".join(retcols),
@@ -1588,25 +1452,11 @@ class SQLBaseStore(object):
orderby,
order_direction,
)
- txn.execute(sql, list(keyvalues.values()) + [limit, start])
+ txn.execute(sql, arg_list + [limit, start])
return cls.cursor_to_dict(txn)
- def get_user_count_txn(self, txn):
- """Get a total number of registered users in the users list.
-
- Args:
- txn : Transaction object
- Returns:
- int : number of users
- """
- sql_count = "SELECT COUNT(*) FROM users WHERE is_guest = 0;"
- txn.execute(sql_count)
- return txn.fetchone()[0]
-
- def _simple_search_list(
- self, table, term, col, retcols, desc="_simple_search_list"
- ):
+ def simple_search_list(self, table, term, col, retcols, desc="simple_search_list"):
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
@@ -1621,11 +1471,11 @@ class SQLBaseStore(object):
"""
return self.runInteraction(
- desc, self._simple_search_list_txn, table, term, col, retcols
+ desc, self.simple_search_list_txn, table, term, col, retcols
)
@classmethod
- def _simple_search_list_txn(cls, txn, table, term, col, retcols):
+ def simple_search_list_txn(cls, txn, table, term, col, retcols):
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
@@ -1648,14 +1498,6 @@ class SQLBaseStore(object):
return cls.cursor_to_dict(txn)
- @property
- def database_engine_name(self):
- return self.database_engine.module.__name__
-
- def get_server_version(self):
- """Returns a string describing the server version number"""
- return self.database_engine.server_version
-
class _RollbackButIsFineException(Exception):
""" This exception is used to rollback a transaction without implying
|