diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 9205e550bb..fd5bb3e1de 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -16,1304 +16,28 @@
# limitations under the License.
import logging
import random
-import sys
-import time
-from typing import Iterable, Tuple
-from six import PY2, iteritems, iterkeys, itervalues
-from six.moves import builtins, intern, range
+from six import PY2
+from six.moves import builtins
from canonicaljson import json
-from prometheus_client import Histogram
-from twisted.internet import defer
-
-from synapse.api.errors import StoreError
-from synapse.logging.context import LoggingContext, make_deferred_yieldable
-from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from synapse.storage.database import LoggingTransaction # noqa: F401
+from synapse.storage.database import make_in_list_sql_clause # noqa: F401
+from synapse.storage.database import Database
from synapse.types import get_domain_from_id
-from synapse.util.stringutils import exception_to_unicode
-
-# import a function which will return a monotonic time, in seconds
-try:
- # on python 3, use time.monotonic, since time.clock can go backwards
- from time import monotonic as monotonic_time
-except ImportError:
- # ... but python 2 doesn't have it
- from time import clock as monotonic_time
logger = logging.getLogger(__name__)
-try:
- MAX_TXN_ID = sys.maxint - 1
-except AttributeError:
- # python 3 does not have a maximum int value
- MAX_TXN_ID = 2 ** 63 - 1
-
-sql_logger = logging.getLogger("synapse.storage.SQL")
-transaction_logger = logging.getLogger("synapse.storage.txn")
-perf_logger = logging.getLogger("synapse.storage.TIME")
-
-sql_scheduling_timer = Histogram("synapse_storage_schedule_time", "sec")
-
-sql_query_timer = Histogram("synapse_storage_query_time", "sec", ["verb"])
-sql_txn_timer = Histogram("synapse_storage_transaction_time", "sec", ["desc"])
-
-
-# Unique indexes which have been added in background updates. Maps from table name
-# to the name of the background update which added the unique index to that table.
-#
-# This is used by the upsert logic to figure out which tables are safe to do a proper
-# UPSERT on: until the relevant background update has completed, we
-# have to emulate an upsert by locking the table.
-#
-UNIQUE_INDEX_BACKGROUND_UPDATES = {
- "user_ips": "user_ips_device_unique_index",
- "device_lists_remote_extremeties": "device_lists_remote_extremeties_unique_idx",
- "device_lists_remote_cache": "device_lists_remote_cache_unique_idx",
- "event_search": "event_search_event_id_idx",
-}
-
-
-class LoggingTransaction(object):
- """An object that almost-transparently proxies for the 'txn' object
- passed to the constructor. Adds logging and metrics to the .execute()
- method.
-
- Args:
- txn: The database transcation object to wrap.
- name (str): The name of this transactions for logging.
- database_engine (Sqlite3Engine|PostgresEngine)
- after_callbacks(list|None): A list that callbacks will be appended to
- that have been added by `call_after` which should be run on
- successful completion of the transaction. None indicates that no
- callbacks should be allowed to be scheduled to run.
- exception_callbacks(list|None): A list that callbacks will be appended
- to that have been added by `call_on_exception` which should be run
- if transaction ends with an error. None indicates that no callbacks
- should be allowed to be scheduled to run.
- """
-
- __slots__ = [
- "txn",
- "name",
- "database_engine",
- "after_callbacks",
- "exception_callbacks",
- ]
-
- def __init__(
- self, txn, name, database_engine, after_callbacks=None, exception_callbacks=None
- ):
- object.__setattr__(self, "txn", txn)
- object.__setattr__(self, "name", name)
- object.__setattr__(self, "database_engine", database_engine)
- object.__setattr__(self, "after_callbacks", after_callbacks)
- object.__setattr__(self, "exception_callbacks", exception_callbacks)
-
- def call_after(self, callback, *args, **kwargs):
- """Call the given callback on the main twisted thread after the
- transaction has finished. Used to invalidate the caches on the
- correct thread.
- """
- self.after_callbacks.append((callback, args, kwargs))
-
- def call_on_exception(self, callback, *args, **kwargs):
- self.exception_callbacks.append((callback, args, kwargs))
-
- def __getattr__(self, name):
- return getattr(self.txn, name)
-
- def __setattr__(self, name, value):
- setattr(self.txn, name, value)
-
- def __iter__(self):
- return self.txn.__iter__()
-
- def execute_batch(self, sql, args):
- if isinstance(self.database_engine, PostgresEngine):
- from psycopg2.extras import execute_batch
-
- self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
- else:
- for val in args:
- self.execute(sql, val)
-
- def execute(self, sql, *args):
- self._do_execute(self.txn.execute, sql, *args)
-
- def executemany(self, sql, *args):
- self._do_execute(self.txn.executemany, sql, *args)
-
- def _make_sql_one_line(self, sql):
- "Strip newlines out of SQL so that the loggers in the DB are on one line"
- return " ".join(l.strip() for l in sql.splitlines() if l.strip())
-
- def _do_execute(self, func, sql, *args):
- sql = self._make_sql_one_line(sql)
-
- # TODO(paul): Maybe use 'info' and 'debug' for values?
- sql_logger.debug("[SQL] {%s} %s", self.name, sql)
-
- sql = self.database_engine.convert_param_style(sql)
- if args:
- try:
- sql_logger.debug("[SQL values] {%s} %r", self.name, args[0])
- except Exception:
- # Don't let logging failures stop SQL from working
- pass
-
- start = time.time()
-
- try:
- return func(sql, *args)
- except Exception as e:
- logger.debug("[SQL FAIL] {%s} %s", self.name, e)
- raise
- finally:
- secs = time.time() - start
- sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs)
- sql_query_timer.labels(sql.split()[0]).observe(secs)
-
-
-class PerformanceCounters(object):
- def __init__(self):
- self.current_counters = {}
- self.previous_counters = {}
-
- def update(self, key, duration_secs):
- count, cum_time = self.current_counters.get(key, (0, 0))
- count += 1
- cum_time += duration_secs
- self.current_counters[key] = (count, cum_time)
-
- def interval(self, interval_duration_secs, limit=3):
- counters = []
- for name, (count, cum_time) in iteritems(self.current_counters):
- prev_count, prev_time = self.previous_counters.get(name, (0, 0))
- counters.append(
- (
- (cum_time - prev_time) / interval_duration_secs,
- count - prev_count,
- name,
- )
- )
-
- self.previous_counters = dict(self.current_counters)
-
- counters.sort(reverse=True)
-
- top_n_counters = ", ".join(
- "%s(%d): %.3f%%" % (name, count, 100 * ratio)
- for ratio, count, name in counters[:limit]
- )
-
- return top_n_counters
-
class SQLBaseStore(object):
- _TXN_ID = 0
-
def __init__(self, db_conn, hs):
self.hs = hs
self._clock = hs.get_clock()
- self._db_pool = hs.get_db_pool()
-
- self._previous_txn_total_time = 0
- self._current_txn_total_time = 0
- self._previous_loop_ts = 0
-
- # TODO(paul): These can eventually be removed once the metrics code
- # is running in mainline, and we have some nice monitoring frontends
- # to watch it
- self._txn_perf_counters = PerformanceCounters()
-
self.database_engine = hs.database_engine
-
- # A set of tables that are not safe to use native upserts in.
- self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys())
-
- # We add the user_directory_search table to the blacklist on SQLite
- # because the existing search table does not have an index, making it
- # unsafe to use native upserts.
- if isinstance(self.database_engine, Sqlite3Engine):
- self._unsafe_to_upsert_tables.add("user_directory_search")
-
- if self.database_engine.can_native_upsert:
- # Check ASAP (and then later, every 1s) to see if we have finished
- # background updates of tables that aren't safe to update.
- self._clock.call_later(
- 0.0,
- run_as_background_process,
- "upsert_safety_check",
- self._check_safe_to_upsert,
- )
-
+ self.db = Database(hs)
self.rand = random.SystemRandom()
- @defer.inlineCallbacks
- def _check_safe_to_upsert(self):
- """
- Is it safe to use native UPSERT?
-
- If there are background updates, we will need to wait, as they may be
- the addition of indexes that set the UNIQUE constraint that we require.
-
- If the background updates have not completed, wait 15 sec and check again.
- """
- updates = yield self.simple_select_list(
- "background_updates",
- keyvalues=None,
- retcols=["update_name"],
- desc="check_background_updates",
- )
- updates = [x["update_name"] for x in updates]
-
- for table, update_name in UNIQUE_INDEX_BACKGROUND_UPDATES.items():
- if update_name not in updates:
- logger.debug("Now safe to upsert in %s", table)
- self._unsafe_to_upsert_tables.discard(table)
-
- # If there's any updates still running, reschedule to run.
- if updates:
- self._clock.call_later(
- 15.0,
- run_as_background_process,
- "upsert_safety_check",
- self._check_safe_to_upsert,
- )
-
- def start_profiling(self):
- self._previous_loop_ts = monotonic_time()
-
- def loop():
- curr = self._current_txn_total_time
- prev = self._previous_txn_total_time
- self._previous_txn_total_time = curr
-
- time_now = monotonic_time()
- time_then = self._previous_loop_ts
- self._previous_loop_ts = time_now
-
- duration = time_now - time_then
- ratio = (curr - prev) / duration
-
- top_three_counters = self._txn_perf_counters.interval(duration, limit=3)
-
- perf_logger.info(
- "Total database time: %.3f%% {%s}", ratio * 100, top_three_counters
- )
-
- self._clock.looping_call(loop, 10000)
-
- def new_transaction(
- self, conn, desc, after_callbacks, exception_callbacks, func, *args, **kwargs
- ):
- start = monotonic_time()
- txn_id = self._TXN_ID
-
- # We don't really need these to be unique, so lets stop it from
- # growing really large.
- self._TXN_ID = (self._TXN_ID + 1) % (MAX_TXN_ID)
-
- name = "%s-%x" % (desc, txn_id)
-
- transaction_logger.debug("[TXN START] {%s}", name)
-
- try:
- i = 0
- N = 5
- while True:
- cursor = LoggingTransaction(
- conn.cursor(),
- name,
- self.database_engine,
- after_callbacks,
- exception_callbacks,
- )
- try:
- r = func(cursor, *args, **kwargs)
- conn.commit()
- return r
- except self.database_engine.module.OperationalError as e:
- # This can happen if the database disappears mid
- # transaction.
- logger.warning(
- "[TXN OPERROR] {%s} %s %d/%d",
- name,
- exception_to_unicode(e),
- i,
- N,
- )
- if i < N:
- i += 1
- try:
- conn.rollback()
- except self.database_engine.module.Error as e1:
- logger.warning(
- "[TXN EROLL] {%s} %s", name, exception_to_unicode(e1)
- )
- continue
- raise
- except self.database_engine.module.DatabaseError as e:
- if self.database_engine.is_deadlock(e):
- logger.warning("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
- if i < N:
- i += 1
- try:
- conn.rollback()
- except self.database_engine.module.Error as e1:
- logger.warning(
- "[TXN EROLL] {%s} %s",
- name,
- exception_to_unicode(e1),
- )
- continue
- raise
- finally:
- # we're either about to retry with a new cursor, or we're about to
- # release the connection. Once we release the connection, it could
- # get used for another query, which might do a conn.rollback().
- #
- # In the latter case, even though that probably wouldn't affect the
- # results of this transaction, python's sqlite will reset all
- # statements on the connection [1], which will make our cursor
- # invalid [2].
- #
- # In any case, continuing to read rows after commit()ing seems
- # dubious from the PoV of ACID transactional semantics
- # (sqlite explicitly says that once you commit, you may see rows
- # from subsequent updates.)
- #
- # In psycopg2, cursors are essentially a client-side fabrication -
- # all the data is transferred to the client side when the statement
- # finishes executing - so in theory we could go on streaming results
- # from the cursor, but attempting to do so would make us
- # incompatible with sqlite, so let's make sure we're not doing that
- # by closing the cursor.
- #
- # (*named* cursors in psycopg2 are different and are proper server-
- # side things, but (a) we don't use them and (b) they are implicitly
- # closed by ending the transaction anyway.)
- #
- # In short, if we haven't finished with the cursor yet, that's a
- # problem waiting to bite us.
- #
- # TL;DR: we're done with the cursor, so we can close it.
- #
- # [1]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/connection.c#L465
- # [2]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/cursor.c#L236
- cursor.close()
- except Exception as e:
- logger.debug("[TXN FAIL] {%s} %s", name, e)
- raise
- finally:
- end = monotonic_time()
- duration = end - start
-
- LoggingContext.current_context().add_database_transaction(duration)
-
- transaction_logger.debug("[TXN END] {%s} %f sec", name, duration)
-
- self._current_txn_total_time += duration
- self._txn_perf_counters.update(desc, duration)
- sql_txn_timer.labels(desc).observe(duration)
-
- @defer.inlineCallbacks
- def runInteraction(self, desc, func, *args, **kwargs):
- """Starts a transaction on the database and runs a given function
-
- Arguments:
- desc (str): description of the transaction, for logging and metrics
- func (func): callback function, which will be called with a
- database transaction (twisted.enterprise.adbapi.Transaction) as
- its first argument, followed by `args` and `kwargs`.
-
- args (list): positional args to pass to `func`
- kwargs (dict): named args to pass to `func`
-
- Returns:
- Deferred: The result of func
- """
- after_callbacks = []
- exception_callbacks = []
-
- if LoggingContext.current_context() == LoggingContext.sentinel:
- logger.warning("Starting db txn '%s' from sentinel context", desc)
-
- try:
- result = yield self.runWithConnection(
- self.new_transaction,
- desc,
- after_callbacks,
- exception_callbacks,
- func,
- *args,
- **kwargs
- )
-
- for after_callback, after_args, after_kwargs in after_callbacks:
- after_callback(*after_args, **after_kwargs)
- except: # noqa: E722, as we reraise the exception this is fine.
- for after_callback, after_args, after_kwargs in exception_callbacks:
- after_callback(*after_args, **after_kwargs)
- raise
-
- return result
-
- @defer.inlineCallbacks
- def runWithConnection(self, func, *args, **kwargs):
- """Wraps the .runWithConnection() method on the underlying db_pool.
-
- Arguments:
- func (func): callback function, which will be called with a
- database connection (twisted.enterprise.adbapi.Connection) as
- its first argument, followed by `args` and `kwargs`.
- args (list): positional args to pass to `func`
- kwargs (dict): named args to pass to `func`
-
- Returns:
- Deferred: The result of func
- """
- parent_context = LoggingContext.current_context()
- if parent_context == LoggingContext.sentinel:
- logger.warning(
- "Starting db connection from sentinel context: metrics will be lost"
- )
- parent_context = None
-
- start_time = monotonic_time()
-
- def inner_func(conn, *args, **kwargs):
- with LoggingContext("runWithConnection", parent_context) as context:
- sched_duration_sec = monotonic_time() - start_time
- sql_scheduling_timer.observe(sched_duration_sec)
- context.add_database_scheduled(sched_duration_sec)
-
- if self.database_engine.is_connection_closed(conn):
- logger.debug("Reconnecting closed database connection")
- conn.reconnect()
-
- return func(conn, *args, **kwargs)
-
- result = yield make_deferred_yieldable(
- self._db_pool.runWithConnection(inner_func, *args, **kwargs)
- )
-
- return result
-
- @staticmethod
- def cursor_to_dict(cursor):
- """Converts a SQL cursor into an list of dicts.
-
- Args:
- cursor : The DBAPI cursor which has executed a query.
- Returns:
- A list of dicts where the key is the column header.
- """
- col_headers = list(intern(str(column[0])) for column in cursor.description)
- results = list(dict(zip(col_headers, row)) for row in cursor)
- return results
-
- def execute(self, desc, decoder, query, *args):
- """Runs a single query for a result set.
-
- Args:
- decoder - The function which can resolve the cursor results to
- something meaningful.
- query - The query string to execute
- *args - Query args.
- Returns:
- The result of decoder(results)
- """
-
- def interaction(txn):
- txn.execute(query, args)
- if decoder:
- return decoder(txn)
- else:
- return txn.fetchall()
-
- return self.runInteraction(desc, interaction)
-
- # "Simple" SQL API methods that operate on a single table with no JOINs,
- # no complex WHERE clauses, just a dict of values for columns.
-
- @defer.inlineCallbacks
- def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"):
- """Executes an INSERT query on the named table.
-
- Args:
- table : string giving the table name
- values : dict of new column names and values for them
- or_ignore : bool stating whether an exception should be raised
- when a conflicting row already exists. If True, False will be
- returned by the function instead
- desc : string giving a description of the transaction
-
- Returns:
- bool: Whether the row was inserted or not. Only useful when
- `or_ignore` is True
- """
- try:
- yield self.runInteraction(desc, self.simple_insert_txn, table, values)
- except self.database_engine.module.IntegrityError:
- # We have to do or_ignore flag at this layer, since we can't reuse
- # a cursor after we receive an error from the db.
- if not or_ignore:
- raise
- return False
- return True
-
- @staticmethod
- def simple_insert_txn(txn, table, values):
- keys, vals = zip(*values.items())
-
- sql = "INSERT INTO %s (%s) VALUES(%s)" % (
- table,
- ", ".join(k for k in keys),
- ", ".join("?" for _ in keys),
- )
-
- txn.execute(sql, vals)
-
- def simple_insert_many(self, table, values, desc):
- return self.runInteraction(desc, self.simple_insert_many_txn, table, values)
-
- @staticmethod
- def simple_insert_many_txn(txn, table, values):
- if not values:
- return
-
- # This is a *slight* abomination to get a list of tuples of key names
- # and a list of tuples of value names.
- #
- # i.e. [{"a": 1, "b": 2}, {"c": 3, "d": 4}]
- # => [("a", "b",), ("c", "d",)] and [(1, 2,), (3, 4,)]
- #
- # The sort is to ensure that we don't rely on dictionary iteration
- # order.
- keys, vals = zip(
- *[zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i]
- )
-
- for k in keys:
- if k != keys[0]:
- raise RuntimeError("All items must have the same keys")
-
- sql = "INSERT INTO %s (%s) VALUES(%s)" % (
- table,
- ", ".join(k for k in keys[0]),
- ", ".join("?" for _ in keys[0]),
- )
-
- txn.executemany(sql, vals)
-
- @defer.inlineCallbacks
- def simple_upsert(
- self,
- table,
- keyvalues,
- values,
- insertion_values={},
- desc="simple_upsert",
- lock=True,
- ):
- """
-
- `lock` should generally be set to True (the default), but can be set
- to False if either of the following are true:
-
- * there is a UNIQUE INDEX on the key columns. In this case a conflict
- will cause an IntegrityError in which case this function will retry
- the update.
-
- * we somehow know that we are the only thread which will be updating
- this table.
-
- Args:
- table (str): The table to upsert into
- keyvalues (dict): The unique key columns and their new values
- values (dict): The nonunique columns and their new values
- insertion_values (dict): additional key/values to use only when
- inserting
- lock (bool): True to lock the table when doing the upsert.
- Returns:
- Deferred(None or bool): Native upserts always return None. Emulated
- upserts return True if a new entry was created, False if an existing
- one was updated.
- """
- attempts = 0
- while True:
- try:
- result = yield self.runInteraction(
- desc,
- self.simple_upsert_txn,
- table,
- keyvalues,
- values,
- insertion_values,
- lock=lock,
- )
- return result
- except self.database_engine.module.IntegrityError as e:
- attempts += 1
- if attempts >= 5:
- # don't retry forever, because things other than races
- # can cause IntegrityErrors
- raise
-
- # presumably we raced with another transaction: let's retry.
- logger.warning(
- "IntegrityError when upserting into %s; retrying: %s", table, e
- )
-
- def simple_upsert_txn(
- self, txn, table, keyvalues, values, insertion_values={}, lock=True
- ):
- """
- Pick the UPSERT method which works best on the platform. Either the
- native one (Pg9.5+, recent SQLites), or fall back to an emulated method.
-
- Args:
- txn: The transaction to use.
- table (str): The table to upsert into
- keyvalues (dict): The unique key tables and their new values
- values (dict): The nonunique columns and their new values
- insertion_values (dict): additional key/values to use only when
- inserting
- lock (bool): True to lock the table when doing the upsert.
- Returns:
- None or bool: Native upserts always return None. Emulated
- upserts return True if a new entry was created, False if an existing
- one was updated.
- """
- if (
- self.database_engine.can_native_upsert
- and table not in self._unsafe_to_upsert_tables
- ):
- return self.simple_upsert_txn_native_upsert(
- txn, table, keyvalues, values, insertion_values=insertion_values
- )
- else:
- return self.simple_upsert_txn_emulated(
- txn,
- table,
- keyvalues,
- values,
- insertion_values=insertion_values,
- lock=lock,
- )
-
- def simple_upsert_txn_emulated(
- self, txn, table, keyvalues, values, insertion_values={}, lock=True
- ):
- """
- Args:
- table (str): The table to upsert into
- keyvalues (dict): The unique key tables and their new values
- values (dict): The nonunique columns and their new values
- insertion_values (dict): additional key/values to use only when
- inserting
- lock (bool): True to lock the table when doing the upsert.
- Returns:
- bool: Return True if a new entry was created, False if an existing
- one was updated.
- """
- # We need to lock the table :(, unless we're *really* careful
- if lock:
- self.database_engine.lock_table(txn, table)
-
- def _getwhere(key):
- # If the value we're passing in is None (aka NULL), we need to use
- # IS, not =, as NULL = NULL equals NULL (False).
- if keyvalues[key] is None:
- return "%s IS ?" % (key,)
- else:
- return "%s = ?" % (key,)
-
- if not values:
- # If `values` is empty, then all of the values we care about are in
- # the unique key, so there is nothing to UPDATE. We can just do a
- # SELECT instead to see if it exists.
- sql = "SELECT 1 FROM %s WHERE %s" % (
- table,
- " AND ".join(_getwhere(k) for k in keyvalues),
- )
- sqlargs = list(keyvalues.values())
- txn.execute(sql, sqlargs)
- if txn.fetchall():
- # We have an existing record.
- return False
- else:
- # First try to update.
- sql = "UPDATE %s SET %s WHERE %s" % (
- table,
- ", ".join("%s = ?" % (k,) for k in values),
- " AND ".join(_getwhere(k) for k in keyvalues),
- )
- sqlargs = list(values.values()) + list(keyvalues.values())
-
- txn.execute(sql, sqlargs)
- if txn.rowcount > 0:
- # successfully updated at least one row.
- return False
-
- # We didn't find any existing rows, so insert a new one
- allvalues = {}
- allvalues.update(keyvalues)
- allvalues.update(values)
- allvalues.update(insertion_values)
-
- sql = "INSERT INTO %s (%s) VALUES (%s)" % (
- table,
- ", ".join(k for k in allvalues),
- ", ".join("?" for _ in allvalues),
- )
- txn.execute(sql, list(allvalues.values()))
- # successfully inserted
- return True
-
- def simple_upsert_txn_native_upsert(
- self, txn, table, keyvalues, values, insertion_values={}
- ):
- """
- Use the native UPSERT functionality in recent PostgreSQL versions.
-
- Args:
- table (str): The table to upsert into
- keyvalues (dict): The unique key tables and their new values
- values (dict): The nonunique columns and their new values
- insertion_values (dict): additional key/values to use only when
- inserting
- Returns:
- None
- """
- allvalues = {}
- allvalues.update(keyvalues)
- allvalues.update(insertion_values)
-
- if not values:
- latter = "NOTHING"
- else:
- allvalues.update(values)
- latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values)
-
- sql = ("INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s") % (
- table,
- ", ".join(k for k in allvalues),
- ", ".join("?" for _ in allvalues),
- ", ".join(k for k in keyvalues),
- latter,
- )
- txn.execute(sql, list(allvalues.values()))
-
- def simple_upsert_many_txn(
- self, txn, table, key_names, key_values, value_names, value_values
- ):
- """
- Upsert, many times.
-
- Args:
- table (str): The table to upsert into
- key_names (list[str]): The key column names.
- key_values (list[list]): A list of each row's key column values.
- value_names (list[str]): The value column names. If empty, no
- values will be used, even if value_values is provided.
- value_values (list[list]): A list of each row's value column values.
- Returns:
- None
- """
- if (
- self.database_engine.can_native_upsert
- and table not in self._unsafe_to_upsert_tables
- ):
- return self.simple_upsert_many_txn_native_upsert(
- txn, table, key_names, key_values, value_names, value_values
- )
- else:
- return self.simple_upsert_many_txn_emulated(
- txn, table, key_names, key_values, value_names, value_values
- )
-
- def simple_upsert_many_txn_emulated(
- self, txn, table, key_names, key_values, value_names, value_values
- ):
- """
- Upsert, many times, but without native UPSERT support or batching.
-
- Args:
- table (str): The table to upsert into
- key_names (list[str]): The key column names.
- key_values (list[list]): A list of each row's key column values.
- value_names (list[str]): The value column names. If empty, no
- values will be used, even if value_values is provided.
- value_values (list[list]): A list of each row's value column values.
- Returns:
- None
- """
- # No value columns, therefore make a blank list so that the following
- # zip() works correctly.
- if not value_names:
- value_values = [() for x in range(len(key_values))]
-
- for keyv, valv in zip(key_values, value_values):
- _keys = {x: y for x, y in zip(key_names, keyv)}
- _vals = {x: y for x, y in zip(value_names, valv)}
-
- self.simple_upsert_txn_emulated(txn, table, _keys, _vals)
-
- def simple_upsert_many_txn_native_upsert(
- self, txn, table, key_names, key_values, value_names, value_values
- ):
- """
- Upsert, many times, using batching where possible.
-
- Args:
- table (str): The table to upsert into
- key_names (list[str]): The key column names.
- key_values (list[list]): A list of each row's key column values.
- value_names (list[str]): The value column names. If empty, no
- values will be used, even if value_values is provided.
- value_values (list[list]): A list of each row's value column values.
- Returns:
- None
- """
- allnames = []
- allnames.extend(key_names)
- allnames.extend(value_names)
-
- if not value_names:
- # No value columns, therefore make a blank list so that the
- # following zip() works correctly.
- latter = "NOTHING"
- value_values = [() for x in range(len(key_values))]
- else:
- latter = "UPDATE SET " + ", ".join(
- k + "=EXCLUDED." + k for k in value_names
- )
-
- sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s" % (
- table,
- ", ".join(k for k in allnames),
- ", ".join("?" for _ in allnames),
- ", ".join(key_names),
- latter,
- )
-
- args = []
-
- for x, y in zip(key_values, value_values):
- args.append(tuple(x) + tuple(y))
-
- return txn.execute_batch(sql, args)
-
- def simple_select_one(
- self, table, keyvalues, retcols, allow_none=False, desc="simple_select_one"
- ):
- """Executes a SELECT query on the named table, which is expected to
- return a single row, returning multiple columns from it.
-
- Args:
- table : string giving the table name
- keyvalues : dict of column names and values to select the row with
- retcols : list of strings giving the names of the columns to return
-
- allow_none : If true, return None instead of failing if the SELECT
- statement returns no rows
- """
- return self.runInteraction(
- desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none
- )
-
- def simple_select_one_onecol(
- self,
- table,
- keyvalues,
- retcol,
- allow_none=False,
- desc="simple_select_one_onecol",
- ):
- """Executes a SELECT query on the named table, which is expected to
- return a single row, returning a single column from it.
-
- Args:
- table : string giving the table name
- keyvalues : dict of column names and values to select the row with
- retcol : string giving the name of the column to return
- """
- return self.runInteraction(
- desc,
- self.simple_select_one_onecol_txn,
- table,
- keyvalues,
- retcol,
- allow_none=allow_none,
- )
-
- @classmethod
- def simple_select_one_onecol_txn(
- cls, txn, table, keyvalues, retcol, allow_none=False
- ):
- ret = cls.simple_select_onecol_txn(
- txn, table=table, keyvalues=keyvalues, retcol=retcol
- )
-
- if ret:
- return ret[0]
- else:
- if allow_none:
- return None
- else:
- raise StoreError(404, "No row found")
-
- @staticmethod
- def simple_select_onecol_txn(txn, table, keyvalues, retcol):
- sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table}
-
- if keyvalues:
- sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
- txn.execute(sql, list(keyvalues.values()))
- else:
- txn.execute(sql)
-
- return [r[0] for r in txn]
-
- def simple_select_onecol(
- self, table, keyvalues, retcol, desc="simple_select_onecol"
- ):
- """Executes a SELECT query on the named table, which returns a list
- comprising of the values of the named column from the selected rows.
-
- Args:
- table (str): table name
- keyvalues (dict|None): column names and values to select the rows with
- retcol (str): column whos value we wish to retrieve.
-
- Returns:
- Deferred: Results in a list
- """
- return self.runInteraction(
- desc, self.simple_select_onecol_txn, table, keyvalues, retcol
- )
-
- def simple_select_list(self, table, keyvalues, retcols, desc="simple_select_list"):
- """Executes a SELECT query on the named table, which may return zero or
- more rows, returning the result as a list of dicts.
-
- Args:
- table (str): the table name
- keyvalues (dict[str, Any] | None):
- column names and values to select the rows with, or None to not
- apply a WHERE clause.
- retcols (iterable[str]): the names of the columns to return
- Returns:
- defer.Deferred: resolves to list[dict[str, Any]]
- """
- return self.runInteraction(
- desc, self.simple_select_list_txn, table, keyvalues, retcols
- )
-
- @classmethod
- def simple_select_list_txn(cls, txn, table, keyvalues, retcols):
- """Executes a SELECT query on the named table, which may return zero or
- more rows, returning the result as a list of dicts.
-
- Args:
- txn : Transaction object
- table (str): the table name
- keyvalues (dict[str, T] | None):
- column names and values to select the rows with, or None to not
- apply a WHERE clause.
- retcols (iterable[str]): the names of the columns to return
- """
- if keyvalues:
- sql = "SELECT %s FROM %s WHERE %s" % (
- ", ".join(retcols),
- table,
- " AND ".join("%s = ?" % (k,) for k in keyvalues),
- )
- txn.execute(sql, list(keyvalues.values()))
- else:
- sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
- txn.execute(sql)
-
- return cls.cursor_to_dict(txn)
-
- @defer.inlineCallbacks
- def simple_select_many_batch(
- self,
- table,
- column,
- iterable,
- retcols,
- keyvalues={},
- desc="simple_select_many_batch",
- batch_size=100,
- ):
- """Executes a SELECT query on the named table, which may return zero or
- more rows, returning the result as a list of dicts.
-
- Filters rows by if value of `column` is in `iterable`.
-
- Args:
- table : string giving the table name
- column : column name to test for inclusion against `iterable`
- iterable : list
- keyvalues : dict of column names and values to select the rows with
- retcols : list of strings giving the names of the columns to return
- """
- results = []
-
- if not iterable:
- return results
-
- # iterables can not be sliced, so convert it to a list first
- it_list = list(iterable)
-
- chunks = [
- it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size)
- ]
- for chunk in chunks:
- rows = yield self.runInteraction(
- desc,
- self.simple_select_many_txn,
- table,
- column,
- chunk,
- keyvalues,
- retcols,
- )
-
- results.extend(rows)
-
- return results
-
- @classmethod
- def simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols):
- """Executes a SELECT query on the named table, which may return zero or
- more rows, returning the result as a list of dicts.
-
- Filters rows by if value of `column` is in `iterable`.
-
- Args:
- txn : Transaction object
- table : string giving the table name
- column : column name to test for inclusion against `iterable`
- iterable : list
- keyvalues : dict of column names and values to select the rows with
- retcols : list of strings giving the names of the columns to return
- """
- if not iterable:
- return []
-
- clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable)
- clauses = [clause]
-
- for key, value in iteritems(keyvalues):
- clauses.append("%s = ?" % (key,))
- values.append(value)
-
- sql = "SELECT %s FROM %s WHERE %s" % (
- ", ".join(retcols),
- table,
- " AND ".join(clauses),
- )
-
- txn.execute(sql, values)
- return cls.cursor_to_dict(txn)
-
- def simple_update(self, table, keyvalues, updatevalues, desc):
- return self.runInteraction(
- desc, self.simple_update_txn, table, keyvalues, updatevalues
- )
-
- @staticmethod
- def simple_update_txn(txn, table, keyvalues, updatevalues):
- if keyvalues:
- where = "WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
- else:
- where = ""
-
- update_sql = "UPDATE %s SET %s %s" % (
- table,
- ", ".join("%s = ?" % (k,) for k in updatevalues),
- where,
- )
-
- txn.execute(update_sql, list(updatevalues.values()) + list(keyvalues.values()))
-
- return txn.rowcount
-
- def simple_update_one(
- self, table, keyvalues, updatevalues, desc="simple_update_one"
- ):
- """Executes an UPDATE query on the named table, setting new values for
- columns in a row matching the key values.
-
- Args:
- table : string giving the table name
- keyvalues : dict of column names and values to select the row with
- updatevalues : dict giving column names and values to update
- retcols : optional list of column names to return
-
- If present, retcols gives a list of column names on which to perform
- a SELECT statement *before* performing the UPDATE statement. The values
- of these will be returned in a dict.
-
- These are performed within the same transaction, allowing an atomic
- get-and-set. This can be used to implement compare-and-set by putting
- the update column in the 'keyvalues' dict as well.
- """
- return self.runInteraction(
- desc, self.simple_update_one_txn, table, keyvalues, updatevalues
- )
-
- @classmethod
- def simple_update_one_txn(cls, txn, table, keyvalues, updatevalues):
- rowcount = cls.simple_update_txn(txn, table, keyvalues, updatevalues)
-
- if rowcount == 0:
- raise StoreError(404, "No row found (%s)" % (table,))
- if rowcount > 1:
- raise StoreError(500, "More than one row matched (%s)" % (table,))
-
- @staticmethod
- def simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False):
- select_sql = "SELECT %s FROM %s WHERE %s" % (
- ", ".join(retcols),
- table,
- " AND ".join("%s = ?" % (k,) for k in keyvalues),
- )
-
- txn.execute(select_sql, list(keyvalues.values()))
- row = txn.fetchone()
-
- if not row:
- if allow_none:
- return None
- raise StoreError(404, "No row found (%s)" % (table,))
- if txn.rowcount > 1:
- raise StoreError(500, "More than one row matched (%s)" % (table,))
-
- return dict(zip(retcols, row))
-
- def simple_delete_one(self, table, keyvalues, desc="simple_delete_one"):
- """Executes a DELETE query on the named table, expecting to delete a
- single row.
-
- Args:
- table : string giving the table name
- keyvalues : dict of column names and values to select the row with
- """
- return self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues)
-
- @staticmethod
- def simple_delete_one_txn(txn, table, keyvalues):
- """Executes a DELETE query on the named table, expecting to delete a
- single row.
-
- Args:
- table : string giving the table name
- keyvalues : dict of column names and values to select the row with
- """
- sql = "DELETE FROM %s WHERE %s" % (
- table,
- " AND ".join("%s = ?" % (k,) for k in keyvalues),
- )
-
- txn.execute(sql, list(keyvalues.values()))
- if txn.rowcount == 0:
- raise StoreError(404, "No row found (%s)" % (table,))
- if txn.rowcount > 1:
- raise StoreError(500, "More than one row matched (%s)" % (table,))
-
- def simple_delete(self, table, keyvalues, desc):
- return self.runInteraction(desc, self.simple_delete_txn, table, keyvalues)
-
- @staticmethod
- def simple_delete_txn(txn, table, keyvalues):
- sql = "DELETE FROM %s WHERE %s" % (
- table,
- " AND ".join("%s = ?" % (k,) for k in keyvalues),
- )
-
- txn.execute(sql, list(keyvalues.values()))
- return txn.rowcount
-
- def simple_delete_many(self, table, column, iterable, keyvalues, desc):
- return self.runInteraction(
- desc, self.simple_delete_many_txn, table, column, iterable, keyvalues
- )
-
- @staticmethod
- def simple_delete_many_txn(txn, table, column, iterable, keyvalues):
- """Executes a DELETE query on the named table.
-
- Filters rows by if value of `column` is in `iterable`.
-
- Args:
- txn : Transaction object
- table : string giving the table name
- column : column name to test for inclusion against `iterable`
- iterable : list
- keyvalues : dict of column names and values to select the rows with
-
- Returns:
- int: Number rows deleted
- """
- if not iterable:
- return 0
-
- sql = "DELETE FROM %s" % table
-
- clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable)
- clauses = [clause]
-
- for key, value in iteritems(keyvalues):
- clauses.append("%s = ?" % (key,))
- values.append(value)
-
- if clauses:
- sql = "%s WHERE %s" % (sql, " AND ".join(clauses))
- txn.execute(sql, values)
-
- return txn.rowcount
-
- def get_cache_dict(
- self, db_conn, table, entity_column, stream_column, max_value, limit=100000
- ):
- # Fetch a mapping of room_id -> max stream position for "recent" rooms.
- # It doesn't really matter how many we get, the StreamChangeCache will
- # do the right thing to ensure it respects the max size of cache.
- sql = (
- "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s"
- " WHERE %(stream)s > ? - %(limit)s"
- " GROUP BY %(entity)s"
- ) % {
- "table": table,
- "entity": entity_column,
- "stream": stream_column,
- "limit": limit,
- }
-
- sql = self.database_engine.convert_param_style(sql)
-
- txn = db_conn.cursor()
- txn.execute(sql, (int(max_value),))
-
- cache = {row[0]: int(row[1]) for row in txn}
-
- txn.close()
-
- if cache:
- min_val = min(itervalues(cache))
- else:
- min_val = max_value
-
- return cache, min_val
-
def _invalidate_state_caches(self, room_id, members_changed):
"""Invalidates caches that are based on the current state, but does
not stream invalidations down replication.
@@ -1347,159 +71,6 @@ class SQLBaseStore(object):
# which is fine.
pass
- def simple_select_list_paginate(
- self,
- table,
- keyvalues,
- orderby,
- start,
- limit,
- retcols,
- order_direction="ASC",
- desc="simple_select_list_paginate",
- ):
- """
- Executes a SELECT query on the named table with start and limit,
- of row numbers, which may return zero or number of rows from start to limit,
- returning the result as a list of dicts.
-
- Args:
- table (str): the table name
- keyvalues (dict[str, T] | None):
- column names and values to select the rows with, or None to not
- apply a WHERE clause.
- orderby (str): Column to order the results by.
- start (int): Index to begin the query at.
- limit (int): Number of results to return.
- retcols (iterable[str]): the names of the columns to return
- order_direction (str): Whether the results should be ordered "ASC" or "DESC".
- Returns:
- defer.Deferred: resolves to list[dict[str, Any]]
- """
- return self.runInteraction(
- desc,
- self.simple_select_list_paginate_txn,
- table,
- keyvalues,
- orderby,
- start,
- limit,
- retcols,
- order_direction=order_direction,
- )
-
- @classmethod
- def simple_select_list_paginate_txn(
- cls,
- txn,
- table,
- keyvalues,
- orderby,
- start,
- limit,
- retcols,
- order_direction="ASC",
- ):
- """
- Executes a SELECT query on the named table with start and limit,
- of row numbers, which may return zero or number of rows from start to limit,
- returning the result as a list of dicts.
-
- Args:
- txn : Transaction object
- table (str): the table name
- keyvalues (dict[str, T] | None):
- column names and values to select the rows with, or None to not
- apply a WHERE clause.
- orderby (str): Column to order the results by.
- start (int): Index to begin the query at.
- limit (int): Number of results to return.
- retcols (iterable[str]): the names of the columns to return
- order_direction (str): Whether the results should be ordered "ASC" or "DESC".
- Returns:
- defer.Deferred: resolves to list[dict[str, Any]]
- """
- if order_direction not in ["ASC", "DESC"]:
- raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
-
- if keyvalues:
- where_clause = "WHERE " + " AND ".join("%s = ?" % (k,) for k in keyvalues)
- else:
- where_clause = ""
-
- sql = "SELECT %s FROM %s %s ORDER BY %s %s LIMIT ? OFFSET ?" % (
- ", ".join(retcols),
- table,
- where_clause,
- orderby,
- order_direction,
- )
- txn.execute(sql, list(keyvalues.values()) + [limit, start])
-
- return cls.cursor_to_dict(txn)
-
- def get_user_count_txn(self, txn):
- """Get a total number of registered users in the users list.
-
- Args:
- txn : Transaction object
- Returns:
- int : number of users
- """
- sql_count = "SELECT COUNT(*) FROM users WHERE is_guest = 0;"
- txn.execute(sql_count)
- return txn.fetchone()[0]
-
- def simple_search_list(self, table, term, col, retcols, desc="simple_search_list"):
- """Executes a SELECT query on the named table, which may return zero or
- more rows, returning the result as a list of dicts.
-
- Args:
- table (str): the table name
- term (str | None):
- term for searching the table matched to a column.
- col (str): column to query term should be matched to
- retcols (iterable[str]): the names of the columns to return
- Returns:
- defer.Deferred: resolves to list[dict[str, Any]] or None
- """
-
- return self.runInteraction(
- desc, self.simple_search_list_txn, table, term, col, retcols
- )
-
- @classmethod
- def simple_search_list_txn(cls, txn, table, term, col, retcols):
- """Executes a SELECT query on the named table, which may return zero or
- more rows, returning the result as a list of dicts.
-
- Args:
- txn : Transaction object
- table (str): the table name
- term (str | None):
- term for searching the table matched to a column.
- col (str): column to query term should be matched to
- retcols (iterable[str]): the names of the columns to return
- Returns:
- defer.Deferred: resolves to list[dict[str, Any]] or None
- """
- if term:
- sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col)
- termvalues = ["%%" + term + "%%"]
- txn.execute(sql, termvalues)
- else:
- return 0
-
- return cls.cursor_to_dict(txn)
-
-
-class _RollbackButIsFineException(Exception):
- """ This exception is used to rollback a transaction without implying
- something went wrong.
- """
-
- pass
-
def db_to_json(db_content):
"""
@@ -1528,30 +99,3 @@ def db_to_json(db_content):
except Exception:
logging.warning("Tried to decode '%r' as JSON and failed", db_content)
raise
-
-
-def make_in_list_sql_clause(
- database_engine, column: str, iterable: Iterable
-) -> Tuple[str, Iterable]:
- """Returns an SQL clause that checks the given column is in the iterable.
-
- On SQLite this expands to `column IN (?, ?, ...)`, whereas on Postgres
- it expands to `column = ANY(?)`. While both DBs support the `IN` form,
- using the `ANY` form on postgres means that it views queries with
- different length iterables as the same, helping the query stats.
-
- Args:
- database_engine
- column: Name of the column
- iterable: The values to check the column against.
-
- Returns:
- A tuple of SQL query and the args
- """
-
- if database_engine.supports_using_any_list:
- # This should hopefully be faster, but also makes postgres query
- # stats easier to understand.
- return "%s = ANY(?)" % (column,), [list(iterable)]
- else:
- return "%s IN (%s)" % (column, ",".join("?" for _ in iterable)), list(iterable)
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 06955a0537..dfca94b0e0 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -139,7 +139,7 @@ class BackgroundUpdateStore(SQLBaseStore):
# otherwise, check if there are updates to be run. This is important,
# as we may be running on a worker which doesn't perform the bg updates
# itself, but still wants to wait for them to happen.
- updates = yield self.simple_select_onecol(
+ updates = yield self.db.simple_select_onecol(
"background_updates",
keyvalues=None,
retcol="1",
@@ -161,7 +161,7 @@ class BackgroundUpdateStore(SQLBaseStore):
if update_name in self._background_update_queue:
return False
- update_exists = await self.simple_select_one_onecol(
+ update_exists = await self.db.simple_select_one_onecol(
"background_updates",
keyvalues={"update_name": update_name},
retcol="1",
@@ -184,7 +184,7 @@ class BackgroundUpdateStore(SQLBaseStore):
no more work to do.
"""
if not self._background_update_queue:
- updates = yield self.simple_select_list(
+ updates = yield self.db.simple_select_list(
"background_updates",
keyvalues=None,
retcols=("update_name", "depends_on"),
@@ -226,7 +226,7 @@ class BackgroundUpdateStore(SQLBaseStore):
else:
batch_size = self.DEFAULT_BACKGROUND_BATCH_SIZE
- progress_json = yield self.simple_select_one_onecol(
+ progress_json = yield self.db.simple_select_one_onecol(
"background_updates",
keyvalues={"update_name": update_name},
retcol="progress_json",
@@ -391,7 +391,7 @@ class BackgroundUpdateStore(SQLBaseStore):
def updater(progress, batch_size):
if runner is not None:
logger.info("Adding index %s to %s", index_name, table)
- yield self.runWithConnection(runner)
+ yield self.db.runWithConnection(runner)
yield self._end_background_update(update_name)
return 1
@@ -413,7 +413,7 @@ class BackgroundUpdateStore(SQLBaseStore):
self._background_update_queue = []
progress_json = json.dumps(progress)
- return self.simple_insert(
+ return self.db.simple_insert(
"background_updates",
{"update_name": update_name, "progress_json": progress_json},
)
@@ -429,7 +429,7 @@ class BackgroundUpdateStore(SQLBaseStore):
self._background_update_queue = [
name for name in self._background_update_queue if name != update_name
]
- return self.simple_delete_one(
+ return self.db.simple_delete_one(
"background_updates", keyvalues={"update_name": update_name}
)
@@ -444,7 +444,7 @@ class BackgroundUpdateStore(SQLBaseStore):
progress_json = json.dumps(progress)
- self.simple_update_one_txn(
+ self.db.simple_update_one_txn(
txn,
"background_updates",
keyvalues={"update_name": update_name},
diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py
index 2a5b33dda1..46f0f26af6 100644
--- a/synapse/storage/data_stores/main/__init__.py
+++ b/synapse/storage/data_stores/main/__init__.py
@@ -171,9 +171,11 @@ class DataStore(
else:
self._cache_id_gen = None
+ super(DataStore, self).__init__(db_conn, hs)
+
self._presence_on_startup = self._get_active_presence(db_conn)
- presence_cache_prefill, min_presence_val = self.get_cache_dict(
+ presence_cache_prefill, min_presence_val = self.db.get_cache_dict(
db_conn,
"presence_stream",
entity_column="user_id",
@@ -187,7 +189,7 @@ class DataStore(
)
max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
- device_inbox_prefill, min_device_inbox_id = self.get_cache_dict(
+ device_inbox_prefill, min_device_inbox_id = self.db.get_cache_dict(
db_conn,
"device_inbox",
entity_column="user_id",
@@ -202,7 +204,7 @@ class DataStore(
)
# The federation outbox and the local device inbox uses the same
# stream_id generator.
- device_outbox_prefill, min_device_outbox_id = self.get_cache_dict(
+ device_outbox_prefill, min_device_outbox_id = self.db.get_cache_dict(
db_conn,
"device_federation_outbox",
entity_column="destination",
@@ -228,7 +230,7 @@ class DataStore(
)
events_max = self._stream_id_gen.get_current_token()
- curr_state_delta_prefill, min_curr_state_delta_id = self.get_cache_dict(
+ curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict(
db_conn,
"current_state_delta_stream",
entity_column="room_id",
@@ -242,7 +244,7 @@ class DataStore(
prefilled_cache=curr_state_delta_prefill,
)
- _group_updates_prefill, min_group_updates_id = self.get_cache_dict(
+ _group_updates_prefill, min_group_updates_id = self.db.get_cache_dict(
db_conn,
"local_group_updates",
entity_column="user_id",
@@ -262,8 +264,6 @@ class DataStore(
# Used in _generate_user_daily_visits to keep track of progress
self._last_user_visit_update = self._get_start_of_day()
- super(DataStore, self).__init__(db_conn, hs)
-
def take_presence_startup_info(self):
active_on_startup = self._presence_on_startup
self._presence_on_startup = None
@@ -283,7 +283,7 @@ class DataStore(
txn = db_conn.cursor()
txn.execute(sql, (PresenceState.OFFLINE,))
- rows = self.cursor_to_dict(txn)
+ rows = self.db.cursor_to_dict(txn)
txn.close()
for row in rows:
@@ -296,7 +296,7 @@ class DataStore(
Counts the number of users who used this homeserver in the last 24 hours.
"""
yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
- return self.runInteraction("count_daily_users", self._count_users, yesterday)
+ return self.db.runInteraction("count_daily_users", self._count_users, yesterday)
def count_monthly_users(self):
"""
@@ -306,7 +306,7 @@ class DataStore(
amongst other things, includes a 3 day grace period before a user counts.
"""
thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
- return self.runInteraction(
+ return self.db.runInteraction(
"count_monthly_users", self._count_users, thirty_days_ago
)
@@ -406,7 +406,7 @@ class DataStore(
return results
- return self.runInteraction("count_r30_users", _count_r30_users)
+ return self.db.runInteraction("count_r30_users", _count_r30_users)
def _get_start_of_day(self):
"""
@@ -471,7 +471,7 @@ class DataStore(
# frequently
self._last_user_visit_update = now
- return self.runInteraction(
+ return self.db.runInteraction(
"generate_user_daily_visits", _generate_user_daily_visits
)
@@ -482,7 +482,7 @@ class DataStore(
Returns:
defer.Deferred: resolves to list[dict[str, Any]]
"""
- return self.simple_select_list(
+ return self.db.simple_select_list(
table="users",
keyvalues={},
retcols=["name", "password_hash", "is_guest", "admin", "user_type"],
@@ -502,9 +502,9 @@ class DataStore(
Returns:
defer.Deferred: resolves to json object {list[dict[str, Any]], count}
"""
- users = yield self.runInteraction(
+ users = yield self.db.runInteraction(
"get_users_paginate",
- self.simple_select_list_paginate_txn,
+ self.db.simple_select_list_paginate_txn,
table="users",
keyvalues={"is_guest": False},
orderby=order,
@@ -512,7 +512,9 @@ class DataStore(
limit=limit,
retcols=["name", "password_hash", "is_guest", "admin", "user_type"],
)
- count = yield self.runInteraction("get_users_paginate", self.get_user_count_txn)
+ count = yield self.db.runInteraction(
+ "get_users_paginate", self.get_user_count_txn
+ )
retval = {"users": users, "total": count}
return retval
@@ -526,7 +528,7 @@ class DataStore(
Returns:
defer.Deferred: resolves to list[dict[str, Any]]
"""
- return self.simple_search_list(
+ return self.db.simple_search_list(
table="users",
term=term,
col="name",
diff --git a/synapse/storage/data_stores/main/account_data.py b/synapse/storage/data_stores/main/account_data.py
index b0d22faf3f..a96fe9485c 100644
--- a/synapse/storage/data_stores/main/account_data.py
+++ b/synapse/storage/data_stores/main/account_data.py
@@ -67,7 +67,7 @@ class AccountDataWorkerStore(SQLBaseStore):
"""
def get_account_data_for_user_txn(txn):
- rows = self.simple_select_list_txn(
+ rows = self.db.simple_select_list_txn(
txn,
"account_data",
{"user_id": user_id},
@@ -78,7 +78,7 @@ class AccountDataWorkerStore(SQLBaseStore):
row["account_data_type"]: json.loads(row["content"]) for row in rows
}
- rows = self.simple_select_list_txn(
+ rows = self.db.simple_select_list_txn(
txn,
"room_account_data",
{"user_id": user_id},
@@ -92,7 +92,7 @@ class AccountDataWorkerStore(SQLBaseStore):
return global_account_data, by_room
- return self.runInteraction(
+ return self.db.runInteraction(
"get_account_data_for_user", get_account_data_for_user_txn
)
@@ -102,7 +102,7 @@ class AccountDataWorkerStore(SQLBaseStore):
Returns:
Deferred: A dict
"""
- result = yield self.simple_select_one_onecol(
+ result = yield self.db.simple_select_one_onecol(
table="account_data",
keyvalues={"user_id": user_id, "account_data_type": data_type},
retcol="content",
@@ -127,7 +127,7 @@ class AccountDataWorkerStore(SQLBaseStore):
"""
def get_account_data_for_room_txn(txn):
- rows = self.simple_select_list_txn(
+ rows = self.db.simple_select_list_txn(
txn,
"room_account_data",
{"user_id": user_id, "room_id": room_id},
@@ -138,7 +138,7 @@ class AccountDataWorkerStore(SQLBaseStore):
row["account_data_type"]: json.loads(row["content"]) for row in rows
}
- return self.runInteraction(
+ return self.db.runInteraction(
"get_account_data_for_room", get_account_data_for_room_txn
)
@@ -156,7 +156,7 @@ class AccountDataWorkerStore(SQLBaseStore):
"""
def get_account_data_for_room_and_type_txn(txn):
- content_json = self.simple_select_one_onecol_txn(
+ content_json = self.db.simple_select_one_onecol_txn(
txn,
table="room_account_data",
keyvalues={
@@ -170,7 +170,7 @@ class AccountDataWorkerStore(SQLBaseStore):
return json.loads(content_json) if content_json else None
- return self.runInteraction(
+ return self.db.runInteraction(
"get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
)
@@ -207,7 +207,7 @@ class AccountDataWorkerStore(SQLBaseStore):
room_results = txn.fetchall()
return global_results, room_results
- return self.runInteraction(
+ return self.db.runInteraction(
"get_all_updated_account_data_txn", get_updated_account_data_txn
)
@@ -252,7 +252,7 @@ class AccountDataWorkerStore(SQLBaseStore):
if not changed:
return {}, {}
- return self.runInteraction(
+ return self.db.runInteraction(
"get_updated_account_data_for_user", get_updated_account_data_for_user_txn
)
@@ -302,7 +302,7 @@ class AccountDataStore(AccountDataWorkerStore):
# no need to lock here as room_account_data has a unique constraint
# on (user_id, room_id, account_data_type) so simple_upsert will
# retry if there is a conflict.
- yield self.simple_upsert(
+ yield self.db.simple_upsert(
desc="add_room_account_data",
table="room_account_data",
keyvalues={
@@ -348,7 +348,7 @@ class AccountDataStore(AccountDataWorkerStore):
# no need to lock here as account_data has a unique constraint on
# (user_id, account_data_type) so simple_upsert will retry if
# there is a conflict.
- yield self.simple_upsert(
+ yield self.db.simple_upsert(
desc="add_user_account_data",
table="account_data",
keyvalues={"user_id": user_id, "account_data_type": account_data_type},
@@ -388,4 +388,4 @@ class AccountDataStore(AccountDataWorkerStore):
)
txn.execute(update_max_id_sql, (next_id, next_id))
- return self.runInteraction("update_account_data_max_stream_id", _update)
+ return self.db.runInteraction("update_account_data_max_stream_id", _update)
diff --git a/synapse/storage/data_stores/main/appservice.py b/synapse/storage/data_stores/main/appservice.py
index 6b82fd392a..6b2e12719c 100644
--- a/synapse/storage/data_stores/main/appservice.py
+++ b/synapse/storage/data_stores/main/appservice.py
@@ -133,7 +133,7 @@ class ApplicationServiceTransactionWorkerStore(
A Deferred which resolves to a list of ApplicationServices, which
may be empty.
"""
- results = yield self.simple_select_list(
+ results = yield self.db.simple_select_list(
"application_services_state", dict(state=state), ["as_id"]
)
# NB: This assumes this class is linked with ApplicationServiceStore
@@ -155,7 +155,7 @@ class ApplicationServiceTransactionWorkerStore(
Returns:
A Deferred which resolves to ApplicationServiceState.
"""
- result = yield self.simple_select_one(
+ result = yield self.db.simple_select_one(
"application_services_state",
dict(as_id=service.id),
["state"],
@@ -175,7 +175,7 @@ class ApplicationServiceTransactionWorkerStore(
Returns:
A Deferred which resolves when the state was set successfully.
"""
- return self.simple_upsert(
+ return self.db.simple_upsert(
"application_services_state", dict(as_id=service.id), dict(state=state)
)
@@ -216,7 +216,7 @@ class ApplicationServiceTransactionWorkerStore(
)
return AppServiceTransaction(service=service, id=new_txn_id, events=events)
- return self.runInteraction("create_appservice_txn", _create_appservice_txn)
+ return self.db.runInteraction("create_appservice_txn", _create_appservice_txn)
def complete_appservice_txn(self, txn_id, service):
"""Completes an application service transaction.
@@ -249,7 +249,7 @@ class ApplicationServiceTransactionWorkerStore(
)
# Set current txn_id for AS to 'txn_id'
- self.simple_upsert_txn(
+ self.db.simple_upsert_txn(
txn,
"application_services_state",
dict(as_id=service.id),
@@ -257,11 +257,13 @@ class ApplicationServiceTransactionWorkerStore(
)
# Delete txn
- self.simple_delete_txn(
+ self.db.simple_delete_txn(
txn, "application_services_txns", dict(txn_id=txn_id, as_id=service.id)
)
- return self.runInteraction("complete_appservice_txn", _complete_appservice_txn)
+ return self.db.runInteraction(
+ "complete_appservice_txn", _complete_appservice_txn
+ )
@defer.inlineCallbacks
def get_oldest_unsent_txn(self, service):
@@ -283,7 +285,7 @@ class ApplicationServiceTransactionWorkerStore(
" ORDER BY txn_id ASC LIMIT 1",
(service.id,),
)
- rows = self.cursor_to_dict(txn)
+ rows = self.db.cursor_to_dict(txn)
if not rows:
return None
@@ -291,7 +293,7 @@ class ApplicationServiceTransactionWorkerStore(
return entry
- entry = yield self.runInteraction(
+ entry = yield self.db.runInteraction(
"get_oldest_unsent_appservice_txn", _get_oldest_unsent_txn
)
@@ -321,7 +323,7 @@ class ApplicationServiceTransactionWorkerStore(
"UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
)
- return self.runInteraction(
+ return self.db.runInteraction(
"set_appservice_last_pos", set_appservice_last_pos_txn
)
@@ -350,7 +352,7 @@ class ApplicationServiceTransactionWorkerStore(
return upper_bound, [row[1] for row in rows]
- upper_bound, event_ids = yield self.runInteraction(
+ upper_bound, event_ids = yield self.db.runInteraction(
"get_new_events_for_appservice", get_new_events_for_appservice_txn
)
diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py
index de3256049d..54ed8574c4 100644
--- a/synapse/storage/data_stores/main/cache.py
+++ b/synapse/storage/data_stores/main/cache.py
@@ -95,7 +95,7 @@ class CacheInvalidationStore(SQLBaseStore):
txn.call_after(ctx.__exit__, None, None, None)
txn.call_after(self.hs.get_notifier().on_new_replication_data)
- self.simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="cache_invalidation_stream",
values={
@@ -122,7 +122,9 @@ class CacheInvalidationStore(SQLBaseStore):
txn.execute(sql, (last_id, limit))
return txn.fetchall()
- return self.runInteraction("get_all_updated_caches", get_all_updated_caches_txn)
+ return self.db.runInteraction(
+ "get_all_updated_caches", get_all_updated_caches_txn
+ )
def get_cache_stream_token(self):
if self._cache_id_gen:
diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py
index 66522a04b7..6f2a720b97 100644
--- a/synapse/storage/data_stores/main/client_ips.py
+++ b/synapse/storage/data_stores/main/client_ips.py
@@ -91,7 +91,7 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore):
txn.execute("DROP INDEX IF EXISTS user_ips_user_ip")
txn.close()
- yield self.runWithConnection(f)
+ yield self.db.runWithConnection(f)
yield self._end_background_update("user_ips_drop_nonunique_index")
return 1
@@ -106,7 +106,7 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore):
def user_ips_analyze(txn):
txn.execute("ANALYZE user_ips")
- yield self.runInteraction("user_ips_analyze", user_ips_analyze)
+ yield self.db.runInteraction("user_ips_analyze", user_ips_analyze)
yield self._end_background_update("user_ips_analyze")
@@ -140,7 +140,7 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore):
return None
# Get a last seen that has roughly `batch_size` since `begin_last_seen`
- end_last_seen = yield self.runInteraction(
+ end_last_seen = yield self.db.runInteraction(
"user_ips_dups_get_last_seen", get_last_seen
)
@@ -275,7 +275,7 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore):
txn, "user_ips_remove_dupes", {"last_seen": end_last_seen}
)
- yield self.runInteraction("user_ips_dups_remove", remove)
+ yield self.db.runInteraction("user_ips_dups_remove", remove)
if last:
yield self._end_background_update("user_ips_remove_dupes")
@@ -352,7 +352,7 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore):
return len(rows)
- updated = yield self.runInteraction(
+ updated = yield self.db.runInteraction(
"_devices_last_seen_update", _devices_last_seen_update_txn
)
@@ -417,12 +417,12 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
to_update = self._batch_row_update
self._batch_row_update = {}
- return self.runInteraction(
+ return self.db.runInteraction(
"_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
)
def _update_client_ips_batch_txn(self, txn, to_update):
- if "user_ips" in self._unsafe_to_upsert_tables or (
+ if "user_ips" in self.db._unsafe_to_upsert_tables or (
not self.database_engine.can_native_upsert
):
self.database_engine.lock_table(txn, "user_ips")
@@ -431,7 +431,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
(user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
try:
- self.simple_upsert_txn(
+ self.db.simple_upsert_txn(
txn,
table="user_ips",
keyvalues={
@@ -450,7 +450,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
# Technically an access token might not be associated with
# a device so we need to check.
if device_id:
- self.simple_upsert_txn(
+ self.db.simple_upsert_txn(
txn,
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id},
@@ -483,7 +483,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
if device_id is not None:
keyvalues["device_id"] = device_id
- res = yield self.simple_select_list(
+ res = yield self.db.simple_select_list(
table="devices",
keyvalues=keyvalues,
retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
@@ -516,7 +516,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
user_agent, _, last_seen = self._batch_row_update[key]
results[(access_token, ip)] = (user_agent, last_seen)
- rows = yield self.simple_select_list(
+ rows = yield self.db.simple_select_list(
table="user_ips",
keyvalues={"user_id": user_id},
retcols=["access_token", "ip", "user_agent", "last_seen"],
@@ -577,4 +577,4 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
def _prune_old_user_ips_txn(txn):
txn.execute(sql, (timestamp,))
- await self.runInteraction("_prune_old_user_ips", _prune_old_user_ips_txn)
+ await self.db.runInteraction("_prune_old_user_ips", _prune_old_user_ips_txn)
diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py
index 206d39134d..440793ad49 100644
--- a/synapse/storage/data_stores/main/deviceinbox.py
+++ b/synapse/storage/data_stores/main/deviceinbox.py
@@ -69,7 +69,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
stream_pos = current_stream_id
return messages, stream_pos
- return self.runInteraction(
+ return self.db.runInteraction(
"get_new_messages_for_device", get_new_messages_for_device_txn
)
@@ -109,7 +109,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
txn.execute(sql, (user_id, device_id, up_to_stream_id))
return txn.rowcount
- count = yield self.runInteraction(
+ count = yield self.db.runInteraction(
"delete_messages_for_device", delete_messages_for_device_txn
)
@@ -178,7 +178,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
stream_pos = current_stream_id
return messages, stream_pos
- return self.runInteraction(
+ return self.db.runInteraction(
"get_new_device_msgs_for_remote",
get_new_messages_for_remote_destination_txn,
)
@@ -203,7 +203,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
)
txn.execute(sql, (destination, up_to_stream_id))
- return self.runInteraction(
+ return self.db.runInteraction(
"delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn
)
@@ -232,7 +232,7 @@ class DeviceInboxBackgroundUpdateStore(BackgroundUpdateStore):
txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
txn.close()
- yield self.runWithConnection(reindex_txn)
+ yield self.db.runWithConnection(reindex_txn)
yield self._end_background_update(self.DEVICE_INBOX_STREAM_ID)
@@ -294,7 +294,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
- yield self.runInteraction(
+ yield self.db.runInteraction(
"add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
)
for user_id in local_messages_by_user_then_device.keys():
@@ -314,7 +314,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
# Check if we've already inserted a matching message_id for that
# origin. This can happen if the origin doesn't receive our
# acknowledgement from the first time we received the message.
- already_inserted = self.simple_select_one_txn(
+ already_inserted = self.db.simple_select_one_txn(
txn,
table="device_federation_inbox",
keyvalues={"origin": origin, "message_id": message_id},
@@ -326,7 +326,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
# Add an entry for this message_id so that we know we've processed
# it.
- self.simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="device_federation_inbox",
values={
@@ -344,7 +344,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
- yield self.runInteraction(
+ yield self.db.runInteraction(
"add_messages_from_remote_to_device_inbox",
add_messages_txn,
now_ms,
@@ -465,6 +465,6 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
return rows
- return self.runInteraction(
+ return self.db.runInteraction(
"get_all_new_device_messages", get_all_new_device_messages_txn
)
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index 727c582121..d98511ddd4 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -61,7 +61,7 @@ class DeviceWorkerStore(SQLBaseStore):
Raises:
StoreError: if the device is not found
"""
- return self.simple_select_one(
+ return self.db.simple_select_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"),
@@ -80,7 +80,7 @@ class DeviceWorkerStore(SQLBaseStore):
containing "device_id", "user_id" and "display_name" for each
device.
"""
- devices = yield self.simple_select_list(
+ devices = yield self.db.simple_select_list(
table="devices",
keyvalues={"user_id": user_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"),
@@ -122,7 +122,7 @@ class DeviceWorkerStore(SQLBaseStore):
# consider the device update to be too large, and simply skip the
# stream_id; the rationale being that such a large device list update
# is likely an error.
- updates = yield self.runInteraction(
+ updates = yield self.db.runInteraction(
"get_device_updates_by_remote",
self._get_device_updates_by_remote_txn,
destination,
@@ -283,7 +283,7 @@ class DeviceWorkerStore(SQLBaseStore):
"""
devices = (
- yield self.runInteraction(
+ yield self.db.runInteraction(
"_get_e2e_device_keys_txn",
self._get_e2e_device_keys_txn,
query_map.keys(),
@@ -340,12 +340,12 @@ class DeviceWorkerStore(SQLBaseStore):
rows = txn.fetchall()
return rows[0][0]
- return self.runInteraction("get_last_device_update_for_remote_user", f)
+ return self.db.runInteraction("get_last_device_update_for_remote_user", f)
def mark_as_sent_devices_by_remote(self, destination, stream_id):
"""Mark that updates have successfully been sent to the destination.
"""
- return self.runInteraction(
+ return self.db.runInteraction(
"mark_as_sent_devices_by_remote",
self._mark_as_sent_devices_by_remote_txn,
destination,
@@ -399,7 +399,7 @@ class DeviceWorkerStore(SQLBaseStore):
"""
with self._device_list_id_gen.get_next() as stream_id:
- yield self.runInteraction(
+ yield self.db.runInteraction(
"add_user_sig_change_to_streams",
self._add_user_signature_change_txn,
from_user_id,
@@ -414,7 +414,7 @@ class DeviceWorkerStore(SQLBaseStore):
from_user_id,
stream_id,
)
- self.simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
"user_signature_stream",
values={
@@ -466,7 +466,7 @@ class DeviceWorkerStore(SQLBaseStore):
@cachedInlineCallbacks(num_args=2, tree=True)
def _get_cached_user_device(self, user_id, device_id):
- content = yield self.simple_select_one_onecol(
+ content = yield self.db.simple_select_one_onecol(
table="device_lists_remote_cache",
keyvalues={"user_id": user_id, "device_id": device_id},
retcol="content",
@@ -476,7 +476,7 @@ class DeviceWorkerStore(SQLBaseStore):
@cachedInlineCallbacks()
def _get_cached_devices_for_user(self, user_id):
- devices = yield self.simple_select_list(
+ devices = yield self.db.simple_select_list(
table="device_lists_remote_cache",
keyvalues={"user_id": user_id},
retcols=("device_id", "content"),
@@ -492,7 +492,7 @@ class DeviceWorkerStore(SQLBaseStore):
Returns:
(stream_id, devices)
"""
- return self.runInteraction(
+ return self.db.runInteraction(
"get_devices_with_keys_by_user",
self._get_devices_with_keys_by_user_txn,
user_id,
@@ -565,7 +565,7 @@ class DeviceWorkerStore(SQLBaseStore):
return changes
- return self.runInteraction(
+ return self.db.runInteraction(
"get_users_whose_devices_changed", _get_users_whose_devices_changed_txn
)
@@ -584,7 +584,7 @@ class DeviceWorkerStore(SQLBaseStore):
SELECT DISTINCT user_ids FROM user_signature_stream
WHERE from_user_id = ? AND stream_id > ?
"""
- rows = yield self.execute(
+ rows = yield self.db.execute(
"get_users_whose_signatures_changed", None, sql, user_id, from_key
)
return set(user for row in rows for user in json.loads(row[0]))
@@ -605,7 +605,7 @@ class DeviceWorkerStore(SQLBaseStore):
WHERE ? < stream_id AND stream_id <= ?
GROUP BY user_id, destination
"""
- return self.execute(
+ return self.db.execute(
"get_all_device_list_changes_for_remotes", None, sql, from_key, to_key
)
@@ -614,7 +614,7 @@ class DeviceWorkerStore(SQLBaseStore):
"""Get the last stream_id we got for a user. May be None if we haven't
got any information for them.
"""
- return self.simple_select_one_onecol(
+ return self.db.simple_select_one_onecol(
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
retcol="stream_id",
@@ -628,7 +628,7 @@ class DeviceWorkerStore(SQLBaseStore):
inlineCallbacks=True,
)
def get_device_list_last_stream_id_for_remotes(self, user_ids):
- rows = yield self.simple_select_many_batch(
+ rows = yield self.db.simple_select_many_batch(
table="device_lists_remote_extremeties",
column="user_id",
iterable=user_ids,
@@ -685,7 +685,7 @@ class DeviceBackgroundUpdateStore(BackgroundUpdateStore):
txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id")
txn.close()
- yield self.runWithConnection(f)
+ yield self.db.runWithConnection(f)
yield self._end_background_update(DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES)
return 1
@@ -722,7 +722,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
return False
try:
- inserted = yield self.simple_insert(
+ inserted = yield self.db.simple_insert(
"devices",
values={
"user_id": user_id,
@@ -736,7 +736,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
if not inserted:
# if the device already exists, check if it's a real device, or
# if the device ID is reserved by something else
- hidden = yield self.simple_select_one_onecol(
+ hidden = yield self.db.simple_select_one_onecol(
"devices",
keyvalues={"user_id": user_id, "device_id": device_id},
retcol="hidden",
@@ -771,7 +771,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
Returns:
defer.Deferred
"""
- yield self.simple_delete_one(
+ yield self.db.simple_delete_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
desc="delete_device",
@@ -789,7 +789,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
Returns:
defer.Deferred
"""
- yield self.simple_delete_many(
+ yield self.db.simple_delete_many(
table="devices",
column="device_id",
iterable=device_ids,
@@ -818,7 +818,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
updates["display_name"] = new_display_name
if not updates:
return defer.succeed(None)
- return self.simple_update_one(
+ return self.db.simple_update_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
updatevalues=updates,
@@ -829,7 +829,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
def mark_remote_user_device_list_as_unsubscribed(self, user_id):
"""Mark that we no longer track device lists for remote user.
"""
- yield self.simple_delete(
+ yield self.db.simple_delete(
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
desc="mark_remote_user_device_list_as_unsubscribed",
@@ -853,7 +853,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
Returns:
Deferred[None]
"""
- return self.runInteraction(
+ return self.db.runInteraction(
"update_remote_device_list_cache_entry",
self._update_remote_device_list_cache_entry_txn,
user_id,
@@ -866,7 +866,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
self, txn, user_id, device_id, content, stream_id
):
if content.get("deleted"):
- self.simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="device_lists_remote_cache",
keyvalues={"user_id": user_id, "device_id": device_id},
@@ -874,7 +874,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
txn.call_after(self.device_id_exists_cache.invalidate, (user_id, device_id))
else:
- self.simple_upsert_txn(
+ self.db.simple_upsert_txn(
txn,
table="device_lists_remote_cache",
keyvalues={"user_id": user_id, "device_id": device_id},
@@ -890,7 +890,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
)
- self.simple_upsert_txn(
+ self.db.simple_upsert_txn(
txn,
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
@@ -914,7 +914,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
Returns:
Deferred[None]
"""
- return self.runInteraction(
+ return self.db.runInteraction(
"update_remote_device_list_cache",
self._update_remote_device_list_cache_txn,
user_id,
@@ -923,11 +923,11 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
)
def _update_remote_device_list_cache_txn(self, txn, user_id, devices, stream_id):
- self.simple_delete_txn(
+ self.db.simple_delete_txn(
txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id}
)
- self.simple_insert_many_txn(
+ self.db.simple_insert_many_txn(
txn,
table="device_lists_remote_cache",
values=[
@@ -946,7 +946,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
)
- self.simple_upsert_txn(
+ self.db.simple_upsert_txn(
txn,
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
@@ -962,7 +962,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
(if any) should be poked.
"""
with self._device_list_id_gen.get_next() as stream_id:
- yield self.runInteraction(
+ yield self.db.runInteraction(
"add_device_change_to_streams",
self._add_device_change_txn,
user_id,
@@ -995,7 +995,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
[(user_id, device_id, stream_id) for device_id in device_ids],
)
- self.simple_insert_many_txn(
+ self.db.simple_insert_many_txn(
txn,
table="device_lists_stream",
values=[
@@ -1006,7 +1006,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
context = get_active_span_text_map()
- self.simple_insert_many_txn(
+ self.db.simple_insert_many_txn(
txn,
table="device_lists_outbound_pokes",
values=[
@@ -1069,7 +1069,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
return run_as_background_process(
"prune_old_outbound_device_pokes",
- self.runInteraction,
+ self.db.runInteraction,
"_prune_old_outbound_device_pokes",
_prune_txn,
)
diff --git a/synapse/storage/data_stores/main/directory.py b/synapse/storage/data_stores/main/directory.py
index d332f8a409..c9e7de7d12 100644
--- a/synapse/storage/data_stores/main/directory.py
+++ b/synapse/storage/data_stores/main/directory.py
@@ -36,7 +36,7 @@ class DirectoryWorkerStore(SQLBaseStore):
Deferred: results in namedtuple with keys "room_id" and
"servers" or None if no association can be found
"""
- room_id = yield self.simple_select_one_onecol(
+ room_id = yield self.db.simple_select_one_onecol(
"room_aliases",
{"room_alias": room_alias.to_string()},
"room_id",
@@ -47,7 +47,7 @@ class DirectoryWorkerStore(SQLBaseStore):
if not room_id:
return None
- servers = yield self.simple_select_onecol(
+ servers = yield self.db.simple_select_onecol(
"room_alias_servers",
{"room_alias": room_alias.to_string()},
"server",
@@ -60,7 +60,7 @@ class DirectoryWorkerStore(SQLBaseStore):
return RoomAliasMapping(room_id, room_alias.to_string(), servers)
def get_room_alias_creator(self, room_alias):
- return self.simple_select_one_onecol(
+ return self.db.simple_select_one_onecol(
table="room_aliases",
keyvalues={"room_alias": room_alias},
retcol="creator",
@@ -69,7 +69,7 @@ class DirectoryWorkerStore(SQLBaseStore):
@cached(max_entries=5000)
def get_aliases_for_room(self, room_id):
- return self.simple_select_onecol(
+ return self.db.simple_select_onecol(
"room_aliases",
{"room_id": room_id},
"room_alias",
@@ -93,7 +93,7 @@ class DirectoryStore(DirectoryWorkerStore):
"""
def alias_txn(txn):
- self.simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
"room_aliases",
{
@@ -103,7 +103,7 @@ class DirectoryStore(DirectoryWorkerStore):
},
)
- self.simple_insert_many_txn(
+ self.db.simple_insert_many_txn(
txn,
table="room_alias_servers",
values=[
@@ -117,7 +117,9 @@ class DirectoryStore(DirectoryWorkerStore):
)
try:
- ret = yield self.runInteraction("create_room_alias_association", alias_txn)
+ ret = yield self.db.runInteraction(
+ "create_room_alias_association", alias_txn
+ )
except self.database_engine.module.IntegrityError:
raise SynapseError(
409, "Room alias %s already exists" % room_alias.to_string()
@@ -126,7 +128,7 @@ class DirectoryStore(DirectoryWorkerStore):
@defer.inlineCallbacks
def delete_room_alias(self, room_alias):
- room_id = yield self.runInteraction(
+ room_id = yield self.db.runInteraction(
"delete_room_alias", self._delete_room_alias_txn, room_alias
)
@@ -168,6 +170,6 @@ class DirectoryStore(DirectoryWorkerStore):
txn, self.get_aliases_for_room, (new_room_id,)
)
- return self.runInteraction(
+ return self.db.runInteraction(
"_update_aliases_for_room_txn", _update_aliases_for_room_txn
)
diff --git a/synapse/storage/data_stores/main/e2e_room_keys.py b/synapse/storage/data_stores/main/e2e_room_keys.py
index df89eda337..84594cf0a9 100644
--- a/synapse/storage/data_stores/main/e2e_room_keys.py
+++ b/synapse/storage/data_stores/main/e2e_room_keys.py
@@ -38,7 +38,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
StoreError
"""
- yield self.simple_update_one(
+ yield self.db.simple_update_one(
table="e2e_room_keys",
keyvalues={
"user_id": user_id,
@@ -89,7 +89,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
}
)
- yield self.simple_insert_many(
+ yield self.db.simple_insert_many(
table="e2e_room_keys", values=values, desc="add_e2e_room_keys"
)
@@ -125,7 +125,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
if session_id:
keyvalues["session_id"] = session_id
- rows = yield self.simple_select_list(
+ rows = yield self.db.simple_select_list(
table="e2e_room_keys",
keyvalues=keyvalues,
retcols=(
@@ -170,7 +170,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
Deferred[dict[str, dict[str, dict]]]: a map of room IDs to session IDs to room key
"""
- return self.runInteraction(
+ return self.db.runInteraction(
"get_e2e_room_keys_multi",
self._get_e2e_room_keys_multi_txn,
user_id,
@@ -234,7 +234,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
version (str): the version ID of the backup we're querying about
"""
- return self.simple_select_one_onecol(
+ return self.db.simple_select_one_onecol(
table="e2e_room_keys",
keyvalues={"user_id": user_id, "version": version},
retcol="COUNT(*)",
@@ -267,7 +267,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
if session_id:
keyvalues["session_id"] = session_id
- yield self.simple_delete(
+ yield self.db.simple_delete(
table="e2e_room_keys", keyvalues=keyvalues, desc="delete_e2e_room_keys"
)
@@ -312,7 +312,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
# it isn't there.
raise StoreError(404, "No row found")
- result = self.simple_select_one_txn(
+ result = self.db.simple_select_one_txn(
txn,
table="e2e_room_keys_versions",
keyvalues={"user_id": user_id, "version": this_version, "deleted": 0},
@@ -324,7 +324,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
result["etag"] = 0
return result
- return self.runInteraction(
+ return self.db.runInteraction(
"get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn
)
@@ -352,7 +352,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
new_version = str(int(current_version) + 1)
- self.simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="e2e_room_keys_versions",
values={
@@ -365,7 +365,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
return new_version
- return self.runInteraction(
+ return self.db.runInteraction(
"create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn
)
@@ -391,7 +391,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
updatevalues["etag"] = version_etag
if updatevalues:
- return self.simple_update(
+ return self.db.simple_update(
table="e2e_room_keys_versions",
keyvalues={"user_id": user_id, "version": version},
updatevalues=updatevalues,
@@ -420,19 +420,19 @@ class EndToEndRoomKeyStore(SQLBaseStore):
else:
this_version = version
- self.simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="e2e_room_keys",
keyvalues={"user_id": user_id, "version": this_version},
)
- return self.simple_update_one_txn(
+ return self.db.simple_update_one_txn(
txn,
table="e2e_room_keys_versions",
keyvalues={"user_id": user_id, "version": this_version},
updatevalues={"deleted": 1},
)
- return self.runInteraction(
+ return self.db.runInteraction(
"delete_e2e_room_keys_version", _delete_e2e_room_keys_version_txn
)
diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py
index 08bcdc4725..38cd0ca9b8 100644
--- a/synapse/storage/data_stores/main/end_to_end_keys.py
+++ b/synapse/storage/data_stores/main/end_to_end_keys.py
@@ -48,7 +48,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
if not query_list:
return {}
- results = yield self.runInteraction(
+ results = yield self.db.runInteraction(
"get_e2e_device_keys",
self._get_e2e_device_keys_txn,
query_list,
@@ -125,7 +125,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
)
txn.execute(sql, query_params)
- rows = self.cursor_to_dict(txn)
+ rows = self.db.cursor_to_dict(txn)
result = {}
for row in rows:
@@ -143,7 +143,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
)
txn.execute(signature_sql, signature_query_params)
- rows = self.cursor_to_dict(txn)
+ rows = self.db.cursor_to_dict(txn)
# add each cross-signing signature to the correct device in the result dict.
for row in rows:
@@ -186,7 +186,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
key_id) to json string for key
"""
- rows = yield self.simple_select_many_batch(
+ rows = yield self.db.simple_select_many_batch(
table="e2e_one_time_keys_json",
column="key_id",
iterable=key_ids,
@@ -219,7 +219,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
# a unique constraint. If there is a race of two calls to
# `add_e2e_one_time_keys` then they'll conflict and we will only
# insert one set.
- self.simple_insert_many_txn(
+ self.db.simple_insert_many_txn(
txn,
table="e2e_one_time_keys_json",
values=[
@@ -238,7 +238,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
- yield self.runInteraction(
+ yield self.db.runInteraction(
"add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
)
@@ -261,7 +261,9 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
result[algorithm] = key_count
return result
- return self.runInteraction("count_e2e_one_time_keys", _count_e2e_one_time_keys)
+ return self.db.runInteraction(
+ "count_e2e_one_time_keys", _count_e2e_one_time_keys
+ )
def _get_e2e_cross_signing_key_txn(self, txn, user_id, key_type, from_user_id=None):
"""Returns a user's cross-signing key.
@@ -322,7 +324,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
Returns:
dict of the key data or None if not found
"""
- return self.runInteraction(
+ return self.db.runInteraction(
"get_e2e_cross_signing_key",
self._get_e2e_cross_signing_key_txn,
user_id,
@@ -350,7 +352,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
WHERE ? < stream_id AND stream_id <= ?
GROUP BY user_id
"""
- return self.execute(
+ return self.db.execute(
"get_all_user_signature_changes_for_remotes", None, sql, from_key, to_key
)
@@ -367,7 +369,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
set_tag("time_now", time_now)
set_tag("device_keys", device_keys)
- old_key_json = self.simple_select_one_onecol_txn(
+ old_key_json = self.db.simple_select_one_onecol_txn(
txn,
table="e2e_device_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id},
@@ -383,7 +385,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
log_kv({"Message": "Device key already stored."})
return False
- self.simple_upsert_txn(
+ self.db.simple_upsert_txn(
txn,
table="e2e_device_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id},
@@ -392,7 +394,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
log_kv({"message": "Device keys stored."})
return True
- return self.runInteraction("set_e2e_device_keys", _set_e2e_device_keys_txn)
+ return self.db.runInteraction("set_e2e_device_keys", _set_e2e_device_keys_txn)
def claim_e2e_one_time_keys(self, query_list):
"""Take a list of one time keys out of the database"""
@@ -431,7 +433,9 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
)
return result
- return self.runInteraction("claim_e2e_one_time_keys", _claim_e2e_one_time_keys)
+ return self.db.runInteraction(
+ "claim_e2e_one_time_keys", _claim_e2e_one_time_keys
+ )
def delete_e2e_keys_by_device(self, user_id, device_id):
def delete_e2e_keys_by_device_txn(txn):
@@ -442,12 +446,12 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
"user_id": user_id,
}
)
- self.simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="e2e_device_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id},
)
- self.simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="e2e_one_time_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id},
@@ -456,7 +460,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
- return self.runInteraction(
+ return self.db.runInteraction(
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
)
@@ -492,7 +496,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
# The "keys" property must only have one entry, which will be the public
# key, so we just grab the first value in there
pubkey = next(iter(key["keys"].values()))
- self.simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
"devices",
values={
@@ -505,7 +509,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
# and finally, store the key itself
with self._cross_signing_id_gen.get_next() as stream_id:
- self.simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
"e2e_cross_signing_keys",
values={
@@ -524,7 +528,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
key_type (str): the type of cross-signing key to set
key (dict): the key data
"""
- return self.runInteraction(
+ return self.db.runInteraction(
"add_e2e_cross_signing_key",
self._set_e2e_cross_signing_key_txn,
user_id,
@@ -539,7 +543,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
user_id (str): the user who made the signatures
signatures (iterable[SignatureListItem]): signatures to add
"""
- return self.simple_insert_many(
+ return self.db.simple_insert_many(
"e2e_cross_signing_signatures",
[
{
diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py
index 051ac7a8cb..77e4353b59 100644
--- a/synapse/storage/data_stores/main/event_federation.py
+++ b/synapse/storage/data_stores/main/event_federation.py
@@ -58,7 +58,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
Returns:
list of event_ids
"""
- return self.runInteraction(
+ return self.db.runInteraction(
"get_auth_chain_ids", self._get_auth_chain_ids_txn, event_ids, include_given
)
@@ -90,12 +90,12 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
return list(results)
def get_oldest_events_in_room(self, room_id):
- return self.runInteraction(
+ return self.db.runInteraction(
"get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id
)
def get_oldest_events_with_depth_in_room(self, room_id):
- return self.runInteraction(
+ return self.db.runInteraction(
"get_oldest_events_with_depth_in_room",
self.get_oldest_events_with_depth_in_room_txn,
room_id,
@@ -126,7 +126,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
Returns
Deferred[int]
"""
- rows = yield self.simple_select_many_batch(
+ rows = yield self.db.simple_select_many_batch(
table="events",
column="event_id",
iterable=event_ids,
@@ -140,7 +140,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
return max(row["depth"] for row in rows)
def _get_oldest_events_in_room_txn(self, txn, room_id):
- return self.simple_select_onecol_txn(
+ return self.db.simple_select_onecol_txn(
txn,
table="event_backward_extremities",
keyvalues={"room_id": room_id},
@@ -188,7 +188,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
where *hashes* is a map from algorithm to hash.
"""
- return self.runInteraction(
+ return self.db.runInteraction(
"get_latest_event_ids_and_hashes_in_room",
self._get_latest_event_ids_and_hashes_in_room,
room_id,
@@ -229,13 +229,13 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
txn.execute(sql, query_args)
return [room_id for room_id, in txn]
- return self.runInteraction(
+ return self.db.runInteraction(
"get_rooms_with_many_extremities", _get_rooms_with_many_extremities_txn
)
@cached(max_entries=5000, iterable=True)
def get_latest_event_ids_in_room(self, room_id):
- return self.simple_select_onecol(
+ return self.db.simple_select_onecol(
table="event_forward_extremities",
keyvalues={"room_id": room_id},
retcol="event_id",
@@ -266,12 +266,12 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
def get_min_depth(self, room_id):
""" For hte given room, get the minimum depth we have seen for it.
"""
- return self.runInteraction(
+ return self.db.runInteraction(
"get_min_depth", self._get_min_depth_interaction, room_id
)
def _get_min_depth_interaction(self, txn, room_id):
- min_depth = self.simple_select_one_onecol_txn(
+ min_depth = self.db.simple_select_one_onecol_txn(
txn,
table="room_depth",
keyvalues={"room_id": room_id},
@@ -337,7 +337,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
txn.execute(sql, (stream_ordering, room_id))
return [event_id for event_id, in txn]
- return self.runInteraction(
+ return self.db.runInteraction(
"get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn
)
@@ -352,7 +352,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
limit (int)
"""
return (
- self.runInteraction(
+ self.db.runInteraction(
"get_backfill_events",
self._get_backfill_events,
room_id,
@@ -383,7 +383,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
queue = PriorityQueue()
for event_id in event_list:
- depth = self.simple_select_one_onecol_txn(
+ depth = self.db.simple_select_one_onecol_txn(
txn,
table="events",
keyvalues={"event_id": event_id, "room_id": room_id},
@@ -415,7 +415,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
@defer.inlineCallbacks
def get_missing_events(self, room_id, earliest_events, latest_events, limit):
- ids = yield self.runInteraction(
+ ids = yield self.db.runInteraction(
"get_missing_events",
self._get_missing_events,
room_id,
@@ -468,7 +468,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
Returns:
Deferred[list[str]]
"""
- rows = yield self.simple_select_many_batch(
+ rows = yield self.db.simple_select_many_batch(
table="event_edges",
column="prev_event_id",
iterable=event_ids,
@@ -508,7 +508,7 @@ class EventFederationStore(EventFederationWorkerStore):
if min_depth and depth >= min_depth:
return
- self.simple_upsert_txn(
+ self.db.simple_upsert_txn(
txn,
table="room_depth",
keyvalues={"room_id": room_id},
@@ -520,7 +520,7 @@ class EventFederationStore(EventFederationWorkerStore):
For the given event, update the event edges table and forward and
backward extremities tables.
"""
- self.simple_insert_many_txn(
+ self.db.simple_insert_many_txn(
txn,
table="event_edges",
values=[
@@ -604,13 +604,13 @@ class EventFederationStore(EventFederationWorkerStore):
return run_as_background_process(
"delete_old_forward_extrem_cache",
- self.runInteraction,
+ self.db.runInteraction,
"_delete_old_forward_extrem_cache",
_delete_old_forward_extrem_cache_txn,
)
def clean_room_for_join(self, room_id):
- return self.runInteraction(
+ return self.db.runInteraction(
"clean_room_for_join", self._clean_room_for_join_txn, room_id
)
@@ -660,7 +660,7 @@ class EventFederationStore(EventFederationWorkerStore):
return min_stream_id >= target_min_stream_id
- result = yield self.runInteraction(
+ result = yield self.db.runInteraction(
self.EVENT_AUTH_STATE_ONLY, delete_event_auth
)
diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py
index 0a37847cfd..725d0881dc 100644
--- a/synapse/storage/data_stores/main/event_push_actions.py
+++ b/synapse/storage/data_stores/main/event_push_actions.py
@@ -93,7 +93,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
def get_unread_event_push_actions_by_room_for_user(
self, room_id, user_id, last_read_event_id
):
- ret = yield self.runInteraction(
+ ret = yield self.db.runInteraction(
"get_unread_event_push_actions_by_room",
self._get_unread_counts_by_receipt_txn,
room_id,
@@ -177,7 +177,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, (min_stream_ordering, max_stream_ordering))
return [r[0] for r in txn]
- ret = yield self.runInteraction("get_push_action_users_in_range", f)
+ ret = yield self.db.runInteraction("get_push_action_users_in_range", f)
return ret
@defer.inlineCallbacks
@@ -229,7 +229,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, args)
return txn.fetchall()
- after_read_receipt = yield self.runInteraction(
+ after_read_receipt = yield self.db.runInteraction(
"get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt
)
@@ -257,7 +257,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, args)
return txn.fetchall()
- no_read_receipt = yield self.runInteraction(
+ no_read_receipt = yield self.db.runInteraction(
"get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt
)
@@ -329,7 +329,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, args)
return txn.fetchall()
- after_read_receipt = yield self.runInteraction(
+ after_read_receipt = yield self.db.runInteraction(
"get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt
)
@@ -357,7 +357,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, args)
return txn.fetchall()
- no_read_receipt = yield self.runInteraction(
+ no_read_receipt = yield self.db.runInteraction(
"get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt
)
@@ -407,7 +407,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, (user_id, min_stream_ordering))
return bool(txn.fetchone())
- return self.runInteraction(
+ return self.db.runInteraction(
"get_if_maybe_push_in_range_for_user",
_get_if_maybe_push_in_range_for_user_txn,
)
@@ -458,7 +458,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
),
)
- return self.runInteraction(
+ return self.db.runInteraction(
"add_push_actions_to_staging", _add_push_actions_to_staging_txn
)
@@ -472,7 +472,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
"""
try:
- res = yield self.simple_delete(
+ res = yield self.db.simple_delete(
table="event_push_actions_staging",
keyvalues={"event_id": event_id},
desc="remove_push_actions_from_staging",
@@ -489,7 +489,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
def _find_stream_orderings_for_times(self):
return run_as_background_process(
"event_push_action_stream_orderings",
- self.runInteraction,
+ self.db.runInteraction,
"_find_stream_orderings_for_times",
self._find_stream_orderings_for_times_txn,
)
@@ -525,7 +525,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
Deferred[int]: stream ordering of the first event received on/after
the timestamp
"""
- return self.runInteraction(
+ return self.db.runInteraction(
"_find_first_stream_ordering_after_ts_txn",
self._find_first_stream_ordering_after_ts_txn,
ts,
@@ -677,7 +677,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
)
for event, _ in events_and_contexts:
- user_ids = self.simple_select_onecol_txn(
+ user_ids = self.db.simple_select_onecol_txn(
txn,
table="event_push_actions_staging",
keyvalues={"event_id": event.event_id},
@@ -727,9 +727,9 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
" LIMIT ?" % (before_clause,)
)
txn.execute(sql, args)
- return self.cursor_to_dict(txn)
+ return self.db.cursor_to_dict(txn)
- push_actions = yield self.runInteraction("get_push_actions_for_user", f)
+ push_actions = yield self.db.runInteraction("get_push_actions_for_user", f)
for pa in push_actions:
pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
return push_actions
@@ -748,7 +748,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
txn.execute(sql, (stream_ordering,))
return txn.fetchone()
- result = yield self.runInteraction("get_time_of_last_push_action_before", f)
+ result = yield self.db.runInteraction("get_time_of_last_push_action_before", f)
return result[0] if result else None
@defer.inlineCallbacks
@@ -757,7 +757,9 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions")
return txn.fetchone()
- result = yield self.runInteraction("get_latest_push_action_stream_ordering", f)
+ result = yield self.db.runInteraction(
+ "get_latest_push_action_stream_ordering", f
+ )
return result[0] or 0
def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id):
@@ -830,7 +832,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
while True:
logger.info("Rotating notifications")
- caught_up = yield self.runInteraction(
+ caught_up = yield self.db.runInteraction(
"_rotate_notifs", self._rotate_notifs_txn
)
if caught_up:
@@ -844,7 +846,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
the archiving process has caught up or not.
"""
- old_rotate_stream_ordering = self.simple_select_one_onecol_txn(
+ old_rotate_stream_ordering = self.db.simple_select_one_onecol_txn(
txn,
table="event_push_summary_stream_ordering",
keyvalues={},
@@ -880,7 +882,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
return caught_up
def _rotate_notifs_before_txn(self, txn, rotate_to_stream_ordering):
- old_rotate_stream_ordering = self.simple_select_one_onecol_txn(
+ old_rotate_stream_ordering = self.db.simple_select_one_onecol_txn(
txn,
table="event_push_summary_stream_ordering",
keyvalues={},
@@ -912,7 +914,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
# If the `old.user_id` above is NULL then we know there isn't already an
# entry in the table, so we simply insert it. Otherwise we update the
# existing table.
- self.simple_insert_many_txn(
+ self.db.simple_insert_many_txn(
txn,
table="event_push_summary",
values=[
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index 98ae69e996..01ec9ec397 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -143,7 +143,7 @@ class EventsStore(
)
return txn.fetchall()
- res = yield self.runInteraction("read_forward_extremities", fetch)
+ res = yield self.db.runInteraction("read_forward_extremities", fetch)
self._current_forward_extremities_amount = c_counter(list(x[0] for x in res))
@_retry_on_integrity_error
@@ -208,7 +208,7 @@ class EventsStore(
for (event, context), stream in zip(events_and_contexts, stream_orderings):
event.internal_metadata.stream_ordering = stream
- yield self.runInteraction(
+ yield self.db.runInteraction(
"persist_events",
self._persist_events_txn,
events_and_contexts=events_and_contexts,
@@ -281,7 +281,7 @@ class EventsStore(
results.extend(r[0] for r in txn if not json.loads(r[1]).get("soft_failed"))
for chunk in batch_iter(event_ids, 100):
- yield self.runInteraction(
+ yield self.db.runInteraction(
"_get_events_which_are_prevs", _get_events_which_are_prevs_txn, chunk
)
@@ -345,7 +345,7 @@ class EventsStore(
existing_prevs.add(prev_event_id)
for chunk in batch_iter(event_ids, 100):
- yield self.runInteraction(
+ yield self.db.runInteraction(
"_get_prevs_before_rejected", _get_prevs_before_rejected_txn, chunk
)
@@ -432,7 +432,7 @@ class EventsStore(
# event's auth chain, but its easier for now just to store them (and
# it doesn't take much storage compared to storing the entire event
# anyway).
- self.simple_insert_many_txn(
+ self.db.simple_insert_many_txn(
txn,
table="event_auth",
values=[
@@ -580,12 +580,12 @@ class EventsStore(
self, txn, new_forward_extremities, max_stream_order
):
for room_id, new_extrem in iteritems(new_forward_extremities):
- self.simple_delete_txn(
+ self.db.simple_delete_txn(
txn, table="event_forward_extremities", keyvalues={"room_id": room_id}
)
txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))
- self.simple_insert_many_txn(
+ self.db.simple_insert_many_txn(
txn,
table="event_forward_extremities",
values=[
@@ -598,7 +598,7 @@ class EventsStore(
# new stream_ordering to new forward extremeties in the room.
# This allows us to later efficiently look up the forward extremeties
# for a room before a given stream_ordering
- self.simple_insert_many_txn(
+ self.db.simple_insert_many_txn(
txn,
table="stream_ordering_to_exterm",
values=[
@@ -722,7 +722,7 @@ class EventsStore(
# change in outlier status to our workers.
stream_order = event.internal_metadata.stream_ordering
state_group_id = context.state_group
- self.simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="ex_outlier_stream",
values={
@@ -794,7 +794,7 @@ class EventsStore(
d.pop("redacted_because", None)
return d
- self.simple_insert_many_txn(
+ self.db.simple_insert_many_txn(
txn,
table="event_json",
values=[
@@ -811,7 +811,7 @@ class EventsStore(
],
)
- self.simple_insert_many_txn(
+ self.db.simple_insert_many_txn(
txn,
table="events",
values=[
@@ -841,7 +841,7 @@ class EventsStore(
# If we're persisting an unredacted event we go and ensure
# that we mark any redactions that reference this event as
# requiring censoring.
- self.simple_update_txn(
+ self.db.simple_update_txn(
txn,
table="redactions",
keyvalues={"redacts": event.event_id},
@@ -983,7 +983,7 @@ class EventsStore(
state_values.append(vals)
- self.simple_insert_many_txn(txn, table="state_events", values=state_values)
+ self.db.simple_insert_many_txn(txn, table="state_events", values=state_values)
# Prefill the event cache
self._add_to_cache(txn, events_and_contexts)
@@ -1014,7 +1014,7 @@ class EventsStore(
)
txn.execute(sql + clause, args)
- rows = self.cursor_to_dict(txn)
+ rows = self.db.cursor_to_dict(txn)
for row in rows:
event = ev_map[row["event_id"]]
if not row["rejects"] and not row["redacts"]:
@@ -1032,7 +1032,7 @@ class EventsStore(
# invalidate the cache for the redacted event
txn.call_after(self._invalidate_get_event_cache, event.redacts)
- self.simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="redactions",
values={
@@ -1077,7 +1077,9 @@ class EventsStore(
LIMIT ?
"""
- rows = yield self.execute("_censor_redactions_fetch", None, sql, before_ts, 100)
+ rows = yield self.db.execute(
+ "_censor_redactions_fetch", None, sql, before_ts, 100
+ )
updates = []
@@ -1109,14 +1111,14 @@ class EventsStore(
if pruned_json:
self._censor_event_txn(txn, event_id, pruned_json)
- self.simple_update_one_txn(
+ self.db.simple_update_one_txn(
txn,
table="redactions",
keyvalues={"event_id": redaction_id},
updatevalues={"have_censored": True},
)
- yield self.runInteraction("_update_censor_txn", _update_censor_txn)
+ yield self.db.runInteraction("_update_censor_txn", _update_censor_txn)
def _censor_event_txn(self, txn, event_id, pruned_json):
"""Censor an event by replacing its JSON in the event_json table with the
@@ -1127,7 +1129,7 @@ class EventsStore(
event_id (str): The ID of the event to censor.
pruned_json (str): The pruned JSON
"""
- self.simple_update_one_txn(
+ self.db.simple_update_one_txn(
txn,
table="event_json",
keyvalues={"event_id": event_id},
@@ -1153,7 +1155,7 @@ class EventsStore(
(count,) = txn.fetchone()
return count
- ret = yield self.runInteraction("count_messages", _count_messages)
+ ret = yield self.db.runInteraction("count_messages", _count_messages)
return ret
@defer.inlineCallbacks
@@ -1174,7 +1176,7 @@ class EventsStore(
(count,) = txn.fetchone()
return count
- ret = yield self.runInteraction("count_daily_sent_messages", _count_messages)
+ ret = yield self.db.runInteraction("count_daily_sent_messages", _count_messages)
return ret
@defer.inlineCallbacks
@@ -1189,7 +1191,7 @@ class EventsStore(
(count,) = txn.fetchone()
return count
- ret = yield self.runInteraction("count_daily_active_rooms", _count)
+ ret = yield self.db.runInteraction("count_daily_active_rooms", _count)
return ret
def get_current_backfill_token(self):
@@ -1241,7 +1243,7 @@ class EventsStore(
return new_event_updates
- return self.runInteraction(
+ return self.db.runInteraction(
"get_all_new_forward_event_rows", get_all_new_forward_event_rows
)
@@ -1286,7 +1288,7 @@ class EventsStore(
return new_event_updates
- return self.runInteraction(
+ return self.db.runInteraction(
"get_all_new_backfill_event_rows", get_all_new_backfill_event_rows
)
@@ -1379,7 +1381,7 @@ class EventsStore(
backward_ex_outliers,
)
- return self.runInteraction("get_all_new_events", get_all_new_events_txn)
+ return self.db.runInteraction("get_all_new_events", get_all_new_events_txn)
def purge_history(self, room_id, token, delete_local_events):
"""Deletes room history before a certain point
@@ -1399,7 +1401,7 @@ class EventsStore(
deleted events.
"""
- return self.runInteraction(
+ return self.db.runInteraction(
"purge_history",
self._purge_history_txn,
room_id,
@@ -1647,7 +1649,7 @@ class EventsStore(
Deferred[List[int]]: The list of state groups to delete.
"""
- return self.runInteraction("purge_room", self._purge_room_txn, room_id)
+ return self.db.runInteraction("purge_room", self._purge_room_txn, room_id)
def _purge_room_txn(self, txn, room_id):
# First we fetch all the state groups that should be deleted, before
@@ -1766,7 +1768,7 @@ class EventsStore(
to delete.
"""
- return self.runInteraction(
+ return self.db.runInteraction(
"purge_unreferenced_state_groups",
self._purge_unreferenced_state_groups,
room_id,
@@ -1778,7 +1780,7 @@ class EventsStore(
"[purge] found %i state groups to delete", len(state_groups_to_delete)
)
- rows = self.simple_select_many_txn(
+ rows = self.db.simple_select_many_txn(
txn,
table="state_group_edges",
column="prev_state_group",
@@ -1805,15 +1807,15 @@ class EventsStore(
curr_state = self._get_state_groups_from_groups_txn(txn, [sg])
curr_state = curr_state[sg]
- self.simple_delete_txn(
+ self.db.simple_delete_txn(
txn, table="state_groups_state", keyvalues={"state_group": sg}
)
- self.simple_delete_txn(
+ self.db.simple_delete_txn(
txn, table="state_group_edges", keyvalues={"state_group": sg}
)
- self.simple_insert_many_txn(
+ self.db.simple_insert_many_txn(
txn,
table="state_groups_state",
values=[
@@ -1850,7 +1852,7 @@ class EventsStore(
state group.
"""
- rows = yield self.simple_select_many_batch(
+ rows = yield self.db.simple_select_many_batch(
table="state_group_edges",
column="prev_state_group",
iterable=state_groups,
@@ -1869,7 +1871,7 @@ class EventsStore(
state_groups_to_delete (list[int]): State groups to delete
"""
- return self.runInteraction(
+ return self.db.runInteraction(
"purge_room_state",
self._purge_room_state_txn,
room_id,
@@ -1880,7 +1882,7 @@ class EventsStore(
# first we have to delete the state groups states
logger.info("[purge] removing %s from state_groups_state", room_id)
- self.simple_delete_many_txn(
+ self.db.simple_delete_many_txn(
txn,
table="state_groups_state",
column="state_group",
@@ -1891,7 +1893,7 @@ class EventsStore(
# ... and the state group edges
logger.info("[purge] removing %s from state_group_edges", room_id)
- self.simple_delete_many_txn(
+ self.db.simple_delete_many_txn(
txn,
table="state_group_edges",
column="state_group",
@@ -1902,7 +1904,7 @@ class EventsStore(
# ... and the state groups
logger.info("[purge] removing %s from state_groups", room_id)
- self.simple_delete_many_txn(
+ self.db.simple_delete_many_txn(
txn,
table="state_groups",
column="id",
@@ -1919,7 +1921,7 @@ class EventsStore(
@cachedInlineCallbacks(max_entries=5000)
def _get_event_ordering(self, event_id):
- res = yield self.simple_select_one(
+ res = yield self.db.simple_select_one(
table="events",
retcols=["topological_ordering", "stream_ordering"],
keyvalues={"event_id": event_id},
@@ -1942,7 +1944,7 @@ class EventsStore(
txn.execute(sql, (from_token, to_token, limit))
return txn.fetchall()
- return self.runInteraction(
+ return self.db.runInteraction(
"get_all_updated_current_state_deltas",
get_all_updated_current_state_deltas_txn,
)
@@ -1960,7 +1962,7 @@ class EventsStore(
room_id (str): The ID of the room the event was sent to.
topological_ordering (int): The position of the event in the room's topology.
"""
- return self.simple_insert_many_txn(
+ return self.db.simple_insert_many_txn(
txn=txn,
table="event_labels",
values=[
@@ -1982,7 +1984,7 @@ class EventsStore(
event_id (str): The event ID the expiry timestamp is associated with.
expiry_ts (int): The timestamp at which to expire (delete) the event.
"""
- return self.simple_insert_txn(
+ return self.db.simple_insert_txn(
txn=txn,
table="event_expiry",
values={"event_id": event_id, "expiry_ts": expiry_ts},
@@ -2031,7 +2033,7 @@ class EventsStore(
txn, "_get_event_cache", (event.event_id,)
)
- yield self.runInteraction("delete_expired_event", delete_expired_event_txn)
+ yield self.db.runInteraction("delete_expired_event", delete_expired_event_txn)
def _delete_event_expiry_txn(self, txn, event_id):
"""Delete the expiry timestamp associated with an event ID without deleting the
@@ -2041,7 +2043,7 @@ class EventsStore(
txn (LoggingTransaction): The transaction to use to perform the deletion.
event_id (str): The event ID to delete the associated expiry timestamp of.
"""
- return self.simple_delete_txn(
+ return self.db.simple_delete_txn(
txn=txn, table="event_expiry", keyvalues={"event_id": event_id}
)
@@ -2065,7 +2067,7 @@ class EventsStore(
return txn.fetchone()
- return self.runInteraction(
+ return self.db.runInteraction(
desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
)
diff --git a/synapse/storage/data_stores/main/events_bg_updates.py b/synapse/storage/data_stores/main/events_bg_updates.py
index 37dfc8c871..365e966956 100644
--- a/synapse/storage/data_stores/main/events_bg_updates.py
+++ b/synapse/storage/data_stores/main/events_bg_updates.py
@@ -151,7 +151,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
return len(rows)
- result = yield self.runInteraction(
+ result = yield self.db.runInteraction(
self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn
)
@@ -189,7 +189,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
chunks = [event_ids[i : i + 100] for i in range(0, len(event_ids), 100)]
for chunk in chunks:
- ev_rows = self.simple_select_many_txn(
+ ev_rows = self.db.simple_select_many_txn(
txn,
table="event_json",
column="event_id",
@@ -228,7 +228,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
return len(rows_to_update)
- result = yield self.runInteraction(
+ result = yield self.db.runInteraction(
self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn
)
@@ -366,7 +366,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
to_delete.intersection_update(original_set)
- deleted = self.simple_delete_many_txn(
+ deleted = self.db.simple_delete_many_txn(
txn=txn,
table="event_forward_extremities",
column="event_id",
@@ -382,7 +382,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
if deleted:
# We now need to invalidate the caches of these rooms
- rows = self.simple_select_many_txn(
+ rows = self.db.simple_select_many_txn(
txn,
table="events",
column="event_id",
@@ -396,7 +396,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
self.get_latest_event_ids_in_room.invalidate, (room_id,)
)
- self.simple_delete_many_txn(
+ self.db.simple_delete_many_txn(
txn=txn,
table="_extremities_to_check",
column="event_id",
@@ -406,7 +406,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
return len(original_set)
- num_handled = yield self.runInteraction(
+ num_handled = yield self.db.runInteraction(
"_cleanup_extremities_bg_update", _cleanup_extremities_bg_update_txn
)
@@ -416,7 +416,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
def _drop_table_txn(txn):
txn.execute("DROP TABLE _extremities_to_check")
- yield self.runInteraction(
+ yield self.db.runInteraction(
"_cleanup_extremities_bg_update_drop_table", _drop_table_txn
)
@@ -470,7 +470,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
return len(rows)
- count = yield self.runInteraction(
+ count = yield self.db.runInteraction(
"_redactions_received_ts", _redactions_received_ts_txn
)
@@ -501,7 +501,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
txn.execute("DROP INDEX redactions_censored_redacts")
- yield self.runInteraction(
+ yield self.db.runInteraction(
"_event_fix_redactions_bytes", _event_fix_redactions_bytes_txn
)
@@ -533,7 +533,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
try:
event_json = json.loads(event_json_raw)
- self.simple_insert_many_txn(
+ self.db.simple_insert_many_txn(
txn=txn,
table="event_labels",
values=[
@@ -565,7 +565,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
return nbrows
- num_rows = yield self.runInteraction(
+ num_rows = yield self.db.runInteraction(
desc="event_store_labels", func=_event_store_labels_txn
)
diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
index 6a08a746b6..e041fc5eac 100644
--- a/synapse/storage/data_stores/main/events_worker.py
+++ b/synapse/storage/data_stores/main/events_worker.py
@@ -78,7 +78,7 @@ class EventsWorkerStore(SQLBaseStore):
Deferred[int|None]: Timestamp in milliseconds, or None for events
that were persisted before received_ts was implemented.
"""
- return self.simple_select_one_onecol(
+ return self.db.simple_select_one_onecol(
table="events",
keyvalues={"event_id": event_id},
retcol="received_ts",
@@ -117,7 +117,7 @@ class EventsWorkerStore(SQLBaseStore):
return ts
- return self.runInteraction(
+ return self.db.runInteraction(
"get_approximate_received_ts", _get_approximate_received_ts_txn
)
@@ -452,7 +452,7 @@ class EventsWorkerStore(SQLBaseStore):
event_id for events, _ in event_list for event_id in events
)
- row_dict = self.new_transaction(
+ row_dict = self.db.new_transaction(
conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch
)
@@ -584,7 +584,7 @@ class EventsWorkerStore(SQLBaseStore):
if should_start:
run_as_background_process(
- "fetch_events", self.runWithConnection, self._do_fetch
+ "fetch_events", self.db.runWithConnection, self._do_fetch
)
logger.debug("Loading %d events: %s", len(events), events)
@@ -745,7 +745,7 @@ class EventsWorkerStore(SQLBaseStore):
"""Given a list of event ids, check if we have already processed and
stored them as non outliers.
"""
- rows = yield self.simple_select_many_batch(
+ rows = yield self.db.simple_select_many_batch(
table="events",
retcols=("event_id",),
column="event_id",
@@ -780,7 +780,9 @@ class EventsWorkerStore(SQLBaseStore):
# break the input up into chunks of 100
input_iterator = iter(event_ids)
for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []):
- yield self.runInteraction("have_seen_events", have_seen_events_txn, chunk)
+ yield self.db.runInteraction(
+ "have_seen_events", have_seen_events_txn, chunk
+ )
return results
def _get_total_state_event_counts_txn(self, txn, room_id):
@@ -807,7 +809,7 @@ class EventsWorkerStore(SQLBaseStore):
Returns:
Deferred[int]
"""
- return self.runInteraction(
+ return self.db.runInteraction(
"get_total_state_event_counts",
self._get_total_state_event_counts_txn,
room_id,
@@ -832,7 +834,7 @@ class EventsWorkerStore(SQLBaseStore):
Returns:
Deferred[int]
"""
- return self.runInteraction(
+ return self.db.runInteraction(
"get_current_state_event_counts",
self._get_current_state_event_counts_txn,
room_id,
diff --git a/synapse/storage/data_stores/main/filtering.py b/synapse/storage/data_stores/main/filtering.py
index 17ef7b9354..342d6622a4 100644
--- a/synapse/storage/data_stores/main/filtering.py
+++ b/synapse/storage/data_stores/main/filtering.py
@@ -30,7 +30,7 @@ class FilteringStore(SQLBaseStore):
except ValueError:
raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM)
- def_json = yield self.simple_select_one_onecol(
+ def_json = yield self.db.simple_select_one_onecol(
table="user_filters",
keyvalues={"user_id": user_localpart, "filter_id": filter_id},
retcol="filter_json",
@@ -71,4 +71,4 @@ class FilteringStore(SQLBaseStore):
return filter_id
- return self.runInteraction("add_user_filter", _do_txn)
+ return self.db.runInteraction("add_user_filter", _do_txn)
diff --git a/synapse/storage/data_stores/main/group_server.py b/synapse/storage/data_stores/main/group_server.py
index 9e1d12bcb7..7f5e8dce66 100644
--- a/synapse/storage/data_stores/main/group_server.py
+++ b/synapse/storage/data_stores/main/group_server.py
@@ -35,7 +35,7 @@ class GroupServerStore(SQLBaseStore):
* "invite"
* "open"
"""
- return self.simple_update_one(
+ return self.db.simple_update_one(
table="groups",
keyvalues={"group_id": group_id},
updatevalues={"join_policy": join_policy},
@@ -43,7 +43,7 @@ class GroupServerStore(SQLBaseStore):
)
def get_group(self, group_id):
- return self.simple_select_one(
+ return self.db.simple_select_one(
table="groups",
keyvalues={"group_id": group_id},
retcols=(
@@ -65,7 +65,7 @@ class GroupServerStore(SQLBaseStore):
if not include_private:
keyvalues["is_public"] = True
- return self.simple_select_list(
+ return self.db.simple_select_list(
table="group_users",
keyvalues=keyvalues,
retcols=("user_id", "is_public", "is_admin"),
@@ -75,7 +75,7 @@ class GroupServerStore(SQLBaseStore):
def get_invited_users_in_group(self, group_id):
# TODO: Pagination
- return self.simple_select_onecol(
+ return self.db.simple_select_onecol(
table="group_invites",
keyvalues={"group_id": group_id},
retcol="user_id",
@@ -89,7 +89,7 @@ class GroupServerStore(SQLBaseStore):
if not include_private:
keyvalues["is_public"] = True
- return self.simple_select_list(
+ return self.db.simple_select_list(
table="group_rooms",
keyvalues=keyvalues,
retcols=("room_id", "is_public"),
@@ -153,10 +153,12 @@ class GroupServerStore(SQLBaseStore):
return rooms, categories
- return self.runInteraction("get_rooms_for_summary", _get_rooms_for_summary_txn)
+ return self.db.runInteraction(
+ "get_rooms_for_summary", _get_rooms_for_summary_txn
+ )
def add_room_to_summary(self, group_id, room_id, category_id, order, is_public):
- return self.runInteraction(
+ return self.db.runInteraction(
"add_room_to_summary",
self._add_room_to_summary_txn,
group_id,
@@ -180,7 +182,7 @@ class GroupServerStore(SQLBaseStore):
an order of 1 will put the room first. Otherwise, the room gets
added to the end.
"""
- room_in_group = self.simple_select_one_onecol_txn(
+ room_in_group = self.db.simple_select_one_onecol_txn(
txn,
table="group_rooms",
keyvalues={"group_id": group_id, "room_id": room_id},
@@ -193,7 +195,7 @@ class GroupServerStore(SQLBaseStore):
if category_id is None:
category_id = _DEFAULT_CATEGORY_ID
else:
- cat_exists = self.simple_select_one_onecol_txn(
+ cat_exists = self.db.simple_select_one_onecol_txn(
txn,
table="group_room_categories",
keyvalues={"group_id": group_id, "category_id": category_id},
@@ -204,7 +206,7 @@ class GroupServerStore(SQLBaseStore):
raise SynapseError(400, "Category doesn't exist")
# TODO: Check category is part of summary already
- cat_exists = self.simple_select_one_onecol_txn(
+ cat_exists = self.db.simple_select_one_onecol_txn(
txn,
table="group_summary_room_categories",
keyvalues={"group_id": group_id, "category_id": category_id},
@@ -224,7 +226,7 @@ class GroupServerStore(SQLBaseStore):
(group_id, category_id, group_id, category_id),
)
- existing = self.simple_select_one_txn(
+ existing = self.db.simple_select_one_txn(
txn,
table="group_summary_rooms",
keyvalues={
@@ -257,7 +259,7 @@ class GroupServerStore(SQLBaseStore):
to_update["room_order"] = order
if is_public is not None:
to_update["is_public"] = is_public
- self.simple_update_txn(
+ self.db.simple_update_txn(
txn,
table="group_summary_rooms",
keyvalues={
@@ -271,7 +273,7 @@ class GroupServerStore(SQLBaseStore):
if is_public is None:
is_public = True
- self.simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="group_summary_rooms",
values={
@@ -287,7 +289,7 @@ class GroupServerStore(SQLBaseStore):
if category_id is None:
category_id = _DEFAULT_CATEGORY_ID
- return self.simple_delete(
+ return self.db.simple_delete(
table="group_summary_rooms",
keyvalues={
"group_id": group_id,
@@ -299,7 +301,7 @@ class GroupServerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_group_categories(self, group_id):
- rows = yield self.simple_select_list(
+ rows = yield self.db.simple_select_list(
table="group_room_categories",
keyvalues={"group_id": group_id},
retcols=("category_id", "is_public", "profile"),
@@ -316,7 +318,7 @@ class GroupServerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_group_category(self, group_id, category_id):
- category = yield self.simple_select_one(
+ category = yield self.db.simple_select_one(
table="group_room_categories",
keyvalues={"group_id": group_id, "category_id": category_id},
retcols=("is_public", "profile"),
@@ -343,7 +345,7 @@ class GroupServerStore(SQLBaseStore):
else:
update_values["is_public"] = is_public
- return self.simple_upsert(
+ return self.db.simple_upsert(
table="group_room_categories",
keyvalues={"group_id": group_id, "category_id": category_id},
values=update_values,
@@ -352,7 +354,7 @@ class GroupServerStore(SQLBaseStore):
)
def remove_group_category(self, group_id, category_id):
- return self.simple_delete(
+ return self.db.simple_delete(
table="group_room_categories",
keyvalues={"group_id": group_id, "category_id": category_id},
desc="remove_group_category",
@@ -360,7 +362,7 @@ class GroupServerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_group_roles(self, group_id):
- rows = yield self.simple_select_list(
+ rows = yield self.db.simple_select_list(
table="group_roles",
keyvalues={"group_id": group_id},
retcols=("role_id", "is_public", "profile"),
@@ -377,7 +379,7 @@ class GroupServerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_group_role(self, group_id, role_id):
- role = yield self.simple_select_one(
+ role = yield self.db.simple_select_one(
table="group_roles",
keyvalues={"group_id": group_id, "role_id": role_id},
retcols=("is_public", "profile"),
@@ -404,7 +406,7 @@ class GroupServerStore(SQLBaseStore):
else:
update_values["is_public"] = is_public
- return self.simple_upsert(
+ return self.db.simple_upsert(
table="group_roles",
keyvalues={"group_id": group_id, "role_id": role_id},
values=update_values,
@@ -413,14 +415,14 @@ class GroupServerStore(SQLBaseStore):
)
def remove_group_role(self, group_id, role_id):
- return self.simple_delete(
+ return self.db.simple_delete(
table="group_roles",
keyvalues={"group_id": group_id, "role_id": role_id},
desc="remove_group_role",
)
def add_user_to_summary(self, group_id, user_id, role_id, order, is_public):
- return self.runInteraction(
+ return self.db.runInteraction(
"add_user_to_summary",
self._add_user_to_summary_txn,
group_id,
@@ -444,7 +446,7 @@ class GroupServerStore(SQLBaseStore):
an order of 1 will put the user first. Otherwise, the user gets
added to the end.
"""
- user_in_group = self.simple_select_one_onecol_txn(
+ user_in_group = self.db.simple_select_one_onecol_txn(
txn,
table="group_users",
keyvalues={"group_id": group_id, "user_id": user_id},
@@ -457,7 +459,7 @@ class GroupServerStore(SQLBaseStore):
if role_id is None:
role_id = _DEFAULT_ROLE_ID
else:
- role_exists = self.simple_select_one_onecol_txn(
+ role_exists = self.db.simple_select_one_onecol_txn(
txn,
table="group_roles",
keyvalues={"group_id": group_id, "role_id": role_id},
@@ -468,7 +470,7 @@ class GroupServerStore(SQLBaseStore):
raise SynapseError(400, "Role doesn't exist")
# TODO: Check role is part of the summary already
- role_exists = self.simple_select_one_onecol_txn(
+ role_exists = self.db.simple_select_one_onecol_txn(
txn,
table="group_summary_roles",
keyvalues={"group_id": group_id, "role_id": role_id},
@@ -488,7 +490,7 @@ class GroupServerStore(SQLBaseStore):
(group_id, role_id, group_id, role_id),
)
- existing = self.simple_select_one_txn(
+ existing = self.db.simple_select_one_txn(
txn,
table="group_summary_users",
keyvalues={"group_id": group_id, "user_id": user_id, "role_id": role_id},
@@ -517,7 +519,7 @@ class GroupServerStore(SQLBaseStore):
to_update["user_order"] = order
if is_public is not None:
to_update["is_public"] = is_public
- self.simple_update_txn(
+ self.db.simple_update_txn(
txn,
table="group_summary_users",
keyvalues={
@@ -531,7 +533,7 @@ class GroupServerStore(SQLBaseStore):
if is_public is None:
is_public = True
- self.simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="group_summary_users",
values={
@@ -547,7 +549,7 @@ class GroupServerStore(SQLBaseStore):
if role_id is None:
role_id = _DEFAULT_ROLE_ID
- return self.simple_delete(
+ return self.db.simple_delete(
table="group_summary_users",
keyvalues={"group_id": group_id, "role_id": role_id, "user_id": user_id},
desc="remove_user_from_summary",
@@ -561,7 +563,7 @@ class GroupServerStore(SQLBaseStore):
Deferred[list[str]]: A twisted.Deferred containing a list of group ids
containing this room
"""
- return self.simple_select_onecol(
+ return self.db.simple_select_onecol(
table="group_rooms",
keyvalues={"room_id": room_id},
retcol="group_id",
@@ -625,12 +627,12 @@ class GroupServerStore(SQLBaseStore):
return users, roles
- return self.runInteraction(
+ return self.db.runInteraction(
"get_users_for_summary_by_role", _get_users_for_summary_txn
)
def is_user_in_group(self, user_id, group_id):
- return self.simple_select_one_onecol(
+ return self.db.simple_select_one_onecol(
table="group_users",
keyvalues={"group_id": group_id, "user_id": user_id},
retcol="user_id",
@@ -639,7 +641,7 @@ class GroupServerStore(SQLBaseStore):
).addCallback(lambda r: bool(r))
def is_user_admin_in_group(self, group_id, user_id):
- return self.simple_select_one_onecol(
+ return self.db.simple_select_one_onecol(
table="group_users",
keyvalues={"group_id": group_id, "user_id": user_id},
retcol="is_admin",
@@ -650,7 +652,7 @@ class GroupServerStore(SQLBaseStore):
def add_group_invite(self, group_id, user_id):
"""Record that the group server has invited a user
"""
- return self.simple_insert(
+ return self.db.simple_insert(
table="group_invites",
values={"group_id": group_id, "user_id": user_id},
desc="add_group_invite",
@@ -659,7 +661,7 @@ class GroupServerStore(SQLBaseStore):
def is_user_invited_to_local_group(self, group_id, user_id):
"""Has the group server invited a user?
"""
- return self.simple_select_one_onecol(
+ return self.db.simple_select_one_onecol(
table="group_invites",
keyvalues={"group_id": group_id, "user_id": user_id},
retcol="user_id",
@@ -682,7 +684,7 @@ class GroupServerStore(SQLBaseStore):
"""
def _get_users_membership_in_group_txn(txn):
- row = self.simple_select_one_txn(
+ row = self.db.simple_select_one_txn(
txn,
table="group_users",
keyvalues={"group_id": group_id, "user_id": user_id},
@@ -697,7 +699,7 @@ class GroupServerStore(SQLBaseStore):
"is_privileged": row["is_admin"],
}
- row = self.simple_select_one_onecol_txn(
+ row = self.db.simple_select_one_onecol_txn(
txn,
table="group_invites",
keyvalues={"group_id": group_id, "user_id": user_id},
@@ -710,7 +712,7 @@ class GroupServerStore(SQLBaseStore):
return {}
- return self.runInteraction(
+ return self.db.runInteraction(
"get_users_membership_info_in_group", _get_users_membership_in_group_txn
)
@@ -738,7 +740,7 @@ class GroupServerStore(SQLBaseStore):
"""
def _add_user_to_group_txn(txn):
- self.simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="group_users",
values={
@@ -749,14 +751,14 @@ class GroupServerStore(SQLBaseStore):
},
)
- self.simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="group_invites",
keyvalues={"group_id": group_id, "user_id": user_id},
)
if local_attestation:
- self.simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="group_attestations_renewals",
values={
@@ -766,7 +768,7 @@ class GroupServerStore(SQLBaseStore):
},
)
if remote_attestation:
- self.simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="group_attestations_remote",
values={
@@ -777,49 +779,49 @@ class GroupServerStore(SQLBaseStore):
},
)
- return self.runInteraction("add_user_to_group", _add_user_to_group_txn)
+ return self.db.runInteraction("add_user_to_group", _add_user_to_group_txn)
def remove_user_from_group(self, group_id, user_id):
def _remove_user_from_group_txn(txn):
- self.simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="group_users",
keyvalues={"group_id": group_id, "user_id": user_id},
)
- self.simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="group_invites",
keyvalues={"group_id": group_id, "user_id": user_id},
)
- self.simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="group_attestations_renewals",
keyvalues={"group_id": group_id, "user_id": user_id},
)
- self.simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="group_attestations_remote",
keyvalues={"group_id": group_id, "user_id": user_id},
)
- self.simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="group_summary_users",
keyvalues={"group_id": group_id, "user_id": user_id},
)
- return self.runInteraction(
+ return self.db.runInteraction(
"remove_user_from_group", _remove_user_from_group_txn
)
def add_room_to_group(self, group_id, room_id, is_public):
- return self.simple_insert(
+ return self.db.simple_insert(
table="group_rooms",
values={"group_id": group_id, "room_id": room_id, "is_public": is_public},
desc="add_room_to_group",
)
def update_room_in_group_visibility(self, group_id, room_id, is_public):
- return self.simple_update(
+ return self.db.simple_update(
table="group_rooms",
keyvalues={"group_id": group_id, "room_id": room_id},
updatevalues={"is_public": is_public},
@@ -828,26 +830,26 @@ class GroupServerStore(SQLBaseStore):
def remove_room_from_group(self, group_id, room_id):
def _remove_room_from_group_txn(txn):
- self.simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="group_rooms",
keyvalues={"group_id": group_id, "room_id": room_id},
)
- self.simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="group_summary_rooms",
keyvalues={"group_id": group_id, "room_id": room_id},
)
- return self.runInteraction(
+ return self.db.runInteraction(
"remove_room_from_group", _remove_room_from_group_txn
)
def get_publicised_groups_for_user(self, user_id):
"""Get all groups a user is publicising
"""
- return self.simple_select_onecol(
+ return self.db.simple_select_onecol(
table="local_group_membership",
keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True},
retcol="group_id",
@@ -857,7 +859,7 @@ class GroupServerStore(SQLBaseStore):
def update_group_publicity(self, group_id, user_id, publicise):
"""Update whether the user is publicising their membership of the group
"""
- return self.simple_update_one(
+ return self.db.simple_update_one(
table="local_group_membership",
keyvalues={"group_id": group_id, "user_id": user_id},
updatevalues={"is_publicised": publicise},
@@ -893,12 +895,12 @@ class GroupServerStore(SQLBaseStore):
def _register_user_group_membership_txn(txn, next_id):
# TODO: Upsert?
- self.simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="local_group_membership",
keyvalues={"group_id": group_id, "user_id": user_id},
)
- self.simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="local_group_membership",
values={
@@ -911,7 +913,7 @@ class GroupServerStore(SQLBaseStore):
},
)
- self.simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="local_group_updates",
values={
@@ -930,7 +932,7 @@ class GroupServerStore(SQLBaseStore):
if membership == "join":
if local_attestation:
- self.simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="group_attestations_renewals",
values={
@@ -940,7 +942,7 @@ class GroupServerStore(SQLBaseStore):
},
)
if remote_attestation:
- self.simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="group_attestations_remote",
values={
@@ -951,12 +953,12 @@ class GroupServerStore(SQLBaseStore):
},
)
else:
- self.simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="group_attestations_renewals",
keyvalues={"group_id": group_id, "user_id": user_id},
)
- self.simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="group_attestations_remote",
keyvalues={"group_id": group_id, "user_id": user_id},
@@ -965,7 +967,7 @@ class GroupServerStore(SQLBaseStore):
return next_id
with self._group_updates_id_gen.get_next() as next_id:
- res = yield self.runInteraction(
+ res = yield self.db.runInteraction(
"register_user_group_membership",
_register_user_group_membership_txn,
next_id,
@@ -976,7 +978,7 @@ class GroupServerStore(SQLBaseStore):
def create_group(
self, group_id, user_id, name, avatar_url, short_description, long_description
):
- yield self.simple_insert(
+ yield self.db.simple_insert(
table="groups",
values={
"group_id": group_id,
@@ -991,7 +993,7 @@ class GroupServerStore(SQLBaseStore):
@defer.inlineCallbacks
def update_group_profile(self, group_id, profile):
- yield self.simple_update_one(
+ yield self.db.simple_update_one(
table="groups",
keyvalues={"group_id": group_id},
updatevalues=profile,
@@ -1008,16 +1010,16 @@ class GroupServerStore(SQLBaseStore):
WHERE valid_until_ms <= ?
"""
txn.execute(sql, (valid_until_ms,))
- return self.cursor_to_dict(txn)
+ return self.db.cursor_to_dict(txn)
- return self.runInteraction(
+ return self.db.runInteraction(
"get_attestations_need_renewals", _get_attestations_need_renewals_txn
)
def update_attestation_renewal(self, group_id, user_id, attestation):
"""Update an attestation that we have renewed
"""
- return self.simple_update_one(
+ return self.db.simple_update_one(
table="group_attestations_renewals",
keyvalues={"group_id": group_id, "user_id": user_id},
updatevalues={"valid_until_ms": attestation["valid_until_ms"]},
@@ -1027,7 +1029,7 @@ class GroupServerStore(SQLBaseStore):
def update_remote_attestion(self, group_id, user_id, attestation):
"""Update an attestation that a remote has renewed
"""
- return self.simple_update_one(
+ return self.db.simple_update_one(
table="group_attestations_remote",
keyvalues={"group_id": group_id, "user_id": user_id},
updatevalues={
@@ -1046,7 +1048,7 @@ class GroupServerStore(SQLBaseStore):
group_id (str)
user_id (str)
"""
- return self.simple_delete(
+ return self.db.simple_delete(
table="group_attestations_renewals",
keyvalues={"group_id": group_id, "user_id": user_id},
desc="remove_attestation_renewal",
@@ -1057,7 +1059,7 @@ class GroupServerStore(SQLBaseStore):
"""Get the attestation that proves the remote agrees that the user is
in the group.
"""
- row = yield self.simple_select_one(
+ row = yield self.db.simple_select_one(
table="group_attestations_remote",
keyvalues={"group_id": group_id, "user_id": user_id},
retcols=("valid_until_ms", "attestation_json"),
@@ -1072,7 +1074,7 @@ class GroupServerStore(SQLBaseStore):
return None
def get_joined_groups(self, user_id):
- return self.simple_select_onecol(
+ return self.db.simple_select_onecol(
table="local_group_membership",
keyvalues={"user_id": user_id, "membership": "join"},
retcol="group_id",
@@ -1099,7 +1101,7 @@ class GroupServerStore(SQLBaseStore):
for row in txn
]
- return self.runInteraction(
+ return self.db.runInteraction(
"get_all_groups_for_user", _get_all_groups_for_user_txn
)
@@ -1129,7 +1131,7 @@ class GroupServerStore(SQLBaseStore):
for group_id, membership, gtype, content_json in txn
]
- return self.runInteraction(
+ return self.db.runInteraction(
"get_groups_changes_for_user", _get_groups_changes_for_user_txn
)
@@ -1154,7 +1156,7 @@ class GroupServerStore(SQLBaseStore):
for stream_id, group_id, user_id, gtype, content_json in txn
]
- return self.runInteraction(
+ return self.db.runInteraction(
"get_all_groups_changes", _get_all_groups_changes_txn
)
@@ -1188,8 +1190,8 @@ class GroupServerStore(SQLBaseStore):
]
for table in tables:
- self.simple_delete_txn(
+ self.db.simple_delete_txn(
txn, table=table, keyvalues={"group_id": group_id}
)
- return self.runInteraction("delete_group", _delete_group_txn)
+ return self.db.runInteraction("delete_group", _delete_group_txn)
diff --git a/synapse/storage/data_stores/main/keys.py b/synapse/storage/data_stores/main/keys.py
index c7150432b3..6b12f5a75f 100644
--- a/synapse/storage/data_stores/main/keys.py
+++ b/synapse/storage/data_stores/main/keys.py
@@ -92,7 +92,7 @@ class KeyStore(SQLBaseStore):
_get_keys(txn, batch)
return keys
- return self.runInteraction("get_server_verify_keys", _txn)
+ return self.db.runInteraction("get_server_verify_keys", _txn)
def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys):
"""Stores NACL verification keys for remote servers.
@@ -127,9 +127,9 @@ class KeyStore(SQLBaseStore):
f((i,))
return res
- return self.runInteraction(
+ return self.db.runInteraction(
"store_server_verify_keys",
- self.simple_upsert_many_txn,
+ self.db.simple_upsert_many_txn,
table="server_signature_keys",
key_names=("server_name", "key_id"),
key_values=key_values,
@@ -157,7 +157,7 @@ class KeyStore(SQLBaseStore):
ts_valid_until_ms (int): The time when this json stops being valid.
key_json (bytes): The encoded JSON.
"""
- return self.simple_upsert(
+ return self.db.simple_upsert(
table="server_keys_json",
keyvalues={
"server_name": server_name,
@@ -196,7 +196,7 @@ class KeyStore(SQLBaseStore):
keyvalues["key_id"] = key_id
if from_server is not None:
keyvalues["from_server"] = from_server
- rows = self.simple_select_list_txn(
+ rows = self.db.simple_select_list_txn(
txn,
"server_keys_json",
keyvalues=keyvalues,
@@ -211,4 +211,4 @@ class KeyStore(SQLBaseStore):
results[(server_name, key_id, from_server)] = rows
return results
- return self.runInteraction("get_server_keys_json", _get_server_keys_json_txn)
+ return self.db.runInteraction("get_server_keys_json", _get_server_keys_json_txn)
diff --git a/synapse/storage/data_stores/main/media_repository.py b/synapse/storage/data_stores/main/media_repository.py
index 0cb9446f96..ea02497784 100644
--- a/synapse/storage/data_stores/main/media_repository.py
+++ b/synapse/storage/data_stores/main/media_repository.py
@@ -39,7 +39,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
Returns:
None if the media_id doesn't exist.
"""
- return self.simple_select_one(
+ return self.db.simple_select_one(
"local_media_repository",
{"media_id": media_id},
(
@@ -64,7 +64,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
user_id,
url_cache=None,
):
- return self.simple_insert(
+ return self.db.simple_insert(
"local_media_repository",
{
"media_id": media_id,
@@ -124,12 +124,12 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
)
)
- return self.runInteraction("get_url_cache", get_url_cache_txn)
+ return self.db.runInteraction("get_url_cache", get_url_cache_txn)
def store_url_cache(
self, url, response_code, etag, expires_ts, og, media_id, download_ts
):
- return self.simple_insert(
+ return self.db.simple_insert(
"local_media_repository_url_cache",
{
"url": url,
@@ -144,7 +144,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
)
def get_local_media_thumbnails(self, media_id):
- return self.simple_select_list(
+ return self.db.simple_select_list(
"local_media_repository_thumbnails",
{"media_id": media_id},
(
@@ -166,7 +166,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
thumbnail_method,
thumbnail_length,
):
- return self.simple_insert(
+ return self.db.simple_insert(
"local_media_repository_thumbnails",
{
"media_id": media_id,
@@ -180,7 +180,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
)
def get_cached_remote_media(self, origin, media_id):
- return self.simple_select_one(
+ return self.db.simple_select_one(
"remote_media_cache",
{"media_origin": origin, "media_id": media_id},
(
@@ -205,7 +205,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
upload_name,
filesystem_id,
):
- return self.simple_insert(
+ return self.db.simple_insert(
"remote_media_cache",
{
"media_origin": origin,
@@ -250,10 +250,12 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
txn.executemany(sql, ((time_ms, media_id) for media_id in local_media))
- return self.runInteraction("update_cached_last_access_time", update_cache_txn)
+ return self.db.runInteraction(
+ "update_cached_last_access_time", update_cache_txn
+ )
def get_remote_media_thumbnails(self, origin, media_id):
- return self.simple_select_list(
+ return self.db.simple_select_list(
"remote_media_cache_thumbnails",
{"media_origin": origin, "media_id": media_id},
(
@@ -278,7 +280,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
thumbnail_method,
thumbnail_length,
):
- return self.simple_insert(
+ return self.db.simple_insert(
"remote_media_cache_thumbnails",
{
"media_origin": origin,
@@ -300,24 +302,24 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
" WHERE last_access_ts < ?"
)
- return self.execute(
- "get_remote_media_before", self.cursor_to_dict, sql, before_ts
+ return self.db.execute(
+ "get_remote_media_before", self.db.cursor_to_dict, sql, before_ts
)
def delete_remote_media(self, media_origin, media_id):
def delete_remote_media_txn(txn):
- self.simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
"remote_media_cache",
keyvalues={"media_origin": media_origin, "media_id": media_id},
)
- self.simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
"remote_media_cache_thumbnails",
keyvalues={"media_origin": media_origin, "media_id": media_id},
)
- return self.runInteraction("delete_remote_media", delete_remote_media_txn)
+ return self.db.runInteraction("delete_remote_media", delete_remote_media_txn)
def get_expired_url_cache(self, now_ts):
sql = (
@@ -331,7 +333,9 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
txn.execute(sql, (now_ts,))
return [row[0] for row in txn]
- return self.runInteraction("get_expired_url_cache", _get_expired_url_cache_txn)
+ return self.db.runInteraction(
+ "get_expired_url_cache", _get_expired_url_cache_txn
+ )
def delete_url_cache(self, media_ids):
if len(media_ids) == 0:
@@ -342,7 +346,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
def _delete_url_cache_txn(txn):
txn.executemany(sql, [(media_id,) for media_id in media_ids])
- return self.runInteraction("delete_url_cache", _delete_url_cache_txn)
+ return self.db.runInteraction("delete_url_cache", _delete_url_cache_txn)
def get_url_cache_media_before(self, before_ts):
sql = (
@@ -356,7 +360,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
txn.execute(sql, (before_ts,))
return [row[0] for row in txn]
- return self.runInteraction(
+ return self.db.runInteraction(
"get_url_cache_media_before", _get_url_cache_media_before_txn
)
@@ -373,6 +377,6 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
txn.executemany(sql, [(media_id,) for media_id in media_ids])
- return self.runInteraction(
+ return self.db.runInteraction(
"delete_url_cache_media", _delete_url_cache_media_txn
)
diff --git a/synapse/storage/data_stores/main/monthly_active_users.py b/synapse/storage/data_stores/main/monthly_active_users.py
index b8fc28f97b..34bf3a1880 100644
--- a/synapse/storage/data_stores/main/monthly_active_users.py
+++ b/synapse/storage/data_stores/main/monthly_active_users.py
@@ -32,7 +32,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
self._clock = hs.get_clock()
self.hs = hs
# Do not add more reserved users than the total allowable number
- self.new_transaction(
+ self.db.new_transaction(
dbconn,
"initialise_mau_threepids",
[],
@@ -146,7 +146,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
txn.execute(sql, query_args)
reserved_users = yield self.get_registered_reserved_users()
- yield self.runInteraction(
+ yield self.db.runInteraction(
"reap_monthly_active_users", _reap_users, reserved_users
)
# It seems poor to invalidate the whole cache, Postgres supports
@@ -174,7 +174,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
(count,) = txn.fetchone()
return count
- return self.runInteraction("count_users", _count_users)
+ return self.db.runInteraction("count_users", _count_users)
@defer.inlineCallbacks
def get_registered_reserved_users(self):
@@ -217,7 +217,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
if is_support:
return
- yield self.runInteraction(
+ yield self.db.runInteraction(
"upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id
)
@@ -261,7 +261,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
# never be a big table and alternative approaches (batching multiple
# upserts into a single txn) introduced a lot of extra complexity.
# See https://github.com/matrix-org/synapse/issues/3854 for more
- is_insert = self.simple_upsert_txn(
+ is_insert = self.db.simple_upsert_txn(
txn,
table="monthly_active_users",
keyvalues={"user_id": user_id},
@@ -281,7 +281,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
"""
- return self.simple_select_one_onecol(
+ return self.db.simple_select_one_onecol(
table="monthly_active_users",
keyvalues={"user_id": user_id},
retcol="timestamp",
diff --git a/synapse/storage/data_stores/main/openid.py b/synapse/storage/data_stores/main/openid.py
index 650e49750e..cc21437e92 100644
--- a/synapse/storage/data_stores/main/openid.py
+++ b/synapse/storage/data_stores/main/openid.py
@@ -3,7 +3,7 @@ from synapse.storage._base import SQLBaseStore
class OpenIdStore(SQLBaseStore):
def insert_open_id_token(self, token, ts_valid_until_ms, user_id):
- return self.simple_insert(
+ return self.db.simple_insert(
table="open_id_tokens",
values={
"token": token,
@@ -28,4 +28,6 @@ class OpenIdStore(SQLBaseStore):
else:
return rows[0][0]
- return self.runInteraction("get_user_id_for_token", get_user_id_for_token_txn)
+ return self.db.runInteraction(
+ "get_user_id_for_token", get_user_id_for_token_txn
+ )
diff --git a/synapse/storage/data_stores/main/presence.py b/synapse/storage/data_stores/main/presence.py
index a5e121efd1..a2c83e0867 100644
--- a/synapse/storage/data_stores/main/presence.py
+++ b/synapse/storage/data_stores/main/presence.py
@@ -29,7 +29,7 @@ class PresenceStore(SQLBaseStore):
)
with stream_ordering_manager as stream_orderings:
- yield self.runInteraction(
+ yield self.db.runInteraction(
"update_presence",
self._update_presence_txn,
stream_orderings,
@@ -46,7 +46,7 @@ class PresenceStore(SQLBaseStore):
txn.call_after(self._get_presence_for_user.invalidate, (state.user_id,))
# Actually insert new rows
- self.simple_insert_many_txn(
+ self.db.simple_insert_many_txn(
txn,
table="presence_stream",
values=[
@@ -88,7 +88,7 @@ class PresenceStore(SQLBaseStore):
txn.execute(sql, (last_id, current_id))
return txn.fetchall()
- return self.runInteraction(
+ return self.db.runInteraction(
"get_all_presence_updates", get_all_presence_updates_txn
)
@@ -103,7 +103,7 @@ class PresenceStore(SQLBaseStore):
inlineCallbacks=True,
)
def get_presence_for_users(self, user_ids):
- rows = yield self.simple_select_many_batch(
+ rows = yield self.db.simple_select_many_batch(
table="presence_stream",
column="user_id",
iterable=user_ids,
@@ -129,7 +129,7 @@ class PresenceStore(SQLBaseStore):
return self._presence_id_gen.get_current_token()
def allow_presence_visible(self, observed_localpart, observer_userid):
- return self.simple_insert(
+ return self.db.simple_insert(
table="presence_allow_inbound",
values={
"observed_user_id": observed_localpart,
@@ -140,7 +140,7 @@ class PresenceStore(SQLBaseStore):
)
def disallow_presence_visible(self, observed_localpart, observer_userid):
- return self.simple_delete_one(
+ return self.db.simple_delete_one(
table="presence_allow_inbound",
keyvalues={
"observed_user_id": observed_localpart,
diff --git a/synapse/storage/data_stores/main/profile.py b/synapse/storage/data_stores/main/profile.py
index c8b5b60301..2b52cf9c1a 100644
--- a/synapse/storage/data_stores/main/profile.py
+++ b/synapse/storage/data_stores/main/profile.py
@@ -24,7 +24,7 @@ class ProfileWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_profileinfo(self, user_localpart):
try:
- profile = yield self.simple_select_one(
+ profile = yield self.db.simple_select_one(
table="profiles",
keyvalues={"user_id": user_localpart},
retcols=("displayname", "avatar_url"),
@@ -42,7 +42,7 @@ class ProfileWorkerStore(SQLBaseStore):
)
def get_profile_displayname(self, user_localpart):
- return self.simple_select_one_onecol(
+ return self.db.simple_select_one_onecol(
table="profiles",
keyvalues={"user_id": user_localpart},
retcol="displayname",
@@ -50,7 +50,7 @@ class ProfileWorkerStore(SQLBaseStore):
)
def get_profile_avatar_url(self, user_localpart):
- return self.simple_select_one_onecol(
+ return self.db.simple_select_one_onecol(
table="profiles",
keyvalues={"user_id": user_localpart},
retcol="avatar_url",
@@ -58,7 +58,7 @@ class ProfileWorkerStore(SQLBaseStore):
)
def get_from_remote_profile_cache(self, user_id):
- return self.simple_select_one(
+ return self.db.simple_select_one(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
retcols=("displayname", "avatar_url"),
@@ -67,12 +67,12 @@ class ProfileWorkerStore(SQLBaseStore):
)
def create_profile(self, user_localpart):
- return self.simple_insert(
+ return self.db.simple_insert(
table="profiles", values={"user_id": user_localpart}, desc="create_profile"
)
def set_profile_displayname(self, user_localpart, new_displayname):
- return self.simple_update_one(
+ return self.db.simple_update_one(
table="profiles",
keyvalues={"user_id": user_localpart},
updatevalues={"displayname": new_displayname},
@@ -80,7 +80,7 @@ class ProfileWorkerStore(SQLBaseStore):
)
def set_profile_avatar_url(self, user_localpart, new_avatar_url):
- return self.simple_update_one(
+ return self.db.simple_update_one(
table="profiles",
keyvalues={"user_id": user_localpart},
updatevalues={"avatar_url": new_avatar_url},
@@ -95,7 +95,7 @@ class ProfileStore(ProfileWorkerStore):
This should only be called when `is_subscribed_remote_profile_for_user`
would return true for the user.
"""
- return self.simple_upsert(
+ return self.db.simple_upsert(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
values={
@@ -107,7 +107,7 @@ class ProfileStore(ProfileWorkerStore):
)
def update_remote_profile_cache(self, user_id, displayname, avatar_url):
- return self.simple_update(
+ return self.db.simple_update(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
values={
@@ -125,7 +125,7 @@ class ProfileStore(ProfileWorkerStore):
"""
subscribed = yield self.is_subscribed_remote_profile_for_user(user_id)
if not subscribed:
- yield self.simple_delete(
+ yield self.db.simple_delete(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
desc="delete_remote_profile_cache",
@@ -144,9 +144,9 @@ class ProfileStore(ProfileWorkerStore):
txn.execute(sql, (last_checked,))
- return self.cursor_to_dict(txn)
+ return self.db.cursor_to_dict(txn)
- return self.runInteraction(
+ return self.db.runInteraction(
"get_remote_profile_cache_entries_that_expire",
_get_remote_profile_cache_entries_that_expire_txn,
)
@@ -155,7 +155,7 @@ class ProfileStore(ProfileWorkerStore):
def is_subscribed_remote_profile_for_user(self, user_id):
"""Check whether we are interested in a remote user's profile.
"""
- res = yield self.simple_select_one_onecol(
+ res = yield self.db.simple_select_one_onecol(
table="group_users",
keyvalues={"user_id": user_id},
retcol="user_id",
@@ -166,7 +166,7 @@ class ProfileStore(ProfileWorkerStore):
if res:
return True
- res = yield self.simple_select_one_onecol(
+ res = yield self.db.simple_select_one_onecol(
table="group_invites",
keyvalues={"user_id": user_id},
retcol="user_id",
diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py
index 75bd499bcd..de682cc63a 100644
--- a/synapse/storage/data_stores/main/push_rule.py
+++ b/synapse/storage/data_stores/main/push_rule.py
@@ -75,7 +75,7 @@ class PushRulesWorkerStore(
def __init__(self, db_conn, hs):
super(PushRulesWorkerStore, self).__init__(db_conn, hs)
- push_rules_prefill, push_rules_id = self.get_cache_dict(
+ push_rules_prefill, push_rules_id = self.db.get_cache_dict(
db_conn,
"push_rules_stream",
entity_column="user_id",
@@ -100,7 +100,7 @@ class PushRulesWorkerStore(
@cachedInlineCallbacks(max_entries=5000)
def get_push_rules_for_user(self, user_id):
- rows = yield self.simple_select_list(
+ rows = yield self.db.simple_select_list(
table="push_rules",
keyvalues={"user_name": user_id},
retcols=(
@@ -124,7 +124,7 @@ class PushRulesWorkerStore(
@cachedInlineCallbacks(max_entries=5000)
def get_push_rules_enabled_for_user(self, user_id):
- results = yield self.simple_select_list(
+ results = yield self.db.simple_select_list(
table="push_rules_enable",
keyvalues={"user_name": user_id},
retcols=("user_name", "rule_id", "enabled"),
@@ -146,7 +146,7 @@ class PushRulesWorkerStore(
(count,) = txn.fetchone()
return bool(count)
- return self.runInteraction(
+ return self.db.runInteraction(
"have_push_rules_changed", have_push_rules_changed_txn
)
@@ -162,7 +162,7 @@ class PushRulesWorkerStore(
results = {user_id: [] for user_id in user_ids}
- rows = yield self.simple_select_many_batch(
+ rows = yield self.db.simple_select_many_batch(
table="push_rules",
column="user_name",
iterable=user_ids,
@@ -320,7 +320,7 @@ class PushRulesWorkerStore(
results = {user_id: {} for user_id in user_ids}
- rows = yield self.simple_select_many_batch(
+ rows = yield self.db.simple_select_many_batch(
table="push_rules_enable",
column="user_name",
iterable=user_ids,
@@ -350,7 +350,7 @@ class PushRuleStore(PushRulesWorkerStore):
with self._push_rules_stream_id_gen.get_next() as ids:
stream_id, event_stream_ordering = ids
if before or after:
- yield self.runInteraction(
+ yield self.db.runInteraction(
"_add_push_rule_relative_txn",
self._add_push_rule_relative_txn,
stream_id,
@@ -364,7 +364,7 @@ class PushRuleStore(PushRulesWorkerStore):
after,
)
else:
- yield self.runInteraction(
+ yield self.db.runInteraction(
"_add_push_rule_highest_priority_txn",
self._add_push_rule_highest_priority_txn,
stream_id,
@@ -395,7 +395,7 @@ class PushRuleStore(PushRulesWorkerStore):
relative_to_rule = before or after
- res = self.simple_select_one_txn(
+ res = self.db.simple_select_one_txn(
txn,
table="push_rules",
keyvalues={"user_name": user_id, "rule_id": relative_to_rule},
@@ -518,7 +518,7 @@ class PushRuleStore(PushRulesWorkerStore):
# We didn't update a row with the given rule_id so insert one
push_rule_id = self._push_rule_id_gen.get_next()
- self.simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="push_rules",
values={
@@ -561,7 +561,7 @@ class PushRuleStore(PushRulesWorkerStore):
"""
def delete_push_rule_txn(txn, stream_id, event_stream_ordering):
- self.simple_delete_one_txn(
+ self.db.simple_delete_one_txn(
txn, "push_rules", {"user_name": user_id, "rule_id": rule_id}
)
@@ -571,7 +571,7 @@ class PushRuleStore(PushRulesWorkerStore):
with self._push_rules_stream_id_gen.get_next() as ids:
stream_id, event_stream_ordering = ids
- yield self.runInteraction(
+ yield self.db.runInteraction(
"delete_push_rule",
delete_push_rule_txn,
stream_id,
@@ -582,7 +582,7 @@ class PushRuleStore(PushRulesWorkerStore):
def set_push_rule_enabled(self, user_id, rule_id, enabled):
with self._push_rules_stream_id_gen.get_next() as ids:
stream_id, event_stream_ordering = ids
- yield self.runInteraction(
+ yield self.db.runInteraction(
"_set_push_rule_enabled_txn",
self._set_push_rule_enabled_txn,
stream_id,
@@ -596,7 +596,7 @@ class PushRuleStore(PushRulesWorkerStore):
self, txn, stream_id, event_stream_ordering, user_id, rule_id, enabled
):
new_id = self._push_rules_enable_id_gen.get_next()
- self.simple_upsert_txn(
+ self.db.simple_upsert_txn(
txn,
"push_rules_enable",
{"user_name": user_id, "rule_id": rule_id},
@@ -636,7 +636,7 @@ class PushRuleStore(PushRulesWorkerStore):
update_stream=False,
)
else:
- self.simple_update_one_txn(
+ self.db.simple_update_one_txn(
txn,
"push_rules",
{"user_name": user_id, "rule_id": rule_id},
@@ -655,7 +655,7 @@ class PushRuleStore(PushRulesWorkerStore):
with self._push_rules_stream_id_gen.get_next() as ids:
stream_id, event_stream_ordering = ids
- yield self.runInteraction(
+ yield self.db.runInteraction(
"set_push_rule_actions",
set_push_rule_actions_txn,
stream_id,
@@ -675,7 +675,7 @@ class PushRuleStore(PushRulesWorkerStore):
if data is not None:
values.update(data)
- self.simple_insert_txn(txn, "push_rules_stream", values=values)
+ self.db.simple_insert_txn(txn, "push_rules_stream", values=values)
txn.call_after(self.get_push_rules_for_user.invalidate, (user_id,))
txn.call_after(self.get_push_rules_enabled_for_user.invalidate, (user_id,))
@@ -699,7 +699,7 @@ class PushRuleStore(PushRulesWorkerStore):
txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall()
- return self.runInteraction(
+ return self.db.runInteraction(
"get_all_push_rule_updates", get_all_push_rule_updates_txn
)
diff --git a/synapse/storage/data_stores/main/pusher.py b/synapse/storage/data_stores/main/pusher.py
index d5a169872b..f07309ef09 100644
--- a/synapse/storage/data_stores/main/pusher.py
+++ b/synapse/storage/data_stores/main/pusher.py
@@ -59,7 +59,7 @@ class PusherWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def user_has_pusher(self, user_id):
- ret = yield self.simple_select_one_onecol(
+ ret = yield self.db.simple_select_one_onecol(
"pushers", {"user_name": user_id}, "id", allow_none=True
)
return ret is not None
@@ -72,7 +72,7 @@ class PusherWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_pushers_by(self, keyvalues):
- ret = yield self.simple_select_list(
+ ret = yield self.db.simple_select_list(
"pushers",
keyvalues,
[
@@ -100,11 +100,11 @@ class PusherWorkerStore(SQLBaseStore):
def get_all_pushers(self):
def get_pushers(txn):
txn.execute("SELECT * FROM pushers")
- rows = self.cursor_to_dict(txn)
+ rows = self.db.cursor_to_dict(txn)
return self._decode_pushers_rows(rows)
- rows = yield self.runInteraction("get_all_pushers", get_pushers)
+ rows = yield self.db.runInteraction("get_all_pushers", get_pushers)
return rows
def get_all_updated_pushers(self, last_id, current_id, limit):
@@ -134,7 +134,7 @@ class PusherWorkerStore(SQLBaseStore):
return updated, deleted
- return self.runInteraction(
+ return self.db.runInteraction(
"get_all_updated_pushers", get_all_updated_pushers_txn
)
@@ -177,7 +177,7 @@ class PusherWorkerStore(SQLBaseStore):
return results
- return self.runInteraction(
+ return self.db.runInteraction(
"get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn
)
@@ -193,7 +193,7 @@ class PusherWorkerStore(SQLBaseStore):
inlineCallbacks=True,
)
def get_if_users_have_pushers(self, user_ids):
- rows = yield self.simple_select_many_batch(
+ rows = yield self.db.simple_select_many_batch(
table="pushers",
column="user_name",
iterable=user_ids,
@@ -230,7 +230,7 @@ class PusherStore(PusherWorkerStore):
with self._pushers_id_gen.get_next() as stream_id:
# no need to lock because `pushers` has a unique key on
# (app_id, pushkey, user_name) so simple_upsert will retry
- yield self.simple_upsert(
+ yield self.db.simple_upsert(
table="pushers",
keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
values={
@@ -255,7 +255,7 @@ class PusherStore(PusherWorkerStore):
if user_has_pusher is not True:
# invalidate, since we the user might not have had a pusher before
- yield self.runInteraction(
+ yield self.db.runInteraction(
"add_pusher",
self._invalidate_cache_and_stream,
self.get_if_user_has_pusher,
@@ -269,7 +269,7 @@ class PusherStore(PusherWorkerStore):
txn, self.get_if_user_has_pusher, (user_id,)
)
- self.simple_delete_one_txn(
+ self.db.simple_delete_one_txn(
txn,
"pushers",
{"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
@@ -278,7 +278,7 @@ class PusherStore(PusherWorkerStore):
# it's possible for us to end up with duplicate rows for
# (app_id, pushkey, user_id) at different stream_ids, but that
# doesn't really matter.
- self.simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="deleted_pushers",
values={
@@ -290,13 +290,13 @@ class PusherStore(PusherWorkerStore):
)
with self._pushers_id_gen.get_next() as stream_id:
- yield self.runInteraction("delete_pusher", delete_pusher_txn, stream_id)
+ yield self.db.runInteraction("delete_pusher", delete_pusher_txn, stream_id)
@defer.inlineCallbacks
def update_pusher_last_stream_ordering(
self, app_id, pushkey, user_id, last_stream_ordering
):
- yield self.simple_update_one(
+ yield self.db.simple_update_one(
"pushers",
{"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
{"last_stream_ordering": last_stream_ordering},
@@ -319,7 +319,7 @@ class PusherStore(PusherWorkerStore):
Returns:
Deferred[bool]: True if the pusher still exists; False if it has been deleted.
"""
- updated = yield self.simple_update(
+ updated = yield self.db.simple_update(
table="pushers",
keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
updatevalues={
@@ -333,7 +333,7 @@ class PusherStore(PusherWorkerStore):
@defer.inlineCallbacks
def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since):
- yield self.simple_update(
+ yield self.db.simple_update(
table="pushers",
keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
updatevalues={"failing_since": failing_since},
@@ -342,7 +342,7 @@ class PusherStore(PusherWorkerStore):
@defer.inlineCallbacks
def get_throttle_params_by_room(self, pusher_id):
- res = yield self.simple_select_list(
+ res = yield self.db.simple_select_list(
"pusher_throttle",
{"pusher": pusher_id},
["room_id", "last_sent_ts", "throttle_ms"],
@@ -362,7 +362,7 @@ class PusherStore(PusherWorkerStore):
def set_throttle_params(self, pusher_id, room_id, params):
# no need to lock because `pusher_throttle` has a primary key on
# (pusher, room_id) so simple_upsert will retry
- yield self.simple_upsert(
+ yield self.db.simple_upsert(
"pusher_throttle",
{"pusher": pusher_id, "room_id": room_id},
params,
diff --git a/synapse/storage/data_stores/main/receipts.py b/synapse/storage/data_stores/main/receipts.py
index 380f388e30..ac2d45bd5c 100644
--- a/synapse/storage/data_stores/main/receipts.py
+++ b/synapse/storage/data_stores/main/receipts.py
@@ -61,7 +61,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
@cached(num_args=2)
def get_receipts_for_room(self, room_id, receipt_type):
- return self.simple_select_list(
+ return self.db.simple_select_list(
table="receipts_linearized",
keyvalues={"room_id": room_id, "receipt_type": receipt_type},
retcols=("user_id", "event_id"),
@@ -70,7 +70,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
@cached(num_args=3)
def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type):
- return self.simple_select_one_onecol(
+ return self.db.simple_select_one_onecol(
table="receipts_linearized",
keyvalues={
"room_id": room_id,
@@ -84,7 +84,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
@cachedInlineCallbacks(num_args=2)
def get_receipts_for_user(self, user_id, receipt_type):
- rows = yield self.simple_select_list(
+ rows = yield self.db.simple_select_list(
table="receipts_linearized",
keyvalues={"user_id": user_id, "receipt_type": receipt_type},
retcols=("room_id", "event_id"),
@@ -108,7 +108,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
txn.execute(sql, (user_id,))
return txn.fetchall()
- rows = yield self.runInteraction("get_receipts_for_user_with_orderings", f)
+ rows = yield self.db.runInteraction("get_receipts_for_user_with_orderings", f)
return {
row[0]: {
"event_id": row[1],
@@ -187,11 +187,11 @@ class ReceiptsWorkerStore(SQLBaseStore):
txn.execute(sql, (room_id, to_key))
- rows = self.cursor_to_dict(txn)
+ rows = self.db.cursor_to_dict(txn)
return rows
- rows = yield self.runInteraction("get_linearized_receipts_for_room", f)
+ rows = yield self.db.runInteraction("get_linearized_receipts_for_room", f)
if not rows:
return []
@@ -237,9 +237,11 @@ class ReceiptsWorkerStore(SQLBaseStore):
txn.execute(sql + clause, [to_key] + list(args))
- return self.cursor_to_dict(txn)
+ return self.db.cursor_to_dict(txn)
- txn_results = yield self.runInteraction("_get_linearized_receipts_for_rooms", f)
+ txn_results = yield self.db.runInteraction(
+ "_get_linearized_receipts_for_rooms", f
+ )
results = {}
for row in txn_results:
@@ -282,7 +284,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
return list(r[0:5] + (json.loads(r[5]),) for r in txn)
- return self.runInteraction(
+ return self.db.runInteraction(
"get_all_updated_receipts", get_all_updated_receipts_txn
)
@@ -335,7 +337,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
otherwise, the rx timestamp of the event that the RR corresponds to
(or 0 if the event is unknown)
"""
- res = self.simple_select_one_txn(
+ res = self.db.simple_select_one_txn(
txn,
table="events",
retcols=["stream_ordering", "received_ts"],
@@ -388,7 +390,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
(user_id, room_id, receipt_type),
)
- self.simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="receipts_linearized",
keyvalues={
@@ -398,7 +400,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
},
)
- self.simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="receipts_linearized",
values={
@@ -453,13 +455,13 @@ class ReceiptsStore(ReceiptsWorkerStore):
else:
raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))
- linearized_event_id = yield self.runInteraction(
+ linearized_event_id = yield self.db.runInteraction(
"insert_receipt_conv", graph_to_linear
)
stream_id_manager = self._receipts_id_gen.get_next()
with stream_id_manager as stream_id:
- event_ts = yield self.runInteraction(
+ event_ts = yield self.db.runInteraction(
"insert_linearized_receipt",
self.insert_linearized_receipt_txn,
room_id,
@@ -488,7 +490,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
return stream_id, max_persisted_id
def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, data):
- return self.runInteraction(
+ return self.db.runInteraction(
"insert_graph_receipt",
self.insert_graph_receipt_txn,
room_id,
@@ -514,7 +516,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
self._get_linearized_receipts_for_room.invalidate_many, (room_id,)
)
- self.simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="receipts_graph",
keyvalues={
@@ -523,7 +525,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
"user_id": user_id,
},
)
- self.simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="receipts_graph",
values={
diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py
index debc6706f5..8f9aa87ceb 100644
--- a/synapse/storage/data_stores/main/registration.py
+++ b/synapse/storage/data_stores/main/registration.py
@@ -45,7 +45,7 @@ class RegistrationWorkerStore(SQLBaseStore):
@cached()
def get_user_by_id(self, user_id):
- return self.simple_select_one(
+ return self.db.simple_select_one(
table="users",
keyvalues={"name": user_id},
retcols=[
@@ -94,7 +94,7 @@ class RegistrationWorkerStore(SQLBaseStore):
including the keys `name`, `is_guest`, `device_id`, `token_id`,
`valid_until_ms`.
"""
- return self.runInteraction(
+ return self.db.runInteraction(
"get_user_by_access_token", self._query_for_auth, token
)
@@ -109,7 +109,7 @@ class RegistrationWorkerStore(SQLBaseStore):
otherwise int representation of the timestamp (as a number of
milliseconds since epoch).
"""
- res = yield self.simple_select_one_onecol(
+ res = yield self.db.simple_select_one_onecol(
table="account_validity",
keyvalues={"user_id": user_id},
retcol="expiration_ts_ms",
@@ -137,7 +137,7 @@ class RegistrationWorkerStore(SQLBaseStore):
"""
def set_account_validity_for_user_txn(txn):
- self.simple_update_txn(
+ self.db.simple_update_txn(
txn=txn,
table="account_validity",
keyvalues={"user_id": user_id},
@@ -151,7 +151,7 @@ class RegistrationWorkerStore(SQLBaseStore):
txn, self.get_expiration_ts_for_user, (user_id,)
)
- yield self.runInteraction(
+ yield self.db.runInteraction(
"set_account_validity_for_user", set_account_validity_for_user_txn
)
@@ -167,7 +167,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Raises:
StoreError: The provided token is already set for another user.
"""
- yield self.simple_update_one(
+ yield self.db.simple_update_one(
table="account_validity",
keyvalues={"user_id": user_id},
updatevalues={"renewal_token": renewal_token},
@@ -184,7 +184,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
defer.Deferred[str]: The ID of the user to which the token belongs.
"""
- res = yield self.simple_select_one_onecol(
+ res = yield self.db.simple_select_one_onecol(
table="account_validity",
keyvalues={"renewal_token": renewal_token},
retcol="user_id",
@@ -203,7 +203,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
defer.Deferred[str]: The renewal token associated with this user ID.
"""
- res = yield self.simple_select_one_onecol(
+ res = yield self.db.simple_select_one_onecol(
table="account_validity",
keyvalues={"user_id": user_id},
retcol="renewal_token",
@@ -229,9 +229,9 @@ class RegistrationWorkerStore(SQLBaseStore):
)
values = [False, now_ms, renew_at]
txn.execute(sql, values)
- return self.cursor_to_dict(txn)
+ return self.db.cursor_to_dict(txn)
- res = yield self.runInteraction(
+ res = yield self.db.runInteraction(
"get_users_expiring_soon",
select_users_txn,
self.clock.time_msec(),
@@ -250,7 +250,7 @@ class RegistrationWorkerStore(SQLBaseStore):
email_sent (bool): Flag which indicates whether a renewal email has been sent
to this user.
"""
- yield self.simple_update_one(
+ yield self.db.simple_update_one(
table="account_validity",
keyvalues={"user_id": user_id},
updatevalues={"email_sent": email_sent},
@@ -265,7 +265,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Args:
user_id (str): ID of the user to remove from the account validity table.
"""
- yield self.simple_delete_one(
+ yield self.db.simple_delete_one(
table="account_validity",
keyvalues={"user_id": user_id},
desc="delete_account_validity_for_user",
@@ -281,7 +281,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns (bool):
true iff the user is a server admin, false otherwise.
"""
- res = yield self.simple_select_one_onecol(
+ res = yield self.db.simple_select_one_onecol(
table="users",
keyvalues={"name": user.to_string()},
retcol="admin",
@@ -299,7 +299,7 @@ class RegistrationWorkerStore(SQLBaseStore):
admin (bool): true iff the user is to be a server admin,
false otherwise.
"""
- return self.simple_update_one(
+ return self.db.simple_update_one(
table="users",
keyvalues={"name": user.to_string()},
updatevalues={"admin": 1 if admin else 0},
@@ -316,7 +316,7 @@ class RegistrationWorkerStore(SQLBaseStore):
)
txn.execute(sql, (token,))
- rows = self.cursor_to_dict(txn)
+ rows = self.db.cursor_to_dict(txn)
if rows:
return rows[0]
@@ -332,7 +332,9 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
Deferred[bool]: True if user 'user_type' is null or empty string
"""
- res = yield self.runInteraction("is_real_user", self.is_real_user_txn, user_id)
+ res = yield self.db.runInteraction(
+ "is_real_user", self.is_real_user_txn, user_id
+ )
return res
@cachedInlineCallbacks()
@@ -345,13 +347,13 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
Deferred[bool]: True if user is of type UserTypes.SUPPORT
"""
- res = yield self.runInteraction(
+ res = yield self.db.runInteraction(
"is_support_user", self.is_support_user_txn, user_id
)
return res
def is_real_user_txn(self, txn, user_id):
- res = self.simple_select_one_onecol_txn(
+ res = self.db.simple_select_one_onecol_txn(
txn=txn,
table="users",
keyvalues={"name": user_id},
@@ -361,7 +363,7 @@ class RegistrationWorkerStore(SQLBaseStore):
return res is None
def is_support_user_txn(self, txn, user_id):
- res = self.simple_select_one_onecol_txn(
+ res = self.db.simple_select_one_onecol_txn(
txn=txn,
table="users",
keyvalues={"name": user_id},
@@ -380,7 +382,7 @@ class RegistrationWorkerStore(SQLBaseStore):
txn.execute(sql, (user_id,))
return dict(txn)
- return self.runInteraction("get_users_by_id_case_insensitive", f)
+ return self.db.runInteraction("get_users_by_id_case_insensitive", f)
async def get_user_by_external_id(
self, auth_provider: str, external_id: str
@@ -394,7 +396,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
str|None: the mxid of the user, or None if they are not known
"""
- return await self.simple_select_one_onecol(
+ return await self.db.simple_select_one_onecol(
table="user_external_ids",
keyvalues={"auth_provider": auth_provider, "external_id": external_id},
retcol="user_id",
@@ -408,12 +410,12 @@ class RegistrationWorkerStore(SQLBaseStore):
def _count_users(txn):
txn.execute("SELECT COUNT(*) AS users FROM users")
- rows = self.cursor_to_dict(txn)
+ rows = self.db.cursor_to_dict(txn)
if rows:
return rows[0]["users"]
return 0
- ret = yield self.runInteraction("count_users", _count_users)
+ ret = yield self.db.runInteraction("count_users", _count_users)
return ret
def count_daily_user_type(self):
@@ -445,7 +447,7 @@ class RegistrationWorkerStore(SQLBaseStore):
results[row[0]] = row[1]
return results
- return self.runInteraction("count_daily_user_type", _count_daily_user_type)
+ return self.db.runInteraction("count_daily_user_type", _count_daily_user_type)
@defer.inlineCallbacks
def count_nonbridged_users(self):
@@ -459,7 +461,7 @@ class RegistrationWorkerStore(SQLBaseStore):
(count,) = txn.fetchone()
return count
- ret = yield self.runInteraction("count_users", _count_users)
+ ret = yield self.db.runInteraction("count_users", _count_users)
return ret
@defer.inlineCallbacks
@@ -468,12 +470,12 @@ class RegistrationWorkerStore(SQLBaseStore):
def _count_users(txn):
txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null")
- rows = self.cursor_to_dict(txn)
+ rows = self.db.cursor_to_dict(txn)
if rows:
return rows[0]["users"]
return 0
- ret = yield self.runInteraction("count_real_users", _count_users)
+ ret = yield self.db.runInteraction("count_real_users", _count_users)
return ret
@defer.inlineCallbacks
@@ -503,7 +505,7 @@ class RegistrationWorkerStore(SQLBaseStore):
return (
(
- yield self.runInteraction(
+ yield self.db.runInteraction(
"find_next_generated_user_id", _find_next_generated_user_id
)
)
@@ -520,7 +522,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
Deferred[str|None]: user id or None if no user id/threepid mapping exists
"""
- user_id = yield self.runInteraction(
+ user_id = yield self.db.runInteraction(
"get_user_id_by_threepid", self.get_user_id_by_threepid_txn, medium, address
)
return user_id
@@ -536,7 +538,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
str|None: user id or None if no user id/threepid mapping exists
"""
- ret = self.simple_select_one_txn(
+ ret = self.db.simple_select_one_txn(
txn,
"user_threepids",
{"medium": medium, "address": address},
@@ -549,7 +551,7 @@ class RegistrationWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
- yield self.simple_upsert(
+ yield self.db.simple_upsert(
"user_threepids",
{"medium": medium, "address": address},
{"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
@@ -557,7 +559,7 @@ class RegistrationWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def user_get_threepids(self, user_id):
- ret = yield self.simple_select_list(
+ ret = yield self.db.simple_select_list(
"user_threepids",
{"user_id": user_id},
["medium", "address", "validated_at", "added_at"],
@@ -566,7 +568,7 @@ class RegistrationWorkerStore(SQLBaseStore):
return ret
def user_delete_threepid(self, user_id, medium, address):
- return self.simple_delete(
+ return self.db.simple_delete(
"user_threepids",
keyvalues={"user_id": user_id, "medium": medium, "address": address},
desc="user_delete_threepid",
@@ -579,7 +581,7 @@ class RegistrationWorkerStore(SQLBaseStore):
user_id: The user id to delete all threepids of
"""
- return self.simple_delete(
+ return self.db.simple_delete(
"user_threepids",
keyvalues={"user_id": user_id},
desc="user_delete_threepids",
@@ -601,7 +603,7 @@ class RegistrationWorkerStore(SQLBaseStore):
"""
# We need to use an upsert, in case they user had already bound the
# threepid
- return self.simple_upsert(
+ return self.db.simple_upsert(
table="user_threepid_id_server",
keyvalues={
"user_id": user_id,
@@ -627,7 +629,7 @@ class RegistrationWorkerStore(SQLBaseStore):
medium (str): The medium of the threepid (e.g "email")
address (str): The address of the threepid (e.g "bob@example.com")
"""
- return self.simple_select_list(
+ return self.db.simple_select_list(
table="user_threepid_id_server",
keyvalues={"user_id": user_id},
retcols=["medium", "address"],
@@ -648,7 +650,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
Deferred
"""
- return self.simple_delete(
+ return self.db.simple_delete(
table="user_threepid_id_server",
keyvalues={
"user_id": user_id,
@@ -671,7 +673,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
Deferred[list[str]]: Resolves to a list of identity servers
"""
- return self.simple_select_onecol(
+ return self.db.simple_select_onecol(
table="user_threepid_id_server",
keyvalues={"user_id": user_id, "medium": medium, "address": address},
retcol="id_server",
@@ -689,7 +691,7 @@ class RegistrationWorkerStore(SQLBaseStore):
defer.Deferred(bool): The requested value.
"""
- res = yield self.simple_select_one_onecol(
+ res = yield self.db.simple_select_one_onecol(
table="users",
keyvalues={"name": user_id},
retcol="deactivated",
@@ -756,13 +758,13 @@ class RegistrationWorkerStore(SQLBaseStore):
sql += " LIMIT 1"
txn.execute(sql, list(keyvalues.values()))
- rows = self.cursor_to_dict(txn)
+ rows = self.db.cursor_to_dict(txn)
if not rows:
return None
return rows[0]
- return self.runInteraction(
+ return self.db.runInteraction(
"get_threepid_validation_session", get_threepid_validation_session_txn
)
@@ -776,18 +778,18 @@ class RegistrationWorkerStore(SQLBaseStore):
"""
def delete_threepid_session_txn(txn):
- self.simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="threepid_validation_token",
keyvalues={"session_id": session_id},
)
- self.simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="threepid_validation_session",
keyvalues={"session_id": session_id},
)
- return self.runInteraction(
+ return self.db.runInteraction(
"delete_threepid_session", delete_threepid_session_txn
)
@@ -857,7 +859,7 @@ class RegistrationBackgroundUpdateStore(
(last_user, batch_size),
)
- rows = self.cursor_to_dict(txn)
+ rows = self.db.cursor_to_dict(txn)
if not rows:
return True, 0
@@ -880,7 +882,7 @@ class RegistrationBackgroundUpdateStore(
else:
return False, len(rows)
- end, nb_processed = yield self.runInteraction(
+ end, nb_processed = yield self.db.runInteraction(
"users_set_deactivated_flag", _background_update_set_deactivated_flag_txn
)
@@ -911,7 +913,7 @@ class RegistrationBackgroundUpdateStore(
txn.executemany(sql, [(id_server,) for id_server in id_servers])
if id_servers:
- yield self.runInteraction(
+ yield self.db.runInteraction(
"_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn
)
@@ -961,7 +963,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
"""
next_id = self._access_tokens_id_gen.get_next()
- yield self.simple_insert(
+ yield self.db.simple_insert(
"access_tokens",
{
"id": next_id,
@@ -1003,7 +1005,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
Raises:
StoreError if the user_id could not be registered.
"""
- return self.runInteraction(
+ return self.db.runInteraction(
"register_user",
self._register_user,
user_id,
@@ -1037,7 +1039,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
# Ensure that the guest user actually exists
# ``allow_none=False`` makes this raise an exception
# if the row isn't in the database.
- self.simple_select_one_txn(
+ self.db.simple_select_one_txn(
txn,
"users",
keyvalues={"name": user_id, "is_guest": 1},
@@ -1045,7 +1047,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
allow_none=False,
)
- self.simple_update_one_txn(
+ self.db.simple_update_one_txn(
txn,
"users",
keyvalues={"name": user_id, "is_guest": 1},
@@ -1059,7 +1061,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
},
)
else:
- self.simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
"users",
values={
@@ -1114,7 +1116,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
external_id: id on that system
user_id: complete mxid that it is mapped to
"""
- return self.simple_insert(
+ return self.db.simple_insert(
table="user_external_ids",
values={
"auth_provider": auth_provider,
@@ -1132,12 +1134,14 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
"""
def user_set_password_hash_txn(txn):
- self.simple_update_one_txn(
+ self.db.simple_update_one_txn(
txn, "users", {"name": user_id}, {"password_hash": password_hash}
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
- return self.runInteraction("user_set_password_hash", user_set_password_hash_txn)
+ return self.db.runInteraction(
+ "user_set_password_hash", user_set_password_hash_txn
+ )
def user_set_consent_version(self, user_id, consent_version):
"""Updates the user table to record privacy policy consent
@@ -1152,7 +1156,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
"""
def f(txn):
- self.simple_update_one_txn(
+ self.db.simple_update_one_txn(
txn,
table="users",
keyvalues={"name": user_id},
@@ -1160,7 +1164,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
- return self.runInteraction("user_set_consent_version", f)
+ return self.db.runInteraction("user_set_consent_version", f)
def user_set_consent_server_notice_sent(self, user_id, consent_version):
"""Updates the user table to record that we have sent the user a server
@@ -1176,7 +1180,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
"""
def f(txn):
- self.simple_update_one_txn(
+ self.db.simple_update_one_txn(
txn,
table="users",
keyvalues={"name": user_id},
@@ -1184,7 +1188,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
- return self.runInteraction("user_set_consent_server_notice_sent", f)
+ return self.db.runInteraction("user_set_consent_server_notice_sent", f)
def user_delete_access_tokens(self, user_id, except_token_id=None, device_id=None):
"""
@@ -1230,11 +1234,11 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
return tokens_and_devices
- return self.runInteraction("user_delete_access_tokens", f)
+ return self.db.runInteraction("user_delete_access_tokens", f)
def delete_access_token(self, access_token):
def f(txn):
- self.simple_delete_one_txn(
+ self.db.simple_delete_one_txn(
txn, table="access_tokens", keyvalues={"token": access_token}
)
@@ -1242,11 +1246,11 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
txn, self.get_user_by_access_token, (access_token,)
)
- return self.runInteraction("delete_access_token", f)
+ return self.db.runInteraction("delete_access_token", f)
@cachedInlineCallbacks()
def is_guest(self, user_id):
- res = yield self.simple_select_one_onecol(
+ res = yield self.db.simple_select_one_onecol(
table="users",
keyvalues={"name": user_id},
retcol="is_guest",
@@ -1261,7 +1265,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
Adds a user to the table of users who need to be parted from all the rooms they're
in
"""
- return self.simple_insert(
+ return self.db.simple_insert(
"users_pending_deactivation",
values={"user_id": user_id},
desc="add_user_pending_deactivation",
@@ -1274,7 +1278,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
"""
# XXX: This should be simple_delete_one but we failed to put a unique index on
# the table, so somehow duplicate entries have ended up in it.
- return self.simple_delete(
+ return self.db.simple_delete(
"users_pending_deactivation",
keyvalues={"user_id": user_id},
desc="del_user_pending_deactivation",
@@ -1285,7 +1289,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
Gets one user from the table of users waiting to be parted from all the rooms
they're in.
"""
- return self.simple_select_one_onecol(
+ return self.db.simple_select_one_onecol(
"users_pending_deactivation",
keyvalues={},
retcol="user_id",
@@ -1315,7 +1319,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
# Insert everything into a transaction in order to run atomically
def validate_threepid_session_txn(txn):
- row = self.simple_select_one_txn(
+ row = self.db.simple_select_one_txn(
txn,
table="threepid_validation_session",
keyvalues={"session_id": session_id},
@@ -1333,7 +1337,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
400, "This client_secret does not match the provided session_id"
)
- row = self.simple_select_one_txn(
+ row = self.db.simple_select_one_txn(
txn,
table="threepid_validation_token",
keyvalues={"session_id": session_id, "token": token},
@@ -1358,7 +1362,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
# Looks good. Validate the session
- self.simple_update_txn(
+ self.db.simple_update_txn(
txn,
table="threepid_validation_session",
keyvalues={"session_id": session_id},
@@ -1368,7 +1372,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
return next_link
# Return next_link if it exists
- return self.runInteraction(
+ return self.db.runInteraction(
"validate_threepid_session_txn", validate_threepid_session_txn
)
@@ -1401,7 +1405,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
if validated_at:
insertion_values["validated_at"] = validated_at
- return self.simple_upsert(
+ return self.db.simple_upsert(
table="threepid_validation_session",
keyvalues={"session_id": session_id},
values={"last_send_attempt": send_attempt},
@@ -1439,7 +1443,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
def start_or_continue_validation_session_txn(txn):
# Create or update a validation session
- self.simple_upsert_txn(
+ self.db.simple_upsert_txn(
txn,
table="threepid_validation_session",
keyvalues={"session_id": session_id},
@@ -1452,7 +1456,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
# Create a new validation token with this session ID
- self.simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="threepid_validation_token",
values={
@@ -1463,7 +1467,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
},
)
- return self.runInteraction(
+ return self.db.runInteraction(
"start_or_continue_validation_session",
start_or_continue_validation_session_txn,
)
@@ -1478,7 +1482,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
"""
return txn.execute(sql, (ts,))
- return self.runInteraction(
+ return self.db.runInteraction(
"cull_expired_threepid_validation_tokens",
cull_expired_threepid_validation_tokens_txn,
self.clock.time_msec(),
@@ -1493,7 +1497,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
deactivated (bool): The value to set for `deactivated`.
"""
- yield self.runInteraction(
+ yield self.db.runInteraction(
"set_user_deactivated_status",
self.set_user_deactivated_status_txn,
user_id,
@@ -1501,7 +1505,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
def set_user_deactivated_status_txn(self, txn, user_id, deactivated):
- self.simple_update_one_txn(
+ self.db.simple_update_one_txn(
txn=txn,
table="users",
keyvalues={"name": user_id},
@@ -1529,14 +1533,14 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
txn.execute(sql, [])
- res = self.cursor_to_dict(txn)
+ res = self.db.cursor_to_dict(txn)
if res:
for user in res:
self.set_expiration_date_for_user_txn(
txn, user["name"], use_delta=True
)
- yield self.runInteraction(
+ yield self.db.runInteraction(
"get_users_with_no_expiration_date",
select_users_with_no_expiration_date_txn,
)
@@ -1560,7 +1564,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
expiration_ts,
)
- self.simple_upsert_txn(
+ self.db.simple_upsert_txn(
txn,
"account_validity",
keyvalues={"user_id": user_id},
diff --git a/synapse/storage/data_stores/main/rejections.py b/synapse/storage/data_stores/main/rejections.py
index f81f9279a1..1c07c7a425 100644
--- a/synapse/storage/data_stores/main/rejections.py
+++ b/synapse/storage/data_stores/main/rejections.py
@@ -22,7 +22,7 @@ logger = logging.getLogger(__name__)
class RejectionsStore(SQLBaseStore):
def _store_rejections_txn(self, txn, event_id, reason):
- self.simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="rejections",
values={
@@ -33,7 +33,7 @@ class RejectionsStore(SQLBaseStore):
)
def get_rejection_reason(self, event_id):
- return self.simple_select_one_onecol(
+ return self.db.simple_select_one_onecol(
table="rejections",
retcol="reason",
keyvalues={"event_id": event_id},
diff --git a/synapse/storage/data_stores/main/relations.py b/synapse/storage/data_stores/main/relations.py
index aa5e10538b..046c2b4845 100644
--- a/synapse/storage/data_stores/main/relations.py
+++ b/synapse/storage/data_stores/main/relations.py
@@ -129,7 +129,7 @@ class RelationsWorkerStore(SQLBaseStore):
chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
)
- return self.runInteraction(
+ return self.db.runInteraction(
"get_recent_references_for_event", _get_recent_references_for_event_txn
)
@@ -223,7 +223,7 @@ class RelationsWorkerStore(SQLBaseStore):
chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
)
- return self.runInteraction(
+ return self.db.runInteraction(
"get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
)
@@ -268,7 +268,7 @@ class RelationsWorkerStore(SQLBaseStore):
if row:
return row[0]
- edit_id = yield self.runInteraction(
+ edit_id = yield self.db.runInteraction(
"get_applicable_edit", _get_applicable_edit_txn
)
@@ -318,7 +318,7 @@ class RelationsWorkerStore(SQLBaseStore):
return bool(txn.fetchone())
- return self.runInteraction(
+ return self.db.runInteraction(
"get_if_user_has_annotated_event", _get_if_user_has_annotated_event
)
@@ -352,7 +352,7 @@ class RelationsStore(RelationsWorkerStore):
aggregation_key = relation.get("key")
- self.simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="event_relations",
values={
@@ -380,6 +380,6 @@ class RelationsStore(RelationsWorkerStore):
redacted_event_id (str): The event that was redacted.
"""
- self.simple_delete_txn(
+ self.db.simple_delete_txn(
txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
)
diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py
index f309e3640c..a26ed47afc 100644
--- a/synapse/storage/data_stores/main/room.py
+++ b/synapse/storage/data_stores/main/room.py
@@ -54,7 +54,7 @@ class RoomWorkerStore(SQLBaseStore):
Returns:
A dict containing the room information, or None if the room is unknown.
"""
- return self.simple_select_one(
+ return self.db.simple_select_one(
table="rooms",
keyvalues={"room_id": room_id},
retcols=("room_id", "is_public", "creator"),
@@ -63,7 +63,7 @@ class RoomWorkerStore(SQLBaseStore):
)
def get_public_room_ids(self):
- return self.simple_select_onecol(
+ return self.db.simple_select_onecol(
table="rooms",
keyvalues={"is_public": True},
retcol="room_id",
@@ -120,7 +120,7 @@ class RoomWorkerStore(SQLBaseStore):
txn.execute(sql, query_args)
return txn.fetchone()[0]
- return self.runInteraction("count_public_rooms", _count_public_rooms_txn)
+ return self.db.runInteraction("count_public_rooms", _count_public_rooms_txn)
@defer.inlineCallbacks
def get_largest_public_rooms(
@@ -253,21 +253,21 @@ class RoomWorkerStore(SQLBaseStore):
def _get_largest_public_rooms_txn(txn):
txn.execute(sql, query_args)
- results = self.cursor_to_dict(txn)
+ results = self.db.cursor_to_dict(txn)
if not forwards:
results.reverse()
return results
- ret_val = yield self.runInteraction(
+ ret_val = yield self.db.runInteraction(
"get_largest_public_rooms", _get_largest_public_rooms_txn
)
defer.returnValue(ret_val)
@cached(max_entries=10000)
def is_room_blocked(self, room_id):
- return self.simple_select_one_onecol(
+ return self.db.simple_select_one_onecol(
table="blocked_rooms",
keyvalues={"room_id": room_id},
retcol="1",
@@ -288,7 +288,7 @@ class RoomWorkerStore(SQLBaseStore):
of RatelimitOverride are None or 0 then ratelimitng has been
disabled for that user entirely.
"""
- row = yield self.simple_select_one(
+ row = yield self.db.simple_select_one(
table="ratelimit_override",
keyvalues={"user_id": user_id},
retcols=("messages_per_second", "burst_count"),
@@ -330,9 +330,9 @@ class RoomWorkerStore(SQLBaseStore):
(room_id,),
)
- return self.cursor_to_dict(txn)
+ return self.db.cursor_to_dict(txn)
- ret = yield self.runInteraction(
+ ret = yield self.db.runInteraction(
"get_retention_policy_for_room", get_retention_policy_for_room_txn,
)
@@ -396,7 +396,7 @@ class RoomBackgroundUpdateStore(BackgroundUpdateStore):
(last_room, batch_size),
)
- rows = self.cursor_to_dict(txn)
+ rows = self.db.cursor_to_dict(txn)
if not rows:
return True
@@ -408,7 +408,7 @@ class RoomBackgroundUpdateStore(BackgroundUpdateStore):
ev = json.loads(row["json"])
retention_policy = json.dumps(ev["content"])
- self.simple_insert_txn(
+ self.db.simple_insert_txn(
txn=txn,
table="room_retention",
values={
@@ -430,7 +430,7 @@ class RoomBackgroundUpdateStore(BackgroundUpdateStore):
else:
return False
- end = yield self.runInteraction(
+ end = yield self.db.runInteraction(
"insert_room_retention", _background_insert_retention_txn,
)
@@ -461,7 +461,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
try:
def store_room_txn(txn, next_id):
- self.simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
"rooms",
{
@@ -471,7 +471,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
},
)
if is_public:
- self.simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="public_room_list_stream",
values={
@@ -482,7 +482,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
)
with self._public_room_id_gen.get_next() as next_id:
- yield self.runInteraction("store_room_txn", store_room_txn, next_id)
+ yield self.db.runInteraction("store_room_txn", store_room_txn, next_id)
except Exception as e:
logger.error("store_room with room_id=%s failed: %s", room_id, e)
raise StoreError(500, "Problem creating room.")
@@ -490,14 +490,14 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
@defer.inlineCallbacks
def set_room_is_public(self, room_id, is_public):
def set_room_is_public_txn(txn, next_id):
- self.simple_update_one_txn(
+ self.db.simple_update_one_txn(
txn,
table="rooms",
keyvalues={"room_id": room_id},
updatevalues={"is_public": is_public},
)
- entries = self.simple_select_list_txn(
+ entries = self.db.simple_select_list_txn(
txn,
table="public_room_list_stream",
keyvalues={
@@ -515,7 +515,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
add_to_stream = bool(entries[-1]["visibility"]) != is_public
if add_to_stream:
- self.simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="public_room_list_stream",
values={
@@ -528,7 +528,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
)
with self._public_room_id_gen.get_next() as next_id:
- yield self.runInteraction(
+ yield self.db.runInteraction(
"set_room_is_public", set_room_is_public_txn, next_id
)
self.hs.get_notifier().on_new_replication_data()
@@ -555,7 +555,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
def set_room_is_public_appservice_txn(txn, next_id):
if is_public:
try:
- self.simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="appservice_room_list",
values={
@@ -568,7 +568,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
# We've already inserted, nothing to do.
return
else:
- self.simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="appservice_room_list",
keyvalues={
@@ -578,7 +578,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
},
)
- entries = self.simple_select_list_txn(
+ entries = self.db.simple_select_list_txn(
txn,
table="public_room_list_stream",
keyvalues={
@@ -596,7 +596,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
add_to_stream = bool(entries[-1]["visibility"]) != is_public
if add_to_stream:
- self.simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="public_room_list_stream",
values={
@@ -609,7 +609,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
)
with self._public_room_id_gen.get_next() as next_id:
- yield self.runInteraction(
+ yield self.db.runInteraction(
"set_room_is_public_appservice",
set_room_is_public_appservice_txn,
next_id,
@@ -626,7 +626,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
row = txn.fetchone()
return row[0] or 0
- return self.runInteraction("get_rooms", f)
+ return self.db.runInteraction("get_rooms", f)
def _store_room_topic_txn(self, txn, event):
if hasattr(event, "content") and "topic" in event.content:
@@ -660,7 +660,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
# Ignore the event if one of the value isn't an integer.
return
- self.simple_insert_txn(
+ self.db.simple_insert_txn(
txn=txn,
table="room_retention",
values={
@@ -679,7 +679,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
self, room_id, event_id, user_id, reason, content, received_ts
):
next_id = self._event_reports_id_gen.get_next()
- return self.simple_insert(
+ return self.db.simple_insert(
table="event_reports",
values={
"id": next_id,
@@ -712,7 +712,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
if prev_id == current_id:
return defer.succeed([])
- return self.runInteraction("get_all_new_public_rooms", get_all_new_public_rooms)
+ return self.db.runInteraction(
+ "get_all_new_public_rooms", get_all_new_public_rooms
+ )
@defer.inlineCallbacks
def block_room(self, room_id, user_id):
@@ -725,14 +727,14 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
Returns:
Deferred
"""
- yield self.simple_upsert(
+ yield self.db.simple_upsert(
table="blocked_rooms",
keyvalues={"room_id": room_id},
values={},
insertion_values={"user_id": user_id},
desc="block_room",
)
- yield self.runInteraction(
+ yield self.db.runInteraction(
"block_room_invalidation",
self._invalidate_cache_and_stream,
self.is_room_blocked,
@@ -763,7 +765,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
return local_media_mxcs, remote_media_mxcs
- return self.runInteraction("get_media_ids_in_room", _get_media_mxcs_in_room_txn)
+ return self.db.runInteraction(
+ "get_media_ids_in_room", _get_media_mxcs_in_room_txn
+ )
def quarantine_media_ids_in_room(self, room_id, quarantined_by):
"""For a room loops through all events with media and quarantines
@@ -802,7 +806,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
return total_media_quarantined
- return self.runInteraction(
+ return self.db.runInteraction(
"quarantine_media_in_room", _quarantine_media_in_room_txn
)
@@ -907,7 +911,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
txn.execute(sql, args)
- rows = self.cursor_to_dict(txn)
+ rows = self.db.cursor_to_dict(txn)
rooms_dict = {}
for row in rows:
@@ -923,7 +927,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
txn.execute(sql)
- rows = self.cursor_to_dict(txn)
+ rows = self.db.cursor_to_dict(txn)
# If a room isn't already in the dict (i.e. it doesn't have a retention
# policy in its state), add it with a null policy.
@@ -936,7 +940,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
return rooms_dict
- rooms = yield self.runInteraction(
+ rooms = yield self.db.runInteraction(
"get_rooms_for_retention_period_in_range",
get_rooms_for_retention_period_in_range_txn,
)
diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py
index fe2428a281..7f4d02b25b 100644
--- a/synapse/storage/data_stores/main/roommember.py
+++ b/synapse/storage/data_stores/main/roommember.py
@@ -116,7 +116,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
txn.execute(query)
return list(txn)[0][0]
- count = yield self.runInteraction("get_known_servers", _transact)
+ count = yield self.db.runInteraction("get_known_servers", _transact)
# We always know about ourselves, even if we have nothing in
# room_memberships (for example, the server is new).
@@ -128,7 +128,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
membership column is up to date
"""
- pending_update = self.simple_select_one_txn(
+ pending_update = self.db.simple_select_one_txn(
txn,
table="background_updates",
keyvalues={"update_name": _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME},
@@ -144,7 +144,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
15.0,
run_as_background_process,
"_check_safe_current_state_events_membership_updated",
- self.runInteraction,
+ self.db.runInteraction,
"_check_safe_current_state_events_membership_updated",
self._check_safe_current_state_events_membership_updated_txn,
)
@@ -161,7 +161,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
@cached(max_entries=100000, iterable=True)
def get_users_in_room(self, room_id):
- return self.runInteraction(
+ return self.db.runInteraction(
"get_users_in_room", self.get_users_in_room_txn, room_id
)
@@ -269,7 +269,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return res
- return self.runInteraction("get_room_summary", _get_room_summary_txn)
+ return self.db.runInteraction("get_room_summary", _get_room_summary_txn)
def _get_user_counts_in_room_txn(self, txn, room_id):
"""
@@ -339,7 +339,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
if not membership_list:
return defer.succeed(None)
- rooms = yield self.runInteraction(
+ rooms = yield self.db.runInteraction(
"get_rooms_for_user_where_membership_is",
self._get_rooms_for_user_where_membership_is_txn,
user_id,
@@ -392,7 +392,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
txn.execute(sql, (user_id, *args))
- results = [RoomsForUser(**r) for r in self.cursor_to_dict(txn)]
+ results = [RoomsForUser(**r) for r in self.db.cursor_to_dict(txn)]
if do_invite:
sql = (
@@ -412,7 +412,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
stream_ordering=r["stream_ordering"],
membership=Membership.INVITE,
)
- for r in self.cursor_to_dict(txn)
+ for r in self.db.cursor_to_dict(txn)
)
return results
@@ -603,7 +603,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
to `user_id` and ProfileInfo (or None if not join event).
"""
- rows = yield self.simple_select_many_batch(
+ rows = yield self.db.simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=event_ids,
@@ -643,7 +643,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# the returned user actually has the correct domain.
like_clause = "%:" + host
- rows = yield self.execute("is_host_joined", None, sql, room_id, like_clause)
+ rows = yield self.db.execute("is_host_joined", None, sql, room_id, like_clause)
if not rows:
return False
@@ -683,7 +683,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# the returned user actually has the correct domain.
like_clause = "%:" + host
- rows = yield self.execute("was_host_joined", None, sql, room_id, like_clause)
+ rows = yield self.db.execute("was_host_joined", None, sql, room_id, like_clause)
if not rows:
return False
@@ -753,7 +753,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
rows = txn.fetchall()
return rows[0][0]
- count = yield self.runInteraction("did_forget_membership", f)
+ count = yield self.db.runInteraction("did_forget_membership", f)
return count == 0
@cached()
@@ -790,7 +790,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
txn.execute(sql, (user_id,))
return set(row[0] for row in txn if row[1] == 0)
- return self.runInteraction(
+ return self.db.runInteraction(
"get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn
)
@@ -805,7 +805,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
Deferred[set[str]]: Set of room IDs.
"""
- room_ids = yield self.simple_select_onecol(
+ room_ids = yield self.db.simple_select_onecol(
table="room_memberships",
keyvalues={"membership": Membership.JOIN, "user_id": user_id},
retcol="room_id",
@@ -820,7 +820,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
"""Get user_id and membership of a set of event IDs.
"""
- return self.simple_select_many_batch(
+ return self.db.simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=member_event_ids,
@@ -874,7 +874,7 @@ class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore):
txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
- rows = self.cursor_to_dict(txn)
+ rows = self.db.cursor_to_dict(txn)
if not rows:
return 0
@@ -915,7 +915,7 @@ class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore):
return len(rows)
- result = yield self.runInteraction(
+ result = yield self.db.runInteraction(
_MEMBERSHIP_PROFILE_UPDATE_NAME, add_membership_profile_txn
)
@@ -971,7 +971,7 @@ class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore):
# string, which will compare before all room IDs correctly.
last_processed_room = progress.get("last_processed_room", "")
- row_count, finished = yield self.runInteraction(
+ row_count, finished = yield self.db.runInteraction(
"_background_current_state_membership_update",
_background_current_state_membership_txn,
last_processed_room,
@@ -990,7 +990,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
def _store_room_members_txn(self, txn, events, backfilled):
"""Store a room member in the database.
"""
- self.simple_insert_many_txn(
+ self.db.simple_insert_many_txn(
txn,
table="room_memberships",
values=[
@@ -1028,7 +1028,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
is_mine = self.hs.is_mine_id(event.state_key)
if is_new_state and is_mine:
if event.membership == Membership.INVITE:
- self.simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="local_invites",
values={
@@ -1068,7 +1068,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
txn.execute(sql, (stream_ordering, True, room_id, user_id))
with self._stream_id_gen.get_next() as stream_ordering:
- yield self.runInteraction("locally_reject_invite", f, stream_ordering)
+ yield self.db.runInteraction("locally_reject_invite", f, stream_ordering)
def forget(self, user_id, room_id):
"""Indicate that user_id wishes to discard history for room_id."""
@@ -1091,7 +1091,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
txn, self.get_forgotten_rooms_for_user, (user_id,)
)
- return self.runInteraction("forget_membership", f)
+ return self.db.runInteraction("forget_membership", f)
class _JoinedHostsCache(object):
diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py
index f735cf095c..55a604850e 100644
--- a/synapse/storage/data_stores/main/search.py
+++ b/synapse/storage/data_stores/main/search.py
@@ -93,7 +93,7 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore):
# store_search_entries_txn with a generator function, but that
# would mean having two cursors open on the database at once.
# Instead we just build a list of results.
- rows = self.cursor_to_dict(txn)
+ rows = self.db.cursor_to_dict(txn)
if not rows:
return 0
@@ -159,7 +159,7 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore):
return len(event_search_rows)
- result = yield self.runInteraction(
+ result = yield self.db.runInteraction(
self.EVENT_SEARCH_UPDATE_NAME, reindex_search_txn
)
@@ -206,7 +206,7 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore):
conn.set_session(autocommit=False)
if isinstance(self.database_engine, PostgresEngine):
- yield self.runWithConnection(create_index)
+ yield self.db.runWithConnection(create_index)
yield self._end_background_update(self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME)
return 1
@@ -237,12 +237,12 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore):
)
conn.set_session(autocommit=False)
- yield self.runWithConnection(create_index)
+ yield self.db.runWithConnection(create_index)
pg = dict(progress)
pg["have_added_indexes"] = True
- yield self.runInteraction(
+ yield self.db.runInteraction(
self.EVENT_SEARCH_ORDER_UPDATE_NAME,
self._background_update_progress_txn,
self.EVENT_SEARCH_ORDER_UPDATE_NAME,
@@ -280,7 +280,7 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore):
return len(rows), True
- num_rows, finished = yield self.runInteraction(
+ num_rows, finished = yield self.db.runInteraction(
self.EVENT_SEARCH_ORDER_UPDATE_NAME, reindex_search_txn
)
@@ -441,7 +441,9 @@ class SearchStore(SearchBackgroundUpdateStore):
# entire table from the database.
sql += " ORDER BY rank DESC LIMIT 500"
- results = yield self.execute("search_msgs", self.cursor_to_dict, sql, *args)
+ results = yield self.db.execute(
+ "search_msgs", self.db.cursor_to_dict, sql, *args
+ )
results = list(filter(lambda row: row["room_id"] in room_ids, results))
@@ -455,8 +457,8 @@ class SearchStore(SearchBackgroundUpdateStore):
count_sql += " GROUP BY room_id"
- count_results = yield self.execute(
- "search_rooms_count", self.cursor_to_dict, count_sql, *count_args
+ count_results = yield self.db.execute(
+ "search_rooms_count", self.db.cursor_to_dict, count_sql, *count_args
)
count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
@@ -586,7 +588,9 @@ class SearchStore(SearchBackgroundUpdateStore):
args.append(limit)
- results = yield self.execute("search_rooms", self.cursor_to_dict, sql, *args)
+ results = yield self.db.execute(
+ "search_rooms", self.db.cursor_to_dict, sql, *args
+ )
results = list(filter(lambda row: row["room_id"] in room_ids, results))
@@ -600,8 +604,8 @@ class SearchStore(SearchBackgroundUpdateStore):
count_sql += " GROUP BY room_id"
- count_results = yield self.execute(
- "search_rooms_count", self.cursor_to_dict, count_sql, *count_args
+ count_results = yield self.db.execute(
+ "search_rooms_count", self.db.cursor_to_dict, count_sql, *count_args
)
count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
@@ -686,7 +690,7 @@ class SearchStore(SearchBackgroundUpdateStore):
return highlight_words
- return self.runInteraction("_find_highlights", f)
+ return self.db.runInteraction("_find_highlights", f)
def _to_postgres_options(options_dict):
diff --git a/synapse/storage/data_stores/main/signatures.py b/synapse/storage/data_stores/main/signatures.py
index f3da29ce14..563216b63c 100644
--- a/synapse/storage/data_stores/main/signatures.py
+++ b/synapse/storage/data_stores/main/signatures.py
@@ -48,7 +48,7 @@ class SignatureWorkerStore(SQLBaseStore):
for event_id in event_ids
}
- return self.runInteraction("get_event_reference_hashes", f)
+ return self.db.runInteraction("get_event_reference_hashes", f)
@defer.inlineCallbacks
def add_event_hashes(self, event_ids):
@@ -98,4 +98,4 @@ class SignatureStore(SignatureWorkerStore):
}
)
- self.simple_insert_many_txn(txn, table="event_reference_hashes", values=vals)
+ self.db.simple_insert_many_txn(txn, table="event_reference_hashes", values=vals)
diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py
index 2b33ec1a35..851e81d6b3 100644
--- a/synapse/storage/data_stores/main/state.py
+++ b/synapse/storage/data_stores/main/state.py
@@ -89,7 +89,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
count = 0
while next_group:
- next_group = self.simple_select_one_onecol_txn(
+ next_group = self.db.simple_select_one_onecol_txn(
txn,
table="state_group_edges",
keyvalues={"state_group": next_group},
@@ -192,7 +192,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
):
break
- next_group = self.simple_select_one_onecol_txn(
+ next_group = self.db.simple_select_one_onecol_txn(
txn,
table="state_group_edges",
keyvalues={"state_group": next_group},
@@ -348,7 +348,9 @@ class StateGroupWorkerStore(
(intern_string(r[0]), intern_string(r[1])): to_ascii(r[2]) for r in txn
}
- return self.runInteraction("get_current_state_ids", _get_current_state_ids_txn)
+ return self.db.runInteraction(
+ "get_current_state_ids", _get_current_state_ids_txn
+ )
# FIXME: how should this be cached?
def get_filtered_current_state_ids(self, room_id, state_filter=StateFilter.all()):
@@ -392,7 +394,7 @@ class StateGroupWorkerStore(
return results
- return self.runInteraction(
+ return self.db.runInteraction(
"get_filtered_current_state_ids", _get_filtered_current_state_ids_txn
)
@@ -431,7 +433,7 @@ class StateGroupWorkerStore(
"""
def _get_state_group_delta_txn(txn):
- prev_group = self.simple_select_one_onecol_txn(
+ prev_group = self.db.simple_select_one_onecol_txn(
txn,
table="state_group_edges",
keyvalues={"state_group": state_group},
@@ -442,7 +444,7 @@ class StateGroupWorkerStore(
if not prev_group:
return _GetStateGroupDelta(None, None)
- delta_ids = self.simple_select_list_txn(
+ delta_ids = self.db.simple_select_list_txn(
txn,
table="state_groups_state",
keyvalues={"state_group": state_group},
@@ -454,7 +456,9 @@ class StateGroupWorkerStore(
{(row["type"], row["state_key"]): row["event_id"] for row in delta_ids},
)
- return self.runInteraction("get_state_group_delta", _get_state_group_delta_txn)
+ return self.db.runInteraction(
+ "get_state_group_delta", _get_state_group_delta_txn
+ )
@defer.inlineCallbacks
def get_state_groups_ids(self, _room_id, event_ids):
@@ -540,7 +544,7 @@ class StateGroupWorkerStore(
chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)]
for chunk in chunks:
- res = yield self.runInteraction(
+ res = yield self.db.runInteraction(
"_get_state_groups_from_groups",
self._get_state_groups_from_groups_txn,
chunk,
@@ -644,7 +648,7 @@ class StateGroupWorkerStore(
@cached(max_entries=50000)
def _get_state_group_for_event(self, event_id):
- return self.simple_select_one_onecol(
+ return self.db.simple_select_one_onecol(
table="event_to_state_groups",
keyvalues={"event_id": event_id},
retcol="state_group",
@@ -661,7 +665,7 @@ class StateGroupWorkerStore(
def _get_state_group_for_events(self, event_ids):
"""Returns mapping event_id -> state_group
"""
- rows = yield self.simple_select_many_batch(
+ rows = yield self.db.simple_select_many_batch(
table="event_to_state_groups",
column="event_id",
iterable=event_ids,
@@ -902,7 +906,7 @@ class StateGroupWorkerStore(
state_group = self.database_engine.get_next_state_group_id(txn)
- self.simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="state_groups",
values={"id": state_group, "room_id": room_id, "event_id": event_id},
@@ -911,7 +915,7 @@ class StateGroupWorkerStore(
# We persist as a delta if we can, while also ensuring the chain
# of deltas isn't tooo long, as otherwise read performance degrades.
if prev_group:
- is_in_db = self.simple_select_one_onecol_txn(
+ is_in_db = self.db.simple_select_one_onecol_txn(
txn,
table="state_groups",
keyvalues={"id": prev_group},
@@ -926,13 +930,13 @@ class StateGroupWorkerStore(
potential_hops = self._count_state_group_hops_txn(txn, prev_group)
if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
- self.simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="state_group_edges",
values={"state_group": state_group, "prev_state_group": prev_group},
)
- self.simple_insert_many_txn(
+ self.db.simple_insert_many_txn(
txn,
table="state_groups_state",
values=[
@@ -947,7 +951,7 @@ class StateGroupWorkerStore(
],
)
else:
- self.simple_insert_many_txn(
+ self.db.simple_insert_many_txn(
txn,
table="state_groups_state",
values=[
@@ -993,7 +997,7 @@ class StateGroupWorkerStore(
return state_group
- return self.runInteraction("store_state_group", _store_state_group_txn)
+ return self.db.runInteraction("store_state_group", _store_state_group_txn)
@defer.inlineCallbacks
def get_referenced_state_groups(self, state_groups):
@@ -1007,7 +1011,7 @@ class StateGroupWorkerStore(
referenced.
"""
- rows = yield self.simple_select_many_batch(
+ rows = yield self.db.simple_select_many_batch(
table="event_to_state_groups",
column="state_group",
iterable=state_groups,
@@ -1065,7 +1069,7 @@ class StateBackgroundUpdateStore(
batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR))
if max_group is None:
- rows = yield self.execute(
+ rows = yield self.db.execute(
"_background_deduplicate_state",
None,
"SELECT coalesce(max(id), 0) FROM state_groups",
@@ -1135,13 +1139,13 @@ class StateBackgroundUpdateStore(
if prev_state.get(key, None) != value
}
- self.simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="state_group_edges",
keyvalues={"state_group": state_group},
)
- self.simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="state_group_edges",
values={
@@ -1150,13 +1154,13 @@ class StateBackgroundUpdateStore(
},
)
- self.simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="state_groups_state",
keyvalues={"state_group": state_group},
)
- self.simple_insert_many_txn(
+ self.db.simple_insert_many_txn(
txn,
table="state_groups_state",
values=[
@@ -1183,7 +1187,7 @@ class StateBackgroundUpdateStore(
return False, batch_size
- finished, result = yield self.runInteraction(
+ finished, result = yield self.db.runInteraction(
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn
)
@@ -1218,7 +1222,7 @@ class StateBackgroundUpdateStore(
)
txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
- yield self.runWithConnection(reindex_txn)
+ yield self.db.runWithConnection(reindex_txn)
yield self._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME)
@@ -1263,7 +1267,7 @@ class StateStore(StateGroupWorkerStore, StateBackgroundUpdateStore):
state_groups[event.event_id] = context.state_group
- self.simple_insert_many_txn(
+ self.db.simple_insert_many_txn(
txn,
table="event_to_state_groups",
values=[
diff --git a/synapse/storage/data_stores/main/state_deltas.py b/synapse/storage/data_stores/main/state_deltas.py
index 03b908026b..12c982cb26 100644
--- a/synapse/storage/data_stores/main/state_deltas.py
+++ b/synapse/storage/data_stores/main/state_deltas.py
@@ -98,14 +98,14 @@ class StateDeltasStore(SQLBaseStore):
ORDER BY stream_id ASC
"""
txn.execute(sql, (prev_stream_id, clipped_stream_id))
- return clipped_stream_id, self.cursor_to_dict(txn)
+ return clipped_stream_id, self.db.cursor_to_dict(txn)
- return self.runInteraction(
+ return self.db.runInteraction(
"get_current_state_deltas", get_current_state_deltas_txn
)
def _get_max_stream_id_in_current_state_deltas_txn(self, txn):
- return self.simple_select_one_onecol_txn(
+ return self.db.simple_select_one_onecol_txn(
txn,
table="current_state_delta_stream",
keyvalues={},
@@ -113,7 +113,7 @@ class StateDeltasStore(SQLBaseStore):
)
def get_max_stream_id_in_current_state_deltas(self):
- return self.runInteraction(
+ return self.db.runInteraction(
"get_max_stream_id_in_current_state_deltas",
self._get_max_stream_id_in_current_state_deltas_txn,
)
diff --git a/synapse/storage/data_stores/main/stats.py b/synapse/storage/data_stores/main/stats.py
index 3aeba859fd..974ffc15bd 100644
--- a/synapse/storage/data_stores/main/stats.py
+++ b/synapse/storage/data_stores/main/stats.py
@@ -117,7 +117,7 @@ class StatsStore(StateDeltasStore):
txn.execute(sql, (last_user_id, batch_size))
return [r for r, in txn]
- users_to_work_on = yield self.runInteraction(
+ users_to_work_on = yield self.db.runInteraction(
"_populate_stats_process_users", _get_next_batch
)
@@ -130,7 +130,7 @@ class StatsStore(StateDeltasStore):
yield self._calculate_and_set_initial_state_for_user(user_id)
progress["last_user_id"] = user_id
- yield self.runInteraction(
+ yield self.db.runInteraction(
"populate_stats_process_users",
self._background_update_progress_txn,
"populate_stats_process_users",
@@ -160,7 +160,7 @@ class StatsStore(StateDeltasStore):
txn.execute(sql, (last_room_id, batch_size))
return [r for r, in txn]
- rooms_to_work_on = yield self.runInteraction(
+ rooms_to_work_on = yield self.db.runInteraction(
"populate_stats_rooms_get_batch", _get_next_batch
)
@@ -173,7 +173,7 @@ class StatsStore(StateDeltasStore):
yield self._calculate_and_set_initial_state_for_room(room_id)
progress["last_room_id"] = room_id
- yield self.runInteraction(
+ yield self.db.runInteraction(
"_populate_stats_process_rooms",
self._background_update_progress_txn,
"populate_stats_process_rooms",
@@ -186,7 +186,7 @@ class StatsStore(StateDeltasStore):
"""
Returns the stats processor positions.
"""
- return self.simple_select_one_onecol(
+ return self.db.simple_select_one_onecol(
table="stats_incremental_position",
keyvalues={},
retcol="stream_id",
@@ -215,7 +215,7 @@ class StatsStore(StateDeltasStore):
if field and "\0" in field:
fields[col] = None
- return self.simple_upsert(
+ return self.db.simple_upsert(
table="room_stats_state",
keyvalues={"room_id": room_id},
values=fields,
@@ -236,7 +236,7 @@ class StatsStore(StateDeltasStore):
Deferred[list[dict]], where the dict has the keys of
ABSOLUTE_STATS_FIELDS[stats_type], and "bucket_size" and "end_ts".
"""
- return self.runInteraction(
+ return self.db.runInteraction(
"get_statistics_for_subject",
self._get_statistics_for_subject_txn,
stats_type,
@@ -257,7 +257,7 @@ class StatsStore(StateDeltasStore):
ABSOLUTE_STATS_FIELDS[stats_type] + PER_SLICE_FIELDS[stats_type]
)
- slice_list = self.simple_select_list_paginate_txn(
+ slice_list = self.db.simple_select_list_paginate_txn(
txn,
table + "_historical",
{id_col: stats_id},
@@ -282,7 +282,7 @@ class StatsStore(StateDeltasStore):
"name", "topic", "canonical_alias", "avatar", "join_rules",
"history_visibility"
"""
- return self.simple_select_one(
+ return self.db.simple_select_one(
"room_stats_state",
{"room_id": room_id},
retcols=(
@@ -308,7 +308,7 @@ class StatsStore(StateDeltasStore):
"""
table, id_col = TYPE_TO_TABLE[stats_type]
- return self.simple_select_one_onecol(
+ return self.db.simple_select_one_onecol(
"%s_current" % (table,),
keyvalues={id_col: id},
retcol="completed_delta_stream_id",
@@ -344,14 +344,14 @@ class StatsStore(StateDeltasStore):
complete_with_stream_id=stream_id,
)
- self.simple_update_one_txn(
+ self.db.simple_update_one_txn(
txn,
table="stats_incremental_position",
keyvalues={},
updatevalues={"stream_id": stream_id},
)
- return self.runInteraction(
+ return self.db.runInteraction(
"bulk_update_stats_delta", _bulk_update_stats_delta_txn
)
@@ -382,7 +382,7 @@ class StatsStore(StateDeltasStore):
Does not work with per-slice fields.
"""
- return self.runInteraction(
+ return self.db.runInteraction(
"update_stats_delta",
self._update_stats_delta_txn,
ts,
@@ -517,17 +517,17 @@ class StatsStore(StateDeltasStore):
else:
self.database_engine.lock_table(txn, table)
retcols = list(chain(absolutes.keys(), additive_relatives.keys()))
- current_row = self.simple_select_one_txn(
+ current_row = self.db.simple_select_one_txn(
txn, table, keyvalues, retcols, allow_none=True
)
if current_row is None:
merged_dict = {**keyvalues, **absolutes, **additive_relatives}
- self.simple_insert_txn(txn, table, merged_dict)
+ self.db.simple_insert_txn(txn, table, merged_dict)
else:
for (key, val) in additive_relatives.items():
current_row[key] += val
current_row.update(absolutes)
- self.simple_update_one_txn(txn, table, keyvalues, current_row)
+ self.db.simple_update_one_txn(txn, table, keyvalues, current_row)
def _upsert_copy_from_table_with_additive_relatives_txn(
self,
@@ -614,11 +614,11 @@ class StatsStore(StateDeltasStore):
txn.execute(sql, qargs)
else:
self.database_engine.lock_table(txn, into_table)
- src_row = self.simple_select_one_txn(
+ src_row = self.db.simple_select_one_txn(
txn, src_table, keyvalues, copy_columns
)
all_dest_keyvalues = {**keyvalues, **extra_dst_keyvalues}
- dest_current_row = self.simple_select_one_txn(
+ dest_current_row = self.db.simple_select_one_txn(
txn,
into_table,
keyvalues=all_dest_keyvalues,
@@ -634,11 +634,11 @@ class StatsStore(StateDeltasStore):
**src_row,
**additive_relatives,
}
- self.simple_insert_txn(txn, into_table, merged_dict)
+ self.db.simple_insert_txn(txn, into_table, merged_dict)
else:
for (key, val) in additive_relatives.items():
src_row[key] = dest_current_row[key] + val
- self.simple_update_txn(txn, into_table, all_dest_keyvalues, src_row)
+ self.db.simple_update_txn(txn, into_table, all_dest_keyvalues, src_row)
def get_changes_room_total_events_and_bytes(self, min_pos, max_pos):
"""Fetches the counts of events in the given range of stream IDs.
@@ -652,7 +652,7 @@ class StatsStore(StateDeltasStore):
changes.
"""
- return self.runInteraction(
+ return self.db.runInteraction(
"stats_incremental_total_events_and_bytes",
self.get_changes_room_total_events_and_bytes_txn,
min_pos,
@@ -735,7 +735,7 @@ class StatsStore(StateDeltasStore):
def _fetch_current_state_stats(txn):
pos = self.get_room_max_stream_ordering()
- rows = self.simple_select_many_txn(
+ rows = self.db.simple_select_many_txn(
txn,
table="current_state_events",
column="type",
@@ -791,7 +791,7 @@ class StatsStore(StateDeltasStore):
current_state_events_count,
users_in_room,
pos,
- ) = yield self.runInteraction(
+ ) = yield self.db.runInteraction(
"get_initial_state_for_room", _fetch_current_state_stats
)
@@ -866,7 +866,7 @@ class StatsStore(StateDeltasStore):
(count,) = txn.fetchone()
return count, pos
- joined_rooms, pos = yield self.runInteraction(
+ joined_rooms, pos = yield self.db.runInteraction(
"calculate_and_set_initial_state_for_user",
_calculate_and_set_initial_state_for_user_txn,
)
diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py
index 60487c4559..2ff8c57109 100644
--- a/synapse/storage/data_stores/main/stream.py
+++ b/synapse/storage/data_stores/main/stream.py
@@ -255,7 +255,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
super(StreamWorkerStore, self).__init__(db_conn, hs)
events_max = self.get_room_max_stream_ordering()
- event_cache_prefill, min_event_val = self.get_cache_dict(
+ event_cache_prefill, min_event_val = self.db.get_cache_dict(
db_conn,
"events",
entity_column="room_id",
@@ -400,7 +400,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
return rows
- rows = yield self.runInteraction("get_room_events_stream_for_room", f)
+ rows = yield self.db.runInteraction("get_room_events_stream_for_room", f)
ret = yield self.get_events_as_list(
[r.event_id for r in rows], get_prev_content=True
@@ -450,7 +450,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return rows
- rows = yield self.runInteraction("get_membership_changes_for_user", f)
+ rows = yield self.db.runInteraction("get_membership_changes_for_user", f)
ret = yield self.get_events_as_list(
[r.event_id for r in rows], get_prev_content=True
@@ -511,7 +511,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
end_token = RoomStreamToken.parse(end_token)
- rows, token = yield self.runInteraction(
+ rows, token = yield self.db.runInteraction(
"get_recent_event_ids_for_room",
self._paginate_room_events_txn,
room_id,
@@ -548,7 +548,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
txn.execute(sql, (room_id, stream_ordering))
return txn.fetchone()
- return self.runInteraction("get_room_event_after_stream_ordering", _f)
+ return self.db.runInteraction("get_room_event_after_stream_ordering", _f)
@defer.inlineCallbacks
def get_room_events_max_id(self, room_id=None):
@@ -562,7 +562,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
if room_id is None:
return "s%d" % (token,)
else:
- topo = yield self.runInteraction(
+ topo = yield self.db.runInteraction(
"_get_max_topological_txn", self._get_max_topological_txn, room_id
)
return "t%d-%d" % (topo, token)
@@ -576,7 +576,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
Returns:
A deferred "s%d" stream token.
"""
- return self.simple_select_one_onecol(
+ return self.db.simple_select_one_onecol(
table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering"
).addCallback(lambda row: "s%d" % (row,))
@@ -589,7 +589,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
Returns:
A deferred "t%d-%d" topological token.
"""
- return self.simple_select_one(
+ return self.db.simple_select_one(
table="events",
keyvalues={"event_id": event_id},
retcols=("stream_ordering", "topological_ordering"),
@@ -613,7 +613,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
"SELECT coalesce(max(topological_ordering), 0) FROM events"
" WHERE room_id = ? AND stream_ordering < ?"
)
- return self.execute(
+ return self.db.execute(
"get_max_topological_token", None, sql, room_id, stream_key
).addCallback(lambda r: r[0][0] if r else 0)
@@ -667,7 +667,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
dict
"""
- results = yield self.runInteraction(
+ results = yield self.db.runInteraction(
"get_events_around",
self._get_events_around_txn,
room_id,
@@ -709,7 +709,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
dict
"""
- results = self.simple_select_one_txn(
+ results = self.db.simple_select_one_txn(
txn,
"events",
keyvalues={"event_id": event_id, "room_id": room_id},
@@ -788,7 +788,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return upper_bound, [row[1] for row in rows]
- upper_bound, event_ids = yield self.runInteraction(
+ upper_bound, event_ids = yield self.db.runInteraction(
"get_all_new_events_stream", get_all_new_events_stream_txn
)
@@ -797,7 +797,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return upper_bound, events
def get_federation_out_pos(self, typ):
- return self.simple_select_one_onecol(
+ return self.db.simple_select_one_onecol(
table="federation_stream_position",
retcol="stream_id",
keyvalues={"type": typ},
@@ -805,7 +805,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
)
def update_federation_out_pos(self, typ, stream_id):
- return self.simple_update_one(
+ return self.db.simple_update_one(
table="federation_stream_position",
keyvalues={"type": typ},
updatevalues={"stream_id": stream_id},
@@ -956,7 +956,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
if to_key:
to_key = RoomStreamToken.parse(to_key)
- rows, token = yield self.runInteraction(
+ rows, token = yield self.db.runInteraction(
"paginate_room_events",
self._paginate_room_events_txn,
room_id,
diff --git a/synapse/storage/data_stores/main/tags.py b/synapse/storage/data_stores/main/tags.py
index 85012403be..2aa1bafd48 100644
--- a/synapse/storage/data_stores/main/tags.py
+++ b/synapse/storage/data_stores/main/tags.py
@@ -41,7 +41,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
tag strings to tag content.
"""
- deferred = self.simple_select_list(
+ deferred = self.db.simple_select_list(
"room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
)
@@ -78,7 +78,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall()
- tag_ids = yield self.runInteraction(
+ tag_ids = yield self.db.runInteraction(
"get_all_updated_tags", get_all_updated_tags_txn
)
@@ -98,7 +98,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
batch_size = 50
results = []
for i in range(0, len(tag_ids), batch_size):
- tags = yield self.runInteraction(
+ tags = yield self.db.runInteraction(
"get_all_updated_tag_content",
get_tag_content,
tag_ids[i : i + batch_size],
@@ -135,7 +135,9 @@ class TagsWorkerStore(AccountDataWorkerStore):
if not changed:
return {}
- room_ids = yield self.runInteraction("get_updated_tags", get_updated_tags_txn)
+ room_ids = yield self.db.runInteraction(
+ "get_updated_tags", get_updated_tags_txn
+ )
results = {}
if room_ids:
@@ -153,7 +155,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
Returns:
A deferred list of string tags.
"""
- return self.simple_select_list(
+ return self.db.simple_select_list(
table="room_tags",
keyvalues={"user_id": user_id, "room_id": room_id},
retcols=("tag", "content"),
@@ -178,7 +180,7 @@ class TagsStore(TagsWorkerStore):
content_json = json.dumps(content)
def add_tag_txn(txn, next_id):
- self.simple_upsert_txn(
+ self.db.simple_upsert_txn(
txn,
table="room_tags",
keyvalues={"user_id": user_id, "room_id": room_id, "tag": tag},
@@ -187,7 +189,7 @@ class TagsStore(TagsWorkerStore):
self._update_revision_txn(txn, user_id, room_id, next_id)
with self._account_data_id_gen.get_next() as next_id:
- yield self.runInteraction("add_tag", add_tag_txn, next_id)
+ yield self.db.runInteraction("add_tag", add_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,))
@@ -210,7 +212,7 @@ class TagsStore(TagsWorkerStore):
self._update_revision_txn(txn, user_id, room_id, next_id)
with self._account_data_id_gen.get_next() as next_id:
- yield self.runInteraction("remove_tag", remove_tag_txn, next_id)
+ yield self.db.runInteraction("remove_tag", remove_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,))
diff --git a/synapse/storage/data_stores/main/transactions.py b/synapse/storage/data_stores/main/transactions.py
index c162f3ea16..c0d155a43c 100644
--- a/synapse/storage/data_stores/main/transactions.py
+++ b/synapse/storage/data_stores/main/transactions.py
@@ -77,7 +77,7 @@ class TransactionStore(SQLBaseStore):
this transaction or a 2-tuple of (int, dict)
"""
- return self.runInteraction(
+ return self.db.runInteraction(
"get_received_txn_response",
self._get_received_txn_response,
transaction_id,
@@ -85,7 +85,7 @@ class TransactionStore(SQLBaseStore):
)
def _get_received_txn_response(self, txn, transaction_id, origin):
- result = self.simple_select_one_txn(
+ result = self.db.simple_select_one_txn(
txn,
table="received_transactions",
keyvalues={"transaction_id": transaction_id, "origin": origin},
@@ -119,7 +119,7 @@ class TransactionStore(SQLBaseStore):
response_json (str)
"""
- return self.simple_insert(
+ return self.db.simple_insert(
table="received_transactions",
values={
"transaction_id": transaction_id,
@@ -148,7 +148,7 @@ class TransactionStore(SQLBaseStore):
if result is not SENTINEL:
return result
- result = yield self.runInteraction(
+ result = yield self.db.runInteraction(
"get_destination_retry_timings",
self._get_destination_retry_timings,
destination,
@@ -160,7 +160,7 @@ class TransactionStore(SQLBaseStore):
return result
def _get_destination_retry_timings(self, txn, destination):
- result = self.simple_select_one_txn(
+ result = self.db.simple_select_one_txn(
txn,
table="destinations",
keyvalues={"destination": destination},
@@ -187,7 +187,7 @@ class TransactionStore(SQLBaseStore):
"""
self._destination_retry_cache.pop(destination, None)
- return self.runInteraction(
+ return self.db.runInteraction(
"set_destination_retry_timings",
self._set_destination_retry_timings,
destination,
@@ -227,7 +227,7 @@ class TransactionStore(SQLBaseStore):
# We need to be careful here as the data may have changed from under us
# due to a worker setting the timings.
- prev_row = self.simple_select_one_txn(
+ prev_row = self.db.simple_select_one_txn(
txn,
table="destinations",
keyvalues={"destination": destination},
@@ -236,7 +236,7 @@ class TransactionStore(SQLBaseStore):
)
if not prev_row:
- self.simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="destinations",
values={
@@ -247,7 +247,7 @@ class TransactionStore(SQLBaseStore):
},
)
elif retry_interval == 0 or prev_row["retry_interval"] < retry_interval:
- self.simple_update_one_txn(
+ self.db.simple_update_one_txn(
txn,
"destinations",
keyvalues={"destination": destination},
@@ -270,4 +270,6 @@ class TransactionStore(SQLBaseStore):
def _cleanup_transactions_txn(txn):
txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))
- return self.runInteraction("_cleanup_transactions", _cleanup_transactions_txn)
+ return self.db.runInteraction(
+ "_cleanup_transactions", _cleanup_transactions_txn
+ )
diff --git a/synapse/storage/data_stores/main/user_directory.py b/synapse/storage/data_stores/main/user_directory.py
index 1a85aabbfb..7118bd62f3 100644
--- a/synapse/storage/data_stores/main/user_directory.py
+++ b/synapse/storage/data_stores/main/user_directory.py
@@ -85,7 +85,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
"""
txn.execute(sql)
rooms = [{"room_id": x[0], "events": x[1]} for x in txn.fetchall()]
- self.simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms)
+ self.db.simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms)
del rooms
# If search all users is on, get all the users we want to add.
@@ -100,13 +100,13 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
txn.execute("SELECT name FROM users")
users = [{"user_id": x[0]} for x in txn.fetchall()]
- self.simple_insert_many_txn(txn, TEMP_TABLE + "_users", users)
+ self.db.simple_insert_many_txn(txn, TEMP_TABLE + "_users", users)
new_pos = yield self.get_max_stream_id_in_current_state_deltas()
- yield self.runInteraction(
+ yield self.db.runInteraction(
"populate_user_directory_temp_build", _make_staging_area
)
- yield self.simple_insert(TEMP_TABLE + "_position", {"position": new_pos})
+ yield self.db.simple_insert(TEMP_TABLE + "_position", {"position": new_pos})
yield self._end_background_update("populate_user_directory_createtables")
return 1
@@ -116,7 +116,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
"""
Update the user directory stream position, then clean up the old tables.
"""
- position = yield self.simple_select_one_onecol(
+ position = yield self.db.simple_select_one_onecol(
TEMP_TABLE + "_position", None, "position"
)
yield self.update_user_directory_stream_pos(position)
@@ -126,7 +126,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_users")
txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_position")
- yield self.runInteraction(
+ yield self.db.runInteraction(
"populate_user_directory_cleanup", _delete_staging_area
)
@@ -170,7 +170,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
return rooms_to_work_on
- rooms_to_work_on = yield self.runInteraction(
+ rooms_to_work_on = yield self.db.runInteraction(
"populate_user_directory_temp_read", _get_next_batch
)
@@ -243,10 +243,10 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
to_insert.clear()
# We've finished a room. Delete it from the table.
- yield self.simple_delete_one(TEMP_TABLE + "_rooms", {"room_id": room_id})
+ yield self.db.simple_delete_one(TEMP_TABLE + "_rooms", {"room_id": room_id})
# Update the remaining counter.
progress["remaining"] -= 1
- yield self.runInteraction(
+ yield self.db.runInteraction(
"populate_user_directory",
self._background_update_progress_txn,
"populate_user_directory_process_rooms",
@@ -291,7 +291,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
return users_to_work_on
- users_to_work_on = yield self.runInteraction(
+ users_to_work_on = yield self.db.runInteraction(
"populate_user_directory_temp_read", _get_next_batch
)
@@ -312,10 +312,10 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
)
# We've finished processing a user. Delete it from the table.
- yield self.simple_delete_one(TEMP_TABLE + "_users", {"user_id": user_id})
+ yield self.db.simple_delete_one(TEMP_TABLE + "_users", {"user_id": user_id})
# Update the remaining counter.
progress["remaining"] -= 1
- yield self.runInteraction(
+ yield self.db.runInteraction(
"populate_user_directory",
self._background_update_progress_txn,
"populate_user_directory_process_users",
@@ -361,7 +361,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
"""
def _update_profile_in_user_dir_txn(txn):
- new_entry = self.simple_upsert_txn(
+ new_entry = self.db.simple_upsert_txn(
txn,
table="user_directory",
keyvalues={"user_id": user_id},
@@ -435,7 +435,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
)
elif isinstance(self.database_engine, Sqlite3Engine):
value = "%s %s" % (user_id, display_name) if display_name else user_id
- self.simple_upsert_txn(
+ self.db.simple_upsert_txn(
txn,
table="user_directory_search",
keyvalues={"user_id": user_id},
@@ -448,7 +448,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
- return self.runInteraction(
+ return self.db.runInteraction(
"update_profile_in_user_dir", _update_profile_in_user_dir_txn
)
@@ -462,7 +462,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
"""
def _add_users_who_share_room_txn(txn):
- self.simple_upsert_many_txn(
+ self.db.simple_upsert_many_txn(
txn,
table="users_who_share_private_rooms",
key_names=["user_id", "other_user_id", "room_id"],
@@ -474,7 +474,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
value_values=None,
)
- return self.runInteraction(
+ return self.db.runInteraction(
"add_users_who_share_room", _add_users_who_share_room_txn
)
@@ -489,7 +489,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
def _add_users_in_public_rooms_txn(txn):
- self.simple_upsert_many_txn(
+ self.db.simple_upsert_many_txn(
txn,
table="users_in_public_rooms",
key_names=["user_id", "room_id"],
@@ -498,7 +498,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
value_values=None,
)
- return self.runInteraction(
+ return self.db.runInteraction(
"add_users_in_public_rooms", _add_users_in_public_rooms_txn
)
@@ -513,13 +513,13 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
txn.execute("DELETE FROM users_who_share_private_rooms")
txn.call_after(self.get_user_in_directory.invalidate_all)
- return self.runInteraction(
+ return self.db.runInteraction(
"delete_all_from_user_dir", _delete_all_from_user_dir_txn
)
@cached()
def get_user_in_directory(self, user_id):
- return self.simple_select_one(
+ return self.db.simple_select_one(
table="user_directory",
keyvalues={"user_id": user_id},
retcols=("display_name", "avatar_url"),
@@ -528,7 +528,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
)
def update_user_directory_stream_pos(self, stream_id):
- return self.simple_update_one(
+ return self.db.simple_update_one(
table="user_directory_stream_pos",
keyvalues={},
updatevalues={"stream_id": stream_id},
@@ -547,42 +547,42 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
def remove_from_user_dir(self, user_id):
def _remove_from_user_dir_txn(txn):
- self.simple_delete_txn(
+ self.db.simple_delete_txn(
txn, table="user_directory", keyvalues={"user_id": user_id}
)
- self.simple_delete_txn(
+ self.db.simple_delete_txn(
txn, table="user_directory_search", keyvalues={"user_id": user_id}
)
- self.simple_delete_txn(
+ self.db.simple_delete_txn(
txn, table="users_in_public_rooms", keyvalues={"user_id": user_id}
)
- self.simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="users_who_share_private_rooms",
keyvalues={"user_id": user_id},
)
- self.simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="users_who_share_private_rooms",
keyvalues={"other_user_id": user_id},
)
txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
- return self.runInteraction("remove_from_user_dir", _remove_from_user_dir_txn)
+ return self.db.runInteraction("remove_from_user_dir", _remove_from_user_dir_txn)
@defer.inlineCallbacks
def get_users_in_dir_due_to_room(self, room_id):
"""Get all user_ids that are in the room directory because they're
in the given room_id
"""
- user_ids_share_pub = yield self.simple_select_onecol(
+ user_ids_share_pub = yield self.db.simple_select_onecol(
table="users_in_public_rooms",
keyvalues={"room_id": room_id},
retcol="user_id",
desc="get_users_in_dir_due_to_room",
)
- user_ids_share_priv = yield self.simple_select_onecol(
+ user_ids_share_priv = yield self.db.simple_select_onecol(
table="users_who_share_private_rooms",
keyvalues={"room_id": room_id},
retcol="other_user_id",
@@ -605,23 +605,23 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
"""
def _remove_user_who_share_room_txn(txn):
- self.simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="users_who_share_private_rooms",
keyvalues={"user_id": user_id, "room_id": room_id},
)
- self.simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="users_who_share_private_rooms",
keyvalues={"other_user_id": user_id, "room_id": room_id},
)
- self.simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="users_in_public_rooms",
keyvalues={"user_id": user_id, "room_id": room_id},
)
- return self.runInteraction(
+ return self.db.runInteraction(
"remove_user_who_share_room", _remove_user_who_share_room_txn
)
@@ -636,14 +636,14 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
Returns:
list: user_id
"""
- rows = yield self.simple_select_onecol(
+ rows = yield self.db.simple_select_onecol(
table="users_who_share_private_rooms",
keyvalues={"user_id": user_id},
retcol="room_id",
desc="get_rooms_user_is_in",
)
- pub_rows = yield self.simple_select_onecol(
+ pub_rows = yield self.db.simple_select_onecol(
table="users_in_public_rooms",
keyvalues={"user_id": user_id},
retcol="room_id",
@@ -674,14 +674,14 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
) f2 USING (room_id)
"""
- rows = yield self.execute(
+ rows = yield self.db.execute(
"get_rooms_in_common_for_users", None, sql, user_id, other_user_id
)
return [room_id for room_id, in rows]
def get_user_directory_stream_pos(self):
- return self.simple_select_one_onecol(
+ return self.db.simple_select_one_onecol(
table="user_directory_stream_pos",
keyvalues={},
retcol="stream_id",
@@ -786,7 +786,9 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
# This should be unreachable.
raise Exception("Unrecognized database engine")
- results = yield self.execute("search_user_dir", self.cursor_to_dict, sql, *args)
+ results = yield self.db.execute(
+ "search_user_dir", self.db.cursor_to_dict, sql, *args
+ )
limited = len(results) > limit
diff --git a/synapse/storage/data_stores/main/user_erasure_store.py b/synapse/storage/data_stores/main/user_erasure_store.py
index 37860af070..af8025bc17 100644
--- a/synapse/storage/data_stores/main/user_erasure_store.py
+++ b/synapse/storage/data_stores/main/user_erasure_store.py
@@ -31,7 +31,7 @@ class UserErasureWorkerStore(SQLBaseStore):
Returns:
Deferred[bool]: True if the user has requested erasure
"""
- return self.simple_select_onecol(
+ return self.db.simple_select_onecol(
table="erased_users",
keyvalues={"user_id": user_id},
retcol="1",
@@ -56,7 +56,7 @@ class UserErasureWorkerStore(SQLBaseStore):
# iterate it multiple times, and (b) avoiding duplicates.
user_ids = tuple(set(user_ids))
- rows = yield self.simple_select_many_batch(
+ rows = yield self.db.simple_select_many_batch(
table="erased_users",
column="user_id",
iterable=user_ids,
@@ -88,4 +88,4 @@ class UserErasureStore(UserErasureWorkerStore):
self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,))
- return self.runInteraction("mark_user_erased", f)
+ return self.db.runInteraction("mark_user_erased", f)
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
new file mode 100644
index 0000000000..c2e121a001
--- /dev/null
+++ b/synapse/storage/database.py
@@ -0,0 +1,1485 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+import random
+import sys
+import time
+from typing import Iterable, Tuple
+
+from six import iteritems, iterkeys, itervalues
+from six.moves import intern, range
+
+from prometheus_client import Histogram
+
+from twisted.internet import defer
+
+from synapse.api.errors import StoreError
+from synapse.logging.context import LoggingContext, make_deferred_yieldable
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from synapse.util.stringutils import exception_to_unicode
+
+# import a function which will return a monotonic time, in seconds
+try:
+ # on python 3, use time.monotonic, since time.clock can go backwards
+ from time import monotonic as monotonic_time
+except ImportError:
+ # ... but python 2 doesn't have it
+ from time import clock as monotonic_time
+
+logger = logging.getLogger(__name__)
+
+try:
+ MAX_TXN_ID = sys.maxint - 1
+except AttributeError:
+ # python 3 does not have a maximum int value
+ MAX_TXN_ID = 2 ** 63 - 1
+
+sql_logger = logging.getLogger("synapse.storage.SQL")
+transaction_logger = logging.getLogger("synapse.storage.txn")
+perf_logger = logging.getLogger("synapse.storage.TIME")
+
+sql_scheduling_timer = Histogram("synapse_storage_schedule_time", "sec")
+
+sql_query_timer = Histogram("synapse_storage_query_time", "sec", ["verb"])
+sql_txn_timer = Histogram("synapse_storage_transaction_time", "sec", ["desc"])
+
+
+# Unique indexes which have been added in background updates. Maps from table name
+# to the name of the background update which added the unique index to that table.
+#
+# This is used by the upsert logic to figure out which tables are safe to do a proper
+# UPSERT on: until the relevant background update has completed, we
+# have to emulate an upsert by locking the table.
+#
+UNIQUE_INDEX_BACKGROUND_UPDATES = {
+ "user_ips": "user_ips_device_unique_index",
+ "device_lists_remote_extremeties": "device_lists_remote_extremeties_unique_idx",
+ "device_lists_remote_cache": "device_lists_remote_cache_unique_idx",
+ "event_search": "event_search_event_id_idx",
+}
+
+
+class LoggingTransaction(object):
+ """An object that almost-transparently proxies for the 'txn' object
+ passed to the constructor. Adds logging and metrics to the .execute()
+ method.
+
+ Args:
+ txn: The database transcation object to wrap.
+ name (str): The name of this transactions for logging.
+ database_engine (Sqlite3Engine|PostgresEngine)
+ after_callbacks(list|None): A list that callbacks will be appended to
+ that have been added by `call_after` which should be run on
+ successful completion of the transaction. None indicates that no
+ callbacks should be allowed to be scheduled to run.
+ exception_callbacks(list|None): A list that callbacks will be appended
+ to that have been added by `call_on_exception` which should be run
+ if transaction ends with an error. None indicates that no callbacks
+ should be allowed to be scheduled to run.
+ """
+
+ __slots__ = [
+ "txn",
+ "name",
+ "database_engine",
+ "after_callbacks",
+ "exception_callbacks",
+ ]
+
+ def __init__(
+ self, txn, name, database_engine, after_callbacks=None, exception_callbacks=None
+ ):
+ object.__setattr__(self, "txn", txn)
+ object.__setattr__(self, "name", name)
+ object.__setattr__(self, "database_engine", database_engine)
+ object.__setattr__(self, "after_callbacks", after_callbacks)
+ object.__setattr__(self, "exception_callbacks", exception_callbacks)
+
+ def call_after(self, callback, *args, **kwargs):
+ """Call the given callback on the main twisted thread after the
+ transaction has finished. Used to invalidate the caches on the
+ correct thread.
+ """
+ self.after_callbacks.append((callback, args, kwargs))
+
+ def call_on_exception(self, callback, *args, **kwargs):
+ self.exception_callbacks.append((callback, args, kwargs))
+
+ def __getattr__(self, name):
+ return getattr(self.txn, name)
+
+ def __setattr__(self, name, value):
+ setattr(self.txn, name, value)
+
+ def __iter__(self):
+ return self.txn.__iter__()
+
+ def execute_batch(self, sql, args):
+ if isinstance(self.database_engine, PostgresEngine):
+ from psycopg2.extras import execute_batch
+
+ self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
+ else:
+ for val in args:
+ self.execute(sql, val)
+
+ def execute(self, sql, *args):
+ self._do_execute(self.txn.execute, sql, *args)
+
+ def executemany(self, sql, *args):
+ self._do_execute(self.txn.executemany, sql, *args)
+
+ def _make_sql_one_line(self, sql):
+ "Strip newlines out of SQL so that the loggers in the DB are on one line"
+ return " ".join(l.strip() for l in sql.splitlines() if l.strip())
+
+ def _do_execute(self, func, sql, *args):
+ sql = self._make_sql_one_line(sql)
+
+ # TODO(paul): Maybe use 'info' and 'debug' for values?
+ sql_logger.debug("[SQL] {%s} %s", self.name, sql)
+
+ sql = self.database_engine.convert_param_style(sql)
+ if args:
+ try:
+ sql_logger.debug("[SQL values] {%s} %r", self.name, args[0])
+ except Exception:
+ # Don't let logging failures stop SQL from working
+ pass
+
+ start = time.time()
+
+ try:
+ return func(sql, *args)
+ except Exception as e:
+ logger.debug("[SQL FAIL] {%s} %s", self.name, e)
+ raise
+ finally:
+ secs = time.time() - start
+ sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs)
+ sql_query_timer.labels(sql.split()[0]).observe(secs)
+
+
+class PerformanceCounters(object):
+ def __init__(self):
+ self.current_counters = {}
+ self.previous_counters = {}
+
+ def update(self, key, duration_secs):
+ count, cum_time = self.current_counters.get(key, (0, 0))
+ count += 1
+ cum_time += duration_secs
+ self.current_counters[key] = (count, cum_time)
+
+ def interval(self, interval_duration_secs, limit=3):
+ counters = []
+ for name, (count, cum_time) in iteritems(self.current_counters):
+ prev_count, prev_time = self.previous_counters.get(name, (0, 0))
+ counters.append(
+ (
+ (cum_time - prev_time) / interval_duration_secs,
+ count - prev_count,
+ name,
+ )
+ )
+
+ self.previous_counters = dict(self.current_counters)
+
+ counters.sort(reverse=True)
+
+ top_n_counters = ", ".join(
+ "%s(%d): %.3f%%" % (name, count, 100 * ratio)
+ for ratio, count, name in counters[:limit]
+ )
+
+ return top_n_counters
+
+
+class Database(object):
+ _TXN_ID = 0
+
+ def __init__(self, hs):
+ self.hs = hs
+ self._clock = hs.get_clock()
+ self._db_pool = hs.get_db_pool()
+
+ self._previous_txn_total_time = 0
+ self._current_txn_total_time = 0
+ self._previous_loop_ts = 0
+
+ # TODO(paul): These can eventually be removed once the metrics code
+ # is running in mainline, and we have some nice monitoring frontends
+ # to watch it
+ self._txn_perf_counters = PerformanceCounters()
+
+ self.database_engine = hs.database_engine
+
+ # A set of tables that are not safe to use native upserts in.
+ self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys())
+
+ # We add the user_directory_search table to the blacklist on SQLite
+ # because the existing search table does not have an index, making it
+ # unsafe to use native upserts.
+ if isinstance(self.database_engine, Sqlite3Engine):
+ self._unsafe_to_upsert_tables.add("user_directory_search")
+
+ if self.database_engine.can_native_upsert:
+ # Check ASAP (and then later, every 1s) to see if we have finished
+ # background updates of tables that aren't safe to update.
+ self._clock.call_later(
+ 0.0,
+ run_as_background_process,
+ "upsert_safety_check",
+ self._check_safe_to_upsert,
+ )
+
+ self.rand = random.SystemRandom()
+
+ @defer.inlineCallbacks
+ def _check_safe_to_upsert(self):
+ """
+ Is it safe to use native UPSERT?
+
+ If there are background updates, we will need to wait, as they may be
+ the addition of indexes that set the UNIQUE constraint that we require.
+
+ If the background updates have not completed, wait 15 sec and check again.
+ """
+ updates = yield self.simple_select_list(
+ "background_updates",
+ keyvalues=None,
+ retcols=["update_name"],
+ desc="check_background_updates",
+ )
+ updates = [x["update_name"] for x in updates]
+
+ for table, update_name in UNIQUE_INDEX_BACKGROUND_UPDATES.items():
+ if update_name not in updates:
+ logger.debug("Now safe to upsert in %s", table)
+ self._unsafe_to_upsert_tables.discard(table)
+
+ # If there's any updates still running, reschedule to run.
+ if updates:
+ self._clock.call_later(
+ 15.0,
+ run_as_background_process,
+ "upsert_safety_check",
+ self._check_safe_to_upsert,
+ )
+
+ def start_profiling(self):
+ self._previous_loop_ts = monotonic_time()
+
+ def loop():
+ curr = self._current_txn_total_time
+ prev = self._previous_txn_total_time
+ self._previous_txn_total_time = curr
+
+ time_now = monotonic_time()
+ time_then = self._previous_loop_ts
+ self._previous_loop_ts = time_now
+
+ duration = time_now - time_then
+ ratio = (curr - prev) / duration
+
+ top_three_counters = self._txn_perf_counters.interval(duration, limit=3)
+
+ perf_logger.info(
+ "Total database time: %.3f%% {%s}", ratio * 100, top_three_counters
+ )
+
+ self._clock.looping_call(loop, 10000)
+
+ def new_transaction(
+ self, conn, desc, after_callbacks, exception_callbacks, func, *args, **kwargs
+ ):
+ start = monotonic_time()
+ txn_id = self._TXN_ID
+
+ # We don't really need these to be unique, so lets stop it from
+ # growing really large.
+ self._TXN_ID = (self._TXN_ID + 1) % (MAX_TXN_ID)
+
+ name = "%s-%x" % (desc, txn_id)
+
+ transaction_logger.debug("[TXN START] {%s}", name)
+
+ try:
+ i = 0
+ N = 5
+ while True:
+ cursor = LoggingTransaction(
+ conn.cursor(),
+ name,
+ self.database_engine,
+ after_callbacks,
+ exception_callbacks,
+ )
+ try:
+ r = func(cursor, *args, **kwargs)
+ conn.commit()
+ return r
+ except self.database_engine.module.OperationalError as e:
+ # This can happen if the database disappears mid
+ # transaction.
+ logger.warning(
+ "[TXN OPERROR] {%s} %s %d/%d",
+ name,
+ exception_to_unicode(e),
+ i,
+ N,
+ )
+ if i < N:
+ i += 1
+ try:
+ conn.rollback()
+ except self.database_engine.module.Error as e1:
+ logger.warning(
+ "[TXN EROLL] {%s} %s", name, exception_to_unicode(e1)
+ )
+ continue
+ raise
+ except self.database_engine.module.DatabaseError as e:
+ if self.database_engine.is_deadlock(e):
+ logger.warning("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
+ if i < N:
+ i += 1
+ try:
+ conn.rollback()
+ except self.database_engine.module.Error as e1:
+ logger.warning(
+ "[TXN EROLL] {%s} %s",
+ name,
+ exception_to_unicode(e1),
+ )
+ continue
+ raise
+ finally:
+ # we're either about to retry with a new cursor, or we're about to
+ # release the connection. Once we release the connection, it could
+ # get used for another query, which might do a conn.rollback().
+ #
+ # In the latter case, even though that probably wouldn't affect the
+ # results of this transaction, python's sqlite will reset all
+ # statements on the connection [1], which will make our cursor
+ # invalid [2].
+ #
+ # In any case, continuing to read rows after commit()ing seems
+ # dubious from the PoV of ACID transactional semantics
+ # (sqlite explicitly says that once you commit, you may see rows
+ # from subsequent updates.)
+ #
+ # In psycopg2, cursors are essentially a client-side fabrication -
+ # all the data is transferred to the client side when the statement
+ # finishes executing - so in theory we could go on streaming results
+ # from the cursor, but attempting to do so would make us
+ # incompatible with sqlite, so let's make sure we're not doing that
+ # by closing the cursor.
+ #
+ # (*named* cursors in psycopg2 are different and are proper server-
+ # side things, but (a) we don't use them and (b) they are implicitly
+ # closed by ending the transaction anyway.)
+ #
+ # In short, if we haven't finished with the cursor yet, that's a
+ # problem waiting to bite us.
+ #
+ # TL;DR: we're done with the cursor, so we can close it.
+ #
+ # [1]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/connection.c#L465
+ # [2]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/cursor.c#L236
+ cursor.close()
+ except Exception as e:
+ logger.debug("[TXN FAIL] {%s} %s", name, e)
+ raise
+ finally:
+ end = monotonic_time()
+ duration = end - start
+
+ LoggingContext.current_context().add_database_transaction(duration)
+
+ transaction_logger.debug("[TXN END] {%s} %f sec", name, duration)
+
+ self._current_txn_total_time += duration
+ self._txn_perf_counters.update(desc, duration)
+ sql_txn_timer.labels(desc).observe(duration)
+
+ @defer.inlineCallbacks
+ def runInteraction(self, desc, func, *args, **kwargs):
+ """Starts a transaction on the database and runs a given function
+
+ Arguments:
+ desc (str): description of the transaction, for logging and metrics
+ func (func): callback function, which will be called with a
+ database transaction (twisted.enterprise.adbapi.Transaction) as
+ its first argument, followed by `args` and `kwargs`.
+
+ args (list): positional args to pass to `func`
+ kwargs (dict): named args to pass to `func`
+
+ Returns:
+ Deferred: The result of func
+ """
+ after_callbacks = []
+ exception_callbacks = []
+
+ if LoggingContext.current_context() == LoggingContext.sentinel:
+ logger.warning("Starting db txn '%s' from sentinel context", desc)
+
+ try:
+ result = yield self.runWithConnection(
+ self.new_transaction,
+ desc,
+ after_callbacks,
+ exception_callbacks,
+ func,
+ *args,
+ **kwargs
+ )
+
+ for after_callback, after_args, after_kwargs in after_callbacks:
+ after_callback(*after_args, **after_kwargs)
+ except: # noqa: E722, as we reraise the exception this is fine.
+ for after_callback, after_args, after_kwargs in exception_callbacks:
+ after_callback(*after_args, **after_kwargs)
+ raise
+
+ return result
+
+ @defer.inlineCallbacks
+ def runWithConnection(self, func, *args, **kwargs):
+ """Wraps the .runWithConnection() method on the underlying db_pool.
+
+ Arguments:
+ func (func): callback function, which will be called with a
+ database connection (twisted.enterprise.adbapi.Connection) as
+ its first argument, followed by `args` and `kwargs`.
+ args (list): positional args to pass to `func`
+ kwargs (dict): named args to pass to `func`
+
+ Returns:
+ Deferred: The result of func
+ """
+ parent_context = LoggingContext.current_context()
+ if parent_context == LoggingContext.sentinel:
+ logger.warning(
+ "Starting db connection from sentinel context: metrics will be lost"
+ )
+ parent_context = None
+
+ start_time = monotonic_time()
+
+ def inner_func(conn, *args, **kwargs):
+ with LoggingContext("runWithConnection", parent_context) as context:
+ sched_duration_sec = monotonic_time() - start_time
+ sql_scheduling_timer.observe(sched_duration_sec)
+ context.add_database_scheduled(sched_duration_sec)
+
+ if self.database_engine.is_connection_closed(conn):
+ logger.debug("Reconnecting closed database connection")
+ conn.reconnect()
+
+ return func(conn, *args, **kwargs)
+
+ result = yield make_deferred_yieldable(
+ self._db_pool.runWithConnection(inner_func, *args, **kwargs)
+ )
+
+ return result
+
+ @staticmethod
+ def cursor_to_dict(cursor):
+ """Converts a SQL cursor into an list of dicts.
+
+ Args:
+ cursor : The DBAPI cursor which has executed a query.
+ Returns:
+ A list of dicts where the key is the column header.
+ """
+ col_headers = list(intern(str(column[0])) for column in cursor.description)
+ results = list(dict(zip(col_headers, row)) for row in cursor)
+ return results
+
+ def execute(self, desc, decoder, query, *args):
+ """Runs a single query for a result set.
+
+ Args:
+ decoder - The function which can resolve the cursor results to
+ something meaningful.
+ query - The query string to execute
+ *args - Query args.
+ Returns:
+ The result of decoder(results)
+ """
+
+ def interaction(txn):
+ txn.execute(query, args)
+ if decoder:
+ return decoder(txn)
+ else:
+ return txn.fetchall()
+
+ return self.runInteraction(desc, interaction)
+
+ # "Simple" SQL API methods that operate on a single table with no JOINs,
+ # no complex WHERE clauses, just a dict of values for columns.
+
+ @defer.inlineCallbacks
+ def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"):
+ """Executes an INSERT query on the named table.
+
+ Args:
+ table : string giving the table name
+ values : dict of new column names and values for them
+ or_ignore : bool stating whether an exception should be raised
+ when a conflicting row already exists. If True, False will be
+ returned by the function instead
+ desc : string giving a description of the transaction
+
+ Returns:
+ bool: Whether the row was inserted or not. Only useful when
+ `or_ignore` is True
+ """
+ try:
+ yield self.runInteraction(desc, self.simple_insert_txn, table, values)
+ except self.database_engine.module.IntegrityError:
+ # We have to do or_ignore flag at this layer, since we can't reuse
+ # a cursor after we receive an error from the db.
+ if not or_ignore:
+ raise
+ return False
+ return True
+
+ @staticmethod
+ def simple_insert_txn(txn, table, values):
+ keys, vals = zip(*values.items())
+
+ sql = "INSERT INTO %s (%s) VALUES(%s)" % (
+ table,
+ ", ".join(k for k in keys),
+ ", ".join("?" for _ in keys),
+ )
+
+ txn.execute(sql, vals)
+
+ def simple_insert_many(self, table, values, desc):
+ return self.runInteraction(desc, self.simple_insert_many_txn, table, values)
+
+ @staticmethod
+ def simple_insert_many_txn(txn, table, values):
+ if not values:
+ return
+
+ # This is a *slight* abomination to get a list of tuples of key names
+ # and a list of tuples of value names.
+ #
+ # i.e. [{"a": 1, "b": 2}, {"c": 3, "d": 4}]
+ # => [("a", "b",), ("c", "d",)] and [(1, 2,), (3, 4,)]
+ #
+ # The sort is to ensure that we don't rely on dictionary iteration
+ # order.
+ keys, vals = zip(
+ *[zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i]
+ )
+
+ for k in keys:
+ if k != keys[0]:
+ raise RuntimeError("All items must have the same keys")
+
+ sql = "INSERT INTO %s (%s) VALUES(%s)" % (
+ table,
+ ", ".join(k for k in keys[0]),
+ ", ".join("?" for _ in keys[0]),
+ )
+
+ txn.executemany(sql, vals)
+
+ @defer.inlineCallbacks
+ def simple_upsert(
+ self,
+ table,
+ keyvalues,
+ values,
+ insertion_values={},
+ desc="simple_upsert",
+ lock=True,
+ ):
+ """
+
+ `lock` should generally be set to True (the default), but can be set
+ to False if either of the following are true:
+
+ * there is a UNIQUE INDEX on the key columns. In this case a conflict
+ will cause an IntegrityError in which case this function will retry
+ the update.
+
+ * we somehow know that we are the only thread which will be updating
+ this table.
+
+ Args:
+ table (str): The table to upsert into
+ keyvalues (dict): The unique key columns and their new values
+ values (dict): The nonunique columns and their new values
+ insertion_values (dict): additional key/values to use only when
+ inserting
+ lock (bool): True to lock the table when doing the upsert.
+ Returns:
+ Deferred(None or bool): Native upserts always return None. Emulated
+ upserts return True if a new entry was created, False if an existing
+ one was updated.
+ """
+ attempts = 0
+ while True:
+ try:
+ result = yield self.runInteraction(
+ desc,
+ self.simple_upsert_txn,
+ table,
+ keyvalues,
+ values,
+ insertion_values,
+ lock=lock,
+ )
+ return result
+ except self.database_engine.module.IntegrityError as e:
+ attempts += 1
+ if attempts >= 5:
+ # don't retry forever, because things other than races
+ # can cause IntegrityErrors
+ raise
+
+ # presumably we raced with another transaction: let's retry.
+ logger.warning(
+ "IntegrityError when upserting into %s; retrying: %s", table, e
+ )
+
+ def simple_upsert_txn(
+ self, txn, table, keyvalues, values, insertion_values={}, lock=True
+ ):
+ """
+ Pick the UPSERT method which works best on the platform. Either the
+ native one (Pg9.5+, recent SQLites), or fall back to an emulated method.
+
+ Args:
+ txn: The transaction to use.
+ table (str): The table to upsert into
+ keyvalues (dict): The unique key tables and their new values
+ values (dict): The nonunique columns and their new values
+ insertion_values (dict): additional key/values to use only when
+ inserting
+ lock (bool): True to lock the table when doing the upsert.
+ Returns:
+ None or bool: Native upserts always return None. Emulated
+ upserts return True if a new entry was created, False if an existing
+ one was updated.
+ """
+ if (
+ self.database_engine.can_native_upsert
+ and table not in self._unsafe_to_upsert_tables
+ ):
+ return self.simple_upsert_txn_native_upsert(
+ txn, table, keyvalues, values, insertion_values=insertion_values
+ )
+ else:
+ return self.simple_upsert_txn_emulated(
+ txn,
+ table,
+ keyvalues,
+ values,
+ insertion_values=insertion_values,
+ lock=lock,
+ )
+
+ def simple_upsert_txn_emulated(
+ self, txn, table, keyvalues, values, insertion_values={}, lock=True
+ ):
+ """
+ Args:
+ table (str): The table to upsert into
+ keyvalues (dict): The unique key tables and their new values
+ values (dict): The nonunique columns and their new values
+ insertion_values (dict): additional key/values to use only when
+ inserting
+ lock (bool): True to lock the table when doing the upsert.
+ Returns:
+ bool: Return True if a new entry was created, False if an existing
+ one was updated.
+ """
+ # We need to lock the table :(, unless we're *really* careful
+ if lock:
+ self.database_engine.lock_table(txn, table)
+
+ def _getwhere(key):
+ # If the value we're passing in is None (aka NULL), we need to use
+ # IS, not =, as NULL = NULL equals NULL (False).
+ if keyvalues[key] is None:
+ return "%s IS ?" % (key,)
+ else:
+ return "%s = ?" % (key,)
+
+ if not values:
+ # If `values` is empty, then all of the values we care about are in
+ # the unique key, so there is nothing to UPDATE. We can just do a
+ # SELECT instead to see if it exists.
+ sql = "SELECT 1 FROM %s WHERE %s" % (
+ table,
+ " AND ".join(_getwhere(k) for k in keyvalues),
+ )
+ sqlargs = list(keyvalues.values())
+ txn.execute(sql, sqlargs)
+ if txn.fetchall():
+ # We have an existing record.
+ return False
+ else:
+ # First try to update.
+ sql = "UPDATE %s SET %s WHERE %s" % (
+ table,
+ ", ".join("%s = ?" % (k,) for k in values),
+ " AND ".join(_getwhere(k) for k in keyvalues),
+ )
+ sqlargs = list(values.values()) + list(keyvalues.values())
+
+ txn.execute(sql, sqlargs)
+ if txn.rowcount > 0:
+ # successfully updated at least one row.
+ return False
+
+ # We didn't find any existing rows, so insert a new one
+ allvalues = {}
+ allvalues.update(keyvalues)
+ allvalues.update(values)
+ allvalues.update(insertion_values)
+
+ sql = "INSERT INTO %s (%s) VALUES (%s)" % (
+ table,
+ ", ".join(k for k in allvalues),
+ ", ".join("?" for _ in allvalues),
+ )
+ txn.execute(sql, list(allvalues.values()))
+ # successfully inserted
+ return True
+
+ def simple_upsert_txn_native_upsert(
+ self, txn, table, keyvalues, values, insertion_values={}
+ ):
+ """
+ Use the native UPSERT functionality in recent PostgreSQL versions.
+
+ Args:
+ table (str): The table to upsert into
+ keyvalues (dict): The unique key tables and their new values
+ values (dict): The nonunique columns and their new values
+ insertion_values (dict): additional key/values to use only when
+ inserting
+ Returns:
+ None
+ """
+ allvalues = {}
+ allvalues.update(keyvalues)
+ allvalues.update(insertion_values)
+
+ if not values:
+ latter = "NOTHING"
+ else:
+ allvalues.update(values)
+ latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values)
+
+ sql = ("INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s") % (
+ table,
+ ", ".join(k for k in allvalues),
+ ", ".join("?" for _ in allvalues),
+ ", ".join(k for k in keyvalues),
+ latter,
+ )
+ txn.execute(sql, list(allvalues.values()))
+
+ def simple_upsert_many_txn(
+ self, txn, table, key_names, key_values, value_names, value_values
+ ):
+ """
+ Upsert, many times.
+
+ Args:
+ table (str): The table to upsert into
+ key_names (list[str]): The key column names.
+ key_values (list[list]): A list of each row's key column values.
+ value_names (list[str]): The value column names. If empty, no
+ values will be used, even if value_values is provided.
+ value_values (list[list]): A list of each row's value column values.
+ Returns:
+ None
+ """
+ if (
+ self.database_engine.can_native_upsert
+ and table not in self._unsafe_to_upsert_tables
+ ):
+ return self.simple_upsert_many_txn_native_upsert(
+ txn, table, key_names, key_values, value_names, value_values
+ )
+ else:
+ return self.simple_upsert_many_txn_emulated(
+ txn, table, key_names, key_values, value_names, value_values
+ )
+
+ def simple_upsert_many_txn_emulated(
+ self, txn, table, key_names, key_values, value_names, value_values
+ ):
+ """
+ Upsert, many times, but without native UPSERT support or batching.
+
+ Args:
+ table (str): The table to upsert into
+ key_names (list[str]): The key column names.
+ key_values (list[list]): A list of each row's key column values.
+ value_names (list[str]): The value column names. If empty, no
+ values will be used, even if value_values is provided.
+ value_values (list[list]): A list of each row's value column values.
+ Returns:
+ None
+ """
+ # No value columns, therefore make a blank list so that the following
+ # zip() works correctly.
+ if not value_names:
+ value_values = [() for x in range(len(key_values))]
+
+ for keyv, valv in zip(key_values, value_values):
+ _keys = {x: y for x, y in zip(key_names, keyv)}
+ _vals = {x: y for x, y in zip(value_names, valv)}
+
+ self.simple_upsert_txn_emulated(txn, table, _keys, _vals)
+
+ def simple_upsert_many_txn_native_upsert(
+ self, txn, table, key_names, key_values, value_names, value_values
+ ):
+ """
+ Upsert, many times, using batching where possible.
+
+ Args:
+ table (str): The table to upsert into
+ key_names (list[str]): The key column names.
+ key_values (list[list]): A list of each row's key column values.
+ value_names (list[str]): The value column names. If empty, no
+ values will be used, even if value_values is provided.
+ value_values (list[list]): A list of each row's value column values.
+ Returns:
+ None
+ """
+ allnames = []
+ allnames.extend(key_names)
+ allnames.extend(value_names)
+
+ if not value_names:
+ # No value columns, therefore make a blank list so that the
+ # following zip() works correctly.
+ latter = "NOTHING"
+ value_values = [() for x in range(len(key_values))]
+ else:
+ latter = "UPDATE SET " + ", ".join(
+ k + "=EXCLUDED." + k for k in value_names
+ )
+
+ sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s" % (
+ table,
+ ", ".join(k for k in allnames),
+ ", ".join("?" for _ in allnames),
+ ", ".join(key_names),
+ latter,
+ )
+
+ args = []
+
+ for x, y in zip(key_values, value_values):
+ args.append(tuple(x) + tuple(y))
+
+ return txn.execute_batch(sql, args)
+
+ def simple_select_one(
+ self, table, keyvalues, retcols, allow_none=False, desc="simple_select_one"
+ ):
+ """Executes a SELECT query on the named table, which is expected to
+ return a single row, returning multiple columns from it.
+
+ Args:
+ table : string giving the table name
+ keyvalues : dict of column names and values to select the row with
+ retcols : list of strings giving the names of the columns to return
+
+ allow_none : If true, return None instead of failing if the SELECT
+ statement returns no rows
+ """
+ return self.runInteraction(
+ desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none
+ )
+
+ def simple_select_one_onecol(
+ self,
+ table,
+ keyvalues,
+ retcol,
+ allow_none=False,
+ desc="simple_select_one_onecol",
+ ):
+ """Executes a SELECT query on the named table, which is expected to
+ return a single row, returning a single column from it.
+
+ Args:
+ table : string giving the table name
+ keyvalues : dict of column names and values to select the row with
+ retcol : string giving the name of the column to return
+ """
+ return self.runInteraction(
+ desc,
+ self.simple_select_one_onecol_txn,
+ table,
+ keyvalues,
+ retcol,
+ allow_none=allow_none,
+ )
+
+ @classmethod
+ def simple_select_one_onecol_txn(
+ cls, txn, table, keyvalues, retcol, allow_none=False
+ ):
+ ret = cls.simple_select_onecol_txn(
+ txn, table=table, keyvalues=keyvalues, retcol=retcol
+ )
+
+ if ret:
+ return ret[0]
+ else:
+ if allow_none:
+ return None
+ else:
+ raise StoreError(404, "No row found")
+
+ @staticmethod
+ def simple_select_onecol_txn(txn, table, keyvalues, retcol):
+ sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table}
+
+ if keyvalues:
+ sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
+ txn.execute(sql, list(keyvalues.values()))
+ else:
+ txn.execute(sql)
+
+ return [r[0] for r in txn]
+
+ def simple_select_onecol(
+ self, table, keyvalues, retcol, desc="simple_select_onecol"
+ ):
+ """Executes a SELECT query on the named table, which returns a list
+ comprising of the values of the named column from the selected rows.
+
+ Args:
+ table (str): table name
+ keyvalues (dict|None): column names and values to select the rows with
+ retcol (str): column whos value we wish to retrieve.
+
+ Returns:
+ Deferred: Results in a list
+ """
+ return self.runInteraction(
+ desc, self.simple_select_onecol_txn, table, keyvalues, retcol
+ )
+
+ def simple_select_list(self, table, keyvalues, retcols, desc="simple_select_list"):
+ """Executes a SELECT query on the named table, which may return zero or
+ more rows, returning the result as a list of dicts.
+
+ Args:
+ table (str): the table name
+ keyvalues (dict[str, Any] | None):
+ column names and values to select the rows with, or None to not
+ apply a WHERE clause.
+ retcols (iterable[str]): the names of the columns to return
+ Returns:
+ defer.Deferred: resolves to list[dict[str, Any]]
+ """
+ return self.runInteraction(
+ desc, self.simple_select_list_txn, table, keyvalues, retcols
+ )
+
+ @classmethod
+ def simple_select_list_txn(cls, txn, table, keyvalues, retcols):
+ """Executes a SELECT query on the named table, which may return zero or
+ more rows, returning the result as a list of dicts.
+
+ Args:
+ txn : Transaction object
+ table (str): the table name
+ keyvalues (dict[str, T] | None):
+ column names and values to select the rows with, or None to not
+ apply a WHERE clause.
+ retcols (iterable[str]): the names of the columns to return
+ """
+ if keyvalues:
+ sql = "SELECT %s FROM %s WHERE %s" % (
+ ", ".join(retcols),
+ table,
+ " AND ".join("%s = ?" % (k,) for k in keyvalues),
+ )
+ txn.execute(sql, list(keyvalues.values()))
+ else:
+ sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
+ txn.execute(sql)
+
+ return cls.cursor_to_dict(txn)
+
+ @defer.inlineCallbacks
+ def simple_select_many_batch(
+ self,
+ table,
+ column,
+ iterable,
+ retcols,
+ keyvalues={},
+ desc="simple_select_many_batch",
+ batch_size=100,
+ ):
+ """Executes a SELECT query on the named table, which may return zero or
+ more rows, returning the result as a list of dicts.
+
+ Filters rows by if value of `column` is in `iterable`.
+
+ Args:
+ table : string giving the table name
+ column : column name to test for inclusion against `iterable`
+ iterable : list
+ keyvalues : dict of column names and values to select the rows with
+ retcols : list of strings giving the names of the columns to return
+ """
+ results = []
+
+ if not iterable:
+ return results
+
+ # iterables can not be sliced, so convert it to a list first
+ it_list = list(iterable)
+
+ chunks = [
+ it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size)
+ ]
+ for chunk in chunks:
+ rows = yield self.runInteraction(
+ desc,
+ self.simple_select_many_txn,
+ table,
+ column,
+ chunk,
+ keyvalues,
+ retcols,
+ )
+
+ results.extend(rows)
+
+ return results
+
+ @classmethod
+ def simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols):
+ """Executes a SELECT query on the named table, which may return zero or
+ more rows, returning the result as a list of dicts.
+
+ Filters rows by if value of `column` is in `iterable`.
+
+ Args:
+ txn : Transaction object
+ table : string giving the table name
+ column : column name to test for inclusion against `iterable`
+ iterable : list
+ keyvalues : dict of column names and values to select the rows with
+ retcols : list of strings giving the names of the columns to return
+ """
+ if not iterable:
+ return []
+
+ clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable)
+ clauses = [clause]
+
+ for key, value in iteritems(keyvalues):
+ clauses.append("%s = ?" % (key,))
+ values.append(value)
+
+ sql = "SELECT %s FROM %s WHERE %s" % (
+ ", ".join(retcols),
+ table,
+ " AND ".join(clauses),
+ )
+
+ txn.execute(sql, values)
+ return cls.cursor_to_dict(txn)
+
+ def simple_update(self, table, keyvalues, updatevalues, desc):
+ return self.runInteraction(
+ desc, self.simple_update_txn, table, keyvalues, updatevalues
+ )
+
+ @staticmethod
+ def simple_update_txn(txn, table, keyvalues, updatevalues):
+ if keyvalues:
+ where = "WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
+ else:
+ where = ""
+
+ update_sql = "UPDATE %s SET %s %s" % (
+ table,
+ ", ".join("%s = ?" % (k,) for k in updatevalues),
+ where,
+ )
+
+ txn.execute(update_sql, list(updatevalues.values()) + list(keyvalues.values()))
+
+ return txn.rowcount
+
+ def simple_update_one(
+ self, table, keyvalues, updatevalues, desc="simple_update_one"
+ ):
+ """Executes an UPDATE query on the named table, setting new values for
+ columns in a row matching the key values.
+
+ Args:
+ table : string giving the table name
+ keyvalues : dict of column names and values to select the row with
+ updatevalues : dict giving column names and values to update
+ retcols : optional list of column names to return
+
+ If present, retcols gives a list of column names on which to perform
+ a SELECT statement *before* performing the UPDATE statement. The values
+ of these will be returned in a dict.
+
+ These are performed within the same transaction, allowing an atomic
+ get-and-set. This can be used to implement compare-and-set by putting
+ the update column in the 'keyvalues' dict as well.
+ """
+ return self.runInteraction(
+ desc, self.simple_update_one_txn, table, keyvalues, updatevalues
+ )
+
+ @classmethod
+ def simple_update_one_txn(cls, txn, table, keyvalues, updatevalues):
+ rowcount = cls.simple_update_txn(txn, table, keyvalues, updatevalues)
+
+ if rowcount == 0:
+ raise StoreError(404, "No row found (%s)" % (table,))
+ if rowcount > 1:
+ raise StoreError(500, "More than one row matched (%s)" % (table,))
+
+ @staticmethod
+ def simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False):
+ select_sql = "SELECT %s FROM %s WHERE %s" % (
+ ", ".join(retcols),
+ table,
+ " AND ".join("%s = ?" % (k,) for k in keyvalues),
+ )
+
+ txn.execute(select_sql, list(keyvalues.values()))
+ row = txn.fetchone()
+
+ if not row:
+ if allow_none:
+ return None
+ raise StoreError(404, "No row found (%s)" % (table,))
+ if txn.rowcount > 1:
+ raise StoreError(500, "More than one row matched (%s)" % (table,))
+
+ return dict(zip(retcols, row))
+
+ def simple_delete_one(self, table, keyvalues, desc="simple_delete_one"):
+ """Executes a DELETE query on the named table, expecting to delete a
+ single row.
+
+ Args:
+ table : string giving the table name
+ keyvalues : dict of column names and values to select the row with
+ """
+ return self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues)
+
+ @staticmethod
+ def simple_delete_one_txn(txn, table, keyvalues):
+ """Executes a DELETE query on the named table, expecting to delete a
+ single row.
+
+ Args:
+ table : string giving the table name
+ keyvalues : dict of column names and values to select the row with
+ """
+ sql = "DELETE FROM %s WHERE %s" % (
+ table,
+ " AND ".join("%s = ?" % (k,) for k in keyvalues),
+ )
+
+ txn.execute(sql, list(keyvalues.values()))
+ if txn.rowcount == 0:
+ raise StoreError(404, "No row found (%s)" % (table,))
+ if txn.rowcount > 1:
+ raise StoreError(500, "More than one row matched (%s)" % (table,))
+
+ def simple_delete(self, table, keyvalues, desc):
+ return self.runInteraction(desc, self.simple_delete_txn, table, keyvalues)
+
+ @staticmethod
+ def simple_delete_txn(txn, table, keyvalues):
+ sql = "DELETE FROM %s WHERE %s" % (
+ table,
+ " AND ".join("%s = ?" % (k,) for k in keyvalues),
+ )
+
+ txn.execute(sql, list(keyvalues.values()))
+ return txn.rowcount
+
+ def simple_delete_many(self, table, column, iterable, keyvalues, desc):
+ return self.runInteraction(
+ desc, self.simple_delete_many_txn, table, column, iterable, keyvalues
+ )
+
+ @staticmethod
+ def simple_delete_many_txn(txn, table, column, iterable, keyvalues):
+ """Executes a DELETE query on the named table.
+
+ Filters rows by if value of `column` is in `iterable`.
+
+ Args:
+ txn : Transaction object
+ table : string giving the table name
+ column : column name to test for inclusion against `iterable`
+ iterable : list
+ keyvalues : dict of column names and values to select the rows with
+
+ Returns:
+ int: Number rows deleted
+ """
+ if not iterable:
+ return 0
+
+ sql = "DELETE FROM %s" % table
+
+ clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable)
+ clauses = [clause]
+
+ for key, value in iteritems(keyvalues):
+ clauses.append("%s = ?" % (key,))
+ values.append(value)
+
+ if clauses:
+ sql = "%s WHERE %s" % (sql, " AND ".join(clauses))
+ txn.execute(sql, values)
+
+ return txn.rowcount
+
+ def get_cache_dict(
+ self, db_conn, table, entity_column, stream_column, max_value, limit=100000
+ ):
+ # Fetch a mapping of room_id -> max stream position for "recent" rooms.
+ # It doesn't really matter how many we get, the StreamChangeCache will
+ # do the right thing to ensure it respects the max size of cache.
+ sql = (
+ "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s"
+ " WHERE %(stream)s > ? - %(limit)s"
+ " GROUP BY %(entity)s"
+ ) % {
+ "table": table,
+ "entity": entity_column,
+ "stream": stream_column,
+ "limit": limit,
+ }
+
+ sql = self.database_engine.convert_param_style(sql)
+
+ txn = db_conn.cursor()
+ txn.execute(sql, (int(max_value),))
+
+ cache = {row[0]: int(row[1]) for row in txn}
+
+ txn.close()
+
+ if cache:
+ min_val = min(itervalues(cache))
+ else:
+ min_val = max_value
+
+ return cache, min_val
+
+ def simple_select_list_paginate(
+ self,
+ table,
+ keyvalues,
+ orderby,
+ start,
+ limit,
+ retcols,
+ order_direction="ASC",
+ desc="simple_select_list_paginate",
+ ):
+ """
+ Executes a SELECT query on the named table with start and limit,
+ of row numbers, which may return zero or number of rows from start to limit,
+ returning the result as a list of dicts.
+
+ Args:
+ table (str): the table name
+ keyvalues (dict[str, T] | None):
+ column names and values to select the rows with, or None to not
+ apply a WHERE clause.
+ orderby (str): Column to order the results by.
+ start (int): Index to begin the query at.
+ limit (int): Number of results to return.
+ retcols (iterable[str]): the names of the columns to return
+ order_direction (str): Whether the results should be ordered "ASC" or "DESC".
+ Returns:
+ defer.Deferred: resolves to list[dict[str, Any]]
+ """
+ return self.runInteraction(
+ desc,
+ self.simple_select_list_paginate_txn,
+ table,
+ keyvalues,
+ orderby,
+ start,
+ limit,
+ retcols,
+ order_direction=order_direction,
+ )
+
+ @classmethod
+ def simple_select_list_paginate_txn(
+ cls,
+ txn,
+ table,
+ keyvalues,
+ orderby,
+ start,
+ limit,
+ retcols,
+ order_direction="ASC",
+ ):
+ """
+ Executes a SELECT query on the named table with start and limit,
+ of row numbers, which may return zero or number of rows from start to limit,
+ returning the result as a list of dicts.
+
+ Args:
+ txn : Transaction object
+ table (str): the table name
+ keyvalues (dict[str, T] | None):
+ column names and values to select the rows with, or None to not
+ apply a WHERE clause.
+ orderby (str): Column to order the results by.
+ start (int): Index to begin the query at.
+ limit (int): Number of results to return.
+ retcols (iterable[str]): the names of the columns to return
+ order_direction (str): Whether the results should be ordered "ASC" or "DESC".
+ Returns:
+ defer.Deferred: resolves to list[dict[str, Any]]
+ """
+ if order_direction not in ["ASC", "DESC"]:
+ raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
+
+ if keyvalues:
+ where_clause = "WHERE " + " AND ".join("%s = ?" % (k,) for k in keyvalues)
+ else:
+ where_clause = ""
+
+ sql = "SELECT %s FROM %s %s ORDER BY %s %s LIMIT ? OFFSET ?" % (
+ ", ".join(retcols),
+ table,
+ where_clause,
+ orderby,
+ order_direction,
+ )
+ txn.execute(sql, list(keyvalues.values()) + [limit, start])
+
+ return cls.cursor_to_dict(txn)
+
+ def get_user_count_txn(self, txn):
+ """Get a total number of registered users in the users list.
+
+ Args:
+ txn : Transaction object
+ Returns:
+ int : number of users
+ """
+ sql_count = "SELECT COUNT(*) FROM users WHERE is_guest = 0;"
+ txn.execute(sql_count)
+ return txn.fetchone()[0]
+
+ def simple_search_list(self, table, term, col, retcols, desc="simple_search_list"):
+ """Executes a SELECT query on the named table, which may return zero or
+ more rows, returning the result as a list of dicts.
+
+ Args:
+ table (str): the table name
+ term (str | None):
+ term for searching the table matched to a column.
+ col (str): column to query term should be matched to
+ retcols (iterable[str]): the names of the columns to return
+ Returns:
+ defer.Deferred: resolves to list[dict[str, Any]] or None
+ """
+
+ return self.runInteraction(
+ desc, self.simple_search_list_txn, table, term, col, retcols
+ )
+
+ @classmethod
+ def simple_search_list_txn(cls, txn, table, term, col, retcols):
+ """Executes a SELECT query on the named table, which may return zero or
+ more rows, returning the result as a list of dicts.
+
+ Args:
+ txn : Transaction object
+ table (str): the table name
+ term (str | None):
+ term for searching the table matched to a column.
+ col (str): column to query term should be matched to
+ retcols (iterable[str]): the names of the columns to return
+ Returns:
+ defer.Deferred: resolves to list[dict[str, Any]] or None
+ """
+ if term:
+ sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col)
+ termvalues = ["%%" + term + "%%"]
+ txn.execute(sql, termvalues)
+ else:
+ return 0
+
+ return cls.cursor_to_dict(txn)
+
+
+def make_in_list_sql_clause(
+ database_engine, column: str, iterable: Iterable
+) -> Tuple[str, Iterable]:
+ """Returns an SQL clause that checks the given column is in the iterable.
+
+ On SQLite this expands to `column IN (?, ?, ...)`, whereas on Postgres
+ it expands to `column = ANY(?)`. While both DBs support the `IN` form,
+ using the `ANY` form on postgres means that it views queries with
+ different length iterables as the same, helping the query stats.
+
+ Args:
+ database_engine
+ column: Name of the column
+ iterable: The values to check the column against.
+
+ Returns:
+ A tuple of SQL query and the args
+ """
+
+ if database_engine.supports_using_any_list:
+ # This should hopefully be faster, but also makes postgres query
+ # stats easier to understand.
+ return "%s = ANY(?)" % (column,), [list(iterable)]
+ else:
+ return "%s IN (%s)" % (column, ",".join("?" for _ in iterable)), list(iterable)
|