From 756d4942f5707922f29fe1fdfd945d73a19d7ac3 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 4 Dec 2019 13:52:46 +0000 Subject: Move DB pool and helper functions into dedicated Database class --- synapse/storage/database.py | 1485 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 1485 insertions(+) create mode 100644 synapse/storage/database.py (limited to 'synapse/storage/database.py') diff --git a/synapse/storage/database.py b/synapse/storage/database.py new file mode 100644 index 0000000000..c2e121a001 --- /dev/null +++ b/synapse/storage/database.py @@ -0,0 +1,1485 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2017-2018 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import random +import sys +import time +from typing import Iterable, Tuple + +from six import iteritems, iterkeys, itervalues +from six.moves import intern, range + +from prometheus_client import Histogram + +from twisted.internet import defer + +from synapse.api.errors import StoreError +from synapse.logging.context import LoggingContext, make_deferred_yieldable +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage.engines import PostgresEngine, Sqlite3Engine +from synapse.util.stringutils import exception_to_unicode + +# import a function which will return a monotonic time, in seconds +try: + # on python 3, use time.monotonic, since time.clock can go backwards + from time import monotonic as monotonic_time +except ImportError: + # ... but python 2 doesn't have it + from time import clock as monotonic_time + +logger = logging.getLogger(__name__) + +try: + MAX_TXN_ID = sys.maxint - 1 +except AttributeError: + # python 3 does not have a maximum int value + MAX_TXN_ID = 2 ** 63 - 1 + +sql_logger = logging.getLogger("synapse.storage.SQL") +transaction_logger = logging.getLogger("synapse.storage.txn") +perf_logger = logging.getLogger("synapse.storage.TIME") + +sql_scheduling_timer = Histogram("synapse_storage_schedule_time", "sec") + +sql_query_timer = Histogram("synapse_storage_query_time", "sec", ["verb"]) +sql_txn_timer = Histogram("synapse_storage_transaction_time", "sec", ["desc"]) + + +# Unique indexes which have been added in background updates. Maps from table name +# to the name of the background update which added the unique index to that table. +# +# This is used by the upsert logic to figure out which tables are safe to do a proper +# UPSERT on: until the relevant background update has completed, we +# have to emulate an upsert by locking the table. +# +UNIQUE_INDEX_BACKGROUND_UPDATES = { + "user_ips": "user_ips_device_unique_index", + "device_lists_remote_extremeties": "device_lists_remote_extremeties_unique_idx", + "device_lists_remote_cache": "device_lists_remote_cache_unique_idx", + "event_search": "event_search_event_id_idx", +} + + +class LoggingTransaction(object): + """An object that almost-transparently proxies for the 'txn' object + passed to the constructor. Adds logging and metrics to the .execute() + method. + + Args: + txn: The database transcation object to wrap. + name (str): The name of this transactions for logging. + database_engine (Sqlite3Engine|PostgresEngine) + after_callbacks(list|None): A list that callbacks will be appended to + that have been added by `call_after` which should be run on + successful completion of the transaction. None indicates that no + callbacks should be allowed to be scheduled to run. + exception_callbacks(list|None): A list that callbacks will be appended + to that have been added by `call_on_exception` which should be run + if transaction ends with an error. None indicates that no callbacks + should be allowed to be scheduled to run. + """ + + __slots__ = [ + "txn", + "name", + "database_engine", + "after_callbacks", + "exception_callbacks", + ] + + def __init__( + self, txn, name, database_engine, after_callbacks=None, exception_callbacks=None + ): + object.__setattr__(self, "txn", txn) + object.__setattr__(self, "name", name) + object.__setattr__(self, "database_engine", database_engine) + object.__setattr__(self, "after_callbacks", after_callbacks) + object.__setattr__(self, "exception_callbacks", exception_callbacks) + + def call_after(self, callback, *args, **kwargs): + """Call the given callback on the main twisted thread after the + transaction has finished. Used to invalidate the caches on the + correct thread. + """ + self.after_callbacks.append((callback, args, kwargs)) + + def call_on_exception(self, callback, *args, **kwargs): + self.exception_callbacks.append((callback, args, kwargs)) + + def __getattr__(self, name): + return getattr(self.txn, name) + + def __setattr__(self, name, value): + setattr(self.txn, name, value) + + def __iter__(self): + return self.txn.__iter__() + + def execute_batch(self, sql, args): + if isinstance(self.database_engine, PostgresEngine): + from psycopg2.extras import execute_batch + + self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args) + else: + for val in args: + self.execute(sql, val) + + def execute(self, sql, *args): + self._do_execute(self.txn.execute, sql, *args) + + def executemany(self, sql, *args): + self._do_execute(self.txn.executemany, sql, *args) + + def _make_sql_one_line(self, sql): + "Strip newlines out of SQL so that the loggers in the DB are on one line" + return " ".join(l.strip() for l in sql.splitlines() if l.strip()) + + def _do_execute(self, func, sql, *args): + sql = self._make_sql_one_line(sql) + + # TODO(paul): Maybe use 'info' and 'debug' for values? + sql_logger.debug("[SQL] {%s} %s", self.name, sql) + + sql = self.database_engine.convert_param_style(sql) + if args: + try: + sql_logger.debug("[SQL values] {%s} %r", self.name, args[0]) + except Exception: + # Don't let logging failures stop SQL from working + pass + + start = time.time() + + try: + return func(sql, *args) + except Exception as e: + logger.debug("[SQL FAIL] {%s} %s", self.name, e) + raise + finally: + secs = time.time() - start + sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs) + sql_query_timer.labels(sql.split()[0]).observe(secs) + + +class PerformanceCounters(object): + def __init__(self): + self.current_counters = {} + self.previous_counters = {} + + def update(self, key, duration_secs): + count, cum_time = self.current_counters.get(key, (0, 0)) + count += 1 + cum_time += duration_secs + self.current_counters[key] = (count, cum_time) + + def interval(self, interval_duration_secs, limit=3): + counters = [] + for name, (count, cum_time) in iteritems(self.current_counters): + prev_count, prev_time = self.previous_counters.get(name, (0, 0)) + counters.append( + ( + (cum_time - prev_time) / interval_duration_secs, + count - prev_count, + name, + ) + ) + + self.previous_counters = dict(self.current_counters) + + counters.sort(reverse=True) + + top_n_counters = ", ".join( + "%s(%d): %.3f%%" % (name, count, 100 * ratio) + for ratio, count, name in counters[:limit] + ) + + return top_n_counters + + +class Database(object): + _TXN_ID = 0 + + def __init__(self, hs): + self.hs = hs + self._clock = hs.get_clock() + self._db_pool = hs.get_db_pool() + + self._previous_txn_total_time = 0 + self._current_txn_total_time = 0 + self._previous_loop_ts = 0 + + # TODO(paul): These can eventually be removed once the metrics code + # is running in mainline, and we have some nice monitoring frontends + # to watch it + self._txn_perf_counters = PerformanceCounters() + + self.database_engine = hs.database_engine + + # A set of tables that are not safe to use native upserts in. + self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys()) + + # We add the user_directory_search table to the blacklist on SQLite + # because the existing search table does not have an index, making it + # unsafe to use native upserts. + if isinstance(self.database_engine, Sqlite3Engine): + self._unsafe_to_upsert_tables.add("user_directory_search") + + if self.database_engine.can_native_upsert: + # Check ASAP (and then later, every 1s) to see if we have finished + # background updates of tables that aren't safe to update. + self._clock.call_later( + 0.0, + run_as_background_process, + "upsert_safety_check", + self._check_safe_to_upsert, + ) + + self.rand = random.SystemRandom() + + @defer.inlineCallbacks + def _check_safe_to_upsert(self): + """ + Is it safe to use native UPSERT? + + If there are background updates, we will need to wait, as they may be + the addition of indexes that set the UNIQUE constraint that we require. + + If the background updates have not completed, wait 15 sec and check again. + """ + updates = yield self.simple_select_list( + "background_updates", + keyvalues=None, + retcols=["update_name"], + desc="check_background_updates", + ) + updates = [x["update_name"] for x in updates] + + for table, update_name in UNIQUE_INDEX_BACKGROUND_UPDATES.items(): + if update_name not in updates: + logger.debug("Now safe to upsert in %s", table) + self._unsafe_to_upsert_tables.discard(table) + + # If there's any updates still running, reschedule to run. + if updates: + self._clock.call_later( + 15.0, + run_as_background_process, + "upsert_safety_check", + self._check_safe_to_upsert, + ) + + def start_profiling(self): + self._previous_loop_ts = monotonic_time() + + def loop(): + curr = self._current_txn_total_time + prev = self._previous_txn_total_time + self._previous_txn_total_time = curr + + time_now = monotonic_time() + time_then = self._previous_loop_ts + self._previous_loop_ts = time_now + + duration = time_now - time_then + ratio = (curr - prev) / duration + + top_three_counters = self._txn_perf_counters.interval(duration, limit=3) + + perf_logger.info( + "Total database time: %.3f%% {%s}", ratio * 100, top_three_counters + ) + + self._clock.looping_call(loop, 10000) + + def new_transaction( + self, conn, desc, after_callbacks, exception_callbacks, func, *args, **kwargs + ): + start = monotonic_time() + txn_id = self._TXN_ID + + # We don't really need these to be unique, so lets stop it from + # growing really large. + self._TXN_ID = (self._TXN_ID + 1) % (MAX_TXN_ID) + + name = "%s-%x" % (desc, txn_id) + + transaction_logger.debug("[TXN START] {%s}", name) + + try: + i = 0 + N = 5 + while True: + cursor = LoggingTransaction( + conn.cursor(), + name, + self.database_engine, + after_callbacks, + exception_callbacks, + ) + try: + r = func(cursor, *args, **kwargs) + conn.commit() + return r + except self.database_engine.module.OperationalError as e: + # This can happen if the database disappears mid + # transaction. + logger.warning( + "[TXN OPERROR] {%s} %s %d/%d", + name, + exception_to_unicode(e), + i, + N, + ) + if i < N: + i += 1 + try: + conn.rollback() + except self.database_engine.module.Error as e1: + logger.warning( + "[TXN EROLL] {%s} %s", name, exception_to_unicode(e1) + ) + continue + raise + except self.database_engine.module.DatabaseError as e: + if self.database_engine.is_deadlock(e): + logger.warning("[TXN DEADLOCK] {%s} %d/%d", name, i, N) + if i < N: + i += 1 + try: + conn.rollback() + except self.database_engine.module.Error as e1: + logger.warning( + "[TXN EROLL] {%s} %s", + name, + exception_to_unicode(e1), + ) + continue + raise + finally: + # we're either about to retry with a new cursor, or we're about to + # release the connection. Once we release the connection, it could + # get used for another query, which might do a conn.rollback(). + # + # In the latter case, even though that probably wouldn't affect the + # results of this transaction, python's sqlite will reset all + # statements on the connection [1], which will make our cursor + # invalid [2]. + # + # In any case, continuing to read rows after commit()ing seems + # dubious from the PoV of ACID transactional semantics + # (sqlite explicitly says that once you commit, you may see rows + # from subsequent updates.) + # + # In psycopg2, cursors are essentially a client-side fabrication - + # all the data is transferred to the client side when the statement + # finishes executing - so in theory we could go on streaming results + # from the cursor, but attempting to do so would make us + # incompatible with sqlite, so let's make sure we're not doing that + # by closing the cursor. + # + # (*named* cursors in psycopg2 are different and are proper server- + # side things, but (a) we don't use them and (b) they are implicitly + # closed by ending the transaction anyway.) + # + # In short, if we haven't finished with the cursor yet, that's a + # problem waiting to bite us. + # + # TL;DR: we're done with the cursor, so we can close it. + # + # [1]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/connection.c#L465 + # [2]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/cursor.c#L236 + cursor.close() + except Exception as e: + logger.debug("[TXN FAIL] {%s} %s", name, e) + raise + finally: + end = monotonic_time() + duration = end - start + + LoggingContext.current_context().add_database_transaction(duration) + + transaction_logger.debug("[TXN END] {%s} %f sec", name, duration) + + self._current_txn_total_time += duration + self._txn_perf_counters.update(desc, duration) + sql_txn_timer.labels(desc).observe(duration) + + @defer.inlineCallbacks + def runInteraction(self, desc, func, *args, **kwargs): + """Starts a transaction on the database and runs a given function + + Arguments: + desc (str): description of the transaction, for logging and metrics + func (func): callback function, which will be called with a + database transaction (twisted.enterprise.adbapi.Transaction) as + its first argument, followed by `args` and `kwargs`. + + args (list): positional args to pass to `func` + kwargs (dict): named args to pass to `func` + + Returns: + Deferred: The result of func + """ + after_callbacks = [] + exception_callbacks = [] + + if LoggingContext.current_context() == LoggingContext.sentinel: + logger.warning("Starting db txn '%s' from sentinel context", desc) + + try: + result = yield self.runWithConnection( + self.new_transaction, + desc, + after_callbacks, + exception_callbacks, + func, + *args, + **kwargs + ) + + for after_callback, after_args, after_kwargs in after_callbacks: + after_callback(*after_args, **after_kwargs) + except: # noqa: E722, as we reraise the exception this is fine. + for after_callback, after_args, after_kwargs in exception_callbacks: + after_callback(*after_args, **after_kwargs) + raise + + return result + + @defer.inlineCallbacks + def runWithConnection(self, func, *args, **kwargs): + """Wraps the .runWithConnection() method on the underlying db_pool. + + Arguments: + func (func): callback function, which will be called with a + database connection (twisted.enterprise.adbapi.Connection) as + its first argument, followed by `args` and `kwargs`. + args (list): positional args to pass to `func` + kwargs (dict): named args to pass to `func` + + Returns: + Deferred: The result of func + """ + parent_context = LoggingContext.current_context() + if parent_context == LoggingContext.sentinel: + logger.warning( + "Starting db connection from sentinel context: metrics will be lost" + ) + parent_context = None + + start_time = monotonic_time() + + def inner_func(conn, *args, **kwargs): + with LoggingContext("runWithConnection", parent_context) as context: + sched_duration_sec = monotonic_time() - start_time + sql_scheduling_timer.observe(sched_duration_sec) + context.add_database_scheduled(sched_duration_sec) + + if self.database_engine.is_connection_closed(conn): + logger.debug("Reconnecting closed database connection") + conn.reconnect() + + return func(conn, *args, **kwargs) + + result = yield make_deferred_yieldable( + self._db_pool.runWithConnection(inner_func, *args, **kwargs) + ) + + return result + + @staticmethod + def cursor_to_dict(cursor): + """Converts a SQL cursor into an list of dicts. + + Args: + cursor : The DBAPI cursor which has executed a query. + Returns: + A list of dicts where the key is the column header. + """ + col_headers = list(intern(str(column[0])) for column in cursor.description) + results = list(dict(zip(col_headers, row)) for row in cursor) + return results + + def execute(self, desc, decoder, query, *args): + """Runs a single query for a result set. + + Args: + decoder - The function which can resolve the cursor results to + something meaningful. + query - The query string to execute + *args - Query args. + Returns: + The result of decoder(results) + """ + + def interaction(txn): + txn.execute(query, args) + if decoder: + return decoder(txn) + else: + return txn.fetchall() + + return self.runInteraction(desc, interaction) + + # "Simple" SQL API methods that operate on a single table with no JOINs, + # no complex WHERE clauses, just a dict of values for columns. + + @defer.inlineCallbacks + def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"): + """Executes an INSERT query on the named table. + + Args: + table : string giving the table name + values : dict of new column names and values for them + or_ignore : bool stating whether an exception should be raised + when a conflicting row already exists. If True, False will be + returned by the function instead + desc : string giving a description of the transaction + + Returns: + bool: Whether the row was inserted or not. Only useful when + `or_ignore` is True + """ + try: + yield self.runInteraction(desc, self.simple_insert_txn, table, values) + except self.database_engine.module.IntegrityError: + # We have to do or_ignore flag at this layer, since we can't reuse + # a cursor after we receive an error from the db. + if not or_ignore: + raise + return False + return True + + @staticmethod + def simple_insert_txn(txn, table, values): + keys, vals = zip(*values.items()) + + sql = "INSERT INTO %s (%s) VALUES(%s)" % ( + table, + ", ".join(k for k in keys), + ", ".join("?" for _ in keys), + ) + + txn.execute(sql, vals) + + def simple_insert_many(self, table, values, desc): + return self.runInteraction(desc, self.simple_insert_many_txn, table, values) + + @staticmethod + def simple_insert_many_txn(txn, table, values): + if not values: + return + + # This is a *slight* abomination to get a list of tuples of key names + # and a list of tuples of value names. + # + # i.e. [{"a": 1, "b": 2}, {"c": 3, "d": 4}] + # => [("a", "b",), ("c", "d",)] and [(1, 2,), (3, 4,)] + # + # The sort is to ensure that we don't rely on dictionary iteration + # order. + keys, vals = zip( + *[zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i] + ) + + for k in keys: + if k != keys[0]: + raise RuntimeError("All items must have the same keys") + + sql = "INSERT INTO %s (%s) VALUES(%s)" % ( + table, + ", ".join(k for k in keys[0]), + ", ".join("?" for _ in keys[0]), + ) + + txn.executemany(sql, vals) + + @defer.inlineCallbacks + def simple_upsert( + self, + table, + keyvalues, + values, + insertion_values={}, + desc="simple_upsert", + lock=True, + ): + """ + + `lock` should generally be set to True (the default), but can be set + to False if either of the following are true: + + * there is a UNIQUE INDEX on the key columns. In this case a conflict + will cause an IntegrityError in which case this function will retry + the update. + + * we somehow know that we are the only thread which will be updating + this table. + + Args: + table (str): The table to upsert into + keyvalues (dict): The unique key columns and their new values + values (dict): The nonunique columns and their new values + insertion_values (dict): additional key/values to use only when + inserting + lock (bool): True to lock the table when doing the upsert. + Returns: + Deferred(None or bool): Native upserts always return None. Emulated + upserts return True if a new entry was created, False if an existing + one was updated. + """ + attempts = 0 + while True: + try: + result = yield self.runInteraction( + desc, + self.simple_upsert_txn, + table, + keyvalues, + values, + insertion_values, + lock=lock, + ) + return result + except self.database_engine.module.IntegrityError as e: + attempts += 1 + if attempts >= 5: + # don't retry forever, because things other than races + # can cause IntegrityErrors + raise + + # presumably we raced with another transaction: let's retry. + logger.warning( + "IntegrityError when upserting into %s; retrying: %s", table, e + ) + + def simple_upsert_txn( + self, txn, table, keyvalues, values, insertion_values={}, lock=True + ): + """ + Pick the UPSERT method which works best on the platform. Either the + native one (Pg9.5+, recent SQLites), or fall back to an emulated method. + + Args: + txn: The transaction to use. + table (str): The table to upsert into + keyvalues (dict): The unique key tables and their new values + values (dict): The nonunique columns and their new values + insertion_values (dict): additional key/values to use only when + inserting + lock (bool): True to lock the table when doing the upsert. + Returns: + None or bool: Native upserts always return None. Emulated + upserts return True if a new entry was created, False if an existing + one was updated. + """ + if ( + self.database_engine.can_native_upsert + and table not in self._unsafe_to_upsert_tables + ): + return self.simple_upsert_txn_native_upsert( + txn, table, keyvalues, values, insertion_values=insertion_values + ) + else: + return self.simple_upsert_txn_emulated( + txn, + table, + keyvalues, + values, + insertion_values=insertion_values, + lock=lock, + ) + + def simple_upsert_txn_emulated( + self, txn, table, keyvalues, values, insertion_values={}, lock=True + ): + """ + Args: + table (str): The table to upsert into + keyvalues (dict): The unique key tables and their new values + values (dict): The nonunique columns and their new values + insertion_values (dict): additional key/values to use only when + inserting + lock (bool): True to lock the table when doing the upsert. + Returns: + bool: Return True if a new entry was created, False if an existing + one was updated. + """ + # We need to lock the table :(, unless we're *really* careful + if lock: + self.database_engine.lock_table(txn, table) + + def _getwhere(key): + # If the value we're passing in is None (aka NULL), we need to use + # IS, not =, as NULL = NULL equals NULL (False). + if keyvalues[key] is None: + return "%s IS ?" % (key,) + else: + return "%s = ?" % (key,) + + if not values: + # If `values` is empty, then all of the values we care about are in + # the unique key, so there is nothing to UPDATE. We can just do a + # SELECT instead to see if it exists. + sql = "SELECT 1 FROM %s WHERE %s" % ( + table, + " AND ".join(_getwhere(k) for k in keyvalues), + ) + sqlargs = list(keyvalues.values()) + txn.execute(sql, sqlargs) + if txn.fetchall(): + # We have an existing record. + return False + else: + # First try to update. + sql = "UPDATE %s SET %s WHERE %s" % ( + table, + ", ".join("%s = ?" % (k,) for k in values), + " AND ".join(_getwhere(k) for k in keyvalues), + ) + sqlargs = list(values.values()) + list(keyvalues.values()) + + txn.execute(sql, sqlargs) + if txn.rowcount > 0: + # successfully updated at least one row. + return False + + # We didn't find any existing rows, so insert a new one + allvalues = {} + allvalues.update(keyvalues) + allvalues.update(values) + allvalues.update(insertion_values) + + sql = "INSERT INTO %s (%s) VALUES (%s)" % ( + table, + ", ".join(k for k in allvalues), + ", ".join("?" for _ in allvalues), + ) + txn.execute(sql, list(allvalues.values())) + # successfully inserted + return True + + def simple_upsert_txn_native_upsert( + self, txn, table, keyvalues, values, insertion_values={} + ): + """ + Use the native UPSERT functionality in recent PostgreSQL versions. + + Args: + table (str): The table to upsert into + keyvalues (dict): The unique key tables and their new values + values (dict): The nonunique columns and their new values + insertion_values (dict): additional key/values to use only when + inserting + Returns: + None + """ + allvalues = {} + allvalues.update(keyvalues) + allvalues.update(insertion_values) + + if not values: + latter = "NOTHING" + else: + allvalues.update(values) + latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values) + + sql = ("INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s") % ( + table, + ", ".join(k for k in allvalues), + ", ".join("?" for _ in allvalues), + ", ".join(k for k in keyvalues), + latter, + ) + txn.execute(sql, list(allvalues.values())) + + def simple_upsert_many_txn( + self, txn, table, key_names, key_values, value_names, value_values + ): + """ + Upsert, many times. + + Args: + table (str): The table to upsert into + key_names (list[str]): The key column names. + key_values (list[list]): A list of each row's key column values. + value_names (list[str]): The value column names. If empty, no + values will be used, even if value_values is provided. + value_values (list[list]): A list of each row's value column values. + Returns: + None + """ + if ( + self.database_engine.can_native_upsert + and table not in self._unsafe_to_upsert_tables + ): + return self.simple_upsert_many_txn_native_upsert( + txn, table, key_names, key_values, value_names, value_values + ) + else: + return self.simple_upsert_many_txn_emulated( + txn, table, key_names, key_values, value_names, value_values + ) + + def simple_upsert_many_txn_emulated( + self, txn, table, key_names, key_values, value_names, value_values + ): + """ + Upsert, many times, but without native UPSERT support or batching. + + Args: + table (str): The table to upsert into + key_names (list[str]): The key column names. + key_values (list[list]): A list of each row's key column values. + value_names (list[str]): The value column names. If empty, no + values will be used, even if value_values is provided. + value_values (list[list]): A list of each row's value column values. + Returns: + None + """ + # No value columns, therefore make a blank list so that the following + # zip() works correctly. + if not value_names: + value_values = [() for x in range(len(key_values))] + + for keyv, valv in zip(key_values, value_values): + _keys = {x: y for x, y in zip(key_names, keyv)} + _vals = {x: y for x, y in zip(value_names, valv)} + + self.simple_upsert_txn_emulated(txn, table, _keys, _vals) + + def simple_upsert_many_txn_native_upsert( + self, txn, table, key_names, key_values, value_names, value_values + ): + """ + Upsert, many times, using batching where possible. + + Args: + table (str): The table to upsert into + key_names (list[str]): The key column names. + key_values (list[list]): A list of each row's key column values. + value_names (list[str]): The value column names. If empty, no + values will be used, even if value_values is provided. + value_values (list[list]): A list of each row's value column values. + Returns: + None + """ + allnames = [] + allnames.extend(key_names) + allnames.extend(value_names) + + if not value_names: + # No value columns, therefore make a blank list so that the + # following zip() works correctly. + latter = "NOTHING" + value_values = [() for x in range(len(key_values))] + else: + latter = "UPDATE SET " + ", ".join( + k + "=EXCLUDED." + k for k in value_names + ) + + sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s" % ( + table, + ", ".join(k for k in allnames), + ", ".join("?" for _ in allnames), + ", ".join(key_names), + latter, + ) + + args = [] + + for x, y in zip(key_values, value_values): + args.append(tuple(x) + tuple(y)) + + return txn.execute_batch(sql, args) + + def simple_select_one( + self, table, keyvalues, retcols, allow_none=False, desc="simple_select_one" + ): + """Executes a SELECT query on the named table, which is expected to + return a single row, returning multiple columns from it. + + Args: + table : string giving the table name + keyvalues : dict of column names and values to select the row with + retcols : list of strings giving the names of the columns to return + + allow_none : If true, return None instead of failing if the SELECT + statement returns no rows + """ + return self.runInteraction( + desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none + ) + + def simple_select_one_onecol( + self, + table, + keyvalues, + retcol, + allow_none=False, + desc="simple_select_one_onecol", + ): + """Executes a SELECT query on the named table, which is expected to + return a single row, returning a single column from it. + + Args: + table : string giving the table name + keyvalues : dict of column names and values to select the row with + retcol : string giving the name of the column to return + """ + return self.runInteraction( + desc, + self.simple_select_one_onecol_txn, + table, + keyvalues, + retcol, + allow_none=allow_none, + ) + + @classmethod + def simple_select_one_onecol_txn( + cls, txn, table, keyvalues, retcol, allow_none=False + ): + ret = cls.simple_select_onecol_txn( + txn, table=table, keyvalues=keyvalues, retcol=retcol + ) + + if ret: + return ret[0] + else: + if allow_none: + return None + else: + raise StoreError(404, "No row found") + + @staticmethod + def simple_select_onecol_txn(txn, table, keyvalues, retcol): + sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table} + + if keyvalues: + sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)) + txn.execute(sql, list(keyvalues.values())) + else: + txn.execute(sql) + + return [r[0] for r in txn] + + def simple_select_onecol( + self, table, keyvalues, retcol, desc="simple_select_onecol" + ): + """Executes a SELECT query on the named table, which returns a list + comprising of the values of the named column from the selected rows. + + Args: + table (str): table name + keyvalues (dict|None): column names and values to select the rows with + retcol (str): column whos value we wish to retrieve. + + Returns: + Deferred: Results in a list + """ + return self.runInteraction( + desc, self.simple_select_onecol_txn, table, keyvalues, retcol + ) + + def simple_select_list(self, table, keyvalues, retcols, desc="simple_select_list"): + """Executes a SELECT query on the named table, which may return zero or + more rows, returning the result as a list of dicts. + + Args: + table (str): the table name + keyvalues (dict[str, Any] | None): + column names and values to select the rows with, or None to not + apply a WHERE clause. + retcols (iterable[str]): the names of the columns to return + Returns: + defer.Deferred: resolves to list[dict[str, Any]] + """ + return self.runInteraction( + desc, self.simple_select_list_txn, table, keyvalues, retcols + ) + + @classmethod + def simple_select_list_txn(cls, txn, table, keyvalues, retcols): + """Executes a SELECT query on the named table, which may return zero or + more rows, returning the result as a list of dicts. + + Args: + txn : Transaction object + table (str): the table name + keyvalues (dict[str, T] | None): + column names and values to select the rows with, or None to not + apply a WHERE clause. + retcols (iterable[str]): the names of the columns to return + """ + if keyvalues: + sql = "SELECT %s FROM %s WHERE %s" % ( + ", ".join(retcols), + table, + " AND ".join("%s = ?" % (k,) for k in keyvalues), + ) + txn.execute(sql, list(keyvalues.values())) + else: + sql = "SELECT %s FROM %s" % (", ".join(retcols), table) + txn.execute(sql) + + return cls.cursor_to_dict(txn) + + @defer.inlineCallbacks + def simple_select_many_batch( + self, + table, + column, + iterable, + retcols, + keyvalues={}, + desc="simple_select_many_batch", + batch_size=100, + ): + """Executes a SELECT query on the named table, which may return zero or + more rows, returning the result as a list of dicts. + + Filters rows by if value of `column` is in `iterable`. + + Args: + table : string giving the table name + column : column name to test for inclusion against `iterable` + iterable : list + keyvalues : dict of column names and values to select the rows with + retcols : list of strings giving the names of the columns to return + """ + results = [] + + if not iterable: + return results + + # iterables can not be sliced, so convert it to a list first + it_list = list(iterable) + + chunks = [ + it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size) + ] + for chunk in chunks: + rows = yield self.runInteraction( + desc, + self.simple_select_many_txn, + table, + column, + chunk, + keyvalues, + retcols, + ) + + results.extend(rows) + + return results + + @classmethod + def simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols): + """Executes a SELECT query on the named table, which may return zero or + more rows, returning the result as a list of dicts. + + Filters rows by if value of `column` is in `iterable`. + + Args: + txn : Transaction object + table : string giving the table name + column : column name to test for inclusion against `iterable` + iterable : list + keyvalues : dict of column names and values to select the rows with + retcols : list of strings giving the names of the columns to return + """ + if not iterable: + return [] + + clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable) + clauses = [clause] + + for key, value in iteritems(keyvalues): + clauses.append("%s = ?" % (key,)) + values.append(value) + + sql = "SELECT %s FROM %s WHERE %s" % ( + ", ".join(retcols), + table, + " AND ".join(clauses), + ) + + txn.execute(sql, values) + return cls.cursor_to_dict(txn) + + def simple_update(self, table, keyvalues, updatevalues, desc): + return self.runInteraction( + desc, self.simple_update_txn, table, keyvalues, updatevalues + ) + + @staticmethod + def simple_update_txn(txn, table, keyvalues, updatevalues): + if keyvalues: + where = "WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)) + else: + where = "" + + update_sql = "UPDATE %s SET %s %s" % ( + table, + ", ".join("%s = ?" % (k,) for k in updatevalues), + where, + ) + + txn.execute(update_sql, list(updatevalues.values()) + list(keyvalues.values())) + + return txn.rowcount + + def simple_update_one( + self, table, keyvalues, updatevalues, desc="simple_update_one" + ): + """Executes an UPDATE query on the named table, setting new values for + columns in a row matching the key values. + + Args: + table : string giving the table name + keyvalues : dict of column names and values to select the row with + updatevalues : dict giving column names and values to update + retcols : optional list of column names to return + + If present, retcols gives a list of column names on which to perform + a SELECT statement *before* performing the UPDATE statement. The values + of these will be returned in a dict. + + These are performed within the same transaction, allowing an atomic + get-and-set. This can be used to implement compare-and-set by putting + the update column in the 'keyvalues' dict as well. + """ + return self.runInteraction( + desc, self.simple_update_one_txn, table, keyvalues, updatevalues + ) + + @classmethod + def simple_update_one_txn(cls, txn, table, keyvalues, updatevalues): + rowcount = cls.simple_update_txn(txn, table, keyvalues, updatevalues) + + if rowcount == 0: + raise StoreError(404, "No row found (%s)" % (table,)) + if rowcount > 1: + raise StoreError(500, "More than one row matched (%s)" % (table,)) + + @staticmethod + def simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False): + select_sql = "SELECT %s FROM %s WHERE %s" % ( + ", ".join(retcols), + table, + " AND ".join("%s = ?" % (k,) for k in keyvalues), + ) + + txn.execute(select_sql, list(keyvalues.values())) + row = txn.fetchone() + + if not row: + if allow_none: + return None + raise StoreError(404, "No row found (%s)" % (table,)) + if txn.rowcount > 1: + raise StoreError(500, "More than one row matched (%s)" % (table,)) + + return dict(zip(retcols, row)) + + def simple_delete_one(self, table, keyvalues, desc="simple_delete_one"): + """Executes a DELETE query on the named table, expecting to delete a + single row. + + Args: + table : string giving the table name + keyvalues : dict of column names and values to select the row with + """ + return self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues) + + @staticmethod + def simple_delete_one_txn(txn, table, keyvalues): + """Executes a DELETE query on the named table, expecting to delete a + single row. + + Args: + table : string giving the table name + keyvalues : dict of column names and values to select the row with + """ + sql = "DELETE FROM %s WHERE %s" % ( + table, + " AND ".join("%s = ?" % (k,) for k in keyvalues), + ) + + txn.execute(sql, list(keyvalues.values())) + if txn.rowcount == 0: + raise StoreError(404, "No row found (%s)" % (table,)) + if txn.rowcount > 1: + raise StoreError(500, "More than one row matched (%s)" % (table,)) + + def simple_delete(self, table, keyvalues, desc): + return self.runInteraction(desc, self.simple_delete_txn, table, keyvalues) + + @staticmethod + def simple_delete_txn(txn, table, keyvalues): + sql = "DELETE FROM %s WHERE %s" % ( + table, + " AND ".join("%s = ?" % (k,) for k in keyvalues), + ) + + txn.execute(sql, list(keyvalues.values())) + return txn.rowcount + + def simple_delete_many(self, table, column, iterable, keyvalues, desc): + return self.runInteraction( + desc, self.simple_delete_many_txn, table, column, iterable, keyvalues + ) + + @staticmethod + def simple_delete_many_txn(txn, table, column, iterable, keyvalues): + """Executes a DELETE query on the named table. + + Filters rows by if value of `column` is in `iterable`. + + Args: + txn : Transaction object + table : string giving the table name + column : column name to test for inclusion against `iterable` + iterable : list + keyvalues : dict of column names and values to select the rows with + + Returns: + int: Number rows deleted + """ + if not iterable: + return 0 + + sql = "DELETE FROM %s" % table + + clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable) + clauses = [clause] + + for key, value in iteritems(keyvalues): + clauses.append("%s = ?" % (key,)) + values.append(value) + + if clauses: + sql = "%s WHERE %s" % (sql, " AND ".join(clauses)) + txn.execute(sql, values) + + return txn.rowcount + + def get_cache_dict( + self, db_conn, table, entity_column, stream_column, max_value, limit=100000 + ): + # Fetch a mapping of room_id -> max stream position for "recent" rooms. + # It doesn't really matter how many we get, the StreamChangeCache will + # do the right thing to ensure it respects the max size of cache. + sql = ( + "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s" + " WHERE %(stream)s > ? - %(limit)s" + " GROUP BY %(entity)s" + ) % { + "table": table, + "entity": entity_column, + "stream": stream_column, + "limit": limit, + } + + sql = self.database_engine.convert_param_style(sql) + + txn = db_conn.cursor() + txn.execute(sql, (int(max_value),)) + + cache = {row[0]: int(row[1]) for row in txn} + + txn.close() + + if cache: + min_val = min(itervalues(cache)) + else: + min_val = max_value + + return cache, min_val + + def simple_select_list_paginate( + self, + table, + keyvalues, + orderby, + start, + limit, + retcols, + order_direction="ASC", + desc="simple_select_list_paginate", + ): + """ + Executes a SELECT query on the named table with start and limit, + of row numbers, which may return zero or number of rows from start to limit, + returning the result as a list of dicts. + + Args: + table (str): the table name + keyvalues (dict[str, T] | None): + column names and values to select the rows with, or None to not + apply a WHERE clause. + orderby (str): Column to order the results by. + start (int): Index to begin the query at. + limit (int): Number of results to return. + retcols (iterable[str]): the names of the columns to return + order_direction (str): Whether the results should be ordered "ASC" or "DESC". + Returns: + defer.Deferred: resolves to list[dict[str, Any]] + """ + return self.runInteraction( + desc, + self.simple_select_list_paginate_txn, + table, + keyvalues, + orderby, + start, + limit, + retcols, + order_direction=order_direction, + ) + + @classmethod + def simple_select_list_paginate_txn( + cls, + txn, + table, + keyvalues, + orderby, + start, + limit, + retcols, + order_direction="ASC", + ): + """ + Executes a SELECT query on the named table with start and limit, + of row numbers, which may return zero or number of rows from start to limit, + returning the result as a list of dicts. + + Args: + txn : Transaction object + table (str): the table name + keyvalues (dict[str, T] | None): + column names and values to select the rows with, or None to not + apply a WHERE clause. + orderby (str): Column to order the results by. + start (int): Index to begin the query at. + limit (int): Number of results to return. + retcols (iterable[str]): the names of the columns to return + order_direction (str): Whether the results should be ordered "ASC" or "DESC". + Returns: + defer.Deferred: resolves to list[dict[str, Any]] + """ + if order_direction not in ["ASC", "DESC"]: + raise ValueError("order_direction must be one of 'ASC' or 'DESC'.") + + if keyvalues: + where_clause = "WHERE " + " AND ".join("%s = ?" % (k,) for k in keyvalues) + else: + where_clause = "" + + sql = "SELECT %s FROM %s %s ORDER BY %s %s LIMIT ? OFFSET ?" % ( + ", ".join(retcols), + table, + where_clause, + orderby, + order_direction, + ) + txn.execute(sql, list(keyvalues.values()) + [limit, start]) + + return cls.cursor_to_dict(txn) + + def get_user_count_txn(self, txn): + """Get a total number of registered users in the users list. + + Args: + txn : Transaction object + Returns: + int : number of users + """ + sql_count = "SELECT COUNT(*) FROM users WHERE is_guest = 0;" + txn.execute(sql_count) + return txn.fetchone()[0] + + def simple_search_list(self, table, term, col, retcols, desc="simple_search_list"): + """Executes a SELECT query on the named table, which may return zero or + more rows, returning the result as a list of dicts. + + Args: + table (str): the table name + term (str | None): + term for searching the table matched to a column. + col (str): column to query term should be matched to + retcols (iterable[str]): the names of the columns to return + Returns: + defer.Deferred: resolves to list[dict[str, Any]] or None + """ + + return self.runInteraction( + desc, self.simple_search_list_txn, table, term, col, retcols + ) + + @classmethod + def simple_search_list_txn(cls, txn, table, term, col, retcols): + """Executes a SELECT query on the named table, which may return zero or + more rows, returning the result as a list of dicts. + + Args: + txn : Transaction object + table (str): the table name + term (str | None): + term for searching the table matched to a column. + col (str): column to query term should be matched to + retcols (iterable[str]): the names of the columns to return + Returns: + defer.Deferred: resolves to list[dict[str, Any]] or None + """ + if term: + sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col) + termvalues = ["%%" + term + "%%"] + txn.execute(sql, termvalues) + else: + return 0 + + return cls.cursor_to_dict(txn) + + +def make_in_list_sql_clause( + database_engine, column: str, iterable: Iterable +) -> Tuple[str, Iterable]: + """Returns an SQL clause that checks the given column is in the iterable. + + On SQLite this expands to `column IN (?, ?, ...)`, whereas on Postgres + it expands to `column = ANY(?)`. While both DBs support the `IN` form, + using the `ANY` form on postgres means that it views queries with + different length iterables as the same, helping the query stats. + + Args: + database_engine + column: Name of the column + iterable: The values to check the column against. + + Returns: + A tuple of SQL query and the args + """ + + if database_engine.supports_using_any_list: + # This should hopefully be faster, but also makes postgres query + # stats easier to understand. + return "%s = ANY(?)" % (column,), [list(iterable)] + else: + return "%s IN (%s)" % (column, ",".join("?" for _ in iterable)), list(iterable) -- cgit 1.4.1 From 8863624f7852ffc4a261aa9d17f6f7ddb5bf0c19 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 4 Dec 2019 14:00:29 +0000 Subject: Comments --- synapse/storage/__init__.py | 8 ++++---- synapse/storage/_base.py | 8 +++++++- synapse/storage/database.py | 5 +++++ 3 files changed, 16 insertions(+), 5 deletions(-) (limited to 'synapse/storage/database.py') diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 0460fe8cc9..8fb18203dc 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -17,10 +17,10 @@ """ The storage layer is split up into multiple parts to allow Synapse to run against different configurations of databases (e.g. single or multiple -databases). The `data_stores` are classes that talk directly to a single -database and have associated schemas, background updates, etc. On top of those -there are (or will be) classes that provide high level interfaces that combine -calls to multiple `data_stores`. +databases). The `Database` class represents a single physical database. The +`data_stores` are classes that talk directly to a `Database` instance and have +associated schemas, background updates, etc. On top of those there are classes +that provide high level interfaces that combine calls to multiple `data_stores`. There are also schemas that get applied to every database, regardless of the data stores associated with them (e.g. the schema version tables), which are diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index fd5bb3e1de..b7e27d4e97 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -31,11 +31,17 @@ logger = logging.getLogger(__name__) class SQLBaseStore(object): + """Base class for data stores that holds helper functions. + + Note that multiple instances of this class will exist as there will be one + per data store (and not one per physical database). + """ + def __init__(self, db_conn, hs): self.hs = hs self._clock = hs.get_clock() self.database_engine = hs.database_engine - self.db = Database(hs) + self.db = Database(hs) # In future this will be passed in self.rand = random.SystemRandom() def _invalidate_state_caches(self, room_id, members_changed): diff --git a/synapse/storage/database.py b/synapse/storage/database.py index c2e121a001..ac64d80806 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -211,6 +211,11 @@ class PerformanceCounters(object): class Database(object): + """Wraps a single physical database and connection pool. + + A single database may be used by multiple data stores. + """ + _TXN_ID = 0 def __init__(self, hs): -- cgit 1.4.1 From 4a33a6dd19590b8e6626a5af5a69507dc11236f8 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 4 Dec 2019 15:09:36 +0000 Subject: Move background update handling out of store --- synapse/app/homeserver.py | 2 +- synapse/rest/media/v1/preview_url_resource.py | 2 +- synapse/storage/background_updates.py | 15 +++---- synapse/storage/data_stores/main/client_ips.py | 36 ++++++++-------- synapse/storage/data_stores/main/deviceinbox.py | 9 ++-- synapse/storage/data_stores/main/devices.py | 15 ++++--- .../storage/data_stores/main/event_federation.py | 6 +-- .../storage/data_stores/main/event_push_actions.py | 4 +- synapse/storage/data_stores/main/events.py | 6 +-- .../storage/data_stores/main/events_bg_updates.py | 49 +++++++++++---------- .../storage/data_stores/main/media_repository.py | 6 +-- synapse/storage/data_stores/main/registration.py | 21 ++++----- synapse/storage/data_stores/main/room.py | 9 ++-- synapse/storage/data_stores/main/roommember.py | 27 +++++++----- synapse/storage/data_stores/main/search.py | 31 ++++++++------ synapse/storage/data_stores/main/state.py | 19 ++++---- synapse/storage/data_stores/main/stats.py | 20 ++++----- synapse/storage/data_stores/main/user_directory.py | 33 ++++++++------ synapse/storage/database.py | 3 ++ synmark/__init__.py | 6 +-- tests/handlers/test_stats.py | 50 +++++++++++++++------- tests/handlers/test_user_directory.py | 18 +++++--- tests/storage/test_background_update.py | 26 +++++++---- tests/storage/test_cleanup_extrems.py | 14 ++++-- tests/storage/test_client_ips.py | 26 ++++++++--- tests/storage/test_roommember.py | 18 +++++--- tests/unittest.py | 10 +++-- 27 files changed, 281 insertions(+), 200 deletions(-) (limited to 'synapse/storage/database.py') diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 267aebaae9..9f81a857ab 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -436,7 +436,7 @@ def setup(config_options): _base.start(hs, config.listeners) hs.get_pusherpool().start() - hs.get_datastore().start_doing_background_updates() + hs.get_datastore().db.updates.start_doing_background_updates() except Exception: # Print the exception and bail out. print("Error during startup:", file=sys.stderr) diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index fb0d02aa83..6b978be876 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -402,7 +402,7 @@ class PreviewUrlResource(DirectServeResource): logger.info("Running url preview cache expiry") - if not (yield self.store.has_completed_background_updates()): + if not (yield self.store.db.updates.has_completed_background_updates()): logger.info("Still running DB updates; skipping expiry") return diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index dfca94b0e0..a9a13a2658 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -22,7 +22,6 @@ from twisted.internet import defer from synapse.metrics.background_process_metrics import run_as_background_process from . import engines -from ._base import SQLBaseStore logger = logging.getLogger(__name__) @@ -74,7 +73,7 @@ class BackgroundUpdatePerformance(object): return float(self.total_item_count) / float(self.total_duration_ms) -class BackgroundUpdateStore(SQLBaseStore): +class BackgroundUpdater(object): """ Background updates are updates to the database that run in the background. Each update processes a batch of data at once. We attempt to limit the impact of each update by monitoring how long each batch takes to @@ -86,8 +85,10 @@ class BackgroundUpdateStore(SQLBaseStore): BACKGROUND_UPDATE_INTERVAL_MS = 1000 BACKGROUND_UPDATE_DURATION_MS = 100 - def __init__(self, db_conn, hs): - super(BackgroundUpdateStore, self).__init__(db_conn, hs) + def __init__(self, hs, database): + self._clock = hs.get_clock() + self.db = database + self._background_update_performance = {} self._background_update_queue = [] self._background_update_handlers = {} @@ -101,9 +102,7 @@ class BackgroundUpdateStore(SQLBaseStore): logger.info("Starting background schema updates") while True: if sleep: - yield self.hs.get_clock().sleep( - self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0 - ) + yield self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0) try: result = yield self.do_next_background_update( @@ -380,7 +379,7 @@ class BackgroundUpdateStore(SQLBaseStore): logger.debug("[SQL] %s", sql) c.execute(sql) - if isinstance(self.database_engine, engines.PostgresEngine): + if isinstance(self.db.database_engine, engines.PostgresEngine): runner = create_index_psql elif psql_only: runner = None diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py index 6f2a720b97..7b470a58f1 100644 --- a/synapse/storage/data_stores/main/client_ips.py +++ b/synapse/storage/data_stores/main/client_ips.py @@ -20,7 +20,7 @@ from six import iteritems from twisted.internet import defer from synapse.metrics.background_process_metrics import wrap_as_background_process -from synapse.storage import background_updates +from synapse.storage._base import SQLBaseStore from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.caches.descriptors import Cache @@ -32,41 +32,41 @@ logger = logging.getLogger(__name__) LAST_SEEN_GRANULARITY = 120 * 1000 -class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore): +class ClientIpBackgroundUpdateStore(SQLBaseStore): def __init__(self, db_conn, hs): super(ClientIpBackgroundUpdateStore, self).__init__(db_conn, hs) - self.register_background_index_update( + self.db.updates.register_background_index_update( "user_ips_device_index", index_name="user_ips_device_id", table="user_ips", columns=["user_id", "device_id", "last_seen"], ) - self.register_background_index_update( + self.db.updates.register_background_index_update( "user_ips_last_seen_index", index_name="user_ips_last_seen", table="user_ips", columns=["user_id", "last_seen"], ) - self.register_background_index_update( + self.db.updates.register_background_index_update( "user_ips_last_seen_only_index", index_name="user_ips_last_seen_only", table="user_ips", columns=["last_seen"], ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "user_ips_analyze", self._analyze_user_ip ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "user_ips_remove_dupes", self._remove_user_ip_dupes ) # Register a unique index - self.register_background_index_update( + self.db.updates.register_background_index_update( "user_ips_device_unique_index", index_name="user_ips_user_token_ip_unique_index", table="user_ips", @@ -75,12 +75,12 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore): ) # Drop the old non-unique index - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "user_ips_drop_nonunique_index", self._remove_user_ip_nonunique ) # Update the last seen info in devices. - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "devices_last_seen", self._devices_last_seen_update ) @@ -92,7 +92,7 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore): txn.close() yield self.db.runWithConnection(f) - yield self._end_background_update("user_ips_drop_nonunique_index") + yield self.db.updates._end_background_update("user_ips_drop_nonunique_index") return 1 @defer.inlineCallbacks @@ -108,7 +108,7 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore): yield self.db.runInteraction("user_ips_analyze", user_ips_analyze) - yield self._end_background_update("user_ips_analyze") + yield self.db.updates._end_background_update("user_ips_analyze") return 1 @@ -271,14 +271,14 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore): (user_id, access_token, ip, device_id, user_agent, last_seen), ) - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, "user_ips_remove_dupes", {"last_seen": end_last_seen} ) yield self.db.runInteraction("user_ips_dups_remove", remove) if last: - yield self._end_background_update("user_ips_remove_dupes") + yield self.db.updates._end_background_update("user_ips_remove_dupes") return batch_size @@ -344,7 +344,7 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore): txn.execute_batch(sql, rows) _, _, _, user_id, device_id = rows[-1] - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, "devices_last_seen", {"last_user_id": user_id, "last_device_id": device_id}, @@ -357,7 +357,7 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore): ) if not updated: - yield self._end_background_update("devices_last_seen") + yield self.db.updates._end_background_update("devices_last_seen") return updated @@ -546,7 +546,9 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): # Nothing to do return - if not await self.has_completed_background_update("devices_last_seen"): + if not await self.db.updates.has_completed_background_update( + "devices_last_seen" + ): # Only start pruning if we have finished populating the devices # last seen info. return diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py index 440793ad49..3c9f09301a 100644 --- a/synapse/storage/data_stores/main/deviceinbox.py +++ b/synapse/storage/data_stores/main/deviceinbox.py @@ -21,7 +21,6 @@ from twisted.internet import defer from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause -from synapse.storage.background_updates import BackgroundUpdateStore from synapse.util.caches.expiringcache import ExpiringCache logger = logging.getLogger(__name__) @@ -208,20 +207,20 @@ class DeviceInboxWorkerStore(SQLBaseStore): ) -class DeviceInboxBackgroundUpdateStore(BackgroundUpdateStore): +class DeviceInboxBackgroundUpdateStore(SQLBaseStore): DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" def __init__(self, db_conn, hs): super(DeviceInboxBackgroundUpdateStore, self).__init__(db_conn, hs) - self.register_background_index_update( + self.db.updates.register_background_index_update( "device_inbox_stream_index", index_name="device_inbox_stream_id_user_id", table="device_inbox", columns=["stream_id", "user_id"], ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox ) @@ -234,7 +233,7 @@ class DeviceInboxBackgroundUpdateStore(BackgroundUpdateStore): yield self.db.runWithConnection(reindex_txn) - yield self._end_background_update(self.DEVICE_INBOX_STREAM_ID) + yield self.db.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID) return 1 diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index d98511ddd4..91ddaf137e 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -31,7 +31,6 @@ from synapse.logging.opentracing import ( ) from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.background_updates import BackgroundUpdateStore from synapse.types import get_verify_key_from_cross_signing_key from synapse.util import batch_iter from synapse.util.caches.descriptors import ( @@ -642,11 +641,11 @@ class DeviceWorkerStore(SQLBaseStore): return results -class DeviceBackgroundUpdateStore(BackgroundUpdateStore): +class DeviceBackgroundUpdateStore(SQLBaseStore): def __init__(self, db_conn, hs): super(DeviceBackgroundUpdateStore, self).__init__(db_conn, hs) - self.register_background_index_update( + self.db.updates.register_background_index_update( "device_lists_stream_idx", index_name="device_lists_stream_user_id", table="device_lists_stream", @@ -654,7 +653,7 @@ class DeviceBackgroundUpdateStore(BackgroundUpdateStore): ) # create a unique index on device_lists_remote_cache - self.register_background_index_update( + self.db.updates.register_background_index_update( "device_lists_remote_cache_unique_idx", index_name="device_lists_remote_cache_unique_id", table="device_lists_remote_cache", @@ -663,7 +662,7 @@ class DeviceBackgroundUpdateStore(BackgroundUpdateStore): ) # And one on device_lists_remote_extremeties - self.register_background_index_update( + self.db.updates.register_background_index_update( "device_lists_remote_extremeties_unique_idx", index_name="device_lists_remote_extremeties_unique_idx", table="device_lists_remote_extremeties", @@ -672,7 +671,7 @@ class DeviceBackgroundUpdateStore(BackgroundUpdateStore): ) # once they complete, we can remove the old non-unique indexes. - self.register_background_update_handler( + self.db.updates.register_background_update_handler( DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES, self._drop_device_list_streams_non_unique_indexes, ) @@ -686,7 +685,9 @@ class DeviceBackgroundUpdateStore(BackgroundUpdateStore): txn.close() yield self.db.runWithConnection(f) - yield self._end_background_update(DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES) + yield self.db.updates._end_background_update( + DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES + ) return 1 diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py index 77e4353b59..31d2e8eb28 100644 --- a/synapse/storage/data_stores/main/event_federation.py +++ b/synapse/storage/data_stores/main/event_federation.py @@ -494,7 +494,7 @@ class EventFederationStore(EventFederationWorkerStore): def __init__(self, db_conn, hs): super(EventFederationStore, self).__init__(db_conn, hs) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth ) @@ -654,7 +654,7 @@ class EventFederationStore(EventFederationWorkerStore): "max_stream_id_exclusive": min_stream_id, } - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, self.EVENT_AUTH_STATE_ONLY, new_progress ) @@ -665,6 +665,6 @@ class EventFederationStore(EventFederationWorkerStore): ) if not result: - yield self._end_background_update(self.EVENT_AUTH_STATE_ONLY) + yield self.db.updates._end_background_update(self.EVENT_AUTH_STATE_ONLY) return batch_size diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py index 725d0881dc..eec054cd48 100644 --- a/synapse/storage/data_stores/main/event_push_actions.py +++ b/synapse/storage/data_stores/main/event_push_actions.py @@ -614,14 +614,14 @@ class EventPushActionsStore(EventPushActionsWorkerStore): def __init__(self, db_conn, hs): super(EventPushActionsStore, self).__init__(db_conn, hs) - self.register_background_index_update( + self.db.updates.register_background_index_update( self.EPA_HIGHLIGHT_INDEX, index_name="event_push_actions_u_highlight", table="event_push_actions", columns=["user_id", "stream_ordering"], ) - self.register_background_index_update( + self.db.updates.register_background_index_update( "event_push_actions_highlights_index", index_name="event_push_actions_highlights_index", table="event_push_actions", diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index 01ec9ec397..d644c82784 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -38,7 +38,6 @@ from synapse.logging.utils import log_function from synapse.metrics import BucketCollector from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import make_in_list_sql_clause -from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.data_stores.main.event_federation import EventFederationStore from synapse.storage.data_stores.main.events_worker import EventsWorkerStore from synapse.storage.data_stores.main.state import StateGroupWorkerStore @@ -94,10 +93,7 @@ def _retry_on_integrity_error(func): # inherits from EventFederationStore so that we can call _update_backward_extremities # and _handle_mult_prev_events (though arguably those could both be moved in here) class EventsStore( - StateGroupWorkerStore, - EventFederationStore, - EventsWorkerStore, - BackgroundUpdateStore, + StateGroupWorkerStore, EventFederationStore, EventsWorkerStore, ): def __init__(self, db_conn, hs): super(EventsStore, self).__init__(db_conn, hs) diff --git a/synapse/storage/data_stores/main/events_bg_updates.py b/synapse/storage/data_stores/main/events_bg_updates.py index 365e966956..cb1fc30c31 100644 --- a/synapse/storage/data_stores/main/events_bg_updates.py +++ b/synapse/storage/data_stores/main/events_bg_updates.py @@ -22,13 +22,12 @@ from canonicaljson import json from twisted.internet import defer from synapse.api.constants import EventContentFields -from synapse.storage._base import make_in_list_sql_clause -from synapse.storage.background_updates import BackgroundUpdateStore +from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause logger = logging.getLogger(__name__) -class EventsBackgroundUpdatesStore(BackgroundUpdateStore): +class EventsBackgroundUpdatesStore(SQLBaseStore): EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url" @@ -37,15 +36,15 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): def __init__(self, db_conn, hs): super(EventsBackgroundUpdatesStore, self).__init__(db_conn, hs) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, self._background_reindex_fields_sender, ) - self.register_background_index_update( + self.db.updates.register_background_index_update( "event_contains_url_index", index_name="event_contains_url_index", table="events", @@ -56,7 +55,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): # an event_id index on event_search is useful for the purge_history # api. Plus it means we get to enforce some integrity with a UNIQUE # clause - self.register_background_index_update( + self.db.updates.register_background_index_update( "event_search_event_id_idx", index_name="event_search_event_id_idx", table="event_search", @@ -65,16 +64,16 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): psql_only=True, ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( self.DELETE_SOFT_FAILED_EXTREMITIES, self._cleanup_extremities_bg_update ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "redactions_received_ts", self._redactions_received_ts ) # This index gets deleted in `event_fix_redactions_bytes` update - self.register_background_index_update( + self.db.updates.register_background_index_update( "event_fix_redactions_bytes_create_index", index_name="redactions_censored_redacts", table="redactions", @@ -82,11 +81,11 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): where_clause="have_censored", ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "event_fix_redactions_bytes", self._event_fix_redactions_bytes ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "event_store_labels", self._event_store_labels ) @@ -145,7 +144,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): "rows_inserted": rows_inserted + len(rows), } - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, progress ) @@ -156,7 +155,9 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): ) if not result: - yield self._end_background_update(self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME) + yield self.db.updates._end_background_update( + self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME + ) return result @@ -222,7 +223,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): "rows_inserted": rows_inserted + len(rows_to_update), } - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, self.EVENT_ORIGIN_SERVER_TS_NAME, progress ) @@ -233,7 +234,9 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): ) if not result: - yield self._end_background_update(self.EVENT_ORIGIN_SERVER_TS_NAME) + yield self.db.updates._end_background_update( + self.EVENT_ORIGIN_SERVER_TS_NAME + ) return result @@ -411,7 +414,9 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): ) if not num_handled: - yield self._end_background_update(self.DELETE_SOFT_FAILED_EXTREMITIES) + yield self.db.updates._end_background_update( + self.DELETE_SOFT_FAILED_EXTREMITIES + ) def _drop_table_txn(txn): txn.execute("DROP TABLE _extremities_to_check") @@ -464,7 +469,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): txn.execute(sql, (self._clock.time_msec(), last_event_id, upper_event_id)) - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, "redactions_received_ts", {"last_event_id": upper_event_id} ) @@ -475,7 +480,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): ) if not count: - yield self._end_background_update("redactions_received_ts") + yield self.db.updates._end_background_update("redactions_received_ts") return count @@ -505,7 +510,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): "_event_fix_redactions_bytes", _event_fix_redactions_bytes_txn ) - yield self._end_background_update("event_fix_redactions_bytes") + yield self.db.updates._end_background_update("event_fix_redactions_bytes") return 1 @@ -559,7 +564,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): nbrows += 1 last_row_event_id = event_id - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, "event_store_labels", {"last_event_id": last_row_event_id} ) @@ -570,6 +575,6 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): ) if not num_rows: - yield self._end_background_update("event_store_labels") + yield self.db.updates._end_background_update("event_store_labels") return num_rows diff --git a/synapse/storage/data_stores/main/media_repository.py b/synapse/storage/data_stores/main/media_repository.py index ea02497784..03c9c6f8ae 100644 --- a/synapse/storage/data_stores/main/media_repository.py +++ b/synapse/storage/data_stores/main/media_repository.py @@ -12,14 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.background_updates import BackgroundUpdateStore +from synapse.storage._base import SQLBaseStore -class MediaRepositoryBackgroundUpdateStore(BackgroundUpdateStore): +class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): def __init__(self, db_conn, hs): super(MediaRepositoryBackgroundUpdateStore, self).__init__(db_conn, hs) - self.register_background_index_update( + self.db.updates.register_background_index_update( update_name="local_media_repository_url_idx", index_name="local_media_repository_url_idx", table="local_media_repository", diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py index 8f9aa87ceb..1ef143c6d8 100644 --- a/synapse/storage/data_stores/main/registration.py +++ b/synapse/storage/data_stores/main/registration.py @@ -26,7 +26,6 @@ from twisted.internet.defer import Deferred from synapse.api.constants import UserTypes from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage import background_updates from synapse.storage._base import SQLBaseStore from synapse.types import UserID from synapse.util.caches.descriptors import cached, cachedInlineCallbacks @@ -794,23 +793,21 @@ class RegistrationWorkerStore(SQLBaseStore): ) -class RegistrationBackgroundUpdateStore( - RegistrationWorkerStore, background_updates.BackgroundUpdateStore -): +class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): def __init__(self, db_conn, hs): super(RegistrationBackgroundUpdateStore, self).__init__(db_conn, hs) self.clock = hs.get_clock() self.config = hs.config - self.register_background_index_update( + self.db.updates.register_background_index_update( "access_tokens_device_index", index_name="access_tokens_device_id", table="access_tokens", columns=["user_id", "device_id"], ) - self.register_background_index_update( + self.db.updates.register_background_index_update( "users_creation_ts", index_name="users_creation_ts", table="users", @@ -820,13 +817,13 @@ class RegistrationBackgroundUpdateStore( # we no longer use refresh tokens, but it's possible that some people # might have a background update queued to build this index. Just # clear the background update. - self.register_noop_background_update("refresh_tokens_device_index") + self.db.updates.register_noop_background_update("refresh_tokens_device_index") - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "user_threepids_grandfather", self._bg_user_threepids_grandfather ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "users_set_deactivated_flag", self._background_update_set_deactivated_flag ) @@ -873,7 +870,7 @@ class RegistrationBackgroundUpdateStore( logger.info("Marked %d rows as deactivated", rows_processed_nb) - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, "users_set_deactivated_flag", {"user_id": rows[-1]["name"]} ) @@ -887,7 +884,7 @@ class RegistrationBackgroundUpdateStore( ) if end: - yield self._end_background_update("users_set_deactivated_flag") + yield self.db.updates._end_background_update("users_set_deactivated_flag") return nb_processed @@ -917,7 +914,7 @@ class RegistrationBackgroundUpdateStore( "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn ) - yield self._end_background_update("user_threepids_grandfather") + yield self.db.updates._end_background_update("user_threepids_grandfather") return 1 diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py index a26ed47afc..da42dae243 100644 --- a/synapse/storage/data_stores/main/room.py +++ b/synapse/storage/data_stores/main/room.py @@ -28,7 +28,6 @@ from twisted.internet import defer from synapse.api.constants import EventTypes from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore -from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.data_stores.main.search import SearchStore from synapse.types import ThirdPartyInstanceID from synapse.util.caches.descriptors import cached, cachedInlineCallbacks @@ -361,13 +360,13 @@ class RoomWorkerStore(SQLBaseStore): defer.returnValue(row) -class RoomBackgroundUpdateStore(BackgroundUpdateStore): +class RoomBackgroundUpdateStore(SQLBaseStore): def __init__(self, db_conn, hs): super(RoomBackgroundUpdateStore, self).__init__(db_conn, hs) self.config = hs.config - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "insert_room_retention", self._background_insert_retention, ) @@ -421,7 +420,7 @@ class RoomBackgroundUpdateStore(BackgroundUpdateStore): logger.info("Inserted %d rows into room_retention", len(rows)) - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, "insert_room_retention", {"room_id": rows[-1]["room_id"]} ) @@ -435,7 +434,7 @@ class RoomBackgroundUpdateStore(BackgroundUpdateStore): ) if end: - yield self._end_background_update("insert_room_retention") + yield self.db.updates._end_background_update("insert_room_retention") defer.returnValue(batch_size) diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py index 7f4d02b25b..929f6b0d39 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py @@ -26,8 +26,11 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage._base import LoggingTransaction, make_in_list_sql_clause -from synapse.storage.background_updates import BackgroundUpdateStore +from synapse.storage._base import ( + LoggingTransaction, + SQLBaseStore, + make_in_list_sql_clause, +) from synapse.storage.data_stores.main.events_worker import EventsWorkerStore from synapse.storage.engines import Sqlite3Engine from synapse.storage.roommember import ( @@ -831,17 +834,17 @@ class RoomMemberWorkerStore(EventsWorkerStore): ) -class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore): +class RoomMemberBackgroundUpdateStore(SQLBaseStore): def __init__(self, db_conn, hs): super(RoomMemberBackgroundUpdateStore, self).__init__(db_conn, hs) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME, self._background_current_state_membership, ) - self.register_background_index_update( + self.db.updates.register_background_index_update( "room_membership_forgotten_idx", index_name="room_memberships_user_room_forgotten", table="room_memberships", @@ -909,7 +912,7 @@ class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore): "max_stream_id_exclusive": min_stream_id, } - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, _MEMBERSHIP_PROFILE_UPDATE_NAME, progress ) @@ -920,7 +923,9 @@ class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore): ) if not result: - yield self._end_background_update(_MEMBERSHIP_PROFILE_UPDATE_NAME) + yield self.db.updates._end_background_update( + _MEMBERSHIP_PROFILE_UPDATE_NAME + ) return result @@ -959,7 +964,7 @@ class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore): last_processed_room = next_room - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME, {"last_processed_room": last_processed_room}, @@ -978,7 +983,9 @@ class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore): ) if finished: - yield self._end_background_update(_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME) + yield self.db.updates._end_background_update( + _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME + ) return row_count diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py index 55a604850e..ffa1817e64 100644 --- a/synapse/storage/data_stores/main/search.py +++ b/synapse/storage/data_stores/main/search.py @@ -24,8 +24,7 @@ from canonicaljson import json from twisted.internet import defer from synapse.api.errors import SynapseError -from synapse.storage._base import make_in_list_sql_clause -from synapse.storage.background_updates import BackgroundUpdateStore +from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause from synapse.storage.engines import PostgresEngine, Sqlite3Engine logger = logging.getLogger(__name__) @@ -36,7 +35,7 @@ SearchEntry = namedtuple( ) -class SearchBackgroundUpdateStore(BackgroundUpdateStore): +class SearchBackgroundUpdateStore(SQLBaseStore): EVENT_SEARCH_UPDATE_NAME = "event_search" EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order" @@ -49,10 +48,10 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore): if not hs.config.enable_search: return - self.register_background_update_handler( + self.db.updates.register_background_update_handler( self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( self.EVENT_SEARCH_ORDER_UPDATE_NAME, self._background_reindex_search_order ) @@ -61,9 +60,11 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore): # a GIN index. However, it's possible that some people might still have # the background update queued, so we register a handler to clear the # background update. - self.register_noop_background_update(self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME) + self.db.updates.register_noop_background_update( + self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME + ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME, self._background_reindex_gin_search ) @@ -153,7 +154,7 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore): "rows_inserted": rows_inserted + len(event_search_rows), } - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, self.EVENT_SEARCH_UPDATE_NAME, progress ) @@ -164,7 +165,7 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore): ) if not result: - yield self._end_background_update(self.EVENT_SEARCH_UPDATE_NAME) + yield self.db.updates._end_background_update(self.EVENT_SEARCH_UPDATE_NAME) return result @@ -208,7 +209,9 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore): if isinstance(self.database_engine, PostgresEngine): yield self.db.runWithConnection(create_index) - yield self._end_background_update(self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME) + yield self.db.updates._end_background_update( + self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME + ) return 1 @defer.inlineCallbacks @@ -244,7 +247,7 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore): yield self.db.runInteraction( self.EVENT_SEARCH_ORDER_UPDATE_NAME, - self._background_update_progress_txn, + self.db.updates._background_update_progress_txn, self.EVENT_SEARCH_ORDER_UPDATE_NAME, pg, ) @@ -274,7 +277,7 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore): "have_added_indexes": True, } - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, self.EVENT_SEARCH_ORDER_UPDATE_NAME, progress ) @@ -285,7 +288,9 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore): ) if not finished: - yield self._end_background_update(self.EVENT_SEARCH_ORDER_UPDATE_NAME) + yield self.db.updates._end_background_update( + self.EVENT_SEARCH_ORDER_UPDATE_NAME + ) return num_rows diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py index 851e81d6b3..7d5a9f8128 100644 --- a/synapse/storage/data_stores/main/state.py +++ b/synapse/storage/data_stores/main/state.py @@ -27,7 +27,6 @@ from synapse.api.errors import NotFoundError from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.storage._base import SQLBaseStore -from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.data_stores.main.events_worker import EventsWorkerStore from synapse.storage.engines import PostgresEngine from synapse.storage.state import StateFilter @@ -1023,9 +1022,7 @@ class StateGroupWorkerStore( return set(row["state_group"] for row in rows) -class StateBackgroundUpdateStore( - StateGroupBackgroundUpdateStore, BackgroundUpdateStore -): +class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" @@ -1034,21 +1031,21 @@ class StateBackgroundUpdateStore( def __init__(self, db_conn, hs): super(StateBackgroundUpdateStore, self).__init__(db_conn, hs) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, self._background_deduplicate_state, ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state ) - self.register_background_index_update( + self.db.updates.register_background_index_update( self.CURRENT_STATE_INDEX_UPDATE_NAME, index_name="current_state_events_member_index", table="current_state_events", columns=["state_key"], where_clause="type='m.room.member'", ) - self.register_background_index_update( + self.db.updates.register_background_index_update( self.EVENT_STATE_GROUP_INDEX_UPDATE_NAME, index_name="event_to_state_groups_sg_index", table="event_to_state_groups", @@ -1181,7 +1178,7 @@ class StateBackgroundUpdateStore( "max_group": max_group, } - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, progress ) @@ -1192,7 +1189,7 @@ class StateBackgroundUpdateStore( ) if finished: - yield self._end_background_update( + yield self.db.updates._end_background_update( self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME ) @@ -1224,7 +1221,7 @@ class StateBackgroundUpdateStore( yield self.db.runWithConnection(reindex_txn) - yield self._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME) + yield self.db.updates._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME) return 1 diff --git a/synapse/storage/data_stores/main/stats.py b/synapse/storage/data_stores/main/stats.py index 974ffc15bd..6b91988c2a 100644 --- a/synapse/storage/data_stores/main/stats.py +++ b/synapse/storage/data_stores/main/stats.py @@ -68,17 +68,17 @@ class StatsStore(StateDeltasStore): self.stats_delta_processing_lock = DeferredLock() - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "populate_stats_process_rooms", self._populate_stats_process_rooms ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "populate_stats_process_users", self._populate_stats_process_users ) # we no longer need to perform clean-up, but we will give ourselves # the potential to reintroduce it in the future – so documentation # will still encourage the use of this no-op handler. - self.register_noop_background_update("populate_stats_cleanup") - self.register_noop_background_update("populate_stats_prepare") + self.db.updates.register_noop_background_update("populate_stats_cleanup") + self.db.updates.register_noop_background_update("populate_stats_prepare") def quantise_stats_time(self, ts): """ @@ -102,7 +102,7 @@ class StatsStore(StateDeltasStore): This is a background update which regenerates statistics for users. """ if not self.stats_enabled: - yield self._end_background_update("populate_stats_process_users") + yield self.db.updates._end_background_update("populate_stats_process_users") return 1 last_user_id = progress.get("last_user_id", "") @@ -123,7 +123,7 @@ class StatsStore(StateDeltasStore): # No more rooms -- complete the transaction. if not users_to_work_on: - yield self._end_background_update("populate_stats_process_users") + yield self.db.updates._end_background_update("populate_stats_process_users") return 1 for user_id in users_to_work_on: @@ -132,7 +132,7 @@ class StatsStore(StateDeltasStore): yield self.db.runInteraction( "populate_stats_process_users", - self._background_update_progress_txn, + self.db.updates._background_update_progress_txn, "populate_stats_process_users", progress, ) @@ -145,7 +145,7 @@ class StatsStore(StateDeltasStore): This is a background update which regenerates statistics for rooms. """ if not self.stats_enabled: - yield self._end_background_update("populate_stats_process_rooms") + yield self.db.updates._end_background_update("populate_stats_process_rooms") return 1 last_room_id = progress.get("last_room_id", "") @@ -166,7 +166,7 @@ class StatsStore(StateDeltasStore): # No more rooms -- complete the transaction. if not rooms_to_work_on: - yield self._end_background_update("populate_stats_process_rooms") + yield self.db.updates._end_background_update("populate_stats_process_rooms") return 1 for room_id in rooms_to_work_on: @@ -175,7 +175,7 @@ class StatsStore(StateDeltasStore): yield self.db.runInteraction( "_populate_stats_process_rooms", - self._background_update_progress_txn, + self.db.updates._background_update_progress_txn, "populate_stats_process_rooms", progress, ) diff --git a/synapse/storage/data_stores/main/user_directory.py b/synapse/storage/data_stores/main/user_directory.py index 7118bd62f3..62ffb34b29 100644 --- a/synapse/storage/data_stores/main/user_directory.py +++ b/synapse/storage/data_stores/main/user_directory.py @@ -19,7 +19,6 @@ import re from twisted.internet import defer from synapse.api.constants import EventTypes, JoinRules -from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.data_stores.main.state import StateFilter from synapse.storage.data_stores.main.state_deltas import StateDeltasStore from synapse.storage.engines import PostgresEngine, Sqlite3Engine @@ -32,7 +31,7 @@ logger = logging.getLogger(__name__) TEMP_TABLE = "_temp_populate_user_directory" -class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore): +class UserDirectoryBackgroundUpdateStore(StateDeltasStore): # How many records do we calculate before sending it to # add_users_who_share_private_rooms? @@ -43,19 +42,19 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore self.server_name = hs.hostname - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "populate_user_directory_createtables", self._populate_user_directory_createtables, ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "populate_user_directory_process_rooms", self._populate_user_directory_process_rooms, ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "populate_user_directory_process_users", self._populate_user_directory_process_users, ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "populate_user_directory_cleanup", self._populate_user_directory_cleanup ) @@ -108,7 +107,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore ) yield self.db.simple_insert(TEMP_TABLE + "_position", {"position": new_pos}) - yield self._end_background_update("populate_user_directory_createtables") + yield self.db.updates._end_background_update( + "populate_user_directory_createtables" + ) return 1 @defer.inlineCallbacks @@ -130,7 +131,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore "populate_user_directory_cleanup", _delete_staging_area ) - yield self._end_background_update("populate_user_directory_cleanup") + yield self.db.updates._end_background_update("populate_user_directory_cleanup") return 1 @defer.inlineCallbacks @@ -176,7 +177,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore # No more rooms -- complete the transaction. if not rooms_to_work_on: - yield self._end_background_update("populate_user_directory_process_rooms") + yield self.db.updates._end_background_update( + "populate_user_directory_process_rooms" + ) return 1 logger.info( @@ -248,7 +251,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore progress["remaining"] -= 1 yield self.db.runInteraction( "populate_user_directory", - self._background_update_progress_txn, + self.db.updates._background_update_progress_txn, "populate_user_directory_process_rooms", progress, ) @@ -267,7 +270,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore If search_all_users is enabled, add all of the users to the user directory. """ if not self.hs.config.user_directory_search_all_users: - yield self._end_background_update("populate_user_directory_process_users") + yield self.db.updates._end_background_update( + "populate_user_directory_process_users" + ) return 1 def _get_next_batch(txn): @@ -297,7 +302,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore # No more users -- complete the transaction. if not users_to_work_on: - yield self._end_background_update("populate_user_directory_process_users") + yield self.db.updates._end_background_update( + "populate_user_directory_process_users" + ) return 1 logger.info( @@ -317,7 +324,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore progress["remaining"] -= 1 yield self.db.runInteraction( "populate_user_directory", - self._background_update_progress_txn, + self.db.updates._background_update_progress_txn, "populate_user_directory_process_users", progress, ) diff --git a/synapse/storage/database.py b/synapse/storage/database.py index ac64d80806..be36c1b829 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -30,6 +30,7 @@ from twisted.internet import defer from synapse.api.errors import StoreError from synapse.logging.context import LoggingContext, make_deferred_yieldable from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage.background_updates import BackgroundUpdater from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.util.stringutils import exception_to_unicode @@ -223,6 +224,8 @@ class Database(object): self._clock = hs.get_clock() self._db_pool = hs.get_db_pool() + self.updates = BackgroundUpdater(hs, self) + self._previous_txn_total_time = 0 self._current_txn_total_time = 0 self._previous_loop_ts = 0 diff --git a/synmark/__init__.py b/synmark/__init__.py index 570eb818d9..afe4fad8cb 100644 --- a/synmark/__init__.py +++ b/synmark/__init__.py @@ -47,9 +47,9 @@ async def make_homeserver(reactor, config=None): stor = hs.get_datastore() # Run the database background updates. - if hasattr(stor, "do_next_background_update"): - while not await stor.has_completed_background_updates(): - await stor.do_next_background_update(1) + if hasattr(stor.db.updates, "do_next_background_update"): + while not await stor.db.updates.has_completed_background_updates(): + await stor.db.updates.do_next_background_update(1) def cleanup(): for i in cleanup_tasks: diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index 7f7962c3dd..d9d312f0fb 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -42,7 +42,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): Add the background updates we need to run. """ # Ugh, have to reset this flag - self.store._all_done = False + self.store.db.updates._all_done = False self.get_success( self.store.db.simple_insert( @@ -108,8 +108,12 @@ class StatsRoomTests(unittest.HomeserverTestCase): # Do the initial population of the stats via the background update self._add_background_updates() - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) def test_initial_room(self): """ @@ -141,8 +145,12 @@ class StatsRoomTests(unittest.HomeserverTestCase): # Do the initial population of the user directory via the background update self._add_background_updates() - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) r = self.get_success(self.get_all_room_state()) @@ -178,7 +186,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): # the position that the deltas should begin at, once they take over. self.hs.config.stats_enabled = True self.handler.stats_enabled = True - self.store._all_done = False + self.store.db.updates._all_done = False self.get_success( self.store.db.simple_update_one( table="stats_incremental_position", @@ -194,8 +202,12 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) # Now, before the table is actually ingested, add some more events. self.helper.invite(room=room_1, src=u1, targ=u2, tok=u1_token) @@ -221,9 +233,13 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) - self.store._all_done = False - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + self.store.db.updates._all_done = False + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) self.reactor.advance(86401) @@ -653,7 +669,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): # preparation stage of the initial background update # Ugh, have to reset this flag - self.store._all_done = False + self.store.db.updates._all_done = False self.get_success( self.store.db.simple_delete( @@ -673,7 +689,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): # now do the background updates - self.store._all_done = False + self.store.db.updates._all_done = False self.get_success( self.store.db.simple_insert( "background_updates", @@ -705,8 +721,12 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) r1stats_complete = self._get_current_stats("room", r1) u1stats_complete = self._get_current_stats("user", u1) diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index bc9d441541..26071059d2 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -181,7 +181,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): Add the background updates we need to run. """ # Ugh, have to reset this flag - self.store._all_done = False + self.store.db.updates._all_done = False self.get_success( self.store.db.simple_insert( @@ -255,8 +255,12 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): # Do the initial population of the user directory via the background update self._add_background_updates() - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) shares_private = self.get_users_who_share_private_rooms() public_users = self.get_users_in_public_rooms() @@ -290,8 +294,12 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): # Do the initial population of the user directory via the background update self._add_background_updates() - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) shares_private = self.get_users_who_share_private_rooms() public_users = self.get_users_in_public_rooms() diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index e360297df9..aec76f4ab1 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -15,7 +15,7 @@ class BackgroundUpdateTestCase(unittest.TestCase): self.update_handler = Mock() - yield self.store.register_background_update_handler( + yield self.store.db.updates.register_background_update_handler( "test_update", self.update_handler ) @@ -23,7 +23,7 @@ class BackgroundUpdateTestCase(unittest.TestCase): # (perhaps we should run them as part of the test HS setup, since we # run all of the other schema setup stuff there?) while True: - res = yield self.store.do_next_background_update(1000) + res = yield self.store.db.updates.do_next_background_update(1000) if res is None: break @@ -39,7 +39,7 @@ class BackgroundUpdateTestCase(unittest.TestCase): progress = {"my_key": progress["my_key"] + 1} yield self.store.db.runInteraction( "update_progress", - self.store._background_update_progress_txn, + self.store.db.updates._background_update_progress_txn, "test_update", progress, ) @@ -47,29 +47,37 @@ class BackgroundUpdateTestCase(unittest.TestCase): self.update_handler.side_effect = update - yield self.store.start_background_update("test_update", {"my_key": 1}) + yield self.store.db.updates.start_background_update( + "test_update", {"my_key": 1} + ) self.update_handler.reset_mock() - result = yield self.store.do_next_background_update(duration_ms * desired_count) + result = yield self.store.db.updates.do_next_background_update( + duration_ms * desired_count + ) self.assertIsNotNone(result) self.update_handler.assert_called_once_with( - {"my_key": 1}, self.store.DEFAULT_BACKGROUND_BATCH_SIZE + {"my_key": 1}, self.store.db.updates.DEFAULT_BACKGROUND_BATCH_SIZE ) # second step: complete the update @defer.inlineCallbacks def update(progress, count): - yield self.store._end_background_update("test_update") + yield self.store.db.updates._end_background_update("test_update") return count self.update_handler.side_effect = update self.update_handler.reset_mock() - result = yield self.store.do_next_background_update(duration_ms * desired_count) + result = yield self.store.db.updates.do_next_background_update( + duration_ms * desired_count + ) self.assertIsNotNone(result) self.update_handler.assert_called_once_with({"my_key": 2}, desired_count) # third step: we don't expect to be called any more self.update_handler.reset_mock() - result = yield self.store.do_next_background_update(duration_ms * desired_count) + result = yield self.store.db.updates.do_next_background_update( + duration_ms * desired_count + ) self.assertIsNone(result) self.assertFalse(self.update_handler.called) diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py index e454bbff29..029ac26454 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py @@ -46,7 +46,9 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): """Re run the background update to clean up the extremities. """ # Make sure we don't clash with in progress updates. - self.assertTrue(self.store._all_done, "Background updates are still ongoing") + self.assertTrue( + self.store.db.updates._all_done, "Background updates are still ongoing" + ) schema_path = os.path.join( prepare_database.dir_path, @@ -68,10 +70,14 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): ) # Ugh, have to reset this flag - self.store._all_done = False + self.store.db.updates._all_done = False - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) def test_soft_failed_extremities_handled_correctly(self): """Test that extremities are correctly calculated in the presence of diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index c4f838907c..fc279340d4 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -202,8 +202,12 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): def test_devices_last_seen_bg_update(self): # First make sure we have completed all updates. - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) # Insert a user IP user_id = "@user:id" @@ -256,11 +260,15 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): ) # ... and tell the DataStore that it hasn't finished all updates yet - self.store._all_done = False + self.store.db.updates._all_done = False # Now let's actually drive the updates to completion - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) # We should now get the correct result again result = self.get_success( @@ -281,8 +289,12 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): def test_old_user_ips_pruned(self): # First make sure we have completed all updates. - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) # Insert a user IP user_id = "@user:id" diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 5f957680a2..7840f63fe3 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -122,8 +122,12 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): def test_can_rerun_update(self): # First make sure we have completed all updates. - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) # Now let's create a room, which will insert a membership user = UserID("alice", "test") @@ -143,8 +147,12 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): ) # ... and tell the DataStore that it hasn't finished all updates yet - self.store._all_done = False + self.store.db.updates._all_done = False # Now let's actually drive the updates to completion - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) diff --git a/tests/unittest.py b/tests/unittest.py index fc856a574a..68d245ec9f 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -401,10 +401,12 @@ class HomeserverTestCase(TestCase): hs = setup_test_homeserver(self.addCleanup, *args, **kwargs) stor = hs.get_datastore() - # Run the database background updates. - if hasattr(stor, "do_next_background_update"): - while not self.get_success(stor.has_completed_background_updates()): - self.get_success(stor.do_next_background_update(1)) + # Run the database background updates, when running against "master". + if hs.__class__.__name__ == "TestHomeServer": + while not self.get_success( + stor.db.updates.has_completed_background_updates() + ): + self.get_success(stor.db.updates.do_next_background_update(1)) return hs -- cgit 1.4.1 From e216ec381af713b9bc9d629bad219f4eb6a1a884 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 6 Dec 2019 11:15:25 +0000 Subject: Remove unused var --- .buildkite/postgres-config.yaml | 2 +- .buildkite/sqlite-config.yaml | 2 +- synapse/storage/database.py | 2 -- 3 files changed, 2 insertions(+), 4 deletions(-) (limited to 'synapse/storage/database.py') diff --git a/.buildkite/postgres-config.yaml b/.buildkite/postgres-config.yaml index a35fec394d..dcf72cfda0 100644 --- a/.buildkite/postgres-config.yaml +++ b/.buildkite/postgres-config.yaml @@ -1,7 +1,7 @@ # Configuration file used for testing the 'synapse_port_db' script. # Tells the script to connect to the postgresql database that will be available in the # CI's Docker setup at the point where this file is considered. -server_name: "test" +server_name: "localhost:8080" signing_key_path: "/src/.buildkite/test.signing.key" diff --git a/.buildkite/sqlite-config.yaml b/.buildkite/sqlite-config.yaml index 635b921764..5276aaff03 100644 --- a/.buildkite/sqlite-config.yaml +++ b/.buildkite/sqlite-config.yaml @@ -1,7 +1,7 @@ # Configuration file used for testing the 'synapse_port_db' script. # Tells the 'update_database' script to connect to the test SQLite database to upgrade its # schema and run background updates on it. -server_name: "test" +server_name: "localhost:8080" signing_key_path: "/src/.buildkite/test.signing.key" diff --git a/synapse/storage/database.py b/synapse/storage/database.py index be36c1b829..bd515d70d2 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -256,8 +256,6 @@ class Database(object): self._check_safe_to_upsert, ) - self.rand = random.SystemRandom() - @defer.inlineCallbacks def _check_safe_to_upsert(self): """ -- cgit 1.4.1 From d537be1ebd0e7ce4c84118efa400932cc6432aa9 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 6 Dec 2019 13:40:02 +0000 Subject: Pass Database into the data store --- synapse/server.py | 3 +-- synapse/storage/_base.py | 2 +- synapse/storage/background_updates.py | 2 +- synapse/storage/data_stores/__init__.py | 7 ++++-- synapse/storage/database.py | 38 ++++++++++++++------------------- 5 files changed, 24 insertions(+), 28 deletions(-) (limited to 'synapse/storage/database.py') diff --git a/synapse/server.py b/synapse/server.py index be9af7f986..2db3dab221 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -238,8 +238,7 @@ class HomeServer(object): def setup(self): logger.info("Setting up.") with self.get_db_conn() as conn: - datastore = self.DATASTORE_CLASS(conn, self) - self.datastores = DataStores(datastore, conn, self) + self.datastores = DataStores(self.DATASTORE_CLASS, conn, self) conn.commit() self.start_time = int(self.get_clock().time()) logger.info("Finished setting up.") diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index f9e7f9a71e..b7637b5dc0 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -41,7 +41,7 @@ class SQLBaseStore(object): self.hs = hs self._clock = hs.get_clock() self.database_engine = hs.database_engine - self.db = Database(hs) # In future this will be passed in + self.db = database self.rand = random.SystemRandom() def _invalidate_state_caches(self, room_id, members_changed): diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index a9a13a2658..4f97fd5ab6 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -379,7 +379,7 @@ class BackgroundUpdater(object): logger.debug("[SQL] %s", sql) c.execute(sql) - if isinstance(self.db.database_engine, engines.PostgresEngine): + if isinstance(self.db.engine, engines.PostgresEngine): runner = create_index_psql elif psql_only: runner = None diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/data_stores/__init__.py index cb184a98cc..79ecc62735 100644 --- a/synapse/storage/data_stores/__init__.py +++ b/synapse/storage/data_stores/__init__.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from synapse.storage.database import Database + class DataStores(object): """The various data stores. @@ -20,7 +22,8 @@ class DataStores(object): These are low level interfaces to physical databases. """ - def __init__(self, main_store, db_conn, hs): + def __init__(self, main_store_class, db_conn, hs): # Note we pass in the main store here as workers use a different main # store. - self.main = main_store + database = Database(hs) + self.main = main_store_class(database, db_conn, hs) diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 6843b7e7f8..ec19ae1d9d 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -234,7 +234,7 @@ class Database(object): # to watch it self._txn_perf_counters = PerformanceCounters() - self.database_engine = hs.database_engine + self.engine = hs.database_engine # A set of tables that are not safe to use native upserts in. self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys()) @@ -242,10 +242,10 @@ class Database(object): # We add the user_directory_search table to the blacklist on SQLite # because the existing search table does not have an index, making it # unsafe to use native upserts. - if isinstance(self.database_engine, Sqlite3Engine): + if isinstance(self.engine, Sqlite3Engine): self._unsafe_to_upsert_tables.add("user_directory_search") - if self.database_engine.can_native_upsert: + if self.engine.can_native_upsert: # Check ASAP (and then later, every 1s) to see if we have finished # background updates of tables that aren't safe to update. self._clock.call_later( @@ -331,7 +331,7 @@ class Database(object): cursor = LoggingTransaction( conn.cursor(), name, - self.database_engine, + self.engine, after_callbacks, exception_callbacks, ) @@ -339,7 +339,7 @@ class Database(object): r = func(cursor, *args, **kwargs) conn.commit() return r - except self.database_engine.module.OperationalError as e: + except self.engine.module.OperationalError as e: # This can happen if the database disappears mid # transaction. logger.warning( @@ -353,20 +353,20 @@ class Database(object): i += 1 try: conn.rollback() - except self.database_engine.module.Error as e1: + except self.engine.module.Error as e1: logger.warning( "[TXN EROLL] {%s} %s", name, exception_to_unicode(e1) ) continue raise - except self.database_engine.module.DatabaseError as e: - if self.database_engine.is_deadlock(e): + except self.engine.module.DatabaseError as e: + if self.engine.is_deadlock(e): logger.warning("[TXN DEADLOCK] {%s} %d/%d", name, i, N) if i < N: i += 1 try: conn.rollback() - except self.database_engine.module.Error as e1: + except self.engine.module.Error as e1: logger.warning( "[TXN EROLL] {%s} %s", name, @@ -494,7 +494,7 @@ class Database(object): sql_scheduling_timer.observe(sched_duration_sec) context.add_database_scheduled(sched_duration_sec) - if self.database_engine.is_connection_closed(conn): + if self.engine.is_connection_closed(conn): logger.debug("Reconnecting closed database connection") conn.reconnect() @@ -561,7 +561,7 @@ class Database(object): """ try: yield self.runInteraction(desc, self.simple_insert_txn, table, values) - except self.database_engine.module.IntegrityError: + except self.engine.module.IntegrityError: # We have to do or_ignore flag at this layer, since we can't reuse # a cursor after we receive an error from the db. if not or_ignore: @@ -660,7 +660,7 @@ class Database(object): lock=lock, ) return result - except self.database_engine.module.IntegrityError as e: + except self.engine.module.IntegrityError as e: attempts += 1 if attempts >= 5: # don't retry forever, because things other than races @@ -692,10 +692,7 @@ class Database(object): upserts return True if a new entry was created, False if an existing one was updated. """ - if ( - self.database_engine.can_native_upsert - and table not in self._unsafe_to_upsert_tables - ): + if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables: return self.simple_upsert_txn_native_upsert( txn, table, keyvalues, values, insertion_values=insertion_values ) @@ -726,7 +723,7 @@ class Database(object): """ # We need to lock the table :(, unless we're *really* careful if lock: - self.database_engine.lock_table(txn, table) + self.engine.lock_table(txn, table) def _getwhere(key): # If the value we're passing in is None (aka NULL), we need to use @@ -828,10 +825,7 @@ class Database(object): Returns: None """ - if ( - self.database_engine.can_native_upsert - and table not in self._unsafe_to_upsert_tables - ): + if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables: return self.simple_upsert_many_txn_native_upsert( txn, table, key_names, key_values, value_names, value_values ) @@ -1301,7 +1295,7 @@ class Database(object): "limit": limit, } - sql = self.database_engine.convert_param_style(sql) + sql = self.engine.convert_param_style(sql) txn = db_conn.cursor() txn.execute(sql, (int(max_value),)) -- cgit 1.4.1 From 2284eb3a533a2df04784df08da28e67d6588a5ea Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 18 Dec 2019 10:45:12 +0000 Subject: Add database config class (#6513) This encapsulates config for a given database and is the way to get new connections. --- changelog.d/6513.misc | 1 + scripts-dev/update_database | 9 +-- scripts/synapse_port_db | 58 ++++++++----------- synapse/config/database.py | 78 ++++++++++++++++++++------ synapse/handlers/presence.py | 2 +- synapse/server.py | 41 ++------------ synapse/storage/_base.py | 2 +- synapse/storage/data_stores/__init__.py | 40 ++++++++++--- synapse/storage/data_stores/main/client_ips.py | 2 +- synapse/storage/database.py | 45 ++++++++++++++- synapse/storage/engines/sqlite.py | 16 +++++- synapse/storage/prepare_database.py | 7 +-- tests/handlers/test_typing.py | 39 ++++++------- tests/replication/slave/storage/_base.py | 6 +- tests/server.py | 55 +++++++++--------- tests/storage/test_appservice.py | 37 ++++++++---- tests/storage/test_base.py | 14 +++-- tests/storage/test_registration.py | 1 - tests/utils.py | 43 +++++--------- 19 files changed, 287 insertions(+), 209 deletions(-) create mode 100644 changelog.d/6513.misc (limited to 'synapse/storage/database.py') diff --git a/changelog.d/6513.misc b/changelog.d/6513.misc new file mode 100644 index 0000000000..36700f5657 --- /dev/null +++ b/changelog.d/6513.misc @@ -0,0 +1 @@ +Remove all assumptions of there being a single phyiscal DB apart from the `synapse.config`. diff --git a/scripts-dev/update_database b/scripts-dev/update_database index 23017c21f8..1d62f0403a 100755 --- a/scripts-dev/update_database +++ b/scripts-dev/update_database @@ -26,7 +26,6 @@ from synapse.config.homeserver import HomeServerConfig from synapse.metrics.background_process_metrics import run_as_background_process from synapse.server import HomeServer from synapse.storage import DataStore -from synapse.storage.prepare_database import prepare_database logger = logging.getLogger("update_database") @@ -77,12 +76,8 @@ if __name__ == "__main__": # Instantiate and initialise the homeserver object. hs = MockHomeserver(config) - db_conn = hs.get_db_conn() - # Update the database to the latest schema. - prepare_database(db_conn, hs.database_engine, config=config) - db_conn.commit() - - # setup instantiates the store within the homeserver object. + # Setup instantiates the store within the homeserver object and updates the + # DB. hs.setup() store = hs.get_datastore() diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index e393a9b2f7..5b5368988c 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -30,6 +30,7 @@ import yaml from twisted.enterprise import adbapi from twisted.internet import defer, reactor +from synapse.config.database import DatabaseConnectionConfig from synapse.config.homeserver import HomeServerConfig from synapse.logging.context import PreserveLoggingContext from synapse.storage._base import LoggingTransaction @@ -55,7 +56,7 @@ from synapse.storage.data_stores.main.stats import StatsStore from synapse.storage.data_stores.main.user_directory import ( UserDirectoryBackgroundUpdateStore, ) -from synapse.storage.database import Database +from synapse.storage.database import Database, make_conn from synapse.storage.engines import create_engine from synapse.storage.prepare_database import prepare_database from synapse.util import Clock @@ -165,23 +166,17 @@ class Store( class MockHomeserver: - def __init__(self, config, database_engine, db_conn, db_pool): - self.database_engine = database_engine - self.db_conn = db_conn - self.db_pool = db_pool + def __init__(self, config): self.clock = Clock(reactor) self.config = config self.hostname = config.server_name - def get_db_conn(self): - return self.db_conn - - def get_db_pool(self): - return self.db_pool - def get_clock(self): return self.clock + def get_reactor(self): + return reactor + class Porter(object): def __init__(self, **kwargs): @@ -445,45 +440,36 @@ class Porter(object): else: return - def setup_db(self, db_config, database_engine): - db_conn = database_engine.module.connect( - **{ - k: v - for k, v in db_config.get("args", {}).items() - if not k.startswith("cp_") - } - ) - - prepare_database(db_conn, database_engine, config=None) + def setup_db(self, db_config: DatabaseConnectionConfig, engine): + db_conn = make_conn(db_config, engine) + prepare_database(db_conn, engine, config=None) db_conn.commit() return db_conn @defer.inlineCallbacks - def build_db_store(self, config): + def build_db_store(self, db_config: DatabaseConnectionConfig): """Builds and returns a database store using the provided configuration. Args: - config: The database configuration, i.e. a dict following the structure of - the "database" section of Synapse's configuration file. + config: The database configuration Returns: The built Store object. """ - engine = create_engine(config) - - self.progress.set_state("Preparing %s" % config["name"]) - conn = self.setup_db(config, engine) + self.progress.set_state("Preparing %s" % db_config.config["name"]) - db_pool = adbapi.ConnectionPool(config["name"], **config["args"]) + engine = create_engine(db_config.config) + conn = self.setup_db(db_config, engine) - hs = MockHomeserver(self.hs_config, engine, conn, db_pool) + hs = MockHomeserver(self.hs_config) - store = Store(Database(hs), conn, hs) + store = Store(Database(hs, db_config, engine), conn, hs) yield store.db.runInteraction( - "%s_engine.check_database" % config["name"], engine.check_database, + "%s_engine.check_database" % db_config.config["name"], + engine.check_database, ) return store @@ -509,7 +495,11 @@ class Porter(object): @defer.inlineCallbacks def run(self): try: - self.sqlite_store = yield self.build_db_store(self.sqlite_config) + self.sqlite_store = yield self.build_db_store( + DatabaseConnectionConfig( + "master", self.sqlite_config, data_stores=["main"] + ) + ) # Check if all background updates are done, abort if not. updates_complete = ( @@ -524,7 +514,7 @@ class Porter(object): defer.returnValue(None) self.postgres_store = yield self.build_db_store( - self.hs_config.database_config + self.hs_config.get_single_database() ) yield self.run_background_updates_on_postgres() diff --git a/synapse/config/database.py b/synapse/config/database.py index 0e2509f0b1..5f2f3c7cfd 100644 --- a/synapse/config/database.py +++ b/synapse/config/database.py @@ -12,12 +12,43 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging import os from textwrap import indent +from typing import List import yaml -from ._base import Config +from synapse.config._base import Config, ConfigError + +logger = logging.getLogger(__name__) + + +class DatabaseConnectionConfig: + """Contains the connection config for a particular database. + + Args: + name: A label for the database, used for logging. + db_config: The config for a particular database, as per `database` + section of main config. Has two fields: `name` for database + module name, and `args` for the args to give to the database + connector. + data_stores: The list of data stores that should be provisioned on the + database. + """ + + def __init__(self, name: str, db_config: dict, data_stores: List[str]): + if db_config["name"] not in ("sqlite3", "psycopg2"): + raise ConfigError("Unsupported database type %r" % (db_config["name"],)) + + if db_config["name"] == "sqlite3": + db_config.setdefault("args", {}).update( + {"cp_min": 1, "cp_max": 1, "check_same_thread": False} + ) + + self.name = name + self.config = db_config + self.data_stores = data_stores class DatabaseConfig(Config): @@ -26,20 +57,14 @@ class DatabaseConfig(Config): def read_config(self, config, **kwargs): self.event_cache_size = self.parse_size(config.get("event_cache_size", "10K")) - self.database_config = config.get("database") + database_config = config.get("database") - if self.database_config is None: - self.database_config = {"name": "sqlite3", "args": {}} + if database_config is None: + database_config = {"name": "sqlite3", "args": {}} - name = self.database_config.get("name", None) - if name == "psycopg2": - pass - elif name == "sqlite3": - self.database_config.setdefault("args", {}).update( - {"cp_min": 1, "cp_max": 1, "check_same_thread": False} - ) - else: - raise RuntimeError("Unsupported database type '%s'" % (name,)) + self.databases = [ + DatabaseConnectionConfig("master", database_config, data_stores=["main"]) + ] self.set_databasepath(config.get("database_path")) @@ -76,11 +101,24 @@ class DatabaseConfig(Config): self.set_databasepath(args.database_path) def set_databasepath(self, database_path): + if database_path is None: + return + if database_path != ":memory:": database_path = self.abspath(database_path) - if self.database_config.get("name", None) == "sqlite3": - if database_path is not None: - self.database_config["args"]["database"] = database_path + + # We only support setting a database path if we have a single sqlite3 + # database. + if len(self.databases) != 1: + raise ConfigError("Cannot specify 'database_path' with multiple databases") + + database = self.get_single_database() + if database.config["name"] != "sqlite3": + # We don't raise here as we haven't done so before for this case. + logger.warn("Ignoring 'database_path' for non-sqlite3 database") + return + + database.config["args"]["database"] = database_path @staticmethod def add_arguments(parser): @@ -91,3 +129,11 @@ class DatabaseConfig(Config): metavar="SQLITE_DATABASE_PATH", help="The path to a sqlite database to use.", ) + + def get_single_database(self) -> DatabaseConnectionConfig: + """Returns the database if there is only one, useful for e.g. tests + """ + if len(self.databases) != 1: + raise Exception("More than one database exists") + + return self.databases[0] diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index eda15bc623..240c4add12 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -230,7 +230,7 @@ class PresenceHandler(object): is some spurious presence changes that will self-correct. """ # If the DB pool has already terminated, don't try updating - if not self.hs.get_db_pool().running: + if not self.store.database.is_running(): return logger.info( diff --git a/synapse/server.py b/synapse/server.py index 5021068ce0..7926867b77 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -25,7 +25,6 @@ import abc import logging import os -from twisted.enterprise import adbapi from twisted.mail.smtp import sendmail from twisted.web.client import BrowserLikePolicyForHTTPS @@ -98,7 +97,6 @@ from synapse.server_notices.worker_server_notices_sender import ( ) from synapse.state import StateHandler, StateResolutionHandler from synapse.storage import DataStores, Storage -from synapse.storage.engines import create_engine from synapse.streams.events import EventSources from synapse.util import Clock from synapse.util.distributor import Distributor @@ -134,7 +132,6 @@ class HomeServer(object): DEPENDENCIES = [ "http_client", - "db_pool", "federation_client", "federation_server", "handlers", @@ -233,12 +230,6 @@ class HomeServer(object): self.admin_redaction_ratelimiter = Ratelimiter() self.registration_ratelimiter = Ratelimiter() - self.database_engine = create_engine(config.database_config) - config.database_config.setdefault("args", {})[ - "cp_openfun" - ] = self.database_engine.on_new_connection - self.db_config = config.database_config - self.datastores = None # Other kwargs are explicit dependencies @@ -247,10 +238,8 @@ class HomeServer(object): def setup(self): logger.info("Setting up.") - with self.get_db_conn() as conn: - self.datastores = DataStores(self.DATASTORE_CLASS, conn, self) - conn.commit() self.start_time = int(self.get_clock().time()) + self.datastores = DataStores(self.DATASTORE_CLASS, self) logger.info("Finished setting up.") def setup_master(self): @@ -284,6 +273,9 @@ class HomeServer(object): def get_datastore(self): return self.datastores.main + def get_datastores(self): + return self.datastores + def get_config(self): return self.config @@ -433,31 +425,6 @@ class HomeServer(object): ) return MatrixFederationHttpClient(self, tls_client_options_factory) - def build_db_pool(self): - name = self.db_config["name"] - - return adbapi.ConnectionPool( - name, cp_reactor=self.get_reactor(), **self.db_config.get("args", {}) - ) - - def get_db_conn(self, run_new_connection=True): - """Makes a new connection to the database, skipping the db pool - - Returns: - Connection: a connection object implementing the PEP-249 spec - """ - # Any param beginning with cp_ is a parameter for adbapi, and should - # not be passed to the database engine. - db_params = { - k: v - for k, v in self.db_config.get("args", {}).items() - if not k.startswith("cp_") - } - db_conn = self.database_engine.module.connect(**db_params) - if run_new_connection: - self.database_engine.on_new_connection(db_conn) - return db_conn - def build_media_repository_resource(self): # build the media repo resource. This indirects through the HomeServer # to ensure that we only have a single instance of diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index b7637b5dc0..88546ad614 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -40,7 +40,7 @@ class SQLBaseStore(object): def __init__(self, database: Database, db_conn, hs): self.hs = hs self._clock = hs.get_clock() - self.database_engine = hs.database_engine + self.database_engine = database.engine self.db = database self.rand = random.SystemRandom() diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/data_stores/__init__.py index cafedd5c0d..0983e059c0 100644 --- a/synapse/storage/data_stores/__init__.py +++ b/synapse/storage/data_stores/__init__.py @@ -13,24 +13,50 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.database import Database +import logging + +from synapse.storage.database import Database, make_conn +from synapse.storage.engines import create_engine from synapse.storage.prepare_database import prepare_database +logger = logging.getLogger(__name__) + class DataStores(object): """The various data stores. These are low level interfaces to physical databases. + + Attributes: + main (DataStore) """ - def __init__(self, main_store_class, db_conn, hs): + def __init__(self, main_store_class, hs): # Note we pass in the main store class here as workers use a different main # store. - database = Database(hs) - # Check that db is correctly configured. - database.engine.check_database(db_conn.cursor()) + self.databases = [] + + for database_config in hs.config.database.databases: + db_name = database_config.name + engine = create_engine(database_config.config) + + with make_conn(database_config, engine) as db_conn: + logger.info("Preparing database %r...", db_name) + + engine.check_database(db_conn.cursor()) + prepare_database( + db_conn, engine, hs.config, data_stores=database_config.data_stores, + ) + + database = Database(hs, database_config, engine) + + if "main" in database_config.data_stores: + logger.info("Starting 'main' data store") + self.main = main_store_class(database, db_conn, hs) + + db_conn.commit() - prepare_database(db_conn, database.engine, config=hs.config) + self.databases.append(database) - self.main = main_store_class(database, db_conn, hs) + logger.info("Database %r prepared", db_name) diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py index add3037b69..13f4c9c72e 100644 --- a/synapse/storage/data_stores/main/client_ips.py +++ b/synapse/storage/data_stores/main/client_ips.py @@ -412,7 +412,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): def _update_client_ips_batch(self): # If the DB pool has already terminated, don't try updating - if not self.hs.get_db_pool().running: + if not self.db.is_running(): return to_update = self._batch_row_update diff --git a/synapse/storage/database.py b/synapse/storage/database.py index ec19ae1d9d..1003dd84a5 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -24,9 +24,11 @@ from six.moves import intern, range from prometheus_client import Histogram +from twisted.enterprise import adbapi from twisted.internet import defer from synapse.api.errors import StoreError +from synapse.config.database import DatabaseConnectionConfig from synapse.logging.context import LoggingContext, make_deferred_yieldable from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.background_updates import BackgroundUpdater @@ -74,6 +76,37 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = { } +def make_pool( + reactor, db_config: DatabaseConnectionConfig, engine +) -> adbapi.ConnectionPool: + """Get the connection pool for the database. + """ + + return adbapi.ConnectionPool( + db_config.config["name"], + cp_reactor=reactor, + cp_openfun=engine.on_new_connection, + **db_config.config.get("args", {}) + ) + + +def make_conn(db_config: DatabaseConnectionConfig, engine): + """Make a new connection to the database and return it. + + Returns: + Connection + """ + + db_params = { + k: v + for k, v in db_config.config.get("args", {}).items() + if not k.startswith("cp_") + } + db_conn = engine.module.connect(**db_params) + engine.on_new_connection(db_conn) + return db_conn + + class LoggingTransaction(object): """An object that almost-transparently proxies for the 'txn' object passed to the constructor. Adds logging and metrics to the .execute() @@ -218,10 +251,11 @@ class Database(object): _TXN_ID = 0 - def __init__(self, hs): + def __init__(self, hs, database_config: DatabaseConnectionConfig, engine): self.hs = hs self._clock = hs.get_clock() - self._db_pool = hs.get_db_pool() + self._database_config = database_config + self._db_pool = make_pool(hs.get_reactor(), database_config, engine) self.updates = BackgroundUpdater(hs, self) @@ -234,7 +268,7 @@ class Database(object): # to watch it self._txn_perf_counters = PerformanceCounters() - self.engine = hs.database_engine + self.engine = engine # A set of tables that are not safe to use native upserts in. self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys()) @@ -255,6 +289,11 @@ class Database(object): self._check_safe_to_upsert, ) + def is_running(self): + """Is the database pool currently running + """ + return self._db_pool.running + @defer.inlineCallbacks def _check_safe_to_upsert(self): """ diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py index ddad17dc5a..df039a072d 100644 --- a/synapse/storage/engines/sqlite.py +++ b/synapse/storage/engines/sqlite.py @@ -16,8 +16,6 @@ import struct import threading -from synapse.storage.prepare_database import prepare_database - class Sqlite3Engine(object): single_threaded = True @@ -25,6 +23,9 @@ class Sqlite3Engine(object): def __init__(self, database_module, database_config): self.module = database_module + database = database_config.get("args", {}).get("database") + self._is_in_memory = database in (None, ":memory:",) + # The current max state_group, or None if we haven't looked # in the DB yet. self._current_state_group_id = None @@ -59,7 +60,16 @@ class Sqlite3Engine(object): return sql def on_new_connection(self, db_conn): - prepare_database(db_conn, self, config=None) + + # We need to import here to avoid an import loop. + from synapse.storage.prepare_database import prepare_database + + if self._is_in_memory: + # In memory databases need to be rebuilt each time. Ideally we'd + # reuse the same connection as we do when starting up, but that + # would involve using adbapi before we have started the reactor. + prepare_database(db_conn, self, config=None) + db_conn.create_function("rank", 1, _rank) def is_deadlock(self, error): diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 731e1c9d9c..b4194b44ee 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -41,7 +41,7 @@ class UpgradeDatabaseException(PrepareDatabaseException): pass -def prepare_database(db_conn, database_engine, config): +def prepare_database(db_conn, database_engine, config, data_stores=["main"]): """Prepares a database for usage. Will either create all necessary tables or upgrade from an older schema version. @@ -54,11 +54,10 @@ def prepare_database(db_conn, database_engine, config): config (synapse.config.homeserver.HomeServerConfig|None): application config, or None if we are connecting to an existing database which we expect to be configured already + data_stores (list[str]): The name of the data stores that will be used + with this database. Defaults to all data stores. """ - # For now we only have the one datastore. - data_stores = ["main"] - try: cur = db_conn.cursor() version_info = _get_or_create_schema_state(cur, database_engine) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 92b8726093..596ddc6970 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -64,28 +64,29 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): mock_federation_client = Mock(spec=["put_json"]) mock_federation_client.put_json.return_value = defer.succeed((200, "OK")) + datastores = Mock() + datastores.main = Mock( + spec=[ + # Bits that Federation needs + "prep_send_transaction", + "delivered_txn", + "get_received_txn_response", + "set_received_txn_response", + "get_destination_retry_timings", + "get_devices_by_remote", + # Bits that user_directory needs + "get_user_directory_stream_pos", + "get_current_state_deltas", + "get_device_updates_by_remote", + ] + ) + hs = self.setup_test_homeserver( - datastore=( - Mock( - spec=[ - # Bits that Federation needs - "prep_send_transaction", - "delivered_txn", - "get_received_txn_response", - "set_received_txn_response", - "get_destination_retry_timings", - "get_device_updates_by_remote", - # Bits that user_directory needs - "get_user_directory_stream_pos", - "get_current_state_deltas", - ] - ) - ), - notifier=Mock(), - http_client=mock_federation_client, - keyring=mock_keyring, + notifier=Mock(), http_client=mock_federation_client, keyring=mock_keyring ) + hs.datastores = datastores + return hs def prepare(self, reactor, clock, hs): diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 3dae83c543..2a1e7c7166 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -20,7 +20,7 @@ from synapse.replication.tcp.client import ( ReplicationClientHandler, ) from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory -from synapse.storage.database import Database +from synapse.storage.database import make_conn from tests import unittest from tests.server import FakeTransport @@ -41,10 +41,12 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): + db_config = hs.config.database.get_single_database() self.master_store = self.hs.get_datastore() self.storage = hs.get_storage() + database = hs.get_datastores().databases[0] self.slaved_store = self.STORE_TYPE( - Database(hs), self.hs.get_db_conn(), self.hs + database, make_conn(db_config, database.engine), self.hs ) self.event_id = 0 diff --git a/tests/server.py b/tests/server.py index 2b7cf4242e..a554dfdd57 100644 --- a/tests/server.py +++ b/tests/server.py @@ -302,41 +302,42 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs): Set up a synchronous test server, driven by the reactor used by the homeserver. """ - d = _sth(cleanup_func, *args, **kwargs).result + server = _sth(cleanup_func, *args, **kwargs) - if isinstance(d, Failure): - d.raiseException() + database = server.config.database.get_single_database() # Make the thread pool synchronous. - clock = d.get_clock() - pool = d.get_db_pool() - - def runWithConnection(func, *args, **kwargs): - return threads.deferToThreadPool( - pool._reactor, - pool.threadpool, - pool._runWithConnection, - func, - *args, - **kwargs - ) - - def runInteraction(interaction, *args, **kwargs): - return threads.deferToThreadPool( - pool._reactor, - pool.threadpool, - pool._runInteraction, - interaction, - *args, - **kwargs - ) + clock = server.get_clock() + + for database in server.get_datastores().databases: + pool = database._db_pool + + def runWithConnection(func, *args, **kwargs): + return threads.deferToThreadPool( + pool._reactor, + pool.threadpool, + pool._runWithConnection, + func, + *args, + **kwargs + ) + + def runInteraction(interaction, *args, **kwargs): + return threads.deferToThreadPool( + pool._reactor, + pool.threadpool, + pool._runInteraction, + interaction, + *args, + **kwargs + ) - if pool: pool.runWithConnection = runWithConnection pool.runInteraction = runInteraction pool.threadpool = ThreadPool(clock._reactor) pool.running = True - return d + + return server def get_clock(): diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 2e521e9ab7..fd52512696 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -28,7 +28,7 @@ from synapse.storage.data_stores.main.appservice import ( ApplicationServiceStore, ApplicationServiceTransactionStore, ) -from synapse.storage.database import Database +from synapse.storage.database import Database, make_conn from tests import unittest from tests.utils import setup_test_homeserver @@ -55,8 +55,10 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob") self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob") # must be done after inserts - database = Database(hs) - self.store = ApplicationServiceStore(database, hs.get_db_conn(), hs) + database = hs.get_datastores().databases[0] + self.store = ApplicationServiceStore( + database, make_conn(database._database_config, database.engine), hs + ) def tearDown(self): # TODO: suboptimal that we need to create files for tests! @@ -111,9 +113,6 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): hs.config.event_cache_size = 1 hs.config.password_providers = [] - self.db_pool = hs.get_db_pool() - self.engine = hs.database_engine - self.as_list = [ {"token": "token1", "url": "https://matrix-as.org", "id": "id_1"}, {"token": "alpha_tok", "url": "https://alpha.com", "id": "id_alpha"}, @@ -125,8 +124,15 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): self.as_yaml_files = [] - database = Database(hs) - self.store = TestTransactionStore(database, hs.get_db_conn(), hs) + # We assume there is only one database in these tests + database = hs.get_datastores().databases[0] + self.db_pool = database._db_pool + self.engine = database.engine + + db_config = hs.config.get_single_database() + self.store = TestTransactionStore( + database, make_conn(db_config, self.engine), hs + ) def _add_service(self, url, as_token, id): as_yaml = dict( @@ -419,7 +425,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): hs.config.event_cache_size = 1 hs.config.password_providers = [] - ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs) + database = hs.get_datastores().databases[0] + ApplicationServiceStore( + database, make_conn(database._database_config, database.engine), hs + ) @defer.inlineCallbacks def test_duplicate_ids(self): @@ -435,7 +444,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): hs.config.password_providers = [] with self.assertRaises(ConfigError) as cm: - ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs) + database = hs.get_datastores().databases[0] + ApplicationServiceStore( + database, make_conn(database._database_config, database.engine), hs + ) e = cm.exception self.assertIn(f1, str(e)) @@ -456,7 +468,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): hs.config.password_providers = [] with self.assertRaises(ConfigError) as cm: - ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs) + database = hs.get_datastores().databases[0] + ApplicationServiceStore( + database, make_conn(database._database_config, database.engine), hs + ) e = cm.exception self.assertIn(f1, str(e)) diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index 537cfe9f64..cdee0a9e60 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -52,15 +52,17 @@ class SQLBaseStoreTestCase(unittest.TestCase): config = Mock() config._disable_native_upserts = True config.event_cache_size = 1 - config.database_config = {"name": "sqlite3"} - engine = create_engine(config.database_config) + hs = TestHomeServer("test", config=config) + + sqlite_config = {"name": "sqlite3"} + engine = create_engine(sqlite_config) fake_engine = Mock(wraps=engine) fake_engine.can_native_upsert = False - hs = TestHomeServer( - "test", db_pool=self.db_pool, config=config, database_engine=fake_engine - ) - self.datastore = SQLBaseStore(Database(hs), None, hs) + db = Database(Mock(), Mock(config=sqlite_config), fake_engine) + db._db_pool = self.db_pool + + self.datastore = SQLBaseStore(db, None, hs) @defer.inlineCallbacks def test_insert_1col(self): diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 4578cc3b60..ed5786865a 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -26,7 +26,6 @@ class RegistrationStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): hs = yield setup_test_homeserver(self.addCleanup) - self.db_pool = hs.get_db_pool() self.store = hs.get_datastore() diff --git a/tests/utils.py b/tests/utils.py index 585f305b9a..9f5bf40b4b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -30,6 +30,7 @@ from twisted.internet import defer, reactor from synapse.api.constants import EventTypes from synapse.api.errors import CodeMessageException, cs_error from synapse.api.room_versions import RoomVersions +from synapse.config.database import DatabaseConnectionConfig from synapse.config.homeserver import HomeServerConfig from synapse.config.server import DEFAULT_ROOM_VERSION from synapse.federation.transport import server as federation_server @@ -177,7 +178,6 @@ class TestHomeServer(HomeServer): DATASTORE_CLASS = DataStore -@defer.inlineCallbacks def setup_test_homeserver( cleanup_func, name="test", @@ -214,7 +214,7 @@ def setup_test_homeserver( if USE_POSTGRES_FOR_TESTS: test_db = "synapse_test_%s" % uuid.uuid4().hex - config.database_config = { + database_config = { "name": "psycopg2", "args": { "database": test_db, @@ -226,12 +226,15 @@ def setup_test_homeserver( }, } else: - config.database_config = { + database_config = { "name": "sqlite3", "args": {"database": ":memory:", "cp_min": 1, "cp_max": 1}, } - db_engine = create_engine(config.database_config) + database = DatabaseConnectionConfig("master", database_config, ["main"]) + config.database.databases = [database] + + db_engine = create_engine(database.config) # Create the database before we actually try and connect to it, based off # the template database we generate in setupdb() @@ -251,11 +254,6 @@ def setup_test_homeserver( cur.close() db_conn.close() - # we need to configure the connection pool to run the on_new_connection - # function, so that we can test code that uses custom sqlite functions - # (like rank). - config.database_config["args"]["cp_openfun"] = db_engine.on_new_connection - if datastore is None: hs = homeserverToUse( name, @@ -267,21 +265,19 @@ def setup_test_homeserver( **kargs ) - # Prepare the DB on SQLite -- PostgreSQL is a copy of an already up to - # date db - if not isinstance(db_engine, PostgresEngine): - db_conn = hs.get_db_conn() - yield prepare_database(db_conn, db_engine, config) - db_conn.commit() - db_conn.close() + hs.setup() + if homeserverToUse.__name__ == "TestHomeServer": + hs.setup_master() + + if isinstance(db_engine, PostgresEngine): + database = hs.get_datastores().databases[0] - else: # We need to do cleanup on PostgreSQL def cleanup(): import psycopg2 # Close all the db pools - hs.get_db_pool().close() + database._db_pool.close() dropped = False @@ -320,23 +316,12 @@ def setup_test_homeserver( # Register the cleanup hook cleanup_func(cleanup) - hs.setup() - if homeserverToUse.__name__ == "TestHomeServer": - hs.setup_master() else: - # If we have been given an explicit datastore we probably want to mock - # out the DataStores somehow too. This all feels a bit wrong, but then - # mocking the stores feels wrong too. - datastores = Mock(datastore=datastore) - hs = homeserverToUse( name, - db_pool=None, datastore=datastore, - datastores=datastores, config=config, version_string="Synapse/tests", - database_engine=db_engine, tls_server_context_factory=Mock(), tls_client_options_factory=Mock(), reactor=reactor, -- cgit 1.4.1 From a831d2e4e3c424fb54f186bfa7d83a17965f933e Mon Sep 17 00:00:00 2001 From: Michael Kaye <1917473+michaelkaye@users.noreply.github.com> Date: Wed, 5 Feb 2020 08:57:38 +0000 Subject: Reduce performance logging to DEBUG (#6833) * Reduce tnx performance logging to DEBUG * Changelog.d --- changelog.d/6833.misc | 1 + synapse/storage/database.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelog.d/6833.misc (limited to 'synapse/storage/database.py') diff --git a/changelog.d/6833.misc b/changelog.d/6833.misc new file mode 100644 index 0000000000..8a0605f90b --- /dev/null +++ b/changelog.d/6833.misc @@ -0,0 +1 @@ +Reducing log level to DEBUG for synapse.storage.TIME. diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 1003dd84a5..3eeb2f7c04 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -343,7 +343,7 @@ class Database(object): top_three_counters = self._txn_perf_counters.interval(duration, limit=3) - perf_logger.info( + perf_logger.debug( "Total database time: %.3f%% {%s}", ratio * 100, top_three_counters ) -- cgit 1.4.1 From 7b7c3cedf2fdc0d0c05bbc651e0ff5b59921c3a2 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 19 Feb 2020 15:47:11 +0000 Subject: Minor perf fixes to `get_auth_chain_ids`. --- changelog.d/6954.misc | 1 + synapse/storage/data_stores/main/event_federation.py | 10 ++++------ synapse/storage/database.py | 2 +- 3 files changed, 6 insertions(+), 7 deletions(-) create mode 100644 changelog.d/6954.misc (limited to 'synapse/storage/database.py') diff --git a/changelog.d/6954.misc b/changelog.d/6954.misc new file mode 100644 index 0000000000..8b84ce2f19 --- /dev/null +++ b/changelog.d/6954.misc @@ -0,0 +1 @@ +Minor perf fixes to `get_auth_chain_ids`. diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py index e16da2577d..750ec1b70d 100644 --- a/synapse/storage/data_stores/main/event_federation.py +++ b/synapse/storage/data_stores/main/event_federation.py @@ -16,7 +16,6 @@ import itertools import logging from typing import List, Optional, Set -from six.moves import range from six.moves.queue import Empty, PriorityQueue from twisted.internet import defer @@ -28,6 +27,7 @@ from synapse.storage.data_stores.main.events_worker import EventsWorkerStore from synapse.storage.data_stores.main.signatures import SignatureWorkerStore from synapse.storage.database import Database from synapse.util.caches.descriptors import cached +from synapse.util.iterutils import batch_iter logger = logging.getLogger(__name__) @@ -88,14 +88,12 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas front = set(event_ids) while front: new_front = set() - front_list = list(front) - chunks = [front_list[x : x + 100] for x in range(0, len(front), 100)] - for chunk in chunks: + for chunk in batch_iter(front, 100): clause, args = make_in_list_sql_clause( txn.database_engine, "event_id", chunk ) - txn.execute(base_sql + clause, list(args)) - new_front.update([r[0] for r in txn]) + txn.execute(base_sql + clause, args) + new_front.update(r[0] for r in txn) new_front -= ignore_events new_front -= results diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 3eeb2f7c04..6dcb5c04da 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -1504,7 +1504,7 @@ class Database(object): def make_in_list_sql_clause( database_engine, column: str, iterable: Iterable -) -> Tuple[str, Iterable]: +) -> Tuple[str, list]: """Returns an SQL clause that checks the given column is in the iterable. On SQLite this expands to `column IN (?, ?, ...)`, whereas on Postgres -- cgit 1.4.1 From 509e381afa8c656e72f5fef3d651a9819794174a Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 21 Feb 2020 07:15:07 -0500 Subject: Clarify list/set/dict/tuple comprehensions and enforce via flake8 (#6957) Ensure good comprehension hygiene using flake8-comprehensions. --- CONTRIBUTING.md | 2 +- changelog.d/6957.misc | 1 + docs/code_style.md | 2 +- scripts-dev/convert_server_keys.py | 2 +- synapse/app/_base.py | 2 +- synapse/app/federation_sender.py | 4 +-- synapse/app/pusher.py | 2 +- synapse/config/server.py | 4 +-- synapse/config/tls.py | 2 +- synapse/crypto/keyring.py | 6 ++-- synapse/federation/send_queue.py | 4 +-- synapse/groups/groups_server.py | 2 +- synapse/handlers/device.py | 2 +- synapse/handlers/directory.py | 4 +-- synapse/handlers/federation.py | 18 +++++----- synapse/handlers/presence.py | 6 ++-- synapse/handlers/receipts.py | 2 +- synapse/handlers/room.py | 2 +- synapse/handlers/search.py | 8 ++--- synapse/handlers/sync.py | 22 ++++++------ synapse/handlers/typing.py | 4 +-- synapse/logging/utils.py | 2 +- synapse/metrics/__init__.py | 2 +- synapse/metrics/background_process_metrics.py | 4 +-- synapse/push/bulk_push_rule_evaluator.py | 8 ++--- synapse/push/emailpusher.py | 2 +- synapse/push/mailer.py | 20 +++++------ synapse/push/pusherpool.py | 2 +- synapse/rest/admin/_base.py | 4 +-- synapse/rest/client/v1/push_rule.py | 6 ++-- synapse/rest/client/v1/pusher.py | 4 +-- synapse/rest/client/v2_alpha/sync.py | 2 +- synapse/rest/key/v2/remote_key_resource.py | 2 +- synapse/rest/media/v1/_base.py | 40 ++++++++++------------ synapse/state/v1.py | 10 +++--- synapse/state/v2.py | 8 ++--- synapse/storage/_base.py | 2 +- synapse/storage/background_updates.py | 2 +- synapse/storage/data_stores/main/appservice.py | 14 ++++---- synapse/storage/data_stores/main/client_ips.py | 4 +-- synapse/storage/data_stores/main/devices.py | 13 ++++--- .../storage/data_stores/main/event_federation.py | 2 +- synapse/storage/data_stores/main/events.py | 8 ++--- .../storage/data_stores/main/events_bg_updates.py | 2 +- synapse/storage/data_stores/main/events_worker.py | 6 ++-- synapse/storage/data_stores/main/push_rule.py | 8 ++--- synapse/storage/data_stores/main/receipts.py | 4 +-- synapse/storage/data_stores/main/roommember.py | 4 +-- synapse/storage/data_stores/main/state.py | 8 ++--- synapse/storage/data_stores/main/stream.py | 8 ++--- .../storage/data_stores/main/user_erasure_store.py | 4 +-- synapse/storage/data_stores/state/store.py | 4 +-- synapse/storage/database.py | 4 +-- synapse/storage/persist_events.py | 8 ++--- synapse/storage/prepare_database.py | 6 ++-- synapse/util/frozenutils.py | 2 +- synapse/visibility.py | 4 +-- tests/config/test_generate.py | 2 +- tests/federation/test_federation_server.py | 2 +- tests/handlers/test_presence.py | 4 +-- tests/handlers/test_typing.py | 6 ++-- tests/handlers/test_user_directory.py | 12 +++---- tests/push/test_email.py | 6 ++-- tests/push/test_http.py | 8 ++--- tests/rest/client/v2_alpha/test_sync.py | 28 ++++++++------- tests/storage/test__base.py | 4 +-- tests/storage/test_appservice.py | 36 +++++++++---------- tests/storage/test_cleanup_extrems.py | 10 +++--- tests/storage/test_event_metrics.py | 36 +++++++++---------- tests/storage/test_state.py | 2 +- tests/test_state.py | 18 +++------- tests/util/test_stream_change_cache.py | 18 +++------- tox.ini | 1 + 73 files changed, 251 insertions(+), 276 deletions(-) create mode 100644 changelog.d/6957.misc (limited to 'synapse/storage/database.py') diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 4b01b6ac8c..253a0ca648 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -60,7 +60,7 @@ python 3.6 and to install each tool: ``` # Install the dependencies -pip install -U black flake8 isort +pip install -U black flake8 flake8-comprehensions isort # Run the linter script ./scripts-dev/lint.sh diff --git a/changelog.d/6957.misc b/changelog.d/6957.misc new file mode 100644 index 0000000000..4f98030110 --- /dev/null +++ b/changelog.d/6957.misc @@ -0,0 +1 @@ +Use flake8-comprehensions to enforce good hygiene of list/set/dict comprehensions. diff --git a/docs/code_style.md b/docs/code_style.md index 71aecd41f7..6ef6f80290 100644 --- a/docs/code_style.md +++ b/docs/code_style.md @@ -30,7 +30,7 @@ The necessary tools are detailed below. Install `flake8` with: - pip install --upgrade flake8 + pip install --upgrade flake8 flake8-comprehensions Check all application and test code with: diff --git a/scripts-dev/convert_server_keys.py b/scripts-dev/convert_server_keys.py index 179be61c30..06b4c1e2ff 100644 --- a/scripts-dev/convert_server_keys.py +++ b/scripts-dev/convert_server_keys.py @@ -103,7 +103,7 @@ def main(): yaml.safe_dump(result, sys.stdout, default_flow_style=False) - rows = list(row for server, json in result.items() for row in rows_v2(server, json)) + rows = [row for server, json in result.items() for row in rows_v2(server, json)] cursor = connection.cursor() cursor.executemany( diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 109b1e2fb5..9ffd23c6df 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -141,7 +141,7 @@ def start_reactor( def quit_with_error(error_string): message_lines = error_string.split("\n") - line_length = max([len(l) for l in message_lines if len(l) < 80]) + 2 + line_length = max(len(l) for l in message_lines if len(l) < 80) + 2 sys.stderr.write("*" * line_length + "\n") for line in message_lines: sys.stderr.write(" %s\n" % (line.rstrip(),)) diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index 63a91f1177..b7fcf80ddc 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -262,7 +262,7 @@ class FederationSenderHandler(object): # ... as well as device updates and messages elif stream_name == DeviceListsStream.NAME: - hosts = set(row.destination for row in rows) + hosts = {row.destination for row in rows} for host in hosts: self.federation_sender.send_device_messages(host) @@ -270,7 +270,7 @@ class FederationSenderHandler(object): # The to_device stream includes stuff to be pushed to both local # clients and remote servers, so we ignore entities that start with # '@' (since they'll be local users rather than destinations). - hosts = set(row.entity for row in rows if not row.entity.startswith("@")) + hosts = {row.entity for row in rows if not row.entity.startswith("@")} for host in hosts: self.federation_sender.send_device_messages(host) diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py index e46b6ac598..84e9f8d5e2 100644 --- a/synapse/app/pusher.py +++ b/synapse/app/pusher.py @@ -158,7 +158,7 @@ class PusherReplicationHandler(ReplicationClientHandler): yield self.pusher_pool.on_new_notifications(token, token) elif stream_name == "receipts": yield self.pusher_pool.on_new_receipts( - token, token, set(row.room_id for row in rows) + token, token, {row.room_id for row in rows} ) except Exception: logger.exception("Error poking pushers") diff --git a/synapse/config/server.py b/synapse/config/server.py index 0ec1b0fadd..7525765fee 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -1066,12 +1066,12 @@ KNOWN_RESOURCES = ( def _check_resource_config(listeners): - resource_names = set( + resource_names = { res_name for listener in listeners for res in listener.get("resources", []) for res_name in res.get("names", []) - ) + } for resource in resource_names: if resource not in KNOWN_RESOURCES: diff --git a/synapse/config/tls.py b/synapse/config/tls.py index 97a12d51f6..a65538562b 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -260,7 +260,7 @@ class TlsConfig(Config): crypto.FILETYPE_ASN1, self.tls_certificate ) sha256_fingerprint = encode_base64(sha256(x509_certificate_bytes).digest()) - sha256_fingerprints = set(f["sha256"] for f in self.tls_fingerprints) + sha256_fingerprints = {f["sha256"] for f in self.tls_fingerprints} if sha256_fingerprint not in sha256_fingerprints: self.tls_fingerprints.append({"sha256": sha256_fingerprint}) diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 6fe5a6a26a..983f0ead8c 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -326,9 +326,7 @@ class Keyring(object): verify_requests (list[VerifyJsonRequest]): list of verify requests """ - remaining_requests = set( - (rq for rq in verify_requests if not rq.key_ready.called) - ) + remaining_requests = {rq for rq in verify_requests if not rq.key_ready.called} @defer.inlineCallbacks def do_iterations(): @@ -396,7 +394,7 @@ class Keyring(object): results = yield fetcher.get_keys(missing_keys) - completed = list() + completed = [] for verify_request in remaining_requests: server_name = verify_request.server_name diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 001bb304ae..876fb0e245 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -129,9 +129,9 @@ class FederationRemoteSendQueue(object): for key in keys[:i]: del self.presence_changed[key] - user_ids = set( + user_ids = { user_id for uids in self.presence_changed.values() for user_id in uids - ) + } keys = self.presence_destinations.keys() i = self.presence_destinations.bisect_left(position_to_delete) diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py index c106abae21..4f0dc0a209 100644 --- a/synapse/groups/groups_server.py +++ b/synapse/groups/groups_server.py @@ -608,7 +608,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler): user_results = yield self.store.get_users_in_group( group_id, include_private=True ) - if user_id in [user_result["user_id"] for user_result in user_results]: + if user_id in (user_result["user_id"] for user_result in user_results): raise SynapseError(400, "User already in group") content = { diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 50cea3f378..a514c30714 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -742,6 +742,6 @@ class DeviceListUpdater(object): # We clobber the seen updates since we've re-synced from a given # point. - self._seen_updates[user_id] = set([stream_id]) + self._seen_updates[user_id] = {stream_id} defer.returnValue(result) diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 921d887b24..0b23ca919a 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -72,7 +72,7 @@ class DirectoryHandler(BaseHandler): # TODO(erikj): Check if there is a current association. if not servers: users = yield self.state.get_current_users_in_room(room_id) - servers = set(get_domain_from_id(u) for u in users) + servers = {get_domain_from_id(u) for u in users} if not servers: raise SynapseError(400, "Failed to get server list") @@ -255,7 +255,7 @@ class DirectoryHandler(BaseHandler): ) users = yield self.state.get_current_users_in_room(room_id) - extra_servers = set(get_domain_from_id(u) for u in users) + extra_servers = {get_domain_from_id(u) for u in users} servers = set(extra_servers) | set(servers) # If this server is in the list of servers, return it first. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index eb20ef4aec..a689065f89 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -659,11 +659,11 @@ class FederationHandler(BaseHandler): # this can happen if a remote server claims that the state or # auth_events at an event in room A are actually events in room B - bad_events = list( + bad_events = [ (event_id, event.room_id) for event_id, event in fetched_events.items() if event.room_id != room_id - ) + ] for bad_event_id, bad_room_id in bad_events: # This is a bogus situation, but since we may only discover it a long time @@ -856,7 +856,7 @@ class FederationHandler(BaseHandler): # Don't bother processing events we already have. seen_events = await self.store.have_events_in_timeline( - set(e.event_id for e in events) + {e.event_id for e in events} ) events = [e for e in events if e.event_id not in seen_events] @@ -866,7 +866,7 @@ class FederationHandler(BaseHandler): event_map = {e.event_id: e for e in events} - event_ids = set(e.event_id for e in events) + event_ids = {e.event_id for e in events} # build a list of events whose prev_events weren't in the batch. # (XXX: this will include events whose prev_events we already have; that doesn't @@ -892,13 +892,13 @@ class FederationHandler(BaseHandler): state_events.update({s.event_id: s for s in state}) events_to_state[e_id] = state - required_auth = set( + required_auth = { a_id for event in events + list(state_events.values()) + list(auth_events.values()) for a_id in event.auth_event_ids() - ) + } auth_events.update( {e_id: event_map[e_id] for e_id in required_auth if e_id in event_map} ) @@ -1247,7 +1247,7 @@ class FederationHandler(BaseHandler): async def on_event_auth(self, event_id: str) -> List[EventBase]: event = await self.store.get_event(event_id) auth = await self.store.get_auth_chain( - [auth_id for auth_id in event.auth_event_ids()], include_given=True + list(event.auth_event_ids()), include_given=True ) return list(auth) @@ -2152,7 +2152,7 @@ class FederationHandler(BaseHandler): # Now get the current auth_chain for the event. local_auth_chain = await self.store.get_auth_chain( - [auth_id for auth_id in event.auth_event_ids()], include_given=True + list(event.auth_event_ids()), include_given=True ) # TODO: Check if we would now reject event_id. If so we need to tell @@ -2654,7 +2654,7 @@ class FederationHandler(BaseHandler): member_handler = self.hs.get_room_member_handler() yield member_handler.send_membership_event(None, event, context) else: - destinations = set(x.split(":", 1)[-1] for x in (sender_user_id, room_id)) + destinations = {x.split(":", 1)[-1] for x in (sender_user_id, room_id)} yield self.federation_client.forward_third_party_invite( destinations, room_id, event_dict ) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 202aa9294f..0d6cf2b008 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -313,7 +313,7 @@ class PresenceHandler(object): notified_presence_counter.inc(len(to_notify)) yield self._persist_and_notify(list(to_notify.values())) - self.unpersisted_users_changes |= set(s.user_id for s in new_states) + self.unpersisted_users_changes |= {s.user_id for s in new_states} self.unpersisted_users_changes -= set(to_notify.keys()) to_federation_ping = { @@ -698,7 +698,7 @@ class PresenceHandler(object): updates = yield self.current_state_for_users(target_user_ids) updates = list(updates.values()) - for user_id in set(target_user_ids) - set(u.user_id for u in updates): + for user_id in set(target_user_ids) - {u.user_id for u in updates}: updates.append(UserPresenceState.default(user_id)) now = self.clock.time_msec() @@ -886,7 +886,7 @@ class PresenceHandler(object): hosts = yield self.state.get_current_hosts_in_room(room_id) # Filter out ourselves. - hosts = set(host for host in hosts if host != self.server_name) + hosts = {host for host in hosts if host != self.server_name} self.federation.send_presence_to_destinations( states=[state], destinations=hosts diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 9283c039e3..8bc100db42 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -94,7 +94,7 @@ class ReceiptsHandler(BaseHandler): # no new receipts return False - affected_room_ids = list(set([r.room_id for r in receipts])) + affected_room_ids = list({r.room_id for r in receipts}) self.notifier.on_new_event("receipt_key", max_batch_id, rooms=affected_room_ids) # Note that the min here shouldn't be relied upon to be accurate. diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 76e8f61b74..8ee870f0bb 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -355,7 +355,7 @@ class RoomCreationHandler(BaseHandler): # If so, mark the new room as non-federatable as well creation_content["m.federate"] = False - initial_state = dict() + initial_state = {} # Replicate relevant room events types_to_copy = ( diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 110097eab9..ec1542d416 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -184,7 +184,7 @@ class SearchHandler(BaseHandler): membership_list=[Membership.JOIN], # membership_list=[Membership.JOIN, Membership.LEAVE, Membership.Ban], ) - room_ids = set(r.room_id for r in rooms) + room_ids = {r.room_id for r in rooms} # If doing a subset of all rooms seearch, check if any of the rooms # are from an upgraded room, and search their contents as well @@ -374,12 +374,12 @@ class SearchHandler(BaseHandler): ).to_string() if include_profile: - senders = set( + senders = { ev.sender for ev in itertools.chain( res["events_before"], [event], res["events_after"] ) - ) + } if res["events_after"]: last_event_id = res["events_after"][-1].event_id @@ -421,7 +421,7 @@ class SearchHandler(BaseHandler): state_results = {} if include_state: - rooms = set(e.room_id for e in allowed_events) + rooms = {e.room_id for e in allowed_events} for room_id in rooms: state = yield self.state_handler.get_current_state(room_id) state_results[room_id] = list(state.values()) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 4324bc702e..669dbc8a48 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -682,11 +682,9 @@ class SyncHandler(object): # FIXME: order by stream ordering rather than as returned by SQL if joined_user_ids or invited_user_ids: - summary["m.heroes"] = sorted( - [user_id for user_id in (joined_user_ids + invited_user_ids)] - )[0:5] + summary["m.heroes"] = sorted(joined_user_ids + invited_user_ids)[0:5] else: - summary["m.heroes"] = sorted([user_id for user_id in gone_user_ids])[0:5] + summary["m.heroes"] = sorted(gone_user_ids)[0:5] if not sync_config.filter_collection.lazy_load_members(): return summary @@ -697,9 +695,9 @@ class SyncHandler(object): # track which members the client should already know about via LL: # Ones which are already in state... - existing_members = set( + existing_members = { user_id for (typ, user_id) in state.keys() if typ == EventTypes.Member - ) + } # ...or ones which are in the timeline... for ev in batch.events: @@ -773,10 +771,10 @@ class SyncHandler(object): # We only request state for the members needed to display the # timeline: - members_to_fetch = set( + members_to_fetch = { event.sender # FIXME: we also care about invite targets etc. for event in batch.events - ) + } if full_state: # always make sure we LL ourselves so we know we're in the room @@ -1993,10 +1991,10 @@ def _calculate_state( ) } - c_ids = set(e for e in itervalues(current)) - ts_ids = set(e for e in itervalues(timeline_start)) - p_ids = set(e for e in itervalues(previous)) - tc_ids = set(e for e in itervalues(timeline_contains)) + c_ids = set(itervalues(current)) + ts_ids = set(itervalues(timeline_start)) + p_ids = set(itervalues(previous)) + tc_ids = set(itervalues(timeline_contains)) # If we are lazyloading room members, we explicitly add the membership events # for the senders in the timeline into the state block returned by /sync, diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 5406618431..391bceb0c4 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -198,7 +198,7 @@ class TypingHandler(object): now=now, obj=member, then=now + FEDERATION_PING_INTERVAL ) - for domain in set(get_domain_from_id(u) for u in users): + for domain in {get_domain_from_id(u) for u in users}: if domain != self.server_name: logger.debug("sending typing update to %s", domain) self.federation.build_and_send_edu( @@ -231,7 +231,7 @@ class TypingHandler(object): return users = yield self.state.get_current_users_in_room(room_id) - domains = set(get_domain_from_id(u) for u in users) + domains = {get_domain_from_id(u) for u in users} if self.server_name in domains: logger.info("Got typing update from %s: %r", user_id, content) diff --git a/synapse/logging/utils.py b/synapse/logging/utils.py index 6073fc2725..0c2527bd86 100644 --- a/synapse/logging/utils.py +++ b/synapse/logging/utils.py @@ -148,7 +148,7 @@ def trace_function(f): pathname=pathname, lineno=lineno, msg=msg, - args=tuple(), + args=(), exc_info=None, ) diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index 0b45e1f52a..0dba997a23 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -240,7 +240,7 @@ class BucketCollector(object): res.append(["+Inf", sum(data.values())]) metric = HistogramMetricFamily( - self.name, "", buckets=res, sum_value=sum([x * y for x, y in data.items()]) + self.name, "", buckets=res, sum_value=sum(x * y for x, y in data.items()) ) yield metric diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py index c53d2a0d40..b65bcd8806 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py @@ -80,13 +80,13 @@ _background_process_db_sched_duration = Counter( # map from description to a counter, so that we can name our logcontexts # incrementally. (It actually duplicates _background_process_start_count, but # it's much simpler to do so than to try to combine them.) -_background_process_counts = dict() # type: dict[str, int] +_background_process_counts = {} # type: dict[str, int] # map from description to the currently running background processes. # # it's kept as a dict of sets rather than a big set so that we can keep track # of process descriptions that no longer have any active processes. -_background_processes = dict() # type: dict[str, set[_BackgroundProcess]] +_background_processes = {} # type: dict[str, set[_BackgroundProcess]] # A lock that covers the above dicts _bg_metrics_lock = threading.Lock() diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 7d9f5a38d9..433ca2f416 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -400,11 +400,11 @@ class RulesForRoom(object): if logger.isEnabledFor(logging.DEBUG): logger.debug("Found members %r: %r", self.room_id, members.values()) - interested_in_user_ids = set( + interested_in_user_ids = { user_id for user_id, membership in itervalues(members) if membership == Membership.JOIN - ) + } logger.debug("Joined: %r", interested_in_user_ids) @@ -412,9 +412,9 @@ class RulesForRoom(object): interested_in_user_ids, on_invalidate=self.invalidate_all_cb ) - user_ids = set( + user_ids = { uid for uid, have_pusher in iteritems(if_users_with_pushers) if have_pusher - ) + } logger.debug("With pushers: %r", user_ids) diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py index 8c818a86bf..ba4551d619 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py @@ -204,7 +204,7 @@ class EmailPusher(object): yield self.send_notification(unprocessed, reason) yield self.save_last_stream_ordering_and_success( - max([ea["stream_ordering"] for ea in unprocessed]) + max(ea["stream_ordering"] for ea in unprocessed) ) # we update the throttle on all the possible unprocessed push actions diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index b13b646bfd..4ccaf178ce 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -526,12 +526,10 @@ class Mailer(object): # If the room doesn't have a name, say who the messages # are from explicitly to avoid, "messages in the Bob room" sender_ids = list( - set( - [ - notif_events[n["event_id"]].sender - for n in notifs_by_room[room_id] - ] - ) + { + notif_events[n["event_id"]].sender + for n in notifs_by_room[room_id] + } ) member_events = yield self.store.get_events( @@ -558,12 +556,10 @@ class Mailer(object): # If the reason room doesn't have a name, say who the messages # are from explicitly to avoid, "messages in the Bob room" sender_ids = list( - set( - [ - notif_events[n["event_id"]].sender - for n in notifs_by_room[reason["room_id"]] - ] - ) + { + notif_events[n["event_id"]].sender + for n in notifs_by_room[reason["room_id"]] + } ) member_events = yield self.store.get_events( diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index b9dca5bc63..01789a9fb4 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -191,7 +191,7 @@ class PusherPool: min_stream_id - 1, max_stream_id ) # This returns a tuple, user_id is at index 3 - users_affected = set([r[3] for r in updated_receipts]) + users_affected = {r[3] for r in updated_receipts} for u in users_affected: if u in self.pushers: diff --git a/synapse/rest/admin/_base.py b/synapse/rest/admin/_base.py index 459482eb6d..a96f75ce26 100644 --- a/synapse/rest/admin/_base.py +++ b/synapse/rest/admin/_base.py @@ -29,7 +29,7 @@ def historical_admin_path_patterns(path_regex): Note that this should only be used for existing endpoints: new ones should just register for the /_synapse/admin path. """ - return list( + return [ re.compile(prefix + path_regex) for prefix in ( "^/_synapse/admin/v1", @@ -37,7 +37,7 @@ def historical_admin_path_patterns(path_regex): "^/_matrix/client/unstable/admin", "^/_matrix/client/r0/admin", ) - ) + ] def admin_patterns(path_regex: str): diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index 4f74600239..9fd4908136 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -49,7 +49,7 @@ class PushRuleRestServlet(RestServlet): if self._is_worker: raise Exception("Cannot handle PUT /push_rules on worker") - spec = _rule_spec_from_path([x for x in path.split("/")]) + spec = _rule_spec_from_path(path.split("/")) try: priority_class = _priority_class_from_spec(spec) except InvalidRuleException as e: @@ -110,7 +110,7 @@ class PushRuleRestServlet(RestServlet): if self._is_worker: raise Exception("Cannot handle DELETE /push_rules on worker") - spec = _rule_spec_from_path([x for x in path.split("/")]) + spec = _rule_spec_from_path(path.split("/")) requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() @@ -138,7 +138,7 @@ class PushRuleRestServlet(RestServlet): rules = format_push_rules_for_user(requester.user, rules) - path = [x for x in path.split("/")][1:] + path = path.split("/")[1:] if path == []: # we're a reference impl: pedantry is our job. diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index 6f6b7aed6e..550a2f1b44 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -54,9 +54,9 @@ class PushersRestServlet(RestServlet): pushers = await self.hs.get_datastore().get_pushers_by_user_id(user.to_string()) - filtered_pushers = list( + filtered_pushers = [ {k: v for k, v in p.items() if k in ALLOWED_KEYS} for p in pushers - ) + ] return 200, {"pushers": filtered_pushers} diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index d8292ce29f..8fa68dd37f 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -72,7 +72,7 @@ class SyncRestServlet(RestServlet): """ PATTERNS = client_patterns("/sync$") - ALLOWED_PRESENCE = set(["online", "offline", "unavailable"]) + ALLOWED_PRESENCE = {"online", "offline", "unavailable"} def __init__(self, hs): super(SyncRestServlet, self).__init__() diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index 9d6813a047..4b6d030a57 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -149,7 +149,7 @@ class RemoteKey(DirectServeResource): time_now_ms = self.clock.time_msec() - cache_misses = dict() # type: Dict[str, Set[str]] + cache_misses = {} # type: Dict[str, Set[str]] for (server_name, key_id, from_server), results in cached.items(): results = [(result["ts_added_ms"], result) for result in results] diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index 65bbf00073..ba28dd089d 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -135,27 +135,25 @@ def add_file_headers(request, media_type, file_size, upload_name): # separators as defined in RFC2616. SP and HT are handled separately. # see _can_encode_filename_as_token. -_FILENAME_SEPARATOR_CHARS = set( - ( - "(", - ")", - "<", - ">", - "@", - ",", - ";", - ":", - "\\", - '"', - "/", - "[", - "]", - "?", - "=", - "{", - "}", - ) -) +_FILENAME_SEPARATOR_CHARS = { + "(", + ")", + "<", + ">", + "@", + ",", + ";", + ":", + "\\", + '"', + "/", + "[", + "]", + "?", + "=", + "{", + "}", +} def _can_encode_filename_as_token(x): diff --git a/synapse/state/v1.py b/synapse/state/v1.py index 24b7c0faef..9bf98d06f2 100644 --- a/synapse/state/v1.py +++ b/synapse/state/v1.py @@ -69,9 +69,9 @@ def resolve_events_with_store( unconflicted_state, conflicted_state = _seperate(state_sets) - needed_events = set( + needed_events = { event_id for event_ids in itervalues(conflicted_state) for event_id in event_ids - ) + } needed_event_count = len(needed_events) if event_map is not None: needed_events -= set(iterkeys(event_map)) @@ -261,11 +261,11 @@ def _resolve_state_events(conflicted_state, auth_events): def _resolve_auth_events(events, auth_events): - reverse = [i for i in reversed(_ordered_events(events))] + reverse = list(reversed(_ordered_events(events))) - auth_keys = set( + auth_keys = { key for event in events for key in event_auth.auth_types_for_event(event) - ) + } new_auth_events = {} for key in auth_keys: diff --git a/synapse/state/v2.py b/synapse/state/v2.py index 75fe58305a..0ffe6d8c14 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -105,7 +105,7 @@ def resolve_events_with_store( % (room_id, event.event_id, event.room_id,) ) - full_conflicted_set = set(eid for eid in full_conflicted_set if eid in event_map) + full_conflicted_set = {eid for eid in full_conflicted_set if eid in event_map} logger.debug("%d full_conflicted_set entries", len(full_conflicted_set)) @@ -233,7 +233,7 @@ def _get_auth_chain_difference(state_sets, event_map, state_res_store): auth_sets = [] for state_set in state_sets: - auth_ids = set( + auth_ids = { eid for key, eid in iteritems(state_set) if ( @@ -246,7 +246,7 @@ def _get_auth_chain_difference(state_sets, event_map, state_res_store): ) ) and eid not in common - ) + } auth_chain = yield state_res_store.get_auth_chain(auth_ids, common) auth_ids.update(auth_chain) @@ -275,7 +275,7 @@ def _seperate(state_sets): conflicted_state = {} for key in set(itertools.chain.from_iterable(state_sets)): - event_ids = set(state_set.get(key) for state_set in state_sets) + event_ids = {state_set.get(key) for state_set in state_sets} if len(event_ids) == 1: unconflicted_state[key] = event_ids.pop() else: diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index da3b99f93d..13de5f1f62 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -56,7 +56,7 @@ class SQLBaseStore(metaclass=ABCMeta): members_changed (iterable[str]): The user_ids of members that have changed """ - for host in set(get_domain_from_id(u) for u in members_changed): + for host in {get_domain_from_id(u) for u in members_changed}: self._attempt_to_invalidate_cache("is_host_joined", (room_id, host)) self._attempt_to_invalidate_cache("was_host_joined", (room_id, host)) diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index bd547f35cf..eb1a7e5002 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -189,7 +189,7 @@ class BackgroundUpdater(object): keyvalues=None, retcols=("update_name", "depends_on"), ) - in_flight = set(update["update_name"] for update in updates) + in_flight = {update["update_name"] for update in updates} for update in updates: if update["depends_on"] not in in_flight: self._background_update_queue.append(update["update_name"]) diff --git a/synapse/storage/data_stores/main/appservice.py b/synapse/storage/data_stores/main/appservice.py index b2f39649fd..efbc06c796 100644 --- a/synapse/storage/data_stores/main/appservice.py +++ b/synapse/storage/data_stores/main/appservice.py @@ -135,7 +135,7 @@ class ApplicationServiceTransactionWorkerStore( may be empty. """ results = yield self.db.simple_select_list( - "application_services_state", dict(state=state), ["as_id"] + "application_services_state", {"state": state}, ["as_id"] ) # NB: This assumes this class is linked with ApplicationServiceStore as_list = self.get_app_services() @@ -158,7 +158,7 @@ class ApplicationServiceTransactionWorkerStore( """ result = yield self.db.simple_select_one( "application_services_state", - dict(as_id=service.id), + {"as_id": service.id}, ["state"], allow_none=True, desc="get_appservice_state", @@ -177,7 +177,7 @@ class ApplicationServiceTransactionWorkerStore( A Deferred which resolves when the state was set successfully. """ return self.db.simple_upsert( - "application_services_state", dict(as_id=service.id), dict(state=state) + "application_services_state", {"as_id": service.id}, {"state": state} ) def create_appservice_txn(self, service, events): @@ -253,13 +253,15 @@ class ApplicationServiceTransactionWorkerStore( self.db.simple_upsert_txn( txn, "application_services_state", - dict(as_id=service.id), - dict(last_txn=txn_id), + {"as_id": service.id}, + {"last_txn": txn_id}, ) # Delete txn self.db.simple_delete_txn( - txn, "application_services_txns", dict(txn_id=txn_id, as_id=service.id) + txn, + "application_services_txns", + {"txn_id": txn_id, "as_id": service.id}, ) return self.db.runInteraction( diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py index 13f4c9c72e..e1ccb27142 100644 --- a/synapse/storage/data_stores/main/client_ips.py +++ b/synapse/storage/data_stores/main/client_ips.py @@ -530,7 +530,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): ((row["access_token"], row["ip"]), (row["user_agent"], row["last_seen"])) for row in rows ) - return list( + return [ { "access_token": access_token, "ip": ip, @@ -538,7 +538,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): "last_seen": last_seen, } for (access_token, ip), (user_agent, last_seen) in iteritems(results) - ) + ] @wrap_as_background_process("prune_old_user_ips") async def _prune_old_user_ips(self): diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index b7617efb80..d55733a4cd 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -137,7 +137,7 @@ class DeviceWorkerStore(SQLBaseStore): # get the cross-signing keys of the users in the list, so that we can # determine which of the device changes were cross-signing keys - users = set(r[0] for r in updates) + users = {r[0] for r in updates} master_key_by_user = {} self_signing_key_by_user = {} for user in users: @@ -446,7 +446,7 @@ class DeviceWorkerStore(SQLBaseStore): a set of user_ids and results_map is a mapping of user_id -> device_id -> device_info """ - user_ids = set(user_id for user_id, _ in query_list) + user_ids = {user_id for user_id, _ in query_list} user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids)) # We go and check if any of the users need to have their device lists @@ -454,10 +454,9 @@ class DeviceWorkerStore(SQLBaseStore): users_needing_resync = yield self.get_user_ids_requiring_device_list_resync( user_ids ) - user_ids_in_cache = ( - set(user_id for user_id, stream_id in user_map.items() if stream_id) - - users_needing_resync - ) + user_ids_in_cache = { + user_id for user_id, stream_id in user_map.items() if stream_id + } - users_needing_resync user_ids_not_in_cache = user_ids - user_ids_in_cache results = {} @@ -604,7 +603,7 @@ class DeviceWorkerStore(SQLBaseStore): rows = yield self.db.execute( "get_users_whose_signatures_changed", None, sql, user_id, from_key ) - return set(user for row in rows for user in json.loads(row[0])) + return {user for row in rows for user in json.loads(row[0])} else: return set() diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py index 750ec1b70d..49a7b8b433 100644 --- a/synapse/storage/data_stores/main/event_federation.py +++ b/synapse/storage/data_stores/main/event_federation.py @@ -426,7 +426,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas query, (room_id, event_id, False, limit - len(event_results)) ) - new_results = set(t[0] for t in txn) - seen_events + new_results = {t[0] for t in txn} - seen_events new_front |= new_results seen_events |= new_results diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index c9d0d68c3a..8ae23df00a 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -145,7 +145,7 @@ class EventsStore( return txn.fetchall() res = yield self.db.runInteraction("read_forward_extremities", fetch) - self._current_forward_extremities_amount = c_counter(list(x[0] for x in res)) + self._current_forward_extremities_amount = c_counter([x[0] for x in res]) @_retry_on_integrity_error @defer.inlineCallbacks @@ -598,11 +598,11 @@ class EventsStore( # We find out which membership events we may have deleted # and which we have added, then we invlidate the caches for all # those users. - members_changed = set( + members_changed = { state_key for ev_type, state_key in itertools.chain(to_delete, to_insert) if ev_type == EventTypes.Member - ) + } for member in members_changed: txn.call_after( @@ -1615,7 +1615,7 @@ class EventsStore( """ ) - referenced_state_groups = set(sg for sg, in txn) + referenced_state_groups = {sg for sg, in txn} logger.info( "[purge] found %i referenced state groups", len(referenced_state_groups) ) diff --git a/synapse/storage/data_stores/main/events_bg_updates.py b/synapse/storage/data_stores/main/events_bg_updates.py index 5177b71016..f54c8b1ee0 100644 --- a/synapse/storage/data_stores/main/events_bg_updates.py +++ b/synapse/storage/data_stores/main/events_bg_updates.py @@ -402,7 +402,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): keyvalues={}, retcols=("room_id",), ) - room_ids = set(row["room_id"] for row in rows) + room_ids = {row["room_id"] for row in rows} for room_id in room_ids: txn.call_after( self.get_latest_event_ids_in_room.invalidate, (room_id,) diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index 7251e819f5..47a3a26072 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -494,9 +494,9 @@ class EventsWorkerStore(SQLBaseStore): """ with Measure(self._clock, "_fetch_event_list"): try: - events_to_fetch = set( + events_to_fetch = { event_id for events, _ in event_list for event_id in events - ) + } row_dict = self.db.new_transaction( conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch @@ -804,7 +804,7 @@ class EventsWorkerStore(SQLBaseStore): desc="have_events_in_timeline", ) - return set(r["event_id"] for r in rows) + return {r["event_id"] for r in rows} @defer.inlineCallbacks def have_seen_events(self, event_ids): diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py index e2673ae073..62ac88d9f2 100644 --- a/synapse/storage/data_stores/main/push_rule.py +++ b/synapse/storage/data_stores/main/push_rule.py @@ -276,21 +276,21 @@ class PushRulesWorkerStore( # We ignore app service users for now. This is so that we don't fill # up the `get_if_users_have_pushers` cache with AS entries that we # know don't have pushers, nor even read receipts. - local_users_in_room = set( + local_users_in_room = { u for u in users_in_room if self.hs.is_mine_id(u) and not self.get_if_app_services_interested_in_user(u) - ) + } # users in the room who have pushers need to get push rules run because # that's how their pushers work if_users_with_pushers = yield self.get_if_users_have_pushers( local_users_in_room, on_invalidate=cache_context.invalidate ) - user_ids = set( + user_ids = { uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher - ) + } users_with_receipts = yield self.get_users_with_read_receipts_in_room( room_id, on_invalidate=cache_context.invalidate diff --git a/synapse/storage/data_stores/main/receipts.py b/synapse/storage/data_stores/main/receipts.py index 96e54d145e..0d932a0672 100644 --- a/synapse/storage/data_stores/main/receipts.py +++ b/synapse/storage/data_stores/main/receipts.py @@ -58,7 +58,7 @@ class ReceiptsWorkerStore(SQLBaseStore): @cachedInlineCallbacks() def get_users_with_read_receipts_in_room(self, room_id): receipts = yield self.get_receipts_for_room(room_id, "m.read") - return set(r["user_id"] for r in receipts) + return {r["user_id"] for r in receipts} @cached(num_args=2) def get_receipts_for_room(self, room_id, receipt_type): @@ -283,7 +283,7 @@ class ReceiptsWorkerStore(SQLBaseStore): args.append(limit) txn.execute(sql, args) - return list(r[0:5] + (json.loads(r[5]),) for r in txn) + return [r[0:5] + (json.loads(r[5]),) for r in txn] return self.db.runInteraction( "get_all_updated_receipts", get_all_updated_receipts_txn diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py index d5ced05701..d5bd0cb5cf 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py @@ -465,7 +465,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): txn.execute(sql % (clause,), args) - return set(row[0] for row in txn) + return {row[0] for row in txn} return await self.db.runInteraction( "get_users_server_still_shares_room_with", @@ -826,7 +826,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): GROUP BY room_id, user_id; """ txn.execute(sql, (user_id,)) - return set(row[0] for row in txn if row[1] == 0) + return {row[0] for row in txn if row[1] == 0} return self.db.runInteraction( "get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py index 3d34103e67..3a3b9a8e72 100644 --- a/synapse/storage/data_stores/main/state.py +++ b/synapse/storage/data_stores/main/state.py @@ -321,7 +321,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): desc="get_referenced_state_groups", ) - return set(row["state_group"] for row in rows) + return {row["state_group"] for row in rows} class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): @@ -367,7 +367,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): """ txn.execute(sql, (last_room_id, batch_size)) - room_ids = list(row[0] for row in txn) + room_ids = [row[0] for row in txn] if not room_ids: return True, set() @@ -384,7 +384,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): txn.execute(sql, (last_room_id, room_ids[-1], "%:" + self.server_name)) - joined_room_ids = set(row[0] for row in txn) + joined_room_ids = {row[0] for row in txn} left_rooms = set(room_ids) - joined_room_ids @@ -404,7 +404,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): retcols=("state_key",), ) - potentially_left_users = set(row["state_key"] for row in rows) + potentially_left_users = {row["state_key"] for row in rows} # Now lets actually delete the rooms from the DB. self.db.simple_delete_many_txn( diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py index 056b25b13a..ada5cce6c2 100644 --- a/synapse/storage/data_stores/main/stream.py +++ b/synapse/storage/data_stores/main/stream.py @@ -346,11 +346,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): from_key (str): The room_key portion of a StreamToken """ from_key = RoomStreamToken.parse_stream_token(from_key).stream - return set( + return { room_id for room_id in room_ids if self._events_stream_cache.has_entity_changed(room_id, from_key) - ) + } @defer.inlineCallbacks def get_room_events_stream_for_room( @@ -679,11 +679,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): ) events_before = yield self.get_events_as_list( - [e for e in results["before"]["event_ids"]], get_prev_content=True + list(results["before"]["event_ids"]), get_prev_content=True ) events_after = yield self.get_events_as_list( - [e for e in results["after"]["event_ids"]], get_prev_content=True + list(results["after"]["event_ids"]), get_prev_content=True ) return { diff --git a/synapse/storage/data_stores/main/user_erasure_store.py b/synapse/storage/data_stores/main/user_erasure_store.py index af8025bc17..ec6b8a4ffd 100644 --- a/synapse/storage/data_stores/main/user_erasure_store.py +++ b/synapse/storage/data_stores/main/user_erasure_store.py @@ -63,9 +63,9 @@ class UserErasureWorkerStore(SQLBaseStore): retcols=("user_id",), desc="are_users_erased", ) - erased_users = set(row["user_id"] for row in rows) + erased_users = {row["user_id"] for row in rows} - res = dict((u, u in erased_users) for u in user_ids) + res = {u: u in erased_users for u in user_ids} return res diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/data_stores/state/store.py index c4ee9b7ccb..57a5267663 100644 --- a/synapse/storage/data_stores/state/store.py +++ b/synapse/storage/data_stores/state/store.py @@ -520,11 +520,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): retcols=("state_group",), ) - remaining_state_groups = set( + remaining_state_groups = { row["state_group"] for row in rows if row["state_group"] not in state_groups_to_delete - ) + } logger.info( "[purge] de-delta-ing %i remaining state groups", diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 6dcb5c04da..1953614401 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -554,8 +554,8 @@ class Database(object): Returns: A list of dicts where the key is the column header. """ - col_headers = list(intern(str(column[0])) for column in cursor.description) - results = list(dict(zip(col_headers, row)) for row in cursor) + col_headers = [intern(str(column[0])) for column in cursor.description] + results = [dict(zip(col_headers, row)) for row in cursor] return results def execute(self, desc, decoder, query, *args): diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index b950550f23..0f9ac1cf09 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -602,14 +602,14 @@ class EventsPersistenceStorage(object): event_id_to_state_group.update(event_to_groups) # State groups of old_latest_event_ids - old_state_groups = set( + old_state_groups = { event_id_to_state_group[evid] for evid in old_latest_event_ids - ) + } # State groups of new_latest_event_ids - new_state_groups = set( + new_state_groups = { event_id_to_state_group[evid] for evid in new_latest_event_ids - ) + } # If they old and new groups are the same then we don't need to do # anything. diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index c285ef52a0..fc69c32a0a 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -345,9 +345,9 @@ def _upgrade_existing_database( "Could not open delta dir for version %d: %s" % (v, directory) ) - duplicates = set( + duplicates = { file_name for file_name, count in file_name_counter.items() if count > 1 - ) + } if duplicates: # We don't support using the same file name in the same delta version. raise PrepareDatabaseException( @@ -454,7 +454,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams) ), (modname,), ) - applied_deltas = set(d for d, in cur) + applied_deltas = {d for d, in cur} for (name, stream) in names_and_streams: if name in applied_deltas: continue diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py index 635b897d6c..f2ccd5e7c6 100644 --- a/synapse/util/frozenutils.py +++ b/synapse/util/frozenutils.py @@ -30,7 +30,7 @@ def freeze(o): return o try: - return tuple([freeze(i) for i in o]) + return tuple(freeze(i) for i in o) except TypeError: pass diff --git a/synapse/visibility.py b/synapse/visibility.py index d0abd8f04f..e60d9756b7 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -75,7 +75,7 @@ def filter_events_for_client( """ # Filter out events that have been soft failed so that we don't relay them # to clients. - events = list(e for e in events if not e.internal_metadata.is_soft_failed()) + events = [e for e in events if not e.internal_metadata.is_soft_failed()] types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id)) event_id_to_state = yield storage.state.get_state_for_events( @@ -97,7 +97,7 @@ def filter_events_for_client( erased_senders = yield storage.main.are_users_erased((e.sender for e in events)) if apply_retention_policies: - room_ids = set(e.room_id for e in events) + room_ids = {e.room_id for e in events} retention_policies = {} for room_id in room_ids: diff --git a/tests/config/test_generate.py b/tests/config/test_generate.py index 2684e662de..463855ecc8 100644 --- a/tests/config/test_generate.py +++ b/tests/config/test_generate.py @@ -48,7 +48,7 @@ class ConfigGenerationTestCase(unittest.TestCase): ) self.assertSetEqual( - set(["homeserver.yaml", "lemurs.win.log.config", "lemurs.win.signing.key"]), + {"homeserver.yaml", "lemurs.win.log.config", "lemurs.win.signing.key"}, set(os.listdir(self.dir)), ) diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py index e7d8699040..296dc887be 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py @@ -83,7 +83,7 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase): ) ) - self.assertEqual(members, set(["@user:other.example.com", u1])) + self.assertEqual(members, {"@user:other.example.com", u1}) self.assertEqual(len(channel.json_body["pdus"]), 6) def test_needs_to_be_in_room(self): diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index c171038df8..64915bafcd 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -338,7 +338,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): ) new_state = handle_timeout( - state, is_mine=True, syncing_user_ids=set([user_id]), now=now + state, is_mine=True, syncing_user_ids={user_id}, now=now ) self.assertIsNotNone(new_state) @@ -579,7 +579,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): ) self.assertEqual(expected_state.state, PresenceState.ONLINE) self.federation_sender.send_presence_to_destinations.assert_called_once_with( - destinations=set(("server2", "server3")), states=[expected_state] + destinations={"server2", "server3"}, states=[expected_state] ) def _add_new_user(self, room_id, user_id): diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 140cc0a3c2..07b204666e 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -129,12 +129,12 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): hs.get_auth().check_user_in_room = check_user_in_room def get_joined_hosts_for_room(room_id): - return set(member.domain for member in self.room_members) + return {member.domain for member in self.room_members} self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room def get_current_users_in_room(room_id): - return set(str(u) for u in self.room_members) + return {str(u) for u in self.room_members} hs.get_state_handler().get_current_users_in_room = get_current_users_in_room @@ -257,7 +257,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): member = RoomMember(ROOM_ID, U_APPLE.to_string()) self.handler._member_typing_until[member] = 1002000 - self.handler._room_typing[ROOM_ID] = set([U_APPLE.to_string()]) + self.handler._room_typing[ROOM_ID] = {U_APPLE.to_string()} self.assertEquals(self.event_source.get_current_key(), 0) diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 0a4765fff4..7b92bdbc47 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -114,7 +114,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): public_users = self.get_users_in_public_rooms() self.assertEqual( - self._compress_shared(shares_private), set([(u1, u2, room), (u2, u1, room)]) + self._compress_shared(shares_private), {(u1, u2, room), (u2, u1, room)} ) self.assertEqual(public_users, []) @@ -169,7 +169,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): public_users = self.get_users_in_public_rooms() self.assertEqual( - self._compress_shared(shares_private), set([(u1, u2, room), (u2, u1, room)]) + self._compress_shared(shares_private), {(u1, u2, room), (u2, u1, room)} ) self.assertEqual(public_users, []) @@ -226,7 +226,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): public_users = self.get_users_in_public_rooms() self.assertEqual( - self._compress_shared(shares_private), set([(u1, u2, room), (u2, u1, room)]) + self._compress_shared(shares_private), {(u1, u2, room), (u2, u1, room)} ) self.assertEqual(public_users, []) @@ -358,12 +358,12 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): public_users = self.get_users_in_public_rooms() # User 1 and User 2 are in the same public room - self.assertEqual(set(public_users), set([(u1, room), (u2, room)])) + self.assertEqual(set(public_users), {(u1, room), (u2, room)}) # User 1 and User 3 share private rooms self.assertEqual( self._compress_shared(shares_private), - set([(u1, u3, private_room), (u3, u1, private_room)]), + {(u1, u3, private_room), (u3, u1, private_room)}, ) def test_initial_share_all_users(self): @@ -398,7 +398,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): # No users share rooms self.assertEqual(public_users, []) - self.assertEqual(self._compress_shared(shares_private), set([])) + self.assertEqual(self._compress_shared(shares_private), set()) # Despite not sharing a room, search_all_users means we get a search # result. diff --git a/tests/push/test_email.py b/tests/push/test_email.py index 80187406bc..83032cc9ea 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -163,7 +163,7 @@ class EmailPusherTests(HomeserverTestCase): # Get the stream ordering before it gets sent pushers = self.get_success( - self.hs.get_datastore().get_pushers_by(dict(user_name=self.user_id)) + self.hs.get_datastore().get_pushers_by({"user_name": self.user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 1) @@ -174,7 +174,7 @@ class EmailPusherTests(HomeserverTestCase): # It hasn't succeeded yet, so the stream ordering shouldn't have moved pushers = self.get_success( - self.hs.get_datastore().get_pushers_by(dict(user_name=self.user_id)) + self.hs.get_datastore().get_pushers_by({"user_name": self.user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 1) @@ -192,7 +192,7 @@ class EmailPusherTests(HomeserverTestCase): # The stream ordering has increased pushers = self.get_success( - self.hs.get_datastore().get_pushers_by(dict(user_name=self.user_id)) + self.hs.get_datastore().get_pushers_by({"user_name": self.user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 1) diff --git a/tests/push/test_http.py b/tests/push/test_http.py index fe3441f081..baf9c785f4 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -102,7 +102,7 @@ class HTTPPusherTests(HomeserverTestCase): # Get the stream ordering before it gets sent pushers = self.get_success( - self.hs.get_datastore().get_pushers_by(dict(user_name=user_id)) + self.hs.get_datastore().get_pushers_by({"user_name": user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 1) @@ -113,7 +113,7 @@ class HTTPPusherTests(HomeserverTestCase): # It hasn't succeeded yet, so the stream ordering shouldn't have moved pushers = self.get_success( - self.hs.get_datastore().get_pushers_by(dict(user_name=user_id)) + self.hs.get_datastore().get_pushers_by({"user_name": user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 1) @@ -132,7 +132,7 @@ class HTTPPusherTests(HomeserverTestCase): # The stream ordering has increased pushers = self.get_success( - self.hs.get_datastore().get_pushers_by(dict(user_name=user_id)) + self.hs.get_datastore().get_pushers_by({"user_name": user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 1) @@ -152,7 +152,7 @@ class HTTPPusherTests(HomeserverTestCase): # The stream ordering has increased, again pushers = self.get_success( - self.hs.get_datastore().get_pushers_by(dict(user_name=user_id)) + self.hs.get_datastore().get_pushers_by({"user_name": user_id}) ) pushers = list(pushers) self.assertEqual(len(pushers), 1) diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index 9c13a13786..fa3a3ec1bd 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -40,16 +40,14 @@ class FilterTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200) self.assertTrue( - set( - [ - "next_batch", - "rooms", - "presence", - "account_data", - "to_device", - "device_lists", - ] - ).issubset(set(channel.json_body.keys())) + { + "next_batch", + "rooms", + "presence", + "account_data", + "to_device", + "device_lists", + }.issubset(set(channel.json_body.keys())) ) def test_sync_presence_disabled(self): @@ -63,9 +61,13 @@ class FilterTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200) self.assertTrue( - set( - ["next_batch", "rooms", "account_data", "to_device", "device_lists"] - ).issubset(set(channel.json_body.keys())) + { + "next_batch", + "rooms", + "account_data", + "to_device", + "device_lists", + }.issubset(set(channel.json_body.keys())) ) diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index d491ea2924..e37260a820 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -373,7 +373,7 @@ class UpsertManyTests(unittest.HomeserverTestCase): ) self.assertEqual( set(self._dump_to_tuple(res)), - set([(1, "user1", "hello"), (2, "user2", "there")]), + {(1, "user1", "hello"), (2, "user2", "there")}, ) # Update only user2 @@ -400,5 +400,5 @@ class UpsertManyTests(unittest.HomeserverTestCase): ) self.assertEqual( set(self._dump_to_tuple(res)), - set([(1, "user1", "hello"), (2, "user2", "bleb")]), + {(1, "user1", "hello"), (2, "user2", "bleb")}, ) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index fd52512696..31710949a8 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -69,14 +69,14 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): pass def _add_appservice(self, as_token, id, url, hs_token, sender): - as_yaml = dict( - url=url, - as_token=as_token, - hs_token=hs_token, - id=id, - sender_localpart=sender, - namespaces={}, - ) + as_yaml = { + "url": url, + "as_token": as_token, + "hs_token": hs_token, + "id": id, + "sender_localpart": sender, + "namespaces": {}, + } # use the token as the filename with open(as_token, "w") as outfile: outfile.write(yaml.dump(as_yaml)) @@ -135,14 +135,14 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): ) def _add_service(self, url, as_token, id): - as_yaml = dict( - url=url, - as_token=as_token, - hs_token="something", - id=id, - sender_localpart="a_sender", - namespaces={}, - ) + as_yaml = { + "url": url, + "as_token": as_token, + "hs_token": "something", + "id": id, + "sender_localpart": "a_sender", + "namespaces": {}, + } # use the token as the filename with open(as_token, "w") as outfile: outfile.write(yaml.dump(as_yaml)) @@ -384,8 +384,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): ) self.assertEquals(2, len(services)) self.assertEquals( - set([self.as_list[2]["id"], self.as_list[0]["id"]]), - set([services[0].id, services[1].id]), + {self.as_list[2]["id"], self.as_list[0]["id"]}, + {services[0].id, services[1].id}, ) diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py index 029ac26454..0e04b2cf92 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py @@ -134,7 +134,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): latest_event_ids = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id) ) - self.assertEqual(set(latest_event_ids), set((event_id_a, event_id_b))) + self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b}) # Run the background update and check it did the right thing self.run_background_update() @@ -172,7 +172,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): latest_event_ids = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id) ) - self.assertEqual(set(latest_event_ids), set((event_id_a, event_id_b))) + self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b}) # Run the background update and check it did the right thing self.run_background_update() @@ -227,9 +227,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): latest_event_ids = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id) ) - self.assertEqual( - set(latest_event_ids), set((event_id_a, event_id_b, event_id_c)) - ) + self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b, event_id_c}) # Run the background update and check it did the right thing self.run_background_update() @@ -237,7 +235,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): latest_event_ids = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id) ) - self.assertEqual(set(latest_event_ids), set([event_id_b, event_id_c])) + self.assertEqual(set(latest_event_ids), {event_id_b, event_id_c}) class CleanupExtremDummyEventsTestCase(HomeserverTestCase): diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py index f26ff57a18..a7b7fd36d3 100644 --- a/tests/storage/test_event_metrics.py +++ b/tests/storage/test_event_metrics.py @@ -59,24 +59,22 @@ class ExtremStatisticsTestCase(HomeserverTestCase): ) ) - expected = set( - [ - b'synapse_forward_extremities_bucket{le="1.0"} 0.0', - b'synapse_forward_extremities_bucket{le="2.0"} 2.0', - b'synapse_forward_extremities_bucket{le="3.0"} 2.0', - b'synapse_forward_extremities_bucket{le="5.0"} 2.0', - b'synapse_forward_extremities_bucket{le="7.0"} 3.0', - b'synapse_forward_extremities_bucket{le="10.0"} 3.0', - b'synapse_forward_extremities_bucket{le="15.0"} 3.0', - b'synapse_forward_extremities_bucket{le="20.0"} 3.0', - b'synapse_forward_extremities_bucket{le="50.0"} 3.0', - b'synapse_forward_extremities_bucket{le="100.0"} 3.0', - b'synapse_forward_extremities_bucket{le="200.0"} 3.0', - b'synapse_forward_extremities_bucket{le="500.0"} 3.0', - b'synapse_forward_extremities_bucket{le="+Inf"} 3.0', - b"synapse_forward_extremities_count 3.0", - b"synapse_forward_extremities_sum 10.0", - ] - ) + expected = { + b'synapse_forward_extremities_bucket{le="1.0"} 0.0', + b'synapse_forward_extremities_bucket{le="2.0"} 2.0', + b'synapse_forward_extremities_bucket{le="3.0"} 2.0', + b'synapse_forward_extremities_bucket{le="5.0"} 2.0', + b'synapse_forward_extremities_bucket{le="7.0"} 3.0', + b'synapse_forward_extremities_bucket{le="10.0"} 3.0', + b'synapse_forward_extremities_bucket{le="15.0"} 3.0', + b'synapse_forward_extremities_bucket{le="20.0"} 3.0', + b'synapse_forward_extremities_bucket{le="50.0"} 3.0', + b'synapse_forward_extremities_bucket{le="100.0"} 3.0', + b'synapse_forward_extremities_bucket{le="200.0"} 3.0', + b'synapse_forward_extremities_bucket{le="500.0"} 3.0', + b'synapse_forward_extremities_bucket{le="+Inf"} 3.0', + b"synapse_forward_extremities_count 3.0", + b"synapse_forward_extremities_sum 10.0", + } self.assertEqual(items, expected) diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 04d58fbf24..0b88308ff4 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -394,7 +394,7 @@ class StateStoreTestCase(tests.unittest.TestCase): ) = self.state_datastore._state_group_cache.get(group) self.assertEqual(is_all, False) - self.assertEqual(known_absent, set([(e1.type, e1.state_key)])) + self.assertEqual(known_absent, {(e1.type, e1.state_key)}) self.assertDictEqual(state_dict_ids, {(e1.type, e1.state_key): e1.event_id}) ############################################ diff --git a/tests/test_state.py b/tests/test_state.py index d1578fe581..66f22f6813 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -254,9 +254,7 @@ class StateTestCase(unittest.TestCase): ctx_d = context_store["D"] prev_state_ids = yield ctx_d.get_prev_state_ids() - self.assertSetEqual( - {"START", "A", "C"}, {e_id for e_id in prev_state_ids.values()} - ) + self.assertSetEqual({"START", "A", "C"}, set(prev_state_ids.values())) self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event) self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group) @@ -313,9 +311,7 @@ class StateTestCase(unittest.TestCase): ctx_e = context_store["E"] prev_state_ids = yield ctx_e.get_prev_state_ids() - self.assertSetEqual( - {"START", "A", "B", "C"}, {e for e in prev_state_ids.values()} - ) + self.assertSetEqual({"START", "A", "B", "C"}, set(prev_state_ids.values())) self.assertEqual(ctx_c.state_group, ctx_e.state_group_before_event) self.assertEqual(ctx_e.state_group_before_event, ctx_e.state_group) @@ -388,9 +384,7 @@ class StateTestCase(unittest.TestCase): ctx_d = context_store["D"] prev_state_ids = yield ctx_d.get_prev_state_ids() - self.assertSetEqual( - {"A1", "A2", "A3", "A5", "B"}, {e for e in prev_state_ids.values()} - ) + self.assertSetEqual({"A1", "A2", "A3", "A5", "B"}, set(prev_state_ids.values())) self.assertEqual(ctx_b.state_group, ctx_d.state_group_before_event) self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group) @@ -482,7 +476,7 @@ class StateTestCase(unittest.TestCase): current_state_ids = yield context.get_current_state_ids() self.assertEqual( - set([e.event_id for e in old_state]), set(current_state_ids.values()) + {e.event_id for e in old_state}, set(current_state_ids.values()) ) self.assertEqual(group_name, context.state_group) @@ -513,9 +507,7 @@ class StateTestCase(unittest.TestCase): prev_state_ids = yield context.get_prev_state_ids() - self.assertEqual( - set([e.event_id for e in old_state]), set(prev_state_ids.values()) - ) + self.assertEqual({e.event_id for e in old_state}, set(prev_state_ids.values())) self.assertIsNotNone(context.state_group) diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py index f2be63706b..72a9de5370 100644 --- a/tests/util/test_stream_change_cache.py +++ b/tests/util/test_stream_change_cache.py @@ -67,7 +67,7 @@ class StreamChangeCacheTests(unittest.TestCase): # If we update an existing entity, it keeps the two existing entities cache.entity_has_changed("bar@baz.net", 5) self.assertEqual( - set(["bar@baz.net", "user@elsewhere.org"]), set(cache._entity_to_key) + {"bar@baz.net", "user@elsewhere.org"}, set(cache._entity_to_key) ) def test_get_all_entities_changed(self): @@ -137,7 +137,7 @@ class StreamChangeCacheTests(unittest.TestCase): cache.get_entities_changed( ["user@foo.com", "bar@baz.net", "user@elsewhere.org"], stream_pos=2 ), - set(["bar@baz.net", "user@elsewhere.org"]), + {"bar@baz.net", "user@elsewhere.org"}, ) # Query all the entries mid-way through the stream, but include one @@ -153,7 +153,7 @@ class StreamChangeCacheTests(unittest.TestCase): ], stream_pos=2, ), - set(["bar@baz.net", "user@elsewhere.org"]), + {"bar@baz.net", "user@elsewhere.org"}, ) # Query all the entries, but before the first known point. We will get @@ -168,21 +168,13 @@ class StreamChangeCacheTests(unittest.TestCase): ], stream_pos=0, ), - set( - [ - "user@foo.com", - "bar@baz.net", - "user@elsewhere.org", - "not@here.website", - ] - ), + {"user@foo.com", "bar@baz.net", "user@elsewhere.org", "not@here.website"}, ) # Query a subset of the entries mid-way through the stream. We should # only get back the subset. self.assertEqual( - cache.get_entities_changed(["bar@baz.net"], stream_pos=2), - set(["bar@baz.net"]), + cache.get_entities_changed(["bar@baz.net"], stream_pos=2), {"bar@baz.net"}, ) def test_max_pos(self): diff --git a/tox.ini b/tox.ini index b9132a3177..b715ea0bff 100644 --- a/tox.ini +++ b/tox.ini @@ -123,6 +123,7 @@ skip_install = True basepython = python3.6 deps = flake8 + flake8-comprehensions black==19.10b0 # We pin so that our tests don't start failing on new releases of black. commands = python -m black --check --diff . -- cgit 1.4.1 From 132b673dbefa42eb7669a11522426f26e225ac05 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 27 Feb 2020 11:53:40 +0000 Subject: Add some type annotations in `synapse.storage` (#6987) I cracked, and added some type definitions in synapse.storage. --- changelog.d/6987.misc | 1 + synapse/storage/database.py | 143 +++++++++++++++++++++--------------- synapse/storage/engines/__init__.py | 28 +++---- synapse/storage/engines/_base.py | 87 ++++++++++++++++++++++ synapse/storage/engines/postgres.py | 12 +-- synapse/storage/engines/sqlite.py | 13 ++-- synapse/storage/types.py | 65 ++++++++++++++++ tox.ini | 5 +- 8 files changed, 270 insertions(+), 84 deletions(-) create mode 100644 changelog.d/6987.misc create mode 100644 synapse/storage/types.py (limited to 'synapse/storage/database.py') diff --git a/changelog.d/6987.misc b/changelog.d/6987.misc new file mode 100644 index 0000000000..7ff74cda55 --- /dev/null +++ b/changelog.d/6987.misc @@ -0,0 +1 @@ +Add some type annotations to the database storage classes. diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 1953614401..609db40616 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -15,9 +15,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -import sys import time -from typing import Iterable, Tuple +from time import monotonic as monotonic_time +from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple from six import iteritems, iterkeys, itervalues from six.moves import intern, range @@ -32,24 +32,14 @@ from synapse.config.database import DatabaseConnectionConfig from synapse.logging.context import LoggingContext, make_deferred_yieldable from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.background_updates import BackgroundUpdater -from synapse.storage.engines import PostgresEngine, Sqlite3Engine +from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine +from synapse.storage.types import Connection, Cursor from synapse.util.stringutils import exception_to_unicode -# import a function which will return a monotonic time, in seconds -try: - # on python 3, use time.monotonic, since time.clock can go backwards - from time import monotonic as monotonic_time -except ImportError: - # ... but python 2 doesn't have it - from time import clock as monotonic_time - logger = logging.getLogger(__name__) -try: - MAX_TXN_ID = sys.maxint - 1 -except AttributeError: - # python 3 does not have a maximum int value - MAX_TXN_ID = 2 ** 63 - 1 +# python 3 does not have a maximum int value +MAX_TXN_ID = 2 ** 63 - 1 sql_logger = logging.getLogger("synapse.storage.SQL") transaction_logger = logging.getLogger("synapse.storage.txn") @@ -77,7 +67,7 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = { def make_pool( - reactor, db_config: DatabaseConnectionConfig, engine + reactor, db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine ) -> adbapi.ConnectionPool: """Get the connection pool for the database. """ @@ -90,7 +80,9 @@ def make_pool( ) -def make_conn(db_config: DatabaseConnectionConfig, engine): +def make_conn( + db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine +) -> Connection: """Make a new connection to the database and return it. Returns: @@ -107,20 +99,27 @@ def make_conn(db_config: DatabaseConnectionConfig, engine): return db_conn -class LoggingTransaction(object): +# The type of entry which goes on our after_callbacks and exception_callbacks lists. +# +# Python 3.5.2 doesn't support Callable with an ellipsis, so we wrap it in quotes so +# that mypy sees the type but the runtime python doesn't. +_CallbackListEntry = Tuple["Callable[..., None]", Iterable[Any], Dict[str, Any]] + + +class LoggingTransaction: """An object that almost-transparently proxies for the 'txn' object passed to the constructor. Adds logging and metrics to the .execute() method. Args: txn: The database transcation object to wrap. - name (str): The name of this transactions for logging. - database_engine (Sqlite3Engine|PostgresEngine) - after_callbacks(list|None): A list that callbacks will be appended to + name: The name of this transactions for logging. + database_engine + after_callbacks: A list that callbacks will be appended to that have been added by `call_after` which should be run on successful completion of the transaction. None indicates that no callbacks should be allowed to be scheduled to run. - exception_callbacks(list|None): A list that callbacks will be appended + exception_callbacks: A list that callbacks will be appended to that have been added by `call_on_exception` which should be run if transaction ends with an error. None indicates that no callbacks should be allowed to be scheduled to run. @@ -135,46 +134,67 @@ class LoggingTransaction(object): ] def __init__( - self, txn, name, database_engine, after_callbacks=None, exception_callbacks=None + self, + txn: Cursor, + name: str, + database_engine: BaseDatabaseEngine, + after_callbacks: Optional[List[_CallbackListEntry]] = None, + exception_callbacks: Optional[List[_CallbackListEntry]] = None, ): - object.__setattr__(self, "txn", txn) - object.__setattr__(self, "name", name) - object.__setattr__(self, "database_engine", database_engine) - object.__setattr__(self, "after_callbacks", after_callbacks) - object.__setattr__(self, "exception_callbacks", exception_callbacks) + self.txn = txn + self.name = name + self.database_engine = database_engine + self.after_callbacks = after_callbacks + self.exception_callbacks = exception_callbacks - def call_after(self, callback, *args, **kwargs): + def call_after(self, callback: "Callable[..., None]", *args, **kwargs): """Call the given callback on the main twisted thread after the transaction has finished. Used to invalidate the caches on the correct thread. """ + # if self.after_callbacks is None, that means that whatever constructed the + # LoggingTransaction isn't expecting there to be any callbacks; assert that + # is not the case. + assert self.after_callbacks is not None self.after_callbacks.append((callback, args, kwargs)) - def call_on_exception(self, callback, *args, **kwargs): + def call_on_exception(self, callback: "Callable[..., None]", *args, **kwargs): + # if self.exception_callbacks is None, that means that whatever constructed the + # LoggingTransaction isn't expecting there to be any callbacks; assert that + # is not the case. + assert self.exception_callbacks is not None self.exception_callbacks.append((callback, args, kwargs)) - def __getattr__(self, name): - return getattr(self.txn, name) + def fetchall(self) -> List[Tuple]: + return self.txn.fetchall() - def __setattr__(self, name, value): - setattr(self.txn, name, value) + def fetchone(self) -> Tuple: + return self.txn.fetchone() - def __iter__(self): + def __iter__(self) -> Iterator[Tuple]: return self.txn.__iter__() + @property + def rowcount(self) -> int: + return self.txn.rowcount + + @property + def description(self) -> Any: + return self.txn.description + def execute_batch(self, sql, args): if isinstance(self.database_engine, PostgresEngine): - from psycopg2.extras import execute_batch + from psycopg2.extras import execute_batch # type: ignore self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args) else: for val in args: self.execute(sql, val) - def execute(self, sql, *args): + def execute(self, sql: str, *args: Any): self._do_execute(self.txn.execute, sql, *args) - def executemany(self, sql, *args): + def executemany(self, sql: str, *args: Any): self._do_execute(self.txn.executemany, sql, *args) def _make_sql_one_line(self, sql): @@ -207,6 +227,9 @@ class LoggingTransaction(object): sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs) sql_query_timer.labels(sql.split()[0]).observe(secs) + def close(self): + self.txn.close() + class PerformanceCounters(object): def __init__(self): @@ -251,7 +274,9 @@ class Database(object): _TXN_ID = 0 - def __init__(self, hs, database_config: DatabaseConnectionConfig, engine): + def __init__( + self, hs, database_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine + ): self.hs = hs self._clock = hs.get_clock() self._database_config = database_config @@ -259,9 +284,9 @@ class Database(object): self.updates = BackgroundUpdater(hs, self) - self._previous_txn_total_time = 0 - self._current_txn_total_time = 0 - self._previous_loop_ts = 0 + self._previous_txn_total_time = 0.0 + self._current_txn_total_time = 0.0 + self._previous_loop_ts = 0.0 # TODO(paul): These can eventually be removed once the metrics code # is running in mainline, and we have some nice monitoring frontends @@ -463,23 +488,23 @@ class Database(object): sql_txn_timer.labels(desc).observe(duration) @defer.inlineCallbacks - def runInteraction(self, desc, func, *args, **kwargs): + def runInteraction(self, desc: str, func: Callable, *args: Any, **kwargs: Any): """Starts a transaction on the database and runs a given function Arguments: - desc (str): description of the transaction, for logging and metrics - func (func): callback function, which will be called with a + desc: description of the transaction, for logging and metrics + func: callback function, which will be called with a database transaction (twisted.enterprise.adbapi.Transaction) as its first argument, followed by `args` and `kwargs`. - args (list): positional args to pass to `func` - kwargs (dict): named args to pass to `func` + args: positional args to pass to `func` + kwargs: named args to pass to `func` Returns: Deferred: The result of func """ - after_callbacks = [] - exception_callbacks = [] + after_callbacks = [] # type: List[_CallbackListEntry] + exception_callbacks = [] # type: List[_CallbackListEntry] if LoggingContext.current_context() == LoggingContext.sentinel: logger.warning("Starting db txn '%s' from sentinel context", desc) @@ -505,15 +530,15 @@ class Database(object): return result @defer.inlineCallbacks - def runWithConnection(self, func, *args, **kwargs): + def runWithConnection(self, func: Callable, *args: Any, **kwargs: Any): """Wraps the .runWithConnection() method on the underlying db_pool. Arguments: - func (func): callback function, which will be called with a + func: callback function, which will be called with a database connection (twisted.enterprise.adbapi.Connection) as its first argument, followed by `args` and `kwargs`. - args (list): positional args to pass to `func` - kwargs (dict): named args to pass to `func` + args: positional args to pass to `func` + kwargs: named args to pass to `func` Returns: Deferred: The result of func @@ -800,7 +825,7 @@ class Database(object): return False # We didn't find any existing rows, so insert a new one - allvalues = {} + allvalues = {} # type: Dict[str, Any] allvalues.update(keyvalues) allvalues.update(values) allvalues.update(insertion_values) @@ -829,7 +854,7 @@ class Database(object): Returns: None """ - allvalues = {} + allvalues = {} # type: Dict[str, Any] allvalues.update(keyvalues) allvalues.update(insertion_values) @@ -916,7 +941,7 @@ class Database(object): Returns: None """ - allnames = [] + allnames = [] # type: List[str] allnames.extend(key_names) allnames.extend(value_names) @@ -1100,7 +1125,7 @@ class Database(object): keyvalues : dict of column names and values to select the rows with retcols : list of strings giving the names of the columns to return """ - results = [] + results = [] # type: List[Dict[str, Any]] if not iterable: return results @@ -1439,7 +1464,7 @@ class Database(object): raise ValueError("order_direction must be one of 'ASC' or 'DESC'.") where_clause = "WHERE " if filters or keyvalues else "" - arg_list = [] + arg_list = [] # type: List[Any] if filters: where_clause += " AND ".join("%s LIKE ?" % (k,) for k in filters) arg_list += list(filters.values()) diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py index 9d2d519922..035f9ea6e9 100644 --- a/synapse/storage/engines/__init__.py +++ b/synapse/storage/engines/__init__.py @@ -12,29 +12,31 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -import importlib import platform -from ._base import IncorrectDatabaseSetup +from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup from .postgres import PostgresEngine from .sqlite import Sqlite3Engine -SUPPORTED_MODULE = {"sqlite3": Sqlite3Engine, "psycopg2": PostgresEngine} - -def create_engine(database_config): +def create_engine(database_config) -> BaseDatabaseEngine: name = database_config["name"] - engine_class = SUPPORTED_MODULE.get(name, None) - if engine_class: + if name == "sqlite3": + import sqlite3 + + return Sqlite3Engine(sqlite3, database_config) + + if name == "psycopg2": # pypy requires psycopg2cffi rather than psycopg2 - if name == "psycopg2" and platform.python_implementation() == "PyPy": - name = "psycopg2cffi" - module = importlib.import_module(name) - return engine_class(module, database_config) + if platform.python_implementation() == "PyPy": + import psycopg2cffi as psycopg2 # type: ignore + else: + import psycopg2 # type: ignore + + return PostgresEngine(psycopg2, database_config) raise RuntimeError("Unsupported database engine '%s'" % (name,)) -__all__ = ["create_engine", "IncorrectDatabaseSetup"] +__all__ = ["create_engine", "BaseDatabaseEngine", "IncorrectDatabaseSetup"] diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py index ec5a4d198b..ab0bbe4bd3 100644 --- a/synapse/storage/engines/_base.py +++ b/synapse/storage/engines/_base.py @@ -12,7 +12,94 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import abc +from typing import Generic, TypeVar + +from synapse.storage.types import Connection class IncorrectDatabaseSetup(RuntimeError): pass + + +ConnectionType = TypeVar("ConnectionType", bound=Connection) + + +class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta): + def __init__(self, module, database_config: dict): + self.module = module + + @property + @abc.abstractmethod + def single_threaded(self) -> bool: + ... + + @property + @abc.abstractmethod + def can_native_upsert(self) -> bool: + """ + Do we support native UPSERTs? + """ + ... + + @property + @abc.abstractmethod + def supports_tuple_comparison(self) -> bool: + """ + Do we support comparing tuples, i.e. `(a, b) > (c, d)`? + """ + ... + + @property + @abc.abstractmethod + def supports_using_any_list(self) -> bool: + """ + Do we support using `a = ANY(?)` and passing a list + """ + ... + + @abc.abstractmethod + def check_database( + self, db_conn: ConnectionType, allow_outdated_version: bool = False + ) -> None: + ... + + @abc.abstractmethod + def check_new_database(self, txn) -> None: + """Gets called when setting up a brand new database. This allows us to + apply stricter checks on new databases versus existing database. + """ + ... + + @abc.abstractmethod + def convert_param_style(self, sql: str) -> str: + ... + + @abc.abstractmethod + def on_new_connection(self, db_conn: ConnectionType) -> None: + ... + + @abc.abstractmethod + def is_deadlock(self, error: Exception) -> bool: + ... + + @abc.abstractmethod + def is_connection_closed(self, conn: ConnectionType) -> bool: + ... + + @abc.abstractmethod + def lock_table(self, txn, table: str) -> None: + ... + + @abc.abstractmethod + def get_next_state_group_id(self, txn) -> int: + """Returns an int that can be used as a new state_group ID + """ + ... + + @property + @abc.abstractmethod + def server_version(self) -> str: + """Gets a string giving the server version. For example: '3.22.0' + """ + ... diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index 53b3f372b0..6c7d08a6f2 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -15,16 +15,14 @@ import logging -from ._base import IncorrectDatabaseSetup +from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup logger = logging.getLogger(__name__) -class PostgresEngine(object): - single_threaded = False - +class PostgresEngine(BaseDatabaseEngine): def __init__(self, database_module, database_config): - self.module = database_module + super().__init__(database_module, database_config) self.module.extensions.register_type(self.module.extensions.UNICODE) # Disables passing `bytes` to txn.execute, c.f. #6186. If you do @@ -36,6 +34,10 @@ class PostgresEngine(object): self.synchronous_commit = database_config.get("synchronous_commit", True) self._version = None # unknown as yet + @property + def single_threaded(self) -> bool: + return False + def check_database(self, db_conn, allow_outdated_version: bool = False): # Get the version of PostgreSQL that we're using. As per the psycopg2 # docs: The number is formed by converting the major, minor, and diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py index 641e490697..2bfeefd54e 100644 --- a/synapse/storage/engines/sqlite.py +++ b/synapse/storage/engines/sqlite.py @@ -12,16 +12,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import sqlite3 import struct import threading +from synapse.storage.engines import BaseDatabaseEngine -class Sqlite3Engine(object): - single_threaded = True +class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection]): def __init__(self, database_module, database_config): - self.module = database_module + super().__init__(database_module, database_config) database = database_config.get("args", {}).get("database") self._is_in_memory = database in (None, ":memory:",) @@ -31,6 +31,10 @@ class Sqlite3Engine(object): self._current_state_group_id = None self._current_state_group_id_lock = threading.Lock() + @property + def single_threaded(self) -> bool: + return True + @property def can_native_upsert(self): """ @@ -68,7 +72,6 @@ class Sqlite3Engine(object): return sql def on_new_connection(self, db_conn): - # We need to import here to avoid an import loop. from synapse.storage.prepare_database import prepare_database diff --git a/synapse/storage/types.py b/synapse/storage/types.py new file mode 100644 index 0000000000..daff81c5ee --- /dev/null +++ b/synapse/storage/types.py @@ -0,0 +1,65 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Iterable, Iterator, List, Tuple + +from typing_extensions import Protocol + + +""" +Some very basic protocol definitions for the DB-API2 classes specified in PEP-249 +""" + + +class Cursor(Protocol): + def execute(self, sql: str, parameters: Iterable[Any] = ...) -> Any: + ... + + def executemany(self, sql: str, parameters: Iterable[Iterable[Any]]) -> Any: + ... + + def fetchall(self) -> List[Tuple]: + ... + + def fetchone(self) -> Tuple: + ... + + @property + def description(self) -> Any: + return None + + @property + def rowcount(self) -> int: + return 0 + + def __iter__(self) -> Iterator[Tuple]: + ... + + def close(self) -> None: + ... + + +class Connection(Protocol): + def cursor(self) -> Cursor: + ... + + def close(self) -> None: + ... + + def commit(self) -> None: + ... + + def rollback(self, *args, **kwargs) -> None: + ... diff --git a/tox.ini b/tox.ini index 4ccfde01b5..6521535137 100644 --- a/tox.ini +++ b/tox.ini @@ -168,7 +168,6 @@ commands= coverage html [testenv:mypy] -basepython = python3.7 skip_install = True deps = {[base]deps} @@ -179,7 +178,8 @@ env = extras = all commands = mypy \ synapse/api \ - synapse/config/ \ + synapse/appservice \ + synapse/config \ synapse/events/spamcheck.py \ synapse/federation/sender \ synapse/federation/transport \ @@ -192,6 +192,7 @@ commands = mypy \ synapse/rest \ synapse/spam_checker_api \ synapse/storage/engines \ + synapse/storage/database.py \ synapse/streams # To find all folders that pass mypy you run: -- cgit 1.4.1 From dc6fb56c5ffb41d907b7fd645a701c2d9684afc3 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Tue, 10 Mar 2020 14:40:28 +0000 Subject: Hopefully mypy is happy now --- synapse/logging/context.py | 3 ++- synapse/storage/database.py | 10 ++++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) (limited to 'synapse/storage/database.py') diff --git a/synapse/logging/context.py b/synapse/logging/context.py index 56805120be..860b99a4c6 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -210,7 +210,7 @@ class LoggingContext(object): class Sentinel(object): """Sentinel to represent the root context""" - __slots__ = ["previous_context", "alive", "request", "scope"] + __slots__ = ["previous_context", "alive", "request", "scope", "tag"] def __init__(self) -> None: # Minimal set for compatibility with LoggingContext @@ -218,6 +218,7 @@ class LoggingContext(object): self.alive = None self.request = None self.scope = None + self.tag = None def __str__(self): return "sentinel" diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 609db40616..e61595336c 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -29,7 +29,11 @@ from twisted.internet import defer from synapse.api.errors import StoreError from synapse.config.database import DatabaseConnectionConfig -from synapse.logging.context import LoggingContext, make_deferred_yieldable +from synapse.logging.context import ( + LoggingContext, + LoggingContextOrSentinel, + make_deferred_yieldable, +) from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.background_updates import BackgroundUpdater from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine @@ -543,7 +547,9 @@ class Database(object): Returns: Deferred: The result of func """ - parent_context = LoggingContext.current_context() + parent_context = ( + LoggingContext.current_context() + ) # type: Optional[LoggingContextOrSentinel] if parent_context == LoggingContext.sentinel: logger.warning( "Starting db connection from sentinel context: metrics will be lost" -- cgit 1.4.1 From 39230d217104f3cd7aba9065dc478f935ce1e614 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 24 Mar 2020 14:45:33 +0000 Subject: Clean up some LoggingContext stuff (#7120) * Pull Sentinel out of LoggingContext ... and drop a few unnecessary references to it * Factor out LoggingContext.current_context move `current_context` and `set_context` out to top-level functions. Mostly this means that I can more easily trace what's actually referring to LoggingContext, but I think it's generally neater. * move copy-to-parent into `stop` this really just makes `start` and `stop` more symetric. It also means that it behaves correctly if you manually `set_log_context` rather than using the context manager. * Replace `LoggingContext.alive` with `finished` Turn `alive` into `finished` and make it a bit better defined. --- changelog.d/7120.misc | 1 + docs/log_contexts.md | 5 +- synapse/crypto/keyring.py | 4 +- synapse/federation/federation_base.py | 4 +- synapse/handlers/sync.py | 4 +- synapse/http/request_metrics.py | 6 +- synapse/logging/_structured.py | 4 +- synapse/logging/context.py | 234 +++++++++++---------- synapse/logging/scopecontextmanager.py | 13 +- synapse/storage/data_stores/main/events_worker.py | 4 +- synapse/storage/database.py | 11 +- synapse/util/metrics.py | 4 +- synapse/util/patch_inline_callbacks.py | 36 ++-- tests/crypto/test_keyring.py | 7 +- .../federation/test_matrix_federation_agent.py | 6 +- tests/http/federation/test_srv_resolver.py | 6 +- tests/http/test_fedclient.py | 6 +- tests/rest/client/test_transactions.py | 16 +- tests/unittest.py | 12 +- tests/util/caches/test_descriptors.py | 22 +- tests/util/test_async_utils.py | 15 +- tests/util/test_linearizer.py | 6 +- tests/util/test_logcontext.py | 22 +- tests/utils.py | 6 +- 24 files changed, 232 insertions(+), 222 deletions(-) create mode 100644 changelog.d/7120.misc (limited to 'synapse/storage/database.py') diff --git a/changelog.d/7120.misc b/changelog.d/7120.misc new file mode 100644 index 0000000000..731f4dcb52 --- /dev/null +++ b/changelog.d/7120.misc @@ -0,0 +1 @@ +Clean up some LoggingContext code. diff --git a/docs/log_contexts.md b/docs/log_contexts.md index 5331e8c88b..fe30ca2791 100644 --- a/docs/log_contexts.md +++ b/docs/log_contexts.md @@ -29,14 +29,13 @@ from synapse.logging import context # omitted from future snippets def handle_request(request_id): request_context = context.LoggingContext() - calling_context = context.LoggingContext.current_context() - context.LoggingContext.set_current_context(request_context) + calling_context = context.set_current_context(request_context) try: request_context.request = request_id do_request_handling() logger.debug("finished") finally: - context.LoggingContext.set_current_context(calling_context) + context.set_current_context(calling_context) def do_request_handling(): logger.debug("phew") # this will be logged against request_id diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 983f0ead8c..a9f4025bfe 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -43,8 +43,8 @@ from synapse.api.errors import ( SynapseError, ) from synapse.logging.context import ( - LoggingContext, PreserveLoggingContext, + current_context, make_deferred_yieldable, preserve_fn, run_in_background, @@ -236,7 +236,7 @@ class Keyring(object): """ try: - ctx = LoggingContext.current_context() + ctx = current_context() # map from server name to a set of outstanding request ids server_to_request_ids = {} diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index b0b0eba41e..4b115aac04 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -32,8 +32,8 @@ from synapse.events import EventBase, make_event_from_dict from synapse.events.utils import prune_event from synapse.http.servlet import assert_params_in_dict from synapse.logging.context import ( - LoggingContext, PreserveLoggingContext, + current_context, make_deferred_yieldable, ) from synapse.types import JsonDict, get_domain_from_id @@ -78,7 +78,7 @@ class FederationBase(object): """ deferreds = _check_sigs_on_pdus(self.keyring, room_version, pdus) - ctx = LoggingContext.current_context() + ctx = current_context() def callback(_, pdu: EventBase): with PreserveLoggingContext(ctx): diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 669dbc8a48..5746fdea14 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -26,7 +26,7 @@ from prometheus_client import Counter from synapse.api.constants import EventTypes, Membership from synapse.api.filtering import FilterCollection from synapse.events import EventBase -from synapse.logging.context import LoggingContext +from synapse.logging.context import current_context from synapse.push.clientformat import format_push_rules_for_user from synapse.storage.roommember import MemberSummary from synapse.storage.state import StateFilter @@ -301,7 +301,7 @@ class SyncHandler(object): else: sync_type = "incremental_sync" - context = LoggingContext.current_context() + context = current_context() if context: context.tag = sync_type diff --git a/synapse/http/request_metrics.py b/synapse/http/request_metrics.py index 58f9cc61c8..b58ae3d9db 100644 --- a/synapse/http/request_metrics.py +++ b/synapse/http/request_metrics.py @@ -19,7 +19,7 @@ import threading from prometheus_client.core import Counter, Histogram -from synapse.logging.context import LoggingContext +from synapse.logging.context import current_context from synapse.metrics import LaterGauge logger = logging.getLogger(__name__) @@ -148,7 +148,7 @@ LaterGauge( class RequestMetrics(object): def start(self, time_sec, name, method): self.start = time_sec - self.start_context = LoggingContext.current_context() + self.start_context = current_context() self.name = name self.method = method @@ -163,7 +163,7 @@ class RequestMetrics(object): with _in_flight_requests_lock: _in_flight_requests.discard(self) - context = LoggingContext.current_context() + context = current_context() tag = "" if context: diff --git a/synapse/logging/_structured.py b/synapse/logging/_structured.py index ffa7b20ca8..7372450b45 100644 --- a/synapse/logging/_structured.py +++ b/synapse/logging/_structured.py @@ -42,7 +42,7 @@ from synapse.logging._terse_json import ( TerseJSONToConsoleLogObserver, TerseJSONToTCPLogObserver, ) -from synapse.logging.context import LoggingContext +from synapse.logging.context import current_context def stdlib_log_level_to_twisted(level: str) -> LogLevel: @@ -86,7 +86,7 @@ class LogContextObserver(object): ].startswith("Timing out client"): return - context = LoggingContext.current_context() + context = current_context() # Copy the context information to the log event. if context is not None: diff --git a/synapse/logging/context.py b/synapse/logging/context.py index 860b99a4c6..a8eafb1c7c 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -175,7 +175,54 @@ class ContextResourceUsage(object): return res -LoggingContextOrSentinel = Union["LoggingContext", "LoggingContext.Sentinel"] +LoggingContextOrSentinel = Union["LoggingContext", "_Sentinel"] + + +class _Sentinel(object): + """Sentinel to represent the root context""" + + __slots__ = ["previous_context", "finished", "request", "scope", "tag"] + + def __init__(self) -> None: + # Minimal set for compatibility with LoggingContext + self.previous_context = None + self.finished = False + self.request = None + self.scope = None + self.tag = None + + def __str__(self): + return "sentinel" + + def copy_to(self, record): + pass + + def copy_to_twisted_log_entry(self, record): + record["request"] = None + record["scope"] = None + + def start(self): + pass + + def stop(self): + pass + + def add_database_transaction(self, duration_sec): + pass + + def add_database_scheduled(self, sched_sec): + pass + + def record_event_fetch(self, event_count): + pass + + def __nonzero__(self): + return False + + __bool__ = __nonzero__ # python3 + + +SENTINEL_CONTEXT = _Sentinel() class LoggingContext(object): @@ -199,76 +246,33 @@ class LoggingContext(object): "_resource_usage", "usage_start", "main_thread", - "alive", + "finished", "request", "tag", "scope", ] - thread_local = threading.local() - - class Sentinel(object): - """Sentinel to represent the root context""" - - __slots__ = ["previous_context", "alive", "request", "scope", "tag"] - - def __init__(self) -> None: - # Minimal set for compatibility with LoggingContext - self.previous_context = None - self.alive = None - self.request = None - self.scope = None - self.tag = None - - def __str__(self): - return "sentinel" - - def copy_to(self, record): - pass - - def copy_to_twisted_log_entry(self, record): - record["request"] = None - record["scope"] = None - - def start(self): - pass - - def stop(self): - pass - - def add_database_transaction(self, duration_sec): - pass - - def add_database_scheduled(self, sched_sec): - pass - - def record_event_fetch(self, event_count): - pass - - def __nonzero__(self): - return False - - __bool__ = __nonzero__ # python3 - - sentinel = Sentinel() - def __init__(self, name=None, parent_context=None, request=None) -> None: - self.previous_context = LoggingContext.current_context() + self.previous_context = current_context() self.name = name # track the resources used by this context so far self._resource_usage = ContextResourceUsage() - # If alive has the thread resource usage when the logcontext last - # became active. + # The thread resource usage when the logcontext became active. None + # if the context is not currently active. self.usage_start = None self.main_thread = get_thread_id() self.request = None self.tag = "" - self.alive = True self.scope = None # type: Optional[_LogContextScope] + # keep track of whether we have hit the __exit__ block for this context + # (suggesting that the the thing that created the context thinks it should + # be finished, and that re-activating it would suggest an error). + self.finished = False + self.parent_context = parent_context if self.parent_context is not None: @@ -283,44 +287,15 @@ class LoggingContext(object): return str(self.request) return "%s@%x" % (self.name, id(self)) - @classmethod - def current_context(cls) -> LoggingContextOrSentinel: - """Get the current logging context from thread local storage - - Returns: - LoggingContext: the current logging context - """ - return getattr(cls.thread_local, "current_context", cls.sentinel) - - @classmethod - def set_current_context( - cls, context: LoggingContextOrSentinel - ) -> LoggingContextOrSentinel: - """Set the current logging context in thread local storage - Args: - context(LoggingContext): The context to activate. - Returns: - The context that was previously active - """ - current = cls.current_context() - - if current is not context: - current.stop() - cls.thread_local.current_context = context - context.start() - return current - def __enter__(self) -> "LoggingContext": """Enters this logging context into thread local storage""" - old_context = self.set_current_context(self) + old_context = set_current_context(self) if self.previous_context != old_context: logger.warning( "Expected previous context %r, found %r", self.previous_context, old_context, ) - self.alive = True - return self def __exit__(self, type, value, traceback) -> None: @@ -329,24 +304,19 @@ class LoggingContext(object): Returns: None to avoid suppressing any exceptions that were thrown. """ - current = self.set_current_context(self.previous_context) + current = set_current_context(self.previous_context) if current is not self: - if current is self.sentinel: + if current is SENTINEL_CONTEXT: logger.warning("Expected logging context %s was lost", self) else: logger.warning( "Expected logging context %s but found %s", self, current ) - self.alive = False - - # if we have a parent, pass our CPU usage stats on - if self.parent_context is not None and hasattr( - self.parent_context, "_resource_usage" - ): - self.parent_context._resource_usage += self._resource_usage - # reset them in case we get entered again - self._resource_usage.reset() + # the fact that we are here suggests that the caller thinks that everything + # is done and dusted for this logcontext, and further activity will not get + # recorded against the correct metrics. + self.finished = True def copy_to(self, record) -> None: """Copy logging fields from this context to a log record or @@ -371,9 +341,14 @@ class LoggingContext(object): logger.warning("Started logcontext %s on different thread", self) return + if self.finished: + logger.warning("Re-starting finished log context %s", self) + # If we haven't already started record the thread resource usage so # far - if not self.usage_start: + if self.usage_start: + logger.warning("Re-starting already-active log context %s", self) + else: self.usage_start = get_thread_resource_usage() def stop(self) -> None: @@ -396,6 +371,15 @@ class LoggingContext(object): self.usage_start = None + # if we have a parent, pass our CPU usage stats on + if self.parent_context is not None and hasattr( + self.parent_context, "_resource_usage" + ): + self.parent_context._resource_usage += self._resource_usage + + # reset them in case we get entered again + self._resource_usage.reset() + def get_resource_usage(self) -> ContextResourceUsage: """Get resources used by this logcontext so far. @@ -409,7 +393,7 @@ class LoggingContext(object): # If we are on the correct thread and we're currently running then we # can include resource usage so far. is_main_thread = get_thread_id() == self.main_thread - if self.alive and self.usage_start and is_main_thread: + if self.usage_start and is_main_thread: utime_delta, stime_delta = self._get_cputime() res.ru_utime += utime_delta res.ru_stime += stime_delta @@ -492,7 +476,7 @@ class LoggingContextFilter(logging.Filter): Returns: True to include the record in the log output. """ - context = LoggingContext.current_context() + context = current_context() for key, value in self.defaults.items(): setattr(record, key, value) @@ -512,27 +496,24 @@ class PreserveLoggingContext(object): __slots__ = ["current_context", "new_context", "has_parent"] - def __init__(self, new_context: Optional[LoggingContextOrSentinel] = None) -> None: - if new_context is None: - self.new_context = LoggingContext.sentinel # type: LoggingContextOrSentinel - else: - self.new_context = new_context + def __init__( + self, new_context: LoggingContextOrSentinel = SENTINEL_CONTEXT + ) -> None: + self.new_context = new_context def __enter__(self) -> None: """Captures the current logging context""" - self.current_context = LoggingContext.set_current_context(self.new_context) + self.current_context = set_current_context(self.new_context) if self.current_context: self.has_parent = self.current_context.previous_context is not None - if not self.current_context.alive: - logger.debug("Entering dead context: %s", self.current_context) def __exit__(self, type, value, traceback) -> None: """Restores the current logging context""" - context = LoggingContext.set_current_context(self.current_context) + context = set_current_context(self.current_context) if context != self.new_context: - if context is LoggingContext.sentinel: + if not context: logger.warning("Expected logging context %s was lost", self.new_context) else: logger.warning( @@ -541,9 +522,30 @@ class PreserveLoggingContext(object): context, ) - if self.current_context is not LoggingContext.sentinel: - if not self.current_context.alive: - logger.debug("Restoring dead context: %s", self.current_context) + +_thread_local = threading.local() +_thread_local.current_context = SENTINEL_CONTEXT + + +def current_context() -> LoggingContextOrSentinel: + """Get the current logging context from thread local storage""" + return getattr(_thread_local, "current_context", SENTINEL_CONTEXT) + + +def set_current_context(context: LoggingContextOrSentinel) -> LoggingContextOrSentinel: + """Set the current logging context in thread local storage + Args: + context(LoggingContext): The context to activate. + Returns: + The context that was previously active + """ + current = current_context() + + if current is not context: + current.stop() + _thread_local.current_context = context + context.start() + return current def nested_logging_context( @@ -572,7 +574,7 @@ def nested_logging_context( if parent_context is not None: context = parent_context # type: LoggingContextOrSentinel else: - context = LoggingContext.current_context() + context = current_context() return LoggingContext( parent_context=context, request=str(context.request) + "-" + suffix ) @@ -604,7 +606,7 @@ def run_in_background(f, *args, **kwargs): CRITICAL error about an unhandled error will be logged without much indication about where it came from. """ - current = LoggingContext.current_context() + current = current_context() try: res = f(*args, **kwargs) except: # noqa: E722 @@ -625,7 +627,7 @@ def run_in_background(f, *args, **kwargs): # The function may have reset the context before returning, so # we need to restore it now. - ctx = LoggingContext.set_current_context(current) + ctx = set_current_context(current) # The original context will be restored when the deferred # completes, but there is nothing waiting for it, so it will @@ -674,7 +676,7 @@ def make_deferred_yieldable(deferred): # ok, we can't be sure that a yield won't block, so let's reset the # logcontext, and add a callback to the deferred to restore it. - prev_context = LoggingContext.set_current_context(LoggingContext.sentinel) + prev_context = set_current_context(SENTINEL_CONTEXT) deferred.addBoth(_set_context_cb, prev_context) return deferred @@ -684,7 +686,7 @@ ResultT = TypeVar("ResultT") def _set_context_cb(result: ResultT, context: LoggingContext) -> ResultT: """A callback function which just sets the logging context""" - LoggingContext.set_current_context(context) + set_current_context(context) return result @@ -752,7 +754,7 @@ def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs): Deferred: A Deferred which fires a callback with the result of `f`, or an errback if `f` throws an exception. """ - logcontext = LoggingContext.current_context() + logcontext = current_context() def g(): with LoggingContext(parent_context=logcontext): diff --git a/synapse/logging/scopecontextmanager.py b/synapse/logging/scopecontextmanager.py index 4eed4f2338..dc3ab00cbb 100644 --- a/synapse/logging/scopecontextmanager.py +++ b/synapse/logging/scopecontextmanager.py @@ -19,7 +19,7 @@ from opentracing import Scope, ScopeManager import twisted -from synapse.logging.context import LoggingContext, nested_logging_context +from synapse.logging.context import current_context, nested_logging_context logger = logging.getLogger(__name__) @@ -49,11 +49,8 @@ class LogContextScopeManager(ScopeManager): (Scope) : the Scope that is active, or None if not available. """ - ctx = LoggingContext.current_context() - if ctx is LoggingContext.sentinel: - return None - else: - return ctx.scope + ctx = current_context() + return ctx.scope def activate(self, span, finish_on_close): """ @@ -70,9 +67,9 @@ class LogContextScopeManager(ScopeManager): """ enter_logcontext = False - ctx = LoggingContext.current_context() + ctx = current_context() - if ctx is LoggingContext.sentinel: + if not ctx: # We don't want this scope to affect. logger.error("Tried to activate scope outside of loggingcontext") return Scope(None, span) diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index ca237c6f12..3013f49d32 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -35,7 +35,7 @@ from synapse.api.room_versions import ( ) from synapse.events import make_event_from_dict from synapse.events.utils import prune_event -from synapse.logging.context import LoggingContext, PreserveLoggingContext +from synapse.logging.context import PreserveLoggingContext, current_context from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause from synapse.storage.database import Database @@ -409,7 +409,7 @@ class EventsWorkerStore(SQLBaseStore): missing_events_ids = [e for e in event_ids if e not in event_entry_map] if missing_events_ids: - log_ctx = LoggingContext.current_context() + log_ctx = current_context() log_ctx.record_event_fetch(len(missing_events_ids)) # Note that _get_events_from_db is also responsible for turning db rows diff --git a/synapse/storage/database.py b/synapse/storage/database.py index e61595336c..715c0346dd 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -32,6 +32,7 @@ from synapse.config.database import DatabaseConnectionConfig from synapse.logging.context import ( LoggingContext, LoggingContextOrSentinel, + current_context, make_deferred_yieldable, ) from synapse.metrics.background_process_metrics import run_as_background_process @@ -483,7 +484,7 @@ class Database(object): end = monotonic_time() duration = end - start - LoggingContext.current_context().add_database_transaction(duration) + current_context().add_database_transaction(duration) transaction_logger.debug("[TXN END] {%s} %f sec", name, duration) @@ -510,7 +511,7 @@ class Database(object): after_callbacks = [] # type: List[_CallbackListEntry] exception_callbacks = [] # type: List[_CallbackListEntry] - if LoggingContext.current_context() == LoggingContext.sentinel: + if not current_context(): logger.warning("Starting db txn '%s' from sentinel context", desc) try: @@ -547,10 +548,8 @@ class Database(object): Returns: Deferred: The result of func """ - parent_context = ( - LoggingContext.current_context() - ) # type: Optional[LoggingContextOrSentinel] - if parent_context == LoggingContext.sentinel: + parent_context = current_context() # type: Optional[LoggingContextOrSentinel] + if not parent_context: logger.warning( "Starting db connection from sentinel context: metrics will be lost" ) diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index 7b18455469..ec61e14423 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -21,7 +21,7 @@ from prometheus_client import Counter from twisted.internet import defer -from synapse.logging.context import LoggingContext +from synapse.logging.context import LoggingContext, current_context from synapse.metrics import InFlightGauge logger = logging.getLogger(__name__) @@ -106,7 +106,7 @@ class Measure(object): raise RuntimeError("Measure() objects cannot be re-used") self.start = self.clock.time() - parent_context = LoggingContext.current_context() + parent_context = current_context() self._logging_context = LoggingContext( "Measure[%s]" % (self.name,), parent_context ) diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py index 3925927f9f..fdff195771 100644 --- a/synapse/util/patch_inline_callbacks.py +++ b/synapse/util/patch_inline_callbacks.py @@ -32,7 +32,7 @@ def do_patch(): Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit """ - from synapse.logging.context import LoggingContext + from synapse.logging.context import current_context global _already_patched @@ -43,35 +43,35 @@ def do_patch(): def new_inline_callbacks(f): @functools.wraps(f) def wrapped(*args, **kwargs): - start_context = LoggingContext.current_context() + start_context = current_context() changes = [] # type: List[str] orig = orig_inline_callbacks(_check_yield_points(f, changes)) try: res = orig(*args, **kwargs) except Exception: - if LoggingContext.current_context() != start_context: + if current_context() != start_context: for err in changes: print(err, file=sys.stderr) err = "%s changed context from %s to %s on exception" % ( f, start_context, - LoggingContext.current_context(), + current_context(), ) print(err, file=sys.stderr) raise Exception(err) raise if not isinstance(res, Deferred) or res.called: - if LoggingContext.current_context() != start_context: + if current_context() != start_context: for err in changes: print(err, file=sys.stderr) err = "Completed %s changed context from %s to %s" % ( f, start_context, - LoggingContext.current_context(), + current_context(), ) # print the error to stderr because otherwise all we # see in travis-ci is the 500 error @@ -79,23 +79,23 @@ def do_patch(): raise Exception(err) return res - if LoggingContext.current_context() != LoggingContext.sentinel: + if current_context(): err = ( "%s returned incomplete deferred in non-sentinel context " "%s (start was %s)" - ) % (f, LoggingContext.current_context(), start_context) + ) % (f, current_context(), start_context) print(err, file=sys.stderr) raise Exception(err) def check_ctx(r): - if LoggingContext.current_context() != start_context: + if current_context() != start_context: for err in changes: print(err, file=sys.stderr) err = "%s completion of %s changed context from %s to %s" % ( "Failure" if isinstance(r, Failure) else "Success", f, start_context, - LoggingContext.current_context(), + current_context(), ) print(err, file=sys.stderr) raise Exception(err) @@ -127,7 +127,7 @@ def _check_yield_points(f: Callable, changes: List[str]): function """ - from synapse.logging.context import LoggingContext + from synapse.logging.context import current_context @functools.wraps(f) def check_yield_points_inner(*args, **kwargs): @@ -136,7 +136,7 @@ def _check_yield_points(f: Callable, changes: List[str]): last_yield_line_no = gen.gi_frame.f_lineno result = None # type: Any while True: - expected_context = LoggingContext.current_context() + expected_context = current_context() try: isFailure = isinstance(result, Failure) @@ -145,7 +145,7 @@ def _check_yield_points(f: Callable, changes: List[str]): else: d = gen.send(result) except (StopIteration, defer._DefGen_Return) as e: - if LoggingContext.current_context() != expected_context: + if current_context() != expected_context: # This happens when the context is lost sometime *after* the # final yield and returning. E.g. we forgot to yield on a # function that returns a deferred. @@ -159,7 +159,7 @@ def _check_yield_points(f: Callable, changes: List[str]): % ( f.__qualname__, expected_context, - LoggingContext.current_context(), + current_context(), f.__code__.co_filename, last_yield_line_no, ) @@ -173,13 +173,13 @@ def _check_yield_points(f: Callable, changes: List[str]): # This happens if we yield on a deferred that doesn't follow # the log context rules without wrapping in a `make_deferred_yieldable`. # We raise here as this should never happen. - if LoggingContext.current_context() is not LoggingContext.sentinel: + if current_context(): err = ( "%s yielded with context %s rather than sentinel," " yielded on line %d in %s" % ( frame.f_code.co_name, - LoggingContext.current_context(), + current_context(), frame.f_lineno, frame.f_code.co_filename, ) @@ -191,7 +191,7 @@ def _check_yield_points(f: Callable, changes: List[str]): except Exception as e: result = Failure(e) - if LoggingContext.current_context() != expected_context: + if current_context() != expected_context: # This happens because the context is lost sometime *after* the # previous yield and *after* the current yield. E.g. the @@ -206,7 +206,7 @@ def _check_yield_points(f: Callable, changes: List[str]): % ( frame.f_code.co_name, expected_context, - LoggingContext.current_context(), + current_context(), last_yield_line_no, frame.f_lineno, frame.f_code.co_filename, diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 34d5895f18..70c8e72303 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -34,6 +34,7 @@ from synapse.crypto.keyring import ( from synapse.logging.context import ( LoggingContext, PreserveLoggingContext, + current_context, make_deferred_yieldable, ) from synapse.storage.keys import FetchKeyResult @@ -83,9 +84,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): ) def check_context(self, _, expected): - self.assertEquals( - getattr(LoggingContext.current_context(), "request", None), expected - ) + self.assertEquals(getattr(current_context(), "request", None), expected) def test_verify_json_objects_for_server_awaits_previous_requests(self): key1 = signedjson.key.generate_signing_key(1) @@ -105,7 +104,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): @defer.inlineCallbacks def get_perspectives(**kwargs): - self.assertEquals(LoggingContext.current_context().request, "11") + self.assertEquals(current_context().request, "11") with PreserveLoggingContext(): yield persp_deferred return persp_resp diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index fdc1d918ff..562397cdda 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -38,7 +38,7 @@ from synapse.http.federation.well_known_resolver import ( WellKnownResolver, _cache_period_from_headers, ) -from synapse.logging.context import LoggingContext +from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context from synapse.util.caches.ttlcache import TTLCache from tests import unittest @@ -155,7 +155,7 @@ class MatrixFederationAgentTests(unittest.TestCase): self.assertNoResult(fetch_d) # should have reset logcontext to the sentinel - _check_logcontext(LoggingContext.sentinel) + _check_logcontext(SENTINEL_CONTEXT) try: fetch_res = yield fetch_d @@ -1197,7 +1197,7 @@ class TestCachePeriodFromHeaders(unittest.TestCase): def _check_logcontext(context): - current = LoggingContext.current_context() + current = current_context() if current is not context: raise AssertionError("Expected logcontext %s but was %s" % (context, current)) diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py index df034ab237..babc201643 100644 --- a/tests/http/federation/test_srv_resolver.py +++ b/tests/http/federation/test_srv_resolver.py @@ -22,7 +22,7 @@ from twisted.internet.error import ConnectError from twisted.names import dns, error from synapse.http.federation.srv_resolver import SrvResolver -from synapse.logging.context import LoggingContext +from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context from tests import unittest from tests.utils import MockClock @@ -54,12 +54,12 @@ class SrvResolverTestCase(unittest.TestCase): self.assertNoResult(resolve_d) # should have reset to the sentinel context - self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel) + self.assertIs(current_context(), SENTINEL_CONTEXT) result = yield resolve_d # should have restored our context - self.assertIs(LoggingContext.current_context(), ctx) + self.assertIs(current_context(), ctx) return result diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py index 2b01f40a42..fff4f0cbf4 100644 --- a/tests/http/test_fedclient.py +++ b/tests/http/test_fedclient.py @@ -29,14 +29,14 @@ from synapse.http.matrixfederationclient import ( MatrixFederationHttpClient, MatrixFederationRequest, ) -from synapse.logging.context import LoggingContext +from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context from tests.server import FakeTransport from tests.unittest import HomeserverTestCase def check_logcontext(context): - current = LoggingContext.current_context() + current = current_context() if current is not context: raise AssertionError("Expected logcontext %s but was %s" % (context, current)) @@ -64,7 +64,7 @@ class FederationClientTests(HomeserverTestCase): self.assertNoResult(fetch_d) # should have reset logcontext to the sentinel - check_logcontext(LoggingContext.sentinel) + check_logcontext(SENTINEL_CONTEXT) try: fetch_res = yield fetch_d diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py index a3d7e3c046..171632e195 100644 --- a/tests/rest/client/test_transactions.py +++ b/tests/rest/client/test_transactions.py @@ -2,7 +2,7 @@ from mock import Mock, call from twisted.internet import defer, reactor -from synapse.logging.context import LoggingContext +from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context from synapse.rest.client.transactions import CLEANUP_PERIOD_MS, HttpTransactionCache from synapse.util import Clock @@ -52,14 +52,14 @@ class HttpTransactionCacheTestCase(unittest.TestCase): def test(): with LoggingContext("c") as c1: res = yield self.cache.fetch_or_execute(self.mock_key, cb) - self.assertIs(LoggingContext.current_context(), c1) + self.assertIs(current_context(), c1) self.assertEqual(res, "yay") # run the test twice in parallel d = defer.gatherResults([test(), test()]) - self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel) + self.assertIs(current_context(), SENTINEL_CONTEXT) yield d - self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel) + self.assertIs(current_context(), SENTINEL_CONTEXT) @defer.inlineCallbacks def test_does_not_cache_exceptions(self): @@ -81,11 +81,11 @@ class HttpTransactionCacheTestCase(unittest.TestCase): yield self.cache.fetch_or_execute(self.mock_key, cb) except Exception as e: self.assertEqual(e.args[0], "boo") - self.assertIs(LoggingContext.current_context(), test_context) + self.assertIs(current_context(), test_context) res = yield self.cache.fetch_or_execute(self.mock_key, cb) self.assertEqual(res, self.mock_http_response) - self.assertIs(LoggingContext.current_context(), test_context) + self.assertIs(current_context(), test_context) @defer.inlineCallbacks def test_does_not_cache_failures(self): @@ -107,11 +107,11 @@ class HttpTransactionCacheTestCase(unittest.TestCase): yield self.cache.fetch_or_execute(self.mock_key, cb) except Exception as e: self.assertEqual(e.args[0], "boo") - self.assertIs(LoggingContext.current_context(), test_context) + self.assertIs(current_context(), test_context) res = yield self.cache.fetch_or_execute(self.mock_key, cb) self.assertEqual(res, self.mock_http_response) - self.assertIs(LoggingContext.current_context(), test_context) + self.assertIs(current_context(), test_context) @defer.inlineCallbacks def test_cleans_up(self): diff --git a/tests/unittest.py b/tests/unittest.py index 8816a4d152..439174dbfc 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -38,7 +38,11 @@ from synapse.config.ratelimiting import FederationRateLimitConfig from synapse.federation.transport import server as federation_server from synapse.http.server import JsonResource from synapse.http.site import SynapseRequest, SynapseSite -from synapse.logging.context import LoggingContext +from synapse.logging.context import ( + SENTINEL_CONTEXT, + current_context, + set_current_context, +) from synapse.server import HomeServer from synapse.types import Requester, UserID, create_requester from synapse.util.ratelimitutils import FederationRateLimiter @@ -97,10 +101,10 @@ class TestCase(unittest.TestCase): def setUp(orig): # if we're not starting in the sentinel logcontext, then to be honest # all future bets are off. - if LoggingContext.current_context() is not LoggingContext.sentinel: + if current_context(): self.fail( "Test starting with non-sentinel logging context %s" - % (LoggingContext.current_context(),) + % (current_context(),) ) old_level = logging.getLogger().level @@ -122,7 +126,7 @@ class TestCase(unittest.TestCase): # force a GC to workaround problems with deferreds leaking logcontexts when # they are GCed (see the logcontext docs) gc.collect() - LoggingContext.set_current_context(LoggingContext.sentinel) + set_current_context(SENTINEL_CONTEXT) return ret diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 39e360fe24..4d2b9e0d64 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -22,8 +22,10 @@ from twisted.internet import defer, reactor from synapse.api.errors import SynapseError from synapse.logging.context import ( + SENTINEL_CONTEXT, LoggingContext, PreserveLoggingContext, + current_context, make_deferred_yieldable, ) from synapse.util.caches import descriptors @@ -194,7 +196,7 @@ class DescriptorTestCase(unittest.TestCase): with LoggingContext() as c1: c1.name = "c1" r = yield obj.fn(1) - self.assertEqual(LoggingContext.current_context(), c1) + self.assertEqual(current_context(), c1) return r def check_result(r): @@ -204,12 +206,12 @@ class DescriptorTestCase(unittest.TestCase): # set off a deferred which will do a cache lookup d1 = do_lookup() - self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel) + self.assertEqual(current_context(), SENTINEL_CONTEXT) d1.addCallback(check_result) # and another d2 = do_lookup() - self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel) + self.assertEqual(current_context(), SENTINEL_CONTEXT) d2.addCallback(check_result) # let the lookup complete @@ -239,14 +241,14 @@ class DescriptorTestCase(unittest.TestCase): try: d = obj.fn(1) self.assertEqual( - LoggingContext.current_context(), LoggingContext.sentinel + current_context(), SENTINEL_CONTEXT, ) yield d self.fail("No exception thrown") except SynapseError: pass - self.assertEqual(LoggingContext.current_context(), c1) + self.assertEqual(current_context(), c1) # the cache should now be empty self.assertEqual(len(obj.fn.cache.cache), 0) @@ -255,7 +257,7 @@ class DescriptorTestCase(unittest.TestCase): # set off a deferred which will do a cache lookup d1 = do_lookup() - self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel) + self.assertEqual(current_context(), SENTINEL_CONTEXT) return d1 @@ -366,10 +368,10 @@ class CachedListDescriptorTestCase(unittest.TestCase): @descriptors.cachedList("fn", "args1", inlineCallbacks=True) def list_fn(self, args1, arg2): - assert LoggingContext.current_context().request == "c1" + assert current_context().request == "c1" # we want this to behave like an asynchronous function yield run_on_reactor() - assert LoggingContext.current_context().request == "c1" + assert current_context().request == "c1" return self.mock(args1, arg2) with LoggingContext() as c1: @@ -377,9 +379,9 @@ class CachedListDescriptorTestCase(unittest.TestCase): obj = Cls() obj.mock.return_value = {10: "fish", 20: "chips"} d1 = obj.list_fn([10, 20], 2) - self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel) + self.assertEqual(current_context(), SENTINEL_CONTEXT) r = yield d1 - self.assertEqual(LoggingContext.current_context(), c1) + self.assertEqual(current_context(), c1) obj.mock.assert_called_once_with([10, 20], 2) self.assertEqual(r, {10: "fish", 20: "chips"}) obj.mock.reset_mock() diff --git a/tests/util/test_async_utils.py b/tests/util/test_async_utils.py index f60918069a..17fd86d02d 100644 --- a/tests/util/test_async_utils.py +++ b/tests/util/test_async_utils.py @@ -16,7 +16,12 @@ from twisted.internet import defer from twisted.internet.defer import CancelledError, Deferred from twisted.internet.task import Clock -from synapse.logging.context import LoggingContext, PreserveLoggingContext +from synapse.logging.context import ( + SENTINEL_CONTEXT, + LoggingContext, + PreserveLoggingContext, + current_context, +) from synapse.util.async_helpers import timeout_deferred from tests.unittest import TestCase @@ -79,10 +84,10 @@ class TimeoutDeferredTest(TestCase): # the errbacks should be run in the test logcontext def errback(res, deferred_name): self.assertIs( - LoggingContext.current_context(), + current_context(), context_one, "errback %s run in unexpected logcontext %s" - % (deferred_name, LoggingContext.current_context()), + % (deferred_name, current_context()), ) return res @@ -90,7 +95,7 @@ class TimeoutDeferredTest(TestCase): original_deferred.addErrback(errback, "orig") timing_out_d = timeout_deferred(original_deferred, 1.0, self.clock) self.assertNoResult(timing_out_d) - self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel) + self.assertIs(current_context(), SENTINEL_CONTEXT) timing_out_d.addErrback(errback, "timingout") self.clock.pump((1.0,)) @@ -99,4 +104,4 @@ class TimeoutDeferredTest(TestCase): blocking_was_cancelled[0], "non-completing deferred was not cancelled" ) self.failureResultOf(timing_out_d, defer.TimeoutError) - self.assertIs(LoggingContext.current_context(), context_one) + self.assertIs(current_context(), context_one) diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py index 0ec8ef90ce..852ef23185 100644 --- a/tests/util/test_linearizer.py +++ b/tests/util/test_linearizer.py @@ -19,7 +19,7 @@ from six.moves import range from twisted.internet import defer, reactor from twisted.internet.defer import CancelledError -from synapse.logging.context import LoggingContext +from synapse.logging.context import LoggingContext, current_context from synapse.util import Clock from synapse.util.async_helpers import Linearizer @@ -54,11 +54,11 @@ class LinearizerTestCase(unittest.TestCase): def func(i, sleep=False): with LoggingContext("func(%s)" % i) as lc: with (yield linearizer.queue("")): - self.assertEqual(LoggingContext.current_context(), lc) + self.assertEqual(current_context(), lc) if sleep: yield Clock(reactor).sleep(0) - self.assertEqual(LoggingContext.current_context(), lc) + self.assertEqual(current_context(), lc) func(0, sleep=True) for i in range(1, 100): diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py index 281b32c4b8..95301c013c 100644 --- a/tests/util/test_logcontext.py +++ b/tests/util/test_logcontext.py @@ -2,8 +2,10 @@ import twisted.python.failure from twisted.internet import defer, reactor from synapse.logging.context import ( + SENTINEL_CONTEXT, LoggingContext, PreserveLoggingContext, + current_context, make_deferred_yieldable, nested_logging_context, run_in_background, @@ -15,7 +17,7 @@ from .. import unittest class LoggingContextTestCase(unittest.TestCase): def _check_test_key(self, value): - self.assertEquals(LoggingContext.current_context().request, value) + self.assertEquals(current_context().request, value) def test_with_context(self): with LoggingContext() as context_one: @@ -41,7 +43,7 @@ class LoggingContextTestCase(unittest.TestCase): self._check_test_key("one") def _test_run_in_background(self, function): - sentinel_context = LoggingContext.current_context() + sentinel_context = current_context() callback_completed = [False] @@ -71,7 +73,7 @@ class LoggingContextTestCase(unittest.TestCase): # make sure that the context was reset before it got thrown back # into the reactor try: - self.assertIs(LoggingContext.current_context(), sentinel_context) + self.assertIs(current_context(), sentinel_context) d2.callback(None) except BaseException: d2.errback(twisted.python.failure.Failure()) @@ -108,7 +110,7 @@ class LoggingContextTestCase(unittest.TestCase): async def testfunc(): self._check_test_key("one") d = Clock(reactor).sleep(0) - self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel) + self.assertIs(current_context(), SENTINEL_CONTEXT) await d self._check_test_key("one") @@ -129,14 +131,14 @@ class LoggingContextTestCase(unittest.TestCase): reactor.callLater(0, d.callback, None) return d - sentinel_context = LoggingContext.current_context() + sentinel_context = current_context() with LoggingContext() as context_one: context_one.request = "one" d1 = make_deferred_yieldable(blocking_function()) # make sure that the context was reset by make_deferred_yieldable - self.assertIs(LoggingContext.current_context(), sentinel_context) + self.assertIs(current_context(), sentinel_context) yield d1 @@ -145,14 +147,14 @@ class LoggingContextTestCase(unittest.TestCase): @defer.inlineCallbacks def test_make_deferred_yieldable_with_chained_deferreds(self): - sentinel_context = LoggingContext.current_context() + sentinel_context = current_context() with LoggingContext() as context_one: context_one.request = "one" d1 = make_deferred_yieldable(_chained_deferred_function()) # make sure that the context was reset by make_deferred_yieldable - self.assertIs(LoggingContext.current_context(), sentinel_context) + self.assertIs(current_context(), sentinel_context) yield d1 @@ -189,14 +191,14 @@ class LoggingContextTestCase(unittest.TestCase): reactor.callLater(0, d.callback, None) await d - sentinel_context = LoggingContext.current_context() + sentinel_context = current_context() with LoggingContext() as context_one: context_one.request = "one" d1 = make_deferred_yieldable(blocking_function()) # make sure that the context was reset by make_deferred_yieldable - self.assertIs(LoggingContext.current_context(), sentinel_context) + self.assertIs(current_context(), sentinel_context) yield d1 diff --git a/tests/utils.py b/tests/utils.py index 513f358f4f..968d109f77 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -35,7 +35,7 @@ from synapse.config.homeserver import HomeServerConfig from synapse.config.server import DEFAULT_ROOM_VERSION from synapse.federation.transport import server as federation_server from synapse.http.server import HttpServer -from synapse.logging.context import LoggingContext +from synapse.logging.context import current_context, set_current_context from synapse.server import HomeServer from synapse.storage import DataStore from synapse.storage.engines import PostgresEngine, create_engine @@ -493,10 +493,10 @@ class MockClock(object): return self.time() * 1000 def call_later(self, delay, callback, *args, **kwargs): - current_context = LoggingContext.current_context() + ctx = current_context() def wrapped_callback(): - LoggingContext.thread_local.current_context = current_context + set_current_context(ctx) callback(*args, **kwargs) t = [self.now + delay, wrapped_callback, False] -- cgit 1.4.1 From f31e65a749f84f8b3278c91784509d908d4fb342 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 7 Apr 2020 23:06:39 +0100 Subject: bg update to clear out duplicate outbound_device_list_pokes (#7193) We seem to have some duplicates, which could do with being cleared out. --- changelog.d/7193.misc | 1 + synapse/storage/data_stores/main/client_ips.py | 16 ++--- synapse/storage/data_stores/main/devices.py | 73 ++++++++++++++++++- .../delta/58/02remove_dup_outbound_pokes.sql | 22 ++++++ synapse/storage/database.py | 83 +++++++++++++++++++++- tests/storage/test_database.py | 52 ++++++++++++++ 6 files changed, 234 insertions(+), 13 deletions(-) create mode 100644 changelog.d/7193.misc create mode 100644 synapse/storage/data_stores/main/schema/delta/58/02remove_dup_outbound_pokes.sql create mode 100644 tests/storage/test_database.py (limited to 'synapse/storage/database.py') diff --git a/changelog.d/7193.misc b/changelog.d/7193.misc new file mode 100644 index 0000000000..383a738e64 --- /dev/null +++ b/changelog.d/7193.misc @@ -0,0 +1 @@ +Add a background database update job to clear out duplicate `device_lists_outbound_pokes`. diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py index e1ccb27142..92bc06919b 100644 --- a/synapse/storage/data_stores/main/client_ips.py +++ b/synapse/storage/data_stores/main/client_ips.py @@ -21,7 +21,7 @@ from twisted.internet import defer from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore -from synapse.storage.database import Database +from synapse.storage.database import Database, make_tuple_comparison_clause from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.caches.descriptors import Cache @@ -303,16 +303,10 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): # we'll just end up updating the same device row multiple # times, which is fine. - if self.database_engine.supports_tuple_comparison: - where_clause = "(user_id, device_id) > (?, ?)" - where_args = [last_user_id, last_device_id] - else: - # We explicitly do a `user_id >= ? AND (...)` here to ensure - # that an index is used, as doing `user_id > ? OR (user_id = ? AND ...)` - # makes it hard for query optimiser to tell that it can use the - # index on user_id - where_clause = "user_id >= ? AND (user_id > ? OR device_id > ?)" - where_args = [last_user_id, last_user_id, last_device_id] + where_clause, where_args = make_tuple_comparison_clause( + self.database_engine, + [("user_id", last_user_id), ("device_id", last_device_id)], + ) sql = """ SELECT diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index 4c5bea4a5c..ee3a2ab031 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -32,7 +32,11 @@ from synapse.logging.opentracing import ( ) from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.database import Database, LoggingTransaction +from synapse.storage.database import ( + Database, + LoggingTransaction, + make_tuple_comparison_clause, +) from synapse.types import Collection, get_verify_key_from_cross_signing_key from synapse.util.caches.descriptors import ( Cache, @@ -49,6 +53,8 @@ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = ( "drop_device_list_streams_non_unique_indexes" ) +BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes" + class DeviceWorkerStore(SQLBaseStore): def get_device(self, user_id, device_id): @@ -714,6 +720,11 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): self._drop_device_list_streams_non_unique_indexes, ) + # clear out duplicate device list outbound pokes + self.db.updates.register_background_update_handler( + BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, self._remove_duplicate_outbound_pokes, + ) + @defer.inlineCallbacks def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size): def f(conn): @@ -728,6 +739,66 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): ) return 1 + async def _remove_duplicate_outbound_pokes(self, progress, batch_size): + # for some reason, we have accumulated duplicate entries in + # device_lists_outbound_pokes, which makes prune_outbound_device_list_pokes less + # efficient. + # + # For each duplicate, we delete all the existing rows and put one back. + + KEY_COLS = ["stream_id", "destination", "user_id", "device_id"] + last_row = progress.get( + "last_row", + {"stream_id": 0, "destination": "", "user_id": "", "device_id": ""}, + ) + + def _txn(txn): + clause, args = make_tuple_comparison_clause( + self.db.engine, [(x, last_row[x]) for x in KEY_COLS] + ) + sql = """ + SELECT stream_id, destination, user_id, device_id, MAX(ts) AS ts + FROM device_lists_outbound_pokes + WHERE %s + GROUP BY %s + HAVING count(*) > 1 + ORDER BY %s + LIMIT ? + """ % ( + clause, # WHERE + ",".join(KEY_COLS), # GROUP BY + ",".join(KEY_COLS), # ORDER BY + ) + txn.execute(sql, args + [batch_size]) + rows = self.db.cursor_to_dict(txn) + + row = None + for row in rows: + self.db.simple_delete_txn( + txn, "device_lists_outbound_pokes", {x: row[x] for x in KEY_COLS}, + ) + + row["sent"] = False + self.db.simple_insert_txn( + txn, "device_lists_outbound_pokes", row, + ) + + if row: + self.db.updates._background_update_progress_txn( + txn, BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, {"last_row": row}, + ) + + return len(rows) + + rows = await self.db.runInteraction(BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, _txn) + + if not rows: + await self.db.updates._end_background_update( + BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES + ) + + return rows + class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): def __init__(self, database: Database, db_conn, hs): diff --git a/synapse/storage/data_stores/main/schema/delta/58/02remove_dup_outbound_pokes.sql b/synapse/storage/data_stores/main/schema/delta/58/02remove_dup_outbound_pokes.sql new file mode 100644 index 0000000000..fdc39e9ba5 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/58/02remove_dup_outbound_pokes.sql @@ -0,0 +1,22 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + /* for some reason, we have accumulated duplicate entries in + * device_lists_outbound_pokes, which makes prune_outbound_device_list_pokes less + * efficient. + */ + +INSERT INTO background_updates (ordering, update_name, progress_json) + VALUES (5800, 'remove_dup_outbound_pokes', '{}'); diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 715c0346dd..a7cd97b0b0 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -17,7 +17,17 @@ import logging import time from time import monotonic as monotonic_time -from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple +from typing import ( + Any, + Callable, + Dict, + Iterable, + Iterator, + List, + Optional, + Tuple, + TypeVar, +) from six import iteritems, iterkeys, itervalues from six.moves import intern, range @@ -1557,3 +1567,74 @@ def make_in_list_sql_clause( return "%s = ANY(?)" % (column,), [list(iterable)] else: return "%s IN (%s)" % (column, ",".join("?" for _ in iterable)), list(iterable) + + +KV = TypeVar("KV") + + +def make_tuple_comparison_clause( + database_engine: BaseDatabaseEngine, keys: List[Tuple[str, KV]] +) -> Tuple[str, List[KV]]: + """Returns a tuple comparison SQL clause + + Depending what the SQL engine supports, builds a SQL clause that looks like either + "(a, b) > (?, ?)", or "(a > ?) OR (a == ? AND b > ?)". + + Args: + database_engine + keys: A set of (column, value) pairs to be compared. + + Returns: + A tuple of SQL query and the args + """ + if database_engine.supports_tuple_comparison: + return ( + "(%s) > (%s)" % (",".join(k[0] for k in keys), ",".join("?" for _ in keys)), + [k[1] for k in keys], + ) + + # we want to build a clause + # (a > ?) OR + # (a == ? AND b > ?) OR + # (a == ? AND b == ? AND c > ?) + # ... + # (a == ? AND b == ? AND ... AND z > ?) + # + # or, equivalently: + # + # (a > ? OR (a == ? AND + # (b > ? OR (b == ? AND + # ... + # (y > ? OR (y == ? AND + # z > ? + # )) + # ... + # )) + # )) + # + # which itself is equivalent to (and apparently easier for the query optimiser): + # + # (a >= ? AND (a > ? OR + # (b >= ? AND (b > ? OR + # ... + # (y >= ? AND (y > ? OR + # z > ? + # )) + # ... + # )) + # )) + # + # + + clause = "" + args = [] # type: List[KV] + for k, v in keys[:-1]: + clause = clause + "(%s >= ? AND (%s > ? OR " % (k, k) + args.extend([v, v]) + + (k, v) = keys[-1] + clause += "%s > ?" % (k,) + args.append(v) + + clause += "))" * (len(keys) - 1) + return clause, args diff --git a/tests/storage/test_database.py b/tests/storage/test_database.py new file mode 100644 index 0000000000..5a77c84962 --- /dev/null +++ b/tests/storage/test_database.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.storage.database import make_tuple_comparison_clause +from synapse.storage.engines import BaseDatabaseEngine + +from tests import unittest + + +def _stub_db_engine(**kwargs) -> BaseDatabaseEngine: + # returns a DatabaseEngine, circumventing the abc mechanism + # any kwargs are set as attributes on the class before instantiating it + t = type( + "TestBaseDatabaseEngine", + (BaseDatabaseEngine,), + dict(BaseDatabaseEngine.__dict__), + ) + # defeat the abc mechanism + t.__abstractmethods__ = set() + for k, v in kwargs.items(): + setattr(t, k, v) + return t(None, None) + + +class TupleComparisonClauseTestCase(unittest.TestCase): + def test_native_tuple_comparison(self): + db_engine = _stub_db_engine(supports_tuple_comparison=True) + clause, args = make_tuple_comparison_clause(db_engine, [("a", 1), ("b", 2)]) + self.assertEqual(clause, "(a,b) > (?,?)") + self.assertEqual(args, [1, 2]) + + def test_emulated_tuple_comparison(self): + db_engine = _stub_db_engine(supports_tuple_comparison=False) + clause, args = make_tuple_comparison_clause( + db_engine, [("a", 1), ("b", 2), ("c", 3)] + ) + self.assertEqual( + clause, "(a >= ? AND (a > ? OR (b >= ? AND (b > ? OR c > ?))))" + ) + self.assertEqual(args, [1, 1, 2, 2, 3]) -- cgit 1.4.1 From 0f6ebf393d49a3ab8e0b723026ac58c6aea1d51d Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 6 May 2020 01:08:15 +0100 Subject: Better type annotations for simple_upsert_txn most of these params don't really need to be lists. --- synapse/storage/database.py | 73 ++++++++++++++++++++++++++------------------- 1 file changed, 43 insertions(+), 30 deletions(-) (limited to 'synapse/storage/database.py') diff --git a/synapse/storage/database.py b/synapse/storage/database.py index a7cd97b0b0..f66880cbba 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -49,6 +49,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.background_updates import BackgroundUpdater from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.storage.types import Connection, Cursor +from synapse.types import Collection from synapse.util.stringutils import exception_to_unicode logger = logging.getLogger(__name__) @@ -889,20 +890,24 @@ class Database(object): txn.execute(sql, list(allvalues.values())) def simple_upsert_many_txn( - self, txn, table, key_names, key_values, value_names, value_values - ): + self, + txn: LoggingTransaction, + table: str, + key_names: Collection[str], + key_values: Collection[Iterable[Any]], + value_names: Collection[str], + value_values: Iterable[Iterable[str]], + ) -> None: """ Upsert, many times. Args: - table (str): The table to upsert into - key_names (list[str]): The key column names. - key_values (list[list]): A list of each row's key column values. - value_names (list[str]): The value column names. If empty, no - values will be used, even if value_values is provided. - value_values (list[list]): A list of each row's value column values. - Returns: - None + table: The table to upsert into + key_names: The key column names. + key_values: A list of each row's key column values. + value_names: The value column names + value_values: A list of each row's value column values. + Ignored if value_names is empty. """ if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables: return self.simple_upsert_many_txn_native_upsert( @@ -914,20 +919,24 @@ class Database(object): ) def simple_upsert_many_txn_emulated( - self, txn, table, key_names, key_values, value_names, value_values - ): + self, + txn: LoggingTransaction, + table: str, + key_names: Iterable[str], + key_values: Collection[Iterable[Any]], + value_names: Collection[str], + value_values: Iterable[Iterable[str]], + ) -> None: """ Upsert, many times, but without native UPSERT support or batching. Args: - table (str): The table to upsert into - key_names (list[str]): The key column names. - key_values (list[list]): A list of each row's key column values. - value_names (list[str]): The value column names. If empty, no - values will be used, even if value_values is provided. - value_values (list[list]): A list of each row's value column values. - Returns: - None + table: The table to upsert into + key_names: The key column names. + key_values: A list of each row's key column values. + value_names: The value column names + value_values: A list of each row's value column values. + Ignored if value_names is empty. """ # No value columns, therefore make a blank list so that the following # zip() works correctly. @@ -941,20 +950,24 @@ class Database(object): self.simple_upsert_txn_emulated(txn, table, _keys, _vals) def simple_upsert_many_txn_native_upsert( - self, txn, table, key_names, key_values, value_names, value_values - ): + self, + txn: LoggingTransaction, + table: str, + key_names: Collection[str], + key_values: Collection[Iterable[Any]], + value_names: Collection[str], + value_values: Iterable[Iterable[Any]], + ) -> None: """ Upsert, many times, using batching where possible. Args: - table (str): The table to upsert into - key_names (list[str]): The key column names. - key_values (list[list]): A list of each row's key column values. - value_names (list[str]): The value column names. If empty, no - values will be used, even if value_values is provided. - value_values (list[list]): A list of each row's value column values. - Returns: - None + table: The table to upsert into + key_names: The key column names. + key_values: A list of each row's key column values. + value_names: The value column names + value_values: A list of each row's value column values. + Ignored if value_names is empty. """ allnames = [] # type: List[str] allnames.extend(key_names) -- cgit 1.4.1 From e48361545de14be61a1a25096c7fb4b90828ed51 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 6 May 2020 01:16:53 +0100 Subject: use an upsert to update device_lists_outbound_last_success --- changelog.d/7429.misc | 1 + synapse/storage/data_stores/main/devices.py | 60 +++++++++++++++------- ...vice_lists_outbound_last_success_unique_idx.sql | 28 ++++++++++ synapse/storage/database.py | 1 + 4 files changed, 72 insertions(+), 18 deletions(-) create mode 100644 changelog.d/7429.misc create mode 100644 synapse/storage/data_stores/main/schema/delta/58/04device_lists_outbound_last_success_unique_idx.sql (limited to 'synapse/storage/database.py') diff --git a/changelog.d/7429.misc b/changelog.d/7429.misc new file mode 100644 index 0000000000..3c25cd9917 --- /dev/null +++ b/changelog.d/7429.misc @@ -0,0 +1 @@ +Improve performance of `mark_as_sent_devices_by_remote`. diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index ee3a2ab031..536cef3abd 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -55,6 +55,10 @@ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = ( BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes" +BG_UPDATE_DROP_DEVICE_LISTS_OUTBOUND_LAST_SUCCESS_NON_UNIQUE_IDX = ( + "drop_device_lists_outbound_last_success_non_unique_idx" +) + class DeviceWorkerStore(SQLBaseStore): def get_device(self, user_id, device_id): @@ -342,32 +346,23 @@ class DeviceWorkerStore(SQLBaseStore): def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id): # We update the device_lists_outbound_last_success with the successfully - # poked users. We do the join to see which users need to be inserted and - # which updated. + # poked users. sql = """ - SELECT user_id, coalesce(max(o.stream_id), 0), (max(s.stream_id) IS NOT NULL) + SELECT user_id, coalesce(max(o.stream_id), 0) FROM device_lists_outbound_pokes as o - LEFT JOIN device_lists_outbound_last_success as s - USING (destination, user_id) WHERE destination = ? AND o.stream_id <= ? GROUP BY user_id """ txn.execute(sql, (destination, stream_id)) rows = txn.fetchall() - sql = """ - UPDATE device_lists_outbound_last_success - SET stream_id = ? - WHERE destination = ? AND user_id = ? - """ - txn.executemany(sql, ((row[1], destination, row[0]) for row in rows if row[2])) - - sql = """ - INSERT INTO device_lists_outbound_last_success - (destination, user_id, stream_id) VALUES (?, ?, ?) - """ - txn.executemany( - sql, ((destination, row[0], row[1]) for row in rows if not row[2]) + self.db.simple_upsert_many_txn( + txn=txn, + table="device_lists_outbound_last_success", + key_names=("destination", "user_id"), + key_values=((destination, user_id) for user_id, _ in rows), + value_names=("stream_id",), + value_values=((stream_id,) for _, stream_id in rows), ) # Delete all sent outbound pokes @@ -725,6 +720,21 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, self._remove_duplicate_outbound_pokes, ) + # create a unique index on device_lists_outbound_last_success + self.db.updates.register_background_index_update( + "device_lists_outbound_last_success_unique_idx", + index_name="device_lists_outbound_last_success_unique_idx", + table="device_lists_outbound_last_success", + columns=["destination", "user_id"], + unique=True, + ) + + # once that completes, we can remove the old non-unique index. + self.db.updates.register_background_update_handler( + BG_UPDATE_DROP_DEVICE_LISTS_OUTBOUND_LAST_SUCCESS_NON_UNIQUE_IDX, + self._drop_device_lists_outbound_last_success_non_unique_idx, + ) + @defer.inlineCallbacks def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size): def f(conn): @@ -799,6 +809,20 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): return rows + async def _drop_device_lists_outbound_last_success_non_unique_idx( + self, progress, batch_size + ): + def f(txn): + txn.execute("DROP INDEX IF EXISTS device_lists_outbound_last_success_idx") + + await self.db.runInteraction( + "drop_device_lists_outbound_last_success_non_unique_idx", f, + ) + await self.db.updates._end_background_update( + BG_UPDATE_DROP_DEVICE_LISTS_OUTBOUND_LAST_SUCCESS_NON_UNIQUE_IDX + ) + return 1 + class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): def __init__(self, database: Database, db_conn, hs): diff --git a/synapse/storage/data_stores/main/schema/delta/58/04device_lists_outbound_last_success_unique_idx.sql b/synapse/storage/data_stores/main/schema/delta/58/04device_lists_outbound_last_success_unique_idx.sql new file mode 100644 index 0000000000..d5e6deb878 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/58/04device_lists_outbound_last_success_unique_idx.sql @@ -0,0 +1,28 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- register a background update which will create a unique index on +-- device_lists_outbound_last_success +INSERT into background_updates (ordering, update_name, progress_json) + VALUES (5804, 'device_lists_outbound_last_success_unique_idx', '{}'); + +-- once that completes, we can drop the old index. +INSERT into background_updates (ordering, update_name, progress_json, depends_on) + VALUES ( + 5804, + 'drop_device_lists_outbound_last_success_non_unique_idx', + '{}', + 'device_lists_outbound_last_success_unique_idx' + ); diff --git a/synapse/storage/database.py b/synapse/storage/database.py index f66880cbba..2b635d6ca0 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -79,6 +79,7 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = { "device_lists_remote_extremeties": "device_lists_remote_extremeties_unique_idx", "device_lists_remote_cache": "device_lists_remote_cache_unique_idx", "event_search": "event_search_event_id_idx", + "device_lists_outbound_last_success": "device_lists_outbound_last_success_unique_idx", } -- cgit 1.4.1 From 1a1da60ad2c9172fe487cd38a164b39df60f4cb5 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 12 May 2020 11:20:48 +0100 Subject: Fix new flake8 errors (#7470) --- changelog.d/7470.misc | 1 + synapse/app/_base.py | 5 +++-- synapse/config/server.py | 2 +- synapse/notifier.py | 10 ++++++---- synapse/push/mailer.py | 7 +++++-- synapse/storage/database.py | 4 ++-- tests/config/test_load.py | 2 +- 7 files changed, 19 insertions(+), 12 deletions(-) create mode 100644 changelog.d/7470.misc (limited to 'synapse/storage/database.py') diff --git a/changelog.d/7470.misc b/changelog.d/7470.misc new file mode 100644 index 0000000000..45e66ecf48 --- /dev/null +++ b/changelog.d/7470.misc @@ -0,0 +1 @@ +Fix linting errors in new version of Flake8. diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 628292b890..dedff81af3 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -22,6 +22,7 @@ import sys import traceback from daemonize import Daemonize +from typing_extensions import NoReturn from twisted.internet import defer, error, reactor from twisted.protocols.tls import TLSMemoryBIOFactory @@ -139,9 +140,9 @@ def start_reactor( run() -def quit_with_error(error_string): +def quit_with_error(error_string: str) -> NoReturn: message_lines = error_string.split("\n") - line_length = max(len(l) for l in message_lines if len(l) < 80) + 2 + line_length = max(len(line) for line in message_lines if len(line) < 80) + 2 sys.stderr.write("*" * line_length + "\n") for line in message_lines: sys.stderr.write(" %s\n" % (line.rstrip(),)) diff --git a/synapse/config/server.py b/synapse/config/server.py index 6d88231843..ed28da3deb 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -522,7 +522,7 @@ class ServerConfig(Config): ) def has_tls_listener(self) -> bool: - return any(l["tls"] for l in self.listeners) + return any(listener["tls"] for listener in self.listeners) def generate_config_section( self, server_name, data_dir_path, open_private_ports, listeners, **kwargs diff --git a/synapse/notifier.py b/synapse/notifier.py index 71d9ed62b0..87c120a59c 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -15,7 +15,7 @@ import logging from collections import namedtuple -from typing import Callable, List +from typing import Callable, Iterable, List, TypeVar from prometheus_client import Counter @@ -42,12 +42,14 @@ users_woken_by_stream_counter = Counter( "synapse_notifier_users_woken_by_stream", "", ["stream"] ) +T = TypeVar("T") + # TODO(paul): Should be shared somewhere -def count(func, l): - """Return the number of items in l for which func returns true.""" +def count(func: Callable[[T], bool], it: Iterable[T]) -> int: + """Return the number of items in it for which func returns true.""" n = 0 - for x in l: + for x in it: if func(x): n += 1 return n diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 73580c1c6c..ab33abbeed 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -19,6 +19,7 @@ import logging import time from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText +from typing import Iterable, List, TypeVar from six.moves import urllib @@ -41,6 +42,8 @@ from synapse.visibility import filter_events_for_client logger = logging.getLogger(__name__) +T = TypeVar("T") + MESSAGE_FROM_PERSON_IN_ROOM = ( "You have a message on %(app)s from %(person)s in the %(room)s room..." @@ -638,10 +641,10 @@ def safe_text(raw_text): ) -def deduped_ordered_list(l): +def deduped_ordered_list(it: Iterable[T]) -> List[T]: seen = set() ret = [] - for item in l: + for item in it: if item not in seen: seen.add(item) ret.append(item) diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 2b635d6ca0..c3d0863429 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -214,9 +214,9 @@ class LoggingTransaction: def executemany(self, sql: str, *args: Any): self._do_execute(self.txn.executemany, sql, *args) - def _make_sql_one_line(self, sql): + def _make_sql_one_line(self, sql: str) -> str: "Strip newlines out of SQL so that the loggers in the DB are on one line" - return " ".join(l.strip() for l in sql.splitlines() if l.strip()) + return " ".join(line.strip() for line in sql.splitlines() if line.strip()) def _do_execute(self, func, sql, *args): sql = self._make_sql_one_line(sql) diff --git a/tests/config/test_load.py b/tests/config/test_load.py index b3e557bd6a..734a9983e8 100644 --- a/tests/config/test_load.py +++ b/tests/config/test_load.py @@ -122,7 +122,7 @@ class ConfigLoadingTestCase(unittest.TestCase): with open(self.file, "r") as f: contents = f.readlines() - contents = [l for l in contents if needle not in l] + contents = [line for line in contents if needle not in line] with open(self.file, "w") as f: f.write("".join(contents)) -- cgit 1.4.1 From 08fa96f03037178620f5f0dd609fac52fbf7f2d1 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 15 May 2020 19:07:24 +0100 Subject: Remove `exception_to_unicode` this is a no-op on python 3. --- synapse/storage/database.py | 15 +++------------ synapse/util/stringutils.py | 36 ------------------------------------ 2 files changed, 3 insertions(+), 48 deletions(-) (limited to 'synapse/storage/database.py') diff --git a/synapse/storage/database.py b/synapse/storage/database.py index c3d0863429..9947dbce77 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -50,7 +50,6 @@ from synapse.storage.background_updates import BackgroundUpdater from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.storage.types import Connection, Cursor from synapse.types import Collection -from synapse.util.stringutils import exception_to_unicode logger = logging.getLogger(__name__) @@ -424,20 +423,14 @@ class Database(object): # This can happen if the database disappears mid # transaction. logger.warning( - "[TXN OPERROR] {%s} %s %d/%d", - name, - exception_to_unicode(e), - i, - N, + "[TXN OPERROR] {%s} %s %d/%d", name, e, i, N, ) if i < N: i += 1 try: conn.rollback() except self.engine.module.Error as e1: - logger.warning( - "[TXN EROLL] {%s} %s", name, exception_to_unicode(e1) - ) + logger.warning("[TXN EROLL] {%s} %s", name, e1) continue raise except self.engine.module.DatabaseError as e: @@ -449,9 +442,7 @@ class Database(object): conn.rollback() except self.engine.module.Error as e1: logger.warning( - "[TXN EROLL] {%s} %s", - name, - exception_to_unicode(e1), + "[TXN EROLL] {%s} %s", name, e1, ) continue raise diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index 6899bcb788..2cfa5cf721 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -85,42 +85,6 @@ def to_ascii(s): return s -def exception_to_unicode(e): - """Helper function to extract the text of an exception as a unicode string - - Args: - e (Exception): exception to be stringified - - Returns: - unicode - """ - # urgh, this is a mess. The basic problem here is that psycopg2 constructs its - # exceptions with PyErr_SetString, with a (possibly non-ascii) argument. str() will - # then produce the raw byte sequence. Under Python 2, this will then cause another - # error if it gets mixed with a `unicode` object, as per - # https://github.com/matrix-org/synapse/issues/4252 - - # First of all, if we're under python3, everything is fine because it will sort this - # nonsense out for us. - if not PY2: - return str(e) - - # otherwise let's have a stab at decoding the exception message. We'll circumvent - # Exception.__str__(), which would explode if someone raised Exception(u'non-ascii') - # and instead look at what is in the args member. - - if len(e.args) == 0: - return "" - elif len(e.args) > 1: - return six.text_type(repr(e.args)) - - msg = e.args[0] - if isinstance(msg, bytes): - return msg.decode("utf-8", errors="replace") - else: - return msg - - def assert_valid_client_secret(client_secret): """Validate that a given string matches the client_secret regex defined by the spec""" if client_secret_regex.match(client_secret) is None: -- cgit 1.4.1 From edd9a7214c467e96f5a694598b2fbcfae3ac2912 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 26 May 2020 11:43:17 +0100 Subject: Replace device_27_unique_idx bg update with a fg one (#7562) The bg update never managed to complete, because it kept being interrupted by transactions which want to take a lock. Just doing it in the foreground isn't that bad, and is a good deal simpler. --- UPGRADE.rst | 12 +++- changelog.d/7562.misc | 1 + synapse/storage/data_stores/main/devices.py | 34 ++------- ...vice_lists_outbound_last_success_unique_idx.sql | 28 -------- .../main/schema/delta/58/06dlols_unique_idx.py | 80 ++++++++++++++++++++++ synapse/storage/database.py | 1 - synapse/storage/prepare_database.py | 13 ++-- 7 files changed, 104 insertions(+), 65 deletions(-) create mode 100644 changelog.d/7562.misc delete mode 100644 synapse/storage/data_stores/main/schema/delta/58/04device_lists_outbound_last_success_unique_idx.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/58/06dlols_unique_idx.py (limited to 'synapse/storage/database.py') diff --git a/UPGRADE.rst b/UPGRADE.rst index 41c47e964d..3b5627e852 100644 --- a/UPGRADE.rst +++ b/UPGRADE.rst @@ -75,9 +75,15 @@ for example: wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb -Upgrading to v1.13.0 +Upgrading to v1.14.0 ==================== +This version includes a database update which is run as part of the upgrade, +and which may take a couple of minutes in the case of a large server. Synapse +will not respond to HTTP requests while this update is taking place. + +Upgrading to v1.13.0 +==================== Incorrect database migration in old synapse versions ---------------------------------------------------- @@ -136,12 +142,12 @@ back to v1.12.4 you need to: 2. Decrease the schema version in the database: .. code:: sql - + UPDATE schema_version SET version = 57; 3. Downgrade Synapse by following the instructions for your installation method in the "Rolling back to older versions" section above. - + Upgrading to v1.12.0 ==================== diff --git a/changelog.d/7562.misc b/changelog.d/7562.misc new file mode 100644 index 0000000000..3c25cd9917 --- /dev/null +++ b/changelog.d/7562.misc @@ -0,0 +1 @@ +Improve performance of `mark_as_sent_devices_by_remote`. diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index 417ac8dc7c..fb9f798e29 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -55,10 +55,6 @@ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = ( BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes" -BG_UPDATE_DROP_DEVICE_LISTS_OUTBOUND_LAST_SUCCESS_NON_UNIQUE_IDX = ( - "drop_device_lists_outbound_last_success_non_unique_idx" -) - class DeviceWorkerStore(SQLBaseStore): def get_device(self, user_id, device_id): @@ -749,19 +745,13 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, self._remove_duplicate_outbound_pokes, ) - # create a unique index on device_lists_outbound_last_success - self.db.updates.register_background_index_update( + # a pair of background updates that were added during the 1.14 release cycle, + # but replaced with 58/06dlols_unique_idx.py + self.db.updates.register_noop_background_update( "device_lists_outbound_last_success_unique_idx", - index_name="device_lists_outbound_last_success_unique_idx", - table="device_lists_outbound_last_success", - columns=["destination", "user_id"], - unique=True, ) - - # once that completes, we can remove the old non-unique index. - self.db.updates.register_background_update_handler( - BG_UPDATE_DROP_DEVICE_LISTS_OUTBOUND_LAST_SUCCESS_NON_UNIQUE_IDX, - self._drop_device_lists_outbound_last_success_non_unique_idx, + self.db.updates.register_noop_background_update( + "drop_device_lists_outbound_last_success_non_unique_idx", ) @defer.inlineCallbacks @@ -838,20 +828,6 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): return rows - async def _drop_device_lists_outbound_last_success_non_unique_idx( - self, progress, batch_size - ): - def f(txn): - txn.execute("DROP INDEX IF EXISTS device_lists_outbound_last_success_idx") - - await self.db.runInteraction( - "drop_device_lists_outbound_last_success_non_unique_idx", f, - ) - await self.db.updates._end_background_update( - BG_UPDATE_DROP_DEVICE_LISTS_OUTBOUND_LAST_SUCCESS_NON_UNIQUE_IDX - ) - return 1 - class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): def __init__(self, database: Database, db_conn, hs): diff --git a/synapse/storage/data_stores/main/schema/delta/58/04device_lists_outbound_last_success_unique_idx.sql b/synapse/storage/data_stores/main/schema/delta/58/04device_lists_outbound_last_success_unique_idx.sql deleted file mode 100644 index d5e6deb878..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/58/04device_lists_outbound_last_success_unique_idx.sql +++ /dev/null @@ -1,28 +0,0 @@ -/* Copyright 2020 The Matrix.org Foundation C.I.C - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- register a background update which will create a unique index on --- device_lists_outbound_last_success -INSERT into background_updates (ordering, update_name, progress_json) - VALUES (5804, 'device_lists_outbound_last_success_unique_idx', '{}'); - --- once that completes, we can drop the old index. -INSERT into background_updates (ordering, update_name, progress_json, depends_on) - VALUES ( - 5804, - 'drop_device_lists_outbound_last_success_non_unique_idx', - '{}', - 'device_lists_outbound_last_success_unique_idx' - ); diff --git a/synapse/storage/data_stores/main/schema/delta/58/06dlols_unique_idx.py b/synapse/storage/data_stores/main/schema/delta/58/06dlols_unique_idx.py new file mode 100644 index 0000000000..d353f2bcb3 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/58/06dlols_unique_idx.py @@ -0,0 +1,80 @@ +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This migration rebuilds the device_lists_outbound_last_success table without duplicate +entries, and with a UNIQUE index. +""" + +import logging +from io import StringIO + +from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine +from synapse.storage.prepare_database import execute_statements_from_stream +from synapse.storage.types import Cursor + +logger = logging.getLogger(__name__) + + +def run_upgrade(*args, **kwargs): + pass + + +def run_create(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs): + # some instances might already have this index, in which case we can skip this + if isinstance(database_engine, PostgresEngine): + cur.execute( + """ + SELECT 1 FROM pg_class WHERE relkind = 'i' + AND relname = 'device_lists_outbound_last_success_unique_idx' + """ + ) + + if cur.rowcount: + logger.info( + "Unique index exists on device_lists_outbound_last_success: " + "skipping rebuild" + ) + return + + logger.info("Rebuilding device_lists_outbound_last_success with unique index") + execute_statements_from_stream(cur, StringIO(_rebuild_commands)) + + +# there might be duplicates, so the easiest way to achieve this is to create a new +# table with the right data, and renaming it into place + +_rebuild_commands = """ +DROP TABLE IF EXISTS device_lists_outbound_last_success_new; + +CREATE TABLE device_lists_outbound_last_success_new ( + destination TEXT NOT NULL, + user_id TEXT NOT NULL, + stream_id BIGINT NOT NULL +); + +-- this took about 30 seconds on matrix.org's 16 million rows. +INSERT INTO device_lists_outbound_last_success_new + SELECT destination, user_id, MAX(stream_id) FROM device_lists_outbound_last_success + GROUP BY destination, user_id; + +-- and this another 30 seconds. +CREATE UNIQUE INDEX device_lists_outbound_last_success_unique_idx + ON device_lists_outbound_last_success_new (destination, user_id); + +DROP TABLE device_lists_outbound_last_success; + +ALTER TABLE device_lists_outbound_last_success_new + RENAME TO device_lists_outbound_last_success; +""" diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 9947dbce77..b112ff3df2 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -78,7 +78,6 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = { "device_lists_remote_extremeties": "device_lists_remote_extremeties_unique_idx", "device_lists_remote_cache": "device_lists_remote_cache_unique_idx", "event_search": "event_search_event_id_idx", - "device_lists_outbound_last_success": "device_lists_outbound_last_success_unique_idx", } diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 640f242584..9afc145340 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -19,10 +19,12 @@ import logging import os import re from collections import Counter +from typing import TextIO import attr from synapse.storage.engines.postgres import PostgresEngine +from synapse.storage.types import Cursor logger = logging.getLogger(__name__) @@ -479,8 +481,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams) ) logger.info("applying schema %s for %s", name, modname) - for statement in get_statements(stream): - cur.execute(statement) + execute_statements_from_stream(cur, stream) # Mark as done. cur.execute( @@ -538,8 +539,12 @@ def get_statements(f): def executescript(txn, schema_path): with open(schema_path, "r") as f: - for statement in get_statements(f): - txn.execute(statement) + execute_statements_from_stream(txn, f) + + +def execute_statements_from_stream(cur: Cursor, f: TextIO): + for statement in get_statements(f): + cur.execute(statement) def _get_or_create_schema_state(txn, database_engine): -- cgit 1.4.1