From 8123b2f90945e6bb3d65cf48730c2133b46d2b9a Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 4 May 2020 17:17:45 +0100 Subject: Add MultiWriterIdGenerator. (#7281) This will be used to coordinate stream IDs across multiple writers. Functions as the equivalent of both `StreamIdGenerator` and `SlavedIdTracker`. --- synapse/storage/util/id_generators.py | 169 +++++++++++++++++++++++++++++++++- 1 file changed, 167 insertions(+), 2 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 9d851beaa5..86d04ea9ac 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -16,6 +16,11 @@ import contextlib import threading from collections import deque +from typing import Dict, Set, Tuple + +from typing_extensions import Deque + +from synapse.storage.database import Database, LoggingTransaction class IdGenerator(object): @@ -87,7 +92,7 @@ class StreamIdGenerator(object): self._current = (max if step > 0 else min)( self._current, _load_current_id(db_conn, table, column, step) ) - self._unfinished_ids = deque() + self._unfinished_ids = deque() # type: Deque[int] def get_next(self): """ @@ -163,7 +168,7 @@ class ChainedIdGenerator(object): self.chained_generator = chained_generator self._lock = threading.Lock() self._current_max = _load_current_id(db_conn, table, column) - self._unfinished_ids = deque() + self._unfinished_ids = deque() # type: Deque[Tuple[int, int]] def get_next(self): """ @@ -198,3 +203,163 @@ class ChainedIdGenerator(object): return stream_id - 1, chained_id return self._current_max, self.chained_generator.get_current_token() + + +class MultiWriterIdGenerator: + """An ID generator that tracks a stream that can have multiple writers. + + Uses a Postgres sequence to coordinate ID assignment, but positions of other + writers will only get updated when `advance` is called (by replication). + + Note: Only works with Postgres. + + Args: + db_conn + db + instance_name: The name of this instance. + table: Database table associated with stream. + instance_column: Column that stores the row's writer's instance name + id_column: Column that stores the stream ID. + sequence_name: The name of the postgres sequence used to generate new + IDs. + """ + + def __init__( + self, + db_conn, + db: Database, + instance_name: str, + table: str, + instance_column: str, + id_column: str, + sequence_name: str, + ): + self._db = db + self._instance_name = instance_name + self._sequence_name = sequence_name + + # We lock as some functions may be called from DB threads. + self._lock = threading.Lock() + + self._current_positions = self._load_current_ids( + db_conn, table, instance_column, id_column + ) + + # Set of local IDs that we're still processing. The current position + # should be less than the minimum of this set (if not empty). + self._unfinished_ids = set() # type: Set[int] + + def _load_current_ids( + self, db_conn, table: str, instance_column: str, id_column: str + ) -> Dict[str, int]: + sql = """ + SELECT %(instance)s, MAX(%(id)s) FROM %(table)s + GROUP BY %(instance)s + """ % { + "instance": instance_column, + "id": id_column, + "table": table, + } + + cur = db_conn.cursor() + cur.execute(sql) + + # `cur` is an iterable over returned rows, which are 2-tuples. + current_positions = dict(cur) + + cur.close() + + return current_positions + + def _load_next_id_txn(self, txn): + txn.execute("SELECT nextval(?)", (self._sequence_name,)) + (next_id,) = txn.fetchone() + return next_id + + async def get_next(self): + """ + Usage: + with await stream_id_gen.get_next() as stream_id: + # ... persist event ... + """ + next_id = await self._db.runInteraction("_load_next_id", self._load_next_id_txn) + + # Assert the fetched ID is actually greater than what we currently + # believe the ID to be. If not, then the sequence and table have got + # out of sync somehow. + assert self.get_current_token() < next_id + + with self._lock: + self._unfinished_ids.add(next_id) + + @contextlib.contextmanager + def manager(): + try: + yield next_id + finally: + self._mark_id_as_finished(next_id) + + return manager() + + def get_next_txn(self, txn: LoggingTransaction): + """ + Usage: + + stream_id = stream_id_gen.get_next(txn) + # ... persist event ... + """ + + next_id = self._load_next_id_txn(txn) + + with self._lock: + self._unfinished_ids.add(next_id) + + txn.call_after(self._mark_id_as_finished, next_id) + txn.call_on_exception(self._mark_id_as_finished, next_id) + + return next_id + + def _mark_id_as_finished(self, next_id: int): + """The ID has finished being processed so we should advance the + current poistion if possible. + """ + + with self._lock: + self._unfinished_ids.discard(next_id) + + # Figure out if its safe to advance the position by checking there + # aren't any lower allocated IDs that are yet to finish. + if all(c > next_id for c in self._unfinished_ids): + curr = self._current_positions.get(self._instance_name, 0) + self._current_positions[self._instance_name] = max(curr, next_id) + + def get_current_token(self, instance_name: str = None) -> int: + """Gets the current position of a named writer (defaults to current + instance). + + Returns 0 if we don't have a position for the named writer (likely due + to it being a new writer). + """ + + if instance_name is None: + instance_name = self._instance_name + + with self._lock: + return self._current_positions.get(instance_name, 0) + + def get_positions(self) -> Dict[str, int]: + """Get a copy of the current positon map. + """ + + with self._lock: + return dict(self._current_positions) + + def advance(self, instance_name: str, new_id: int): + """Advance the postion of the named writer to the given ID, if greater + than existing entry. + """ + + with self._lock: + self._current_positions[instance_name] = max( + new_id, self._current_positions.get(instance_name, 0) + ) -- cgit 1.5.1 From 0f6ebf393d49a3ab8e0b723026ac58c6aea1d51d Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 6 May 2020 01:08:15 +0100 Subject: Better type annotations for simple_upsert_txn most of these params don't really need to be lists. --- synapse/storage/database.py | 73 ++++++++++++++++++++++++++------------------- 1 file changed, 43 insertions(+), 30 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/database.py b/synapse/storage/database.py index a7cd97b0b0..f66880cbba 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -49,6 +49,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.background_updates import BackgroundUpdater from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.storage.types import Connection, Cursor +from synapse.types import Collection from synapse.util.stringutils import exception_to_unicode logger = logging.getLogger(__name__) @@ -889,20 +890,24 @@ class Database(object): txn.execute(sql, list(allvalues.values())) def simple_upsert_many_txn( - self, txn, table, key_names, key_values, value_names, value_values - ): + self, + txn: LoggingTransaction, + table: str, + key_names: Collection[str], + key_values: Collection[Iterable[Any]], + value_names: Collection[str], + value_values: Iterable[Iterable[str]], + ) -> None: """ Upsert, many times. Args: - table (str): The table to upsert into - key_names (list[str]): The key column names. - key_values (list[list]): A list of each row's key column values. - value_names (list[str]): The value column names. If empty, no - values will be used, even if value_values is provided. - value_values (list[list]): A list of each row's value column values. - Returns: - None + table: The table to upsert into + key_names: The key column names. + key_values: A list of each row's key column values. + value_names: The value column names + value_values: A list of each row's value column values. + Ignored if value_names is empty. """ if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables: return self.simple_upsert_many_txn_native_upsert( @@ -914,20 +919,24 @@ class Database(object): ) def simple_upsert_many_txn_emulated( - self, txn, table, key_names, key_values, value_names, value_values - ): + self, + txn: LoggingTransaction, + table: str, + key_names: Iterable[str], + key_values: Collection[Iterable[Any]], + value_names: Collection[str], + value_values: Iterable[Iterable[str]], + ) -> None: """ Upsert, many times, but without native UPSERT support or batching. Args: - table (str): The table to upsert into - key_names (list[str]): The key column names. - key_values (list[list]): A list of each row's key column values. - value_names (list[str]): The value column names. If empty, no - values will be used, even if value_values is provided. - value_values (list[list]): A list of each row's value column values. - Returns: - None + table: The table to upsert into + key_names: The key column names. + key_values: A list of each row's key column values. + value_names: The value column names + value_values: A list of each row's value column values. + Ignored if value_names is empty. """ # No value columns, therefore make a blank list so that the following # zip() works correctly. @@ -941,20 +950,24 @@ class Database(object): self.simple_upsert_txn_emulated(txn, table, _keys, _vals) def simple_upsert_many_txn_native_upsert( - self, txn, table, key_names, key_values, value_names, value_values - ): + self, + txn: LoggingTransaction, + table: str, + key_names: Collection[str], + key_values: Collection[Iterable[Any]], + value_names: Collection[str], + value_values: Iterable[Iterable[Any]], + ) -> None: """ Upsert, many times, using batching where possible. Args: - table (str): The table to upsert into - key_names (list[str]): The key column names. - key_values (list[list]): A list of each row's key column values. - value_names (list[str]): The value column names. If empty, no - values will be used, even if value_values is provided. - value_values (list[list]): A list of each row's value column values. - Returns: - None + table: The table to upsert into + key_names: The key column names. + key_values: A list of each row's key column values. + value_names: The value column names + value_values: A list of each row's value column values. + Ignored if value_names is empty. """ allnames = [] # type: List[str] allnames.extend(key_names) -- cgit 1.5.1 From e48361545de14be61a1a25096c7fb4b90828ed51 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 6 May 2020 01:16:53 +0100 Subject: use an upsert to update device_lists_outbound_last_success --- changelog.d/7429.misc | 1 + synapse/storage/data_stores/main/devices.py | 60 +++++++++++++++------- ...vice_lists_outbound_last_success_unique_idx.sql | 28 ++++++++++ synapse/storage/database.py | 1 + 4 files changed, 72 insertions(+), 18 deletions(-) create mode 100644 changelog.d/7429.misc create mode 100644 synapse/storage/data_stores/main/schema/delta/58/04device_lists_outbound_last_success_unique_idx.sql (limited to 'synapse/storage') diff --git a/changelog.d/7429.misc b/changelog.d/7429.misc new file mode 100644 index 0000000000..3c25cd9917 --- /dev/null +++ b/changelog.d/7429.misc @@ -0,0 +1 @@ +Improve performance of `mark_as_sent_devices_by_remote`. diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index ee3a2ab031..536cef3abd 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -55,6 +55,10 @@ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = ( BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes" +BG_UPDATE_DROP_DEVICE_LISTS_OUTBOUND_LAST_SUCCESS_NON_UNIQUE_IDX = ( + "drop_device_lists_outbound_last_success_non_unique_idx" +) + class DeviceWorkerStore(SQLBaseStore): def get_device(self, user_id, device_id): @@ -342,32 +346,23 @@ class DeviceWorkerStore(SQLBaseStore): def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id): # We update the device_lists_outbound_last_success with the successfully - # poked users. We do the join to see which users need to be inserted and - # which updated. + # poked users. sql = """ - SELECT user_id, coalesce(max(o.stream_id), 0), (max(s.stream_id) IS NOT NULL) + SELECT user_id, coalesce(max(o.stream_id), 0) FROM device_lists_outbound_pokes as o - LEFT JOIN device_lists_outbound_last_success as s - USING (destination, user_id) WHERE destination = ? AND o.stream_id <= ? GROUP BY user_id """ txn.execute(sql, (destination, stream_id)) rows = txn.fetchall() - sql = """ - UPDATE device_lists_outbound_last_success - SET stream_id = ? - WHERE destination = ? AND user_id = ? - """ - txn.executemany(sql, ((row[1], destination, row[0]) for row in rows if row[2])) - - sql = """ - INSERT INTO device_lists_outbound_last_success - (destination, user_id, stream_id) VALUES (?, ?, ?) - """ - txn.executemany( - sql, ((destination, row[0], row[1]) for row in rows if not row[2]) + self.db.simple_upsert_many_txn( + txn=txn, + table="device_lists_outbound_last_success", + key_names=("destination", "user_id"), + key_values=((destination, user_id) for user_id, _ in rows), + value_names=("stream_id",), + value_values=((stream_id,) for _, stream_id in rows), ) # Delete all sent outbound pokes @@ -725,6 +720,21 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, self._remove_duplicate_outbound_pokes, ) + # create a unique index on device_lists_outbound_last_success + self.db.updates.register_background_index_update( + "device_lists_outbound_last_success_unique_idx", + index_name="device_lists_outbound_last_success_unique_idx", + table="device_lists_outbound_last_success", + columns=["destination", "user_id"], + unique=True, + ) + + # once that completes, we can remove the old non-unique index. + self.db.updates.register_background_update_handler( + BG_UPDATE_DROP_DEVICE_LISTS_OUTBOUND_LAST_SUCCESS_NON_UNIQUE_IDX, + self._drop_device_lists_outbound_last_success_non_unique_idx, + ) + @defer.inlineCallbacks def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size): def f(conn): @@ -799,6 +809,20 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): return rows + async def _drop_device_lists_outbound_last_success_non_unique_idx( + self, progress, batch_size + ): + def f(txn): + txn.execute("DROP INDEX IF EXISTS device_lists_outbound_last_success_idx") + + await self.db.runInteraction( + "drop_device_lists_outbound_last_success_non_unique_idx", f, + ) + await self.db.updates._end_background_update( + BG_UPDATE_DROP_DEVICE_LISTS_OUTBOUND_LAST_SUCCESS_NON_UNIQUE_IDX + ) + return 1 + class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): def __init__(self, database: Database, db_conn, hs): diff --git a/synapse/storage/data_stores/main/schema/delta/58/04device_lists_outbound_last_success_unique_idx.sql b/synapse/storage/data_stores/main/schema/delta/58/04device_lists_outbound_last_success_unique_idx.sql new file mode 100644 index 0000000000..d5e6deb878 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/58/04device_lists_outbound_last_success_unique_idx.sql @@ -0,0 +1,28 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- register a background update which will create a unique index on +-- device_lists_outbound_last_success +INSERT into background_updates (ordering, update_name, progress_json) + VALUES (5804, 'device_lists_outbound_last_success_unique_idx', '{}'); + +-- once that completes, we can drop the old index. +INSERT into background_updates (ordering, update_name, progress_json, depends_on) + VALUES ( + 5804, + 'drop_device_lists_outbound_last_success_non_unique_idx', + '{}', + 'device_lists_outbound_last_success_unique_idx' + ); diff --git a/synapse/storage/database.py b/synapse/storage/database.py index f66880cbba..2b635d6ca0 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -79,6 +79,7 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = { "device_lists_remote_extremeties": "device_lists_remote_extremeties_unique_idx", "device_lists_remote_cache": "device_lists_remote_cache_unique_idx", "event_search": "event_search_event_id_idx", + "device_lists_outbound_last_success": "device_lists_outbound_last_success_unique_idx", } -- cgit 1.5.1 From db5f9031b7ab1617bfd6dd6d06c3bd1229b31d57 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 6 May 2020 00:03:38 +0100 Subject: Fix batching for fetching cross-signing keys There's no point carefully dividing a list into batches, and then completely ignoring the batches. --- .../storage/data_stores/main/end_to_end_keys.py | 38 +++++++++------------- 1 file changed, 15 insertions(+), 23 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py index bcf746b7ef..d0b9742695 100644 --- a/synapse/storage/data_stores/main/end_to_end_keys.py +++ b/synapse/storage/data_stores/main/end_to_end_keys.py @@ -25,7 +25,9 @@ from twisted.internet import defer from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage.database import make_in_list_sql_clause from synapse.util.caches.descriptors import cached, cachedList +from synapse.util.iterutils import batch_iter class EndToEndKeyWorkerStore(SQLBaseStore): @@ -391,26 +393,24 @@ class EndToEndKeyWorkerStore(SQLBaseStore): """ result = {} - batch_size = 100 - chunks = [ - user_ids[i : i + batch_size] for i in range(0, len(user_ids), batch_size) - ] - for user_chunk in chunks: - sql = """ + for user_chunk in batch_iter(user_ids, 100): + clause, params = make_in_list_sql_clause( + txn.database_engine, "k.user_id", user_chunk + ) + sql = ( + """ SELECT k.user_id, k.keytype, k.keydata, k.stream_id FROM e2e_cross_signing_keys k INNER JOIN (SELECT user_id, keytype, MAX(stream_id) AS stream_id FROM e2e_cross_signing_keys GROUP BY user_id, keytype) s USING (user_id, stream_id, keytype) - WHERE k.user_id IN (%s) - """ % ( - ",".join("?" for u in user_chunk), + WHERE + """ + + clause ) - query_params = [] - query_params.extend(user_chunk) - txn.execute(sql, query_params) + txn.execute(sql, params) rows = self.db.cursor_to_dict(txn) for row in rows: @@ -453,15 +453,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): device_id = k devices[(user_id, device_id)] = key_type - device_list = list(devices) - - # split into batches - batch_size = 100 - chunks = [ - device_list[i : i + batch_size] - for i in range(0, len(device_list), batch_size) - ] - for user_chunk in chunks: + for batch in batch_iter(devices.keys(), size=100): sql = """ SELECT target_user_id, target_device_id, key_id, signature FROM e2e_cross_signing_signatures @@ -469,11 +461,11 @@ class EndToEndKeyWorkerStore(SQLBaseStore): AND (%s) """ % ( " OR ".join( - "(target_user_id = ? AND target_device_id = ?)" for d in devices + "(target_user_id = ? AND target_device_id = ?)" for _ in batch ) ) query_params = [from_user_id] - for item in devices: + for item in batch: # item is a (user_id, device_id) tuple query_params.extend(item) -- cgit 1.5.1 From 16b67c404d232bd133f03adc5e5eba60610f941f Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 6 May 2020 00:08:36 +0100 Subject: Make get_e2e_cross_signing_key delegate to get_e2e_cross_signing_keys_bulk ... mostly because the latter has a cache. --- changelog.d/7428.misc | 1 + .../storage/data_stores/main/end_to_end_keys.py | 60 +++------------------- 2 files changed, 7 insertions(+), 54 deletions(-) create mode 100644 changelog.d/7428.misc (limited to 'synapse/storage') diff --git a/changelog.d/7428.misc b/changelog.d/7428.misc new file mode 100644 index 0000000000..db5ff76ded --- /dev/null +++ b/changelog.d/7428.misc @@ -0,0 +1 @@ +Improve performance of `get_e2e_cross_signing_key`. diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py index d0b9742695..20698bfd16 100644 --- a/synapse/storage/data_stores/main/end_to_end_keys.py +++ b/synapse/storage/data_stores/main/end_to_end_keys.py @@ -270,53 +270,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): "count_e2e_one_time_keys", _count_e2e_one_time_keys ) - def _get_e2e_cross_signing_key_txn(self, txn, user_id, key_type, from_user_id=None): - """Returns a user's cross-signing key. - - Args: - txn (twisted.enterprise.adbapi.Connection): db connection - user_id (str): the user whose key is being requested - key_type (str): the type of key that is being requested: either 'master' - for a master key, 'self_signing' for a self-signing key, or - 'user_signing' for a user-signing key - from_user_id (str): if specified, signatures made by this user on - the key will be included in the result - - Returns: - dict of the key data or None if not found - """ - sql = ( - "SELECT keydata " - " FROM e2e_cross_signing_keys " - " WHERE user_id = ? AND keytype = ? ORDER BY stream_id DESC LIMIT 1" - ) - txn.execute(sql, (user_id, key_type)) - row = txn.fetchone() - if not row: - return None - key = json.loads(row[0]) - - device_id = None - for k in key["keys"].values(): - device_id = k - - if from_user_id is not None: - sql = ( - "SELECT key_id, signature " - " FROM e2e_cross_signing_signatures " - " WHERE user_id = ? " - " AND target_user_id = ? " - " AND target_device_id = ? " - ) - txn.execute(sql, (from_user_id, user_id, device_id)) - row = txn.fetchone() - if row: - key.setdefault("signatures", {}).setdefault(from_user_id, {})[ - row[0] - ] = row[1] - - return key - + @defer.inlineCallbacks def get_e2e_cross_signing_key(self, user_id, key_type, from_user_id=None): """Returns a user's cross-signing key. @@ -331,13 +285,11 @@ class EndToEndKeyWorkerStore(SQLBaseStore): Returns: dict of the key data or None if not found """ - return self.db.runInteraction( - "get_e2e_cross_signing_key", - self._get_e2e_cross_signing_key_txn, - user_id, - key_type, - from_user_id, - ) + res = yield self.get_e2e_cross_signing_keys_bulk([user_id], from_user_id) + user_keys = res.get(user_id) + if not user_keys: + return None + return user_keys.get(key_type) @cached(num_args=1) def _get_bare_e2e_cross_signing_keys(self, user_id): -- cgit 1.5.1 From d7983b63a6746d92225295f1e9d521f847cf8ba7 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 7 May 2020 13:51:08 +0100 Subject: Support any process writing to cache invalidation stream. (#7436) --- changelog.d/7436.misc | 1 + docs/tcp_replication.md | 4 - scripts/synapse_port_db | 4 +- synapse/replication/slave/storage/_base.py | 50 +++---------- synapse/replication/slave/storage/account_data.py | 6 +- synapse/replication/slave/storage/deviceinbox.py | 6 +- synapse/replication/slave/storage/devices.py | 6 +- synapse/replication/slave/storage/events.py | 6 +- synapse/replication/slave/storage/groups.py | 6 +- synapse/replication/slave/storage/presence.py | 6 +- synapse/replication/slave/storage/push_rule.py | 6 +- synapse/replication/slave/storage/pushers.py | 6 +- synapse/replication/slave/storage/receipts.py | 6 +- synapse/replication/slave/storage/room.py | 4 +- synapse/replication/tcp/client.py | 6 +- synapse/replication/tcp/commands.py | 33 -------- synapse/replication/tcp/handler.py | 42 ++--------- synapse/replication/tcp/resource.py | 22 +++++- synapse/replication/tcp/streams/_base.py | 87 +++++++++++++++------- synapse/replication/tcp/streams/events.py | 4 +- synapse/replication/tcp/streams/federation.py | 12 ++- synapse/storage/_base.py | 3 + synapse/storage/data_stores/main/__init__.py | 15 +++- synapse/storage/data_stores/main/cache.py | 84 +++++++++++---------- .../schema/delta/58/05cache_instance.sql.postgres | 30 ++++++++ synapse/storage/prepare_database.py | 2 + 26 files changed, 226 insertions(+), 231 deletions(-) create mode 100644 changelog.d/7436.misc create mode 100644 synapse/storage/data_stores/main/schema/delta/58/05cache_instance.sql.postgres (limited to 'synapse/storage') diff --git a/changelog.d/7436.misc b/changelog.d/7436.misc new file mode 100644 index 0000000000..f7c4514950 --- /dev/null +++ b/changelog.d/7436.misc @@ -0,0 +1 @@ +Support any process writing to cache invalidation stream. diff --git a/docs/tcp_replication.md b/docs/tcp_replication.md index ab2fffbfe4..db318baa9d 100644 --- a/docs/tcp_replication.md +++ b/docs/tcp_replication.md @@ -219,10 +219,6 @@ Asks the server for the current position of all streams. Inform the server a pusher should be removed -#### INVALIDATE_CACHE (C) - - Inform the server a cache should be invalidated - ### REMOTE_SERVER_UP (S, C) Inform other processes that a remote server may have come back online. diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index e8b698f3ff..acd9ac4b75 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -122,7 +122,7 @@ APPEND_ONLY_TABLES = [ "presence_stream", "push_rules_stream", "ex_outlier_stream", - "cache_invalidation_stream", + "cache_invalidation_stream_by_instance", "public_room_list_stream", "state_group_edges", "stream_ordering_to_exterm", @@ -188,7 +188,7 @@ class MockHomeserver: self.clock = Clock(reactor) self.config = config self.hostname = config.server_name - self.version_string = "Synapse/"+get_version_string(synapse) + self.version_string = "Synapse/" + get_version_string(synapse) def get_clock(self): return self.clock diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index 5d7c8871a4..2904bd0235 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -18,14 +18,10 @@ from typing import Optional import six -from synapse.storage.data_stores.main.cache import ( - CURRENT_STATE_CACHE_NAME, - CacheInvalidationWorkerStore, -) +from synapse.storage.data_stores.main.cache import CacheInvalidationWorkerStore from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine - -from ._slaved_id_tracker import SlavedIdTracker +from synapse.storage.util.id_generators import MultiWriterIdGenerator logger = logging.getLogger(__name__) @@ -41,40 +37,16 @@ class BaseSlavedStore(CacheInvalidationWorkerStore): def __init__(self, database: Database, db_conn, hs): super(BaseSlavedStore, self).__init__(database, db_conn, hs) if isinstance(self.database_engine, PostgresEngine): - self._cache_id_gen = SlavedIdTracker( - db_conn, "cache_invalidation_stream", "stream_id" - ) # type: Optional[SlavedIdTracker] + self._cache_id_gen = MultiWriterIdGenerator( + db_conn, + database, + instance_name=hs.get_instance_name(), + table="cache_invalidation_stream_by_instance", + instance_column="instance_name", + id_column="stream_id", + sequence_name="cache_invalidation_stream_seq", + ) # type: Optional[MultiWriterIdGenerator] else: self._cache_id_gen = None self.hs = hs - - def get_cache_stream_token(self): - if self._cache_id_gen: - return self._cache_id_gen.get_current_token() - else: - return 0 - - def process_replication_rows(self, stream_name, token, rows): - if stream_name == "caches": - if self._cache_id_gen: - self._cache_id_gen.advance(token) - for row in rows: - if row.cache_func == CURRENT_STATE_CACHE_NAME: - if row.keys is None: - raise Exception( - "Can't send an 'invalidate all' for current state cache" - ) - - room_id = row.keys[0] - members_changed = set(row.keys[1:]) - self._invalidate_state_caches(room_id, members_changed) - else: - self._attempt_to_invalidate_cache(row.cache_func, row.keys) - - def _invalidate_cache_and_stream(self, txn, cache_func, keys): - txn.call_after(cache_func.invalidate, keys) - txn.call_after(self._send_invalidation_poke, cache_func, keys) - - def _send_invalidation_poke(self, cache_func, keys): - self.hs.get_tcp_replication().send_invalidate_cache(cache_func, keys) diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py index 65e54b1c71..2a4f5c7cfd 100644 --- a/synapse/replication/slave/storage/account_data.py +++ b/synapse/replication/slave/storage/account_data.py @@ -32,7 +32,7 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved def get_max_account_data_stream_id(self): return self._account_data_id_gen.get_current_token() - def process_replication_rows(self, stream_name, token, rows): + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == "tag_account_data": self._account_data_id_gen.advance(token) for row in rows: @@ -51,6 +51,4 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved (row.user_id, row.room_id, row.data_type) ) self._account_data_stream_cache.entity_has_changed(row.user_id, token) - return super(SlavedAccountDataStore, self).process_replication_rows( - stream_name, token, rows - ) + return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py index c923751e50..6e7fd259d4 100644 --- a/synapse/replication/slave/storage/deviceinbox.py +++ b/synapse/replication/slave/storage/deviceinbox.py @@ -43,7 +43,7 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore): expiry_ms=30 * 60 * 1000, ) - def process_replication_rows(self, stream_name, token, rows): + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == "to_device": self._device_inbox_id_gen.advance(token) for row in rows: @@ -55,6 +55,4 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore): self._device_federation_outbox_stream_cache.entity_has_changed( row.entity, token ) - return super(SlavedDeviceInboxStore, self).process_replication_rows( - stream_name, token, rows - ) + return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index 58fb0eaae3..9d8067342f 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -48,7 +48,7 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto "DeviceListFederationStreamChangeCache", device_list_max ) - def process_replication_rows(self, stream_name, token, rows): + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == DeviceListsStream.NAME: self._device_list_id_gen.advance(token) self._invalidate_caches_for_devices(token, rows) @@ -56,9 +56,7 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto self._device_list_id_gen.advance(token) for row in rows: self._user_signature_stream_cache.entity_has_changed(row.user_id, token) - return super(SlavedDeviceStore, self).process_replication_rows( - stream_name, token, rows - ) + return super().process_replication_rows(stream_name, instance_name, token, rows) def _invalidate_caches_for_devices(self, token, rows): for row in rows: diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index 15011259df..b313720a4b 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -93,7 +93,7 @@ class SlavedEventStore( def get_room_min_stream_ordering(self): return self._backfill_id_gen.get_current_token() - def process_replication_rows(self, stream_name, token, rows): + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == "events": self._stream_id_gen.advance(token) for row in rows: @@ -111,9 +111,7 @@ class SlavedEventStore( row.relates_to, backfilled=True, ) - return super(SlavedEventStore, self).process_replication_rows( - stream_name, token, rows - ) + return super().process_replication_rows(stream_name, instance_name, token, rows) def _process_event_stream_row(self, token, row): data = row.data diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py index 01bcf0e882..1851e7d525 100644 --- a/synapse/replication/slave/storage/groups.py +++ b/synapse/replication/slave/storage/groups.py @@ -37,12 +37,10 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore): def get_group_stream_token(self): return self._group_updates_id_gen.get_current_token() - def process_replication_rows(self, stream_name, token, rows): + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == "groups": self._group_updates_id_gen.advance(token) for row in rows: self._group_updates_stream_cache.entity_has_changed(row.user_id, token) - return super(SlavedGroupServerStore, self).process_replication_rows( - stream_name, token, rows - ) + return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py index fae3125072..bd79ba99be 100644 --- a/synapse/replication/slave/storage/presence.py +++ b/synapse/replication/slave/storage/presence.py @@ -41,12 +41,10 @@ class SlavedPresenceStore(BaseSlavedStore): def get_current_presence_token(self): return self._presence_id_gen.get_current_token() - def process_replication_rows(self, stream_name, token, rows): + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == "presence": self._presence_id_gen.advance(token) for row in rows: self.presence_stream_cache.entity_has_changed(row.user_id, token) self._get_presence_for_user.invalidate((row.user_id,)) - return super(SlavedPresenceStore, self).process_replication_rows( - stream_name, token, rows - ) + return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py index 6138796da4..5d5816d7eb 100644 --- a/synapse/replication/slave/storage/push_rule.py +++ b/synapse/replication/slave/storage/push_rule.py @@ -37,13 +37,11 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore): def get_max_push_rules_stream_id(self): return self._push_rules_stream_id_gen.get_current_token() - def process_replication_rows(self, stream_name, token, rows): + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == "push_rules": self._push_rules_stream_id_gen.advance(token) for row in rows: self.get_push_rules_for_user.invalidate((row.user_id,)) self.get_push_rules_enabled_for_user.invalidate((row.user_id,)) self.push_rules_stream_cache.entity_has_changed(row.user_id, token) - return super(SlavedPushRuleStore, self).process_replication_rows( - stream_name, token, rows - ) + return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py index 67be337945..cb78b49acb 100644 --- a/synapse/replication/slave/storage/pushers.py +++ b/synapse/replication/slave/storage/pushers.py @@ -31,9 +31,7 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore): def get_pushers_stream_token(self): return self._pushers_id_gen.get_current_token() - def process_replication_rows(self, stream_name, token, rows): + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == "pushers": self._pushers_id_gen.advance(token) - return super(SlavedPusherStore, self).process_replication_rows( - stream_name, token, rows - ) + return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py index 993432edcb..be716cc558 100644 --- a/synapse/replication/slave/storage/receipts.py +++ b/synapse/replication/slave/storage/receipts.py @@ -51,7 +51,7 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id) self.get_receipts_for_room.invalidate((room_id, receipt_type)) - def process_replication_rows(self, stream_name, token, rows): + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == "receipts": self._receipts_id_gen.advance(token) for row in rows: @@ -60,6 +60,4 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): ) self._receipts_stream_cache.entity_has_changed(row.room_id, token) - return super(SlavedReceiptsStore, self).process_replication_rows( - stream_name, token, rows - ) + return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py index 10dda8708f..8873bf37e5 100644 --- a/synapse/replication/slave/storage/room.py +++ b/synapse/replication/slave/storage/room.py @@ -30,8 +30,8 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore): def get_current_public_room_stream_id(self): return self._public_room_id_gen.get_current_token() - def process_replication_rows(self, stream_name, token, rows): + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == "public_rooms": self._public_room_id_gen.advance(token) - return super(RoomStore, self).process_replication_rows(stream_name, token, rows) + return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 3bbf3c3569..20cb8a654f 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -100,10 +100,10 @@ class ReplicationDataHandler: token: stream token for this batch of rows rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. """ - self.store.process_replication_rows(stream_name, token, rows) + self.store.process_replication_rows(stream_name, instance_name, token, rows) - async def on_position(self, stream_name: str, token: int): - self.store.process_replication_rows(stream_name, token, []) + async def on_position(self, stream_name: str, instance_name: str, token: int): + self.store.process_replication_rows(stream_name, instance_name, token, []) def on_remote_server_up(self, server: str): """Called when get a new REMOTE_SERVER_UP command.""" diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index f58e384d17..c04f622816 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -341,37 +341,6 @@ class RemovePusherCommand(Command): return " ".join((self.app_id, self.push_key, self.user_id)) -class InvalidateCacheCommand(Command): - """Sent by the client to invalidate an upstream cache. - - THIS IS NOT RELIABLE, AND SHOULD *NOT* BE USED ACCEPT FOR THINGS THAT ARE - NOT DISASTROUS IF WE DROP ON THE FLOOR. - - Mainly used to invalidate destination retry timing caches. - - Format:: - - INVALIDATE_CACHE - - Where is a json list. - """ - - NAME = "INVALIDATE_CACHE" - - def __init__(self, cache_func, keys): - self.cache_func = cache_func - self.keys = keys - - @classmethod - def from_line(cls, line): - cache_func, keys_json = line.split(" ", 1) - - return cls(cache_func, json.loads(keys_json)) - - def to_line(self): - return " ".join((self.cache_func, _json_encoder.encode(self.keys))) - - class UserIpCommand(Command): """Sent periodically when a worker sees activity from a client. @@ -439,7 +408,6 @@ _COMMANDS = ( UserSyncCommand, FederationAckCommand, RemovePusherCommand, - InvalidateCacheCommand, UserIpCommand, RemoteServerUpCommand, ClearUserSyncsCommand, @@ -467,7 +435,6 @@ VALID_CLIENT_COMMANDS = ( ClearUserSyncsCommand.NAME, FederationAckCommand.NAME, RemovePusherCommand.NAME, - InvalidateCacheCommand.NAME, UserIpCommand.NAME, ErrorCommand.NAME, RemoteServerUpCommand.NAME, diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index b14a3d9fca..7c5d6c76e7 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -15,18 +15,7 @@ # limitations under the License. import logging -from typing import ( - Any, - Callable, - Dict, - Iterable, - Iterator, - List, - Optional, - Set, - Tuple, - TypeVar, -) +from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar from prometheus_client import Counter @@ -38,7 +27,6 @@ from synapse.replication.tcp.commands import ( ClearUserSyncsCommand, Command, FederationAckCommand, - InvalidateCacheCommand, PositionCommand, RdataCommand, RemoteServerUpCommand, @@ -171,7 +159,7 @@ class ReplicationCommandHandler: return for stream_name, stream in self._streams.items(): - current_token = stream.current_token() + current_token = stream.current_token(self._instance_name) self.send_command( PositionCommand(stream_name, self._instance_name, current_token) ) @@ -210,18 +198,6 @@ class ReplicationCommandHandler: self._notifier.on_new_replication_data() - async def on_INVALIDATE_CACHE( - self, conn: AbstractConnection, cmd: InvalidateCacheCommand - ): - invalidate_cache_counter.inc() - - if self._is_master: - # We invalidate the cache locally, but then also stream that to other - # workers. - await self._store.invalidate_cache_and_stream( - cmd.cache_func, tuple(cmd.keys) - ) - async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand): user_ip_cache_counter.inc() @@ -295,7 +271,7 @@ class ReplicationCommandHandler: rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. """ - logger.debug("Received rdata %s -> %s", stream_name, token) + logger.debug("Received rdata %s (%s) -> %s", stream_name, instance_name, token) await self._replication_data_handler.on_rdata( stream_name, instance_name, token, rows ) @@ -326,7 +302,7 @@ class ReplicationCommandHandler: self._pending_batches.pop(stream_name, []) # Find where we previously streamed up to. - current_token = stream.current_token() + current_token = stream.current_token(cmd.instance_name) # If the position token matches our current token then we're up to # date and there's nothing to do. Otherwise, fetch all updates @@ -363,7 +339,9 @@ class ReplicationCommandHandler: logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token) # We've now caught up to position sent to us, notify handler. - await self._replication_data_handler.on_position(stream_name, cmd.token) + await self._replication_data_handler.on_position( + cmd.stream_name, cmd.instance_name, cmd.token + ) self._streams_by_connection.setdefault(conn, set()).add(stream_name) @@ -491,12 +469,6 @@ class ReplicationCommandHandler: cmd = RemovePusherCommand(app_id, push_key, user_id) self.send_command(cmd) - def send_invalidate_cache(self, cache_func: Callable, keys: tuple): - """Poke the master to invalidate a cache. - """ - cmd = InvalidateCacheCommand(cache_func.__name__, keys) - self.send_command(cmd) - def send_user_ip( self, user_id: str, diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index b690abedad..002171ce7c 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -25,7 +25,12 @@ from twisted.internet.protocol import Factory from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol -from synapse.replication.tcp.streams import STREAMS_MAP, FederationStream, Stream +from synapse.replication.tcp.streams import ( + STREAMS_MAP, + CachesStream, + FederationStream, + Stream, +) from synapse.util.metrics import Measure stream_updates_counter = Counter( @@ -71,11 +76,16 @@ class ReplicationStreamer(object): self.store = hs.get_datastore() self.clock = hs.get_clock() self.notifier = hs.get_notifier() + self._instance_name = hs.get_instance_name() self._replication_torture_level = hs.config.replication_torture_level # Work out list of streams that this instance is the source of. self.streams = [] # type: List[Stream] + + # All workers can write to the cache invalidation stream. + self.streams.append(CachesStream(hs)) + if hs.config.worker_app is None: for stream in STREAMS_MAP.values(): if stream == FederationStream and hs.config.send_federation: @@ -83,6 +93,10 @@ class ReplicationStreamer(object): # has been disabled on the master. continue + if stream == CachesStream: + # We've already added it above. + continue + self.streams.append(stream(hs)) self.streams_by_name = {stream.NAME: stream for stream in self.streams} @@ -145,7 +159,9 @@ class ReplicationStreamer(object): random.shuffle(all_streams) for stream in all_streams: - if stream.last_token == stream.current_token(): + if stream.last_token == stream.current_token( + self._instance_name + ): continue if self._replication_torture_level: @@ -157,7 +173,7 @@ class ReplicationStreamer(object): "Getting stream: %s: %s -> %s", stream.NAME, stream.last_token, - stream.current_token(), + stream.current_token(self._instance_name), ) try: updates, current_token, limited = await stream.get_updates() diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 084604e8b0..b48a6a3e91 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -95,20 +95,25 @@ class Stream(object): def __init__( self, local_instance_name: str, - current_token_function: Callable[[], Token], + current_token_function: Callable[[str], Token], update_function: UpdateFunction, ): """Instantiate a Stream - current_token_function and update_function are callbacks which should be - implemented by subclasses. + `current_token_function` and `update_function` are callbacks which + should be implemented by subclasses. - current_token_function is called to get the current token of the underlying - stream. It is only meaningful on the process that is the source of the - replication stream (ie, usually the master). + `current_token_function` takes an instance name, which is a writer to + the stream, and returns the position in the stream of the writer (as + viewed from the current process). On the writer process this is where + the writer has successfully written up to, whereas on other processes + this is the position which we have received updates up to over + replication. (Note that most streams have a single writer and so their + implementations ignore the instance name passed in). - update_function is called to get updates for this stream between a pair of - stream tokens. See the UpdateFunction type definition for more info. + `update_function` is called to get updates for this stream between a + pair of stream tokens. See the `UpdateFunction` type definition for more + info. Args: local_instance_name: The instance name of the current process @@ -120,13 +125,13 @@ class Stream(object): self.update_function = update_function # The token from which we last asked for updates - self.last_token = self.current_token() + self.last_token = self.current_token(self.local_instance_name) def discard_updates_and_advance(self): """Called when the stream should advance but the updates would be discarded, e.g. when there are no currently connected workers. """ - self.last_token = self.current_token() + self.last_token = self.current_token(self.local_instance_name) async def get_updates(self) -> StreamUpdateResult: """Gets all updates since the last time this function was called (or @@ -138,7 +143,7 @@ class Stream(object): position in stream, and `limited` is whether there are more updates to fetch. """ - current_token = self.current_token() + current_token = self.current_token(self.local_instance_name) updates, current_token, limited = await self.get_updates_since( self.local_instance_name, self.last_token, current_token ) @@ -170,6 +175,16 @@ class Stream(object): return updates, upto_token, limited +def current_token_without_instance( + current_token: Callable[[], int] +) -> Callable[[str], int]: + """Takes a current token callback function for a single writer stream + that doesn't take an instance name parameter and wraps it in a function that + does accept an instance name parameter but ignores it. + """ + return lambda instance_name: current_token() + + def db_query_to_update_function( query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]] ) -> UpdateFunction: @@ -235,7 +250,7 @@ class BackfillStream(Stream): store = hs.get_datastore() super().__init__( hs.get_instance_name(), - store.get_current_backfill_token, + current_token_without_instance(store.get_current_backfill_token), db_query_to_update_function(store.get_all_new_backfill_event_rows), ) @@ -271,7 +286,9 @@ class PresenceStream(Stream): update_function = make_http_update_function(hs, self.NAME) super().__init__( - hs.get_instance_name(), store.get_current_presence_token, update_function + hs.get_instance_name(), + current_token_without_instance(store.get_current_presence_token), + update_function, ) @@ -296,7 +313,9 @@ class TypingStream(Stream): update_function = make_http_update_function(hs, self.NAME) super().__init__( - hs.get_instance_name(), typing_handler.get_current_token, update_function + hs.get_instance_name(), + current_token_without_instance(typing_handler.get_current_token), + update_function, ) @@ -319,7 +338,7 @@ class ReceiptsStream(Stream): store = hs.get_datastore() super().__init__( hs.get_instance_name(), - store.get_max_receipt_stream_id, + current_token_without_instance(store.get_max_receipt_stream_id), db_query_to_update_function(store.get_all_updated_receipts), ) @@ -339,7 +358,7 @@ class PushRulesStream(Stream): hs.get_instance_name(), self._current_token, self._update_function ) - def _current_token(self) -> int: + def _current_token(self, instance_name: str) -> int: push_rules_token, _ = self.store.get_push_rules_stream_token() return push_rules_token @@ -373,7 +392,7 @@ class PushersStream(Stream): super().__init__( hs.get_instance_name(), - store.get_pushers_stream_token, + current_token_without_instance(store.get_pushers_stream_token), db_query_to_update_function(store.get_all_updated_pushers_rows), ) @@ -402,12 +421,26 @@ class CachesStream(Stream): ROW_TYPE = CachesStreamRow def __init__(self, hs): - store = hs.get_datastore() + self.store = hs.get_datastore() super().__init__( hs.get_instance_name(), - store.get_cache_stream_token, - db_query_to_update_function(store.get_all_updated_caches), + self.store.get_cache_stream_token, + self._update_function, + ) + + async def _update_function( + self, instance_name: str, from_token: int, upto_token: int, limit: int + ): + rows = await self.store.get_all_updated_caches( + instance_name, from_token, upto_token, limit ) + updates = [(row[0], row[1:]) for row in rows] + limited = False + if len(updates) >= limit: + upto_token = updates[-1][0] + limited = True + + return updates, upto_token, limited class PublicRoomsStream(Stream): @@ -431,7 +464,7 @@ class PublicRoomsStream(Stream): store = hs.get_datastore() super().__init__( hs.get_instance_name(), - store.get_current_public_room_stream_id, + current_token_without_instance(store.get_current_public_room_stream_id), db_query_to_update_function(store.get_all_new_public_rooms), ) @@ -452,7 +485,7 @@ class DeviceListsStream(Stream): store = hs.get_datastore() super().__init__( hs.get_instance_name(), - store.get_device_stream_token, + current_token_without_instance(store.get_device_stream_token), db_query_to_update_function(store.get_all_device_list_changes_for_remotes), ) @@ -470,7 +503,7 @@ class ToDeviceStream(Stream): store = hs.get_datastore() super().__init__( hs.get_instance_name(), - store.get_to_device_stream_token, + current_token_without_instance(store.get_to_device_stream_token), db_query_to_update_function(store.get_all_new_device_messages), ) @@ -490,7 +523,7 @@ class TagAccountDataStream(Stream): store = hs.get_datastore() super().__init__( hs.get_instance_name(), - store.get_max_account_data_stream_id, + current_token_without_instance(store.get_max_account_data_stream_id), db_query_to_update_function(store.get_all_updated_tags), ) @@ -510,7 +543,7 @@ class AccountDataStream(Stream): self.store = hs.get_datastore() super().__init__( hs.get_instance_name(), - self.store.get_max_account_data_stream_id, + current_token_without_instance(self.store.get_max_account_data_stream_id), db_query_to_update_function(self._update_function), ) @@ -541,7 +574,7 @@ class GroupServerStream(Stream): store = hs.get_datastore() super().__init__( hs.get_instance_name(), - store.get_group_stream_token, + current_token_without_instance(store.get_group_stream_token), db_query_to_update_function(store.get_all_groups_changes), ) @@ -559,7 +592,7 @@ class UserSignatureStream(Stream): store = hs.get_datastore() super().__init__( hs.get_instance_name(), - store.get_device_stream_token, + current_token_without_instance(store.get_device_stream_token), db_query_to_update_function( store.get_all_user_signature_changes_for_remotes ), diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index 890e75d827..f370390331 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -20,7 +20,7 @@ from typing import List, Tuple, Type import attr -from ._base import Stream, StreamUpdateResult, Token +from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance """Handling of the 'events' replication stream @@ -119,7 +119,7 @@ class EventsStream(Stream): self._store = hs.get_datastore() super().__init__( hs.get_instance_name(), - self._store.get_current_events_token, + current_token_without_instance(self._store.get_current_events_token), self._update_function, ) diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py index b0505b8a2c..9bcd13b009 100644 --- a/synapse/replication/tcp/streams/federation.py +++ b/synapse/replication/tcp/streams/federation.py @@ -15,7 +15,11 @@ # limitations under the License. from collections import namedtuple -from synapse.replication.tcp.streams._base import Stream, make_http_update_function +from synapse.replication.tcp.streams._base import ( + Stream, + current_token_without_instance, + make_http_update_function, +) class FederationStream(Stream): @@ -41,7 +45,9 @@ class FederationStream(Stream): # will be a real FederationSender, which has stubs for current_token and # get_replication_rows.) federation_sender = hs.get_federation_sender() - current_token = federation_sender.get_current_token + current_token = current_token_without_instance( + federation_sender.get_current_token + ) update_function = federation_sender.get_replication_rows elif hs.should_send_federation(): @@ -58,7 +64,7 @@ class FederationStream(Stream): super().__init__(hs.get_instance_name(), current_token, update_function) @staticmethod - def _stub_current_token(): + def _stub_current_token(instance_name: str) -> int: # dummy current-token method for use on workers return 0 diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 13de5f1f62..59073c0a42 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -47,6 +47,9 @@ class SQLBaseStore(metaclass=ABCMeta): self.db = database self.rand = random.SystemRandom() + def process_replication_rows(self, stream_name, instance_name, token, rows): + pass + def _invalidate_state_caches(self, room_id, members_changed): """Invalidates caches that are based on the current state, but does not stream invalidations down replication. diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py index ceba10882c..cd2a1f0461 100644 --- a/synapse/storage/data_stores/main/__init__.py +++ b/synapse/storage/data_stores/main/__init__.py @@ -26,13 +26,14 @@ from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import ( ChainedIdGenerator, IdGenerator, + MultiWriterIdGenerator, StreamIdGenerator, ) from synapse.util.caches.stream_change_cache import StreamChangeCache from .account_data import AccountDataStore from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore -from .cache import CacheInvalidationStore +from .cache import CacheInvalidationWorkerStore from .client_ips import ClientIpStore from .deviceinbox import DeviceInboxStore from .devices import DeviceStore @@ -112,8 +113,8 @@ class DataStore( MonthlyActiveUsersStore, StatsStore, RelationsStore, - CacheInvalidationStore, UIAuthStore, + CacheInvalidationWorkerStore, ): def __init__(self, database: Database, db_conn, hs): self.hs = hs @@ -170,8 +171,14 @@ class DataStore( ) if isinstance(self.database_engine, PostgresEngine): - self._cache_id_gen = StreamIdGenerator( - db_conn, "cache_invalidation_stream", "stream_id" + self._cache_id_gen = MultiWriterIdGenerator( + db_conn, + database, + instance_name="master", + table="cache_invalidation_stream_by_instance", + instance_column="instance_name", + id_column="stream_id", + sequence_name="cache_invalidation_stream_seq", ) else: self._cache_id_gen = None diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py index 4dc5da3fe8..342a87a46b 100644 --- a/synapse/storage/data_stores/main/cache.py +++ b/synapse/storage/data_stores/main/cache.py @@ -16,11 +16,10 @@ import itertools import logging -from typing import Any, Iterable, Optional, Tuple - -from twisted.internet import defer +from typing import Any, Iterable, Optional from synapse.storage._base import SQLBaseStore +from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine from synapse.util.iterutils import batch_iter @@ -33,47 +32,58 @@ CURRENT_STATE_CACHE_NAME = "cs_cache_fake" class CacheInvalidationWorkerStore(SQLBaseStore): - def get_all_updated_caches(self, last_id, current_id, limit): + def __init__(self, database: Database, db_conn, hs): + super().__init__(database, db_conn, hs) + + self._instance_name = hs.get_instance_name() + + async def get_all_updated_caches( + self, instance_name: str, last_id: int, current_id: int, limit: int + ): + """Fetches cache invalidation rows between the two given IDs written + by the given instance. Returns at most `limit` rows. + """ + if last_id == current_id: - return defer.succeed([]) + return [] def get_all_updated_caches_txn(txn): # We purposefully don't bound by the current token, as we want to # send across cache invalidations as quickly as possible. Cache # invalidations are idempotent, so duplicates are fine. - sql = ( - "SELECT stream_id, cache_func, keys, invalidation_ts" - " FROM cache_invalidation_stream" - " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?" - ) - txn.execute(sql, (last_id, limit)) + sql = """ + SELECT stream_id, cache_func, keys, invalidation_ts + FROM cache_invalidation_stream_by_instance + WHERE stream_id > ? AND instance_name = ? + ORDER BY stream_id ASC + LIMIT ? + """ + txn.execute(sql, (last_id, instance_name, limit)) return txn.fetchall() - return self.db.runInteraction( + return await self.db.runInteraction( "get_all_updated_caches", get_all_updated_caches_txn ) + def process_replication_rows(self, stream_name, instance_name, token, rows): + if stream_name == "caches": + if self._cache_id_gen: + self._cache_id_gen.advance(instance_name, token) -class CacheInvalidationStore(CacheInvalidationWorkerStore): - async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]): - """Invalidates the cache and adds it to the cache stream so slaves - will know to invalidate their caches. + for row in rows: + if row.cache_func == CURRENT_STATE_CACHE_NAME: + if row.keys is None: + raise Exception( + "Can't send an 'invalidate all' for current state cache" + ) - This should only be used to invalidate caches where slaves won't - otherwise know from other replication streams that the cache should - be invalidated. - """ - cache_func = getattr(self, cache_name, None) - if not cache_func: - return - - cache_func.invalidate(keys) - await self.runInteraction( - "invalidate_cache_and_stream", - self._send_invalidation_to_replication, - cache_func.__name__, - keys, - ) + room_id = row.keys[0] + members_changed = set(row.keys[1:]) + self._invalidate_state_caches(room_id, members_changed) + else: + self._attempt_to_invalidate_cache(row.cache_func, row.keys) + + super().process_replication_rows(stream_name, instance_name, token, rows) def _invalidate_cache_and_stream(self, txn, cache_func, keys): """Invalidates the cache and adds it to the cache stream so slaves @@ -147,10 +157,7 @@ class CacheInvalidationStore(CacheInvalidationWorkerStore): # the transaction. However, we want to only get an ID when we want # to use it, here, so we need to call __enter__ manually, and have # __exit__ called after the transaction finishes. - ctx = self._cache_id_gen.get_next() - stream_id = ctx.__enter__() - txn.call_on_exception(ctx.__exit__, None, None, None) - txn.call_after(ctx.__exit__, None, None, None) + stream_id = self._cache_id_gen.get_next_txn(txn) txn.call_after(self.hs.get_notifier().on_new_replication_data) if keys is not None: @@ -158,17 +165,18 @@ class CacheInvalidationStore(CacheInvalidationWorkerStore): self.db.simple_insert_txn( txn, - table="cache_invalidation_stream", + table="cache_invalidation_stream_by_instance", values={ "stream_id": stream_id, + "instance_name": self._instance_name, "cache_func": cache_name, "keys": keys, "invalidation_ts": self.clock.time_msec(), }, ) - def get_cache_stream_token(self): + def get_cache_stream_token(self, instance_name): if self._cache_id_gen: - return self._cache_id_gen.get_current_token() + return self._cache_id_gen.get_current_token(instance_name) else: return 0 diff --git a/synapse/storage/data_stores/main/schema/delta/58/05cache_instance.sql.postgres b/synapse/storage/data_stores/main/schema/delta/58/05cache_instance.sql.postgres new file mode 100644 index 0000000000..aa46eb0e10 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/58/05cache_instance.sql.postgres @@ -0,0 +1,30 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- We keep the old table here to enable us to roll back. It doesn't matter +-- that we have dropped all the data here. +TRUNCATE cache_invalidation_stream; + +CREATE TABLE cache_invalidation_stream_by_instance ( + stream_id BIGINT NOT NULL, + instance_name TEXT NOT NULL, + cache_func TEXT NOT NULL, + keys TEXT[], + invalidation_ts BIGINT +); + +CREATE UNIQUE INDEX cache_invalidation_stream_by_instance_id ON cache_invalidation_stream_by_instance(stream_id); + +CREATE SEQUENCE cache_invalidation_stream_seq; diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 1712932f31..640f242584 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -29,6 +29,8 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. +# XXX: If you're about to bump this to 59 (or higher) please create an update +# that drops the unused `cache_invalidation_stream` table, as per #7436! SCHEMA_VERSION = 58 dir_path = os.path.abspath(os.path.dirname(__file__)) -- cgit 1.5.1 From a4a5ec4096f8de938f4a6e4264aeaaa0e0b26463 Mon Sep 17 00:00:00 2001 From: Manuel Stahl <37705355+awesome-manuel@users.noreply.github.com> Date: Thu, 7 May 2020 21:33:07 +0200 Subject: Add room details admin endpoint (#7317) --- changelog.d/7317.feature | 1 + docs/admin_api/rooms.md | 54 ++++++++++++++++++++++++++++++++ synapse/rest/admin/__init__.py | 2 ++ synapse/rest/admin/rooms.py | 26 ++++++++++++++- synapse/storage/data_stores/main/room.py | 31 ++++++++++++++++++ tests/rest/admin/test_room.py | 41 ++++++++++++++++++++++++ tests/storage/test_room.py | 11 +++++++ 7 files changed, 165 insertions(+), 1 deletion(-) create mode 100644 changelog.d/7317.feature (limited to 'synapse/storage') diff --git a/changelog.d/7317.feature b/changelog.d/7317.feature new file mode 100644 index 0000000000..23c063f280 --- /dev/null +++ b/changelog.d/7317.feature @@ -0,0 +1 @@ +Add room details admin endpoint. Contributed by Awesome Technologies Innovationslabor GmbH. diff --git a/docs/admin_api/rooms.md b/docs/admin_api/rooms.md index 26fe8b8679..624e7745ba 100644 --- a/docs/admin_api/rooms.md +++ b/docs/admin_api/rooms.md @@ -264,3 +264,57 @@ Response: Once the `next_token` parameter is no longer present, we know we've reached the end of the list. + +# DRAFT: Room Details API + +The Room Details admin API allows server admins to get all details of a room. + +This API is still a draft and details might change! + +The following fields are possible in the JSON response body: + +* `room_id` - The ID of the room. +* `name` - The name of the room. +* `canonical_alias` - The canonical (main) alias address of the room. +* `joined_members` - How many users are currently in the room. +* `joined_local_members` - How many local users are currently in the room. +* `version` - The version of the room as a string. +* `creator` - The `user_id` of the room creator. +* `encryption` - Algorithm of end-to-end encryption of messages. Is `null` if encryption is not active. +* `federatable` - Whether users on other servers can join this room. +* `public` - Whether the room is visible in room directory. +* `join_rules` - The type of rules used for users wishing to join this room. One of: ["public", "knock", "invite", "private"]. +* `guest_access` - Whether guests can join the room. One of: ["can_join", "forbidden"]. +* `history_visibility` - Who can see the room history. One of: ["invited", "joined", "shared", "world_readable"]. +* `state_events` - Total number of state_events of a room. Complexity of the room. + +## Usage + +A standard request: + +``` +GET /_synapse/admin/v1/rooms/ + +{} +``` + +Response: + +``` +{ + "room_id": "!mscvqgqpHYjBGDxNym:matrix.org", + "name": "Music Theory", + "canonical_alias": "#musictheory:matrix.org", + "joined_members": 127 + "joined_local_members": 2, + "version": "1", + "creator": "@foo:matrix.org", + "encryption": null, + "federatable": true, + "public": true, + "join_rules": "invite", + "guest_access": null, + "history_visibility": "shared", + "state_events": 93534 +} +``` diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index ed70d448a1..6b85148a32 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -32,6 +32,7 @@ from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet from synapse.rest.admin.rooms import ( JoinRoomAliasServlet, ListRoomRestServlet, + RoomRestServlet, ShutdownRoomRestServlet, ) from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet @@ -193,6 +194,7 @@ def register_servlets(hs, http_server): """ register_servlets_for_client_rest_resource(hs, http_server) ListRoomRestServlet(hs).register(http_server) + RoomRestServlet(hs).register(http_server) JoinRoomAliasServlet(hs).register(http_server) PurgeRoomServlet(hs).register(http_server) SendServerNoticeServlet(hs).register(http_server) diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index d1bdb64111..7d40001988 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -26,6 +26,7 @@ from synapse.http.servlet import ( ) from synapse.rest.admin._base import ( admin_patterns, + assert_requester_is_admin, assert_user_is_admin, historical_admin_path_patterns, ) @@ -169,7 +170,7 @@ class ListRoomRestServlet(RestServlet): in a dictionary containing room information. Supports pagination. """ - PATTERNS = admin_patterns("/rooms") + PATTERNS = admin_patterns("/rooms$") def __init__(self, hs): self.store = hs.get_datastore() @@ -253,6 +254,29 @@ class ListRoomRestServlet(RestServlet): return 200, response +class RoomRestServlet(RestServlet): + """Get room details. + + TODO: Add on_POST to allow room creation without joining the room + """ + + PATTERNS = admin_patterns("/rooms/(?P[^/]+)$") + + def __init__(self, hs): + self.hs = hs + self.auth = hs.get_auth() + self.store = hs.get_datastore() + + async def on_GET(self, request, room_id): + await assert_requester_is_admin(self.auth, request) + + ret = await self.store.get_room_with_stats(room_id) + if not ret: + raise NotFoundError("Room not found") + + return 200, ret + + class JoinRoomAliasServlet(RestServlet): PATTERNS = admin_patterns("/join/(?P[^/]*)") diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py index 147eba1df7..cafa664c16 100644 --- a/synapse/storage/data_stores/main/room.py +++ b/synapse/storage/data_stores/main/room.py @@ -98,6 +98,37 @@ class RoomWorkerStore(SQLBaseStore): allow_none=True, ) + def get_room_with_stats(self, room_id: str): + """Retrieve room with statistics. + + Args: + room_id: The ID of the room to retrieve. + Returns: + A dict containing the room information, or None if the room is unknown. + """ + + def get_room_with_stats_txn(txn, room_id): + sql = """ + SELECT room_id, state.name, state.canonical_alias, curr.joined_members, + curr.local_users_in_room AS joined_local_members, rooms.room_version AS version, + rooms.creator, state.encryption, state.is_federatable AS federatable, + rooms.is_public AS public, state.join_rules, state.guest_access, + state.history_visibility, curr.current_state_events AS state_events + FROM rooms + LEFT JOIN room_stats_state state USING (room_id) + LEFT JOIN room_stats_current curr USING (room_id) + WHERE room_id = ? + """ + txn.execute(sql, [room_id]) + res = self.db.cursor_to_dict(txn)[0] + res["federatable"] = bool(res["federatable"]) + res["public"] = bool(res["public"]) + return res + + return self.db.runInteraction( + "get_room_with_stats", get_room_with_stats_txn, room_id + ) + def get_public_room_ids(self): return self.db.simple_select_onecol( table="rooms", diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 249c93722f..54cd24bf64 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -701,6 +701,47 @@ class RoomTestCase(unittest.HomeserverTestCase): _search_test(None, "bar") _search_test(None, "", expected_http_code=400) + def test_single_room(self): + """Test that a single room can be requested correctly""" + # Create two test rooms + room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + + room_name_1 = "something" + room_name_2 = "else" + + # Set the name for each room + self.helper.send_state( + room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok, + ) + self.helper.send_state( + room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok, + ) + + url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + + self.assertIn("room_id", channel.json_body) + self.assertIn("name", channel.json_body) + self.assertIn("canonical_alias", channel.json_body) + self.assertIn("joined_members", channel.json_body) + self.assertIn("joined_local_members", channel.json_body) + self.assertIn("version", channel.json_body) + self.assertIn("creator", channel.json_body) + self.assertIn("encryption", channel.json_body) + self.assertIn("federatable", channel.json_body) + self.assertIn("public", channel.json_body) + self.assertIn("join_rules", channel.json_body) + self.assertIn("guest_access", channel.json_body) + self.assertIn("history_visibility", channel.json_body) + self.assertIn("state_events", channel.json_body) + + self.assertEqual(room_id_1, channel.json_body["room_id"]) + class JoinAliasRoomTestCase(unittest.HomeserverTestCase): diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index 086adeb8fd..3b78d48896 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -55,6 +55,17 @@ class RoomStoreTestCase(unittest.TestCase): (yield self.store.get_room(self.room.to_string())), ) + @defer.inlineCallbacks + def test_get_room_with_stats(self): + self.assertDictContainsSubset( + { + "room_id": self.room.to_string(), + "creator": self.u_creator.to_string(), + "public": True, + }, + (yield self.store.get_room_with_stats(self.room.to_string())), + ) + class RoomEventsStoreTestCase(unittest.TestCase): @defer.inlineCallbacks -- cgit 1.5.1 From a8580c5f1986e8d2f4bf9cf64190ac5e3d84aaf6 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Mon, 11 May 2020 16:55:57 +0100 Subject: Remove unused store method get_hosts_in_room (#7448) --- changelog.d/7448.misc | 1 + synapse/storage/data_stores/main/roommember.py | 10 ---------- 2 files changed, 1 insertion(+), 10 deletions(-) create mode 100644 changelog.d/7448.misc (limited to 'synapse/storage') diff --git a/changelog.d/7448.misc b/changelog.d/7448.misc new file mode 100644 index 0000000000..2163cc9b97 --- /dev/null +++ b/changelog.d/7448.misc @@ -0,0 +1 @@ +Remove storage method `get_hosts_in_room` that is no longer called anywhere. \ No newline at end of file diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py index d5bd0cb5cf..ae9a76713c 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py @@ -153,16 +153,6 @@ class RoomMemberWorkerStore(EventsWorkerStore): self._check_safe_current_state_events_membership_updated_txn, ) - @cachedInlineCallbacks(max_entries=100000, iterable=True, cache_context=True) - def get_hosts_in_room(self, room_id, cache_context): - """Returns the set of all hosts currently in the room - """ - user_ids = yield self.get_users_in_room( - room_id, on_invalidate=cache_context.invalidate - ) - hosts = frozenset(get_domain_from_id(user_id) for user_id in user_ids) - return hosts - @cached(max_entries=100000, iterable=True) def get_users_in_room(self, room_id): return self.db.runInteraction( -- cgit 1.5.1 From 7cb8b4bc67042a39bd1b0e05df46089a2fce1955 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Tue, 12 May 2020 03:45:23 +1000 Subject: Allow configuration of Synapse's cache without using synctl or environment variables (#6391) --- changelog.d/6391.feature | 1 + docs/sample_config.yaml | 43 +++++- synapse/api/auth.py | 4 +- synapse/app/homeserver.py | 5 +- synapse/config/cache.py | 164 ++++++++++++++++++++++ synapse/config/database.py | 6 - synapse/config/homeserver.py | 2 + synapse/http/client.py | 6 +- synapse/metrics/_exposition.py | 12 +- synapse/push/bulk_push_rule_evaluator.py | 4 +- synapse/push/push_rule_evaluator.py | 4 +- synapse/replication/slave/storage/client_ips.py | 3 +- synapse/state/__init__.py | 4 +- synapse/storage/data_stores/main/client_ips.py | 3 +- synapse/storage/data_stores/main/events_worker.py | 5 +- synapse/storage/data_stores/state/store.py | 6 +- synapse/util/caches/__init__.py | 144 ++++++++++--------- synapse/util/caches/descriptors.py | 36 ++++- synapse/util/caches/expiringcache.py | 29 +++- synapse/util/caches/lrucache.py | 52 +++++-- synapse/util/caches/response_cache.py | 2 +- synapse/util/caches/stream_change_cache.py | 33 ++++- synapse/util/caches/ttlcache.py | 2 +- tests/config/test_cache.py | 127 +++++++++++++++++ tests/storage/test__base.py | 8 +- tests/storage/test_appservice.py | 10 +- tests/storage/test_base.py | 3 +- tests/test_metrics.py | 34 +++++ tests/util/test_expiring_cache.py | 2 +- tests/util/test_lrucache.py | 6 +- tests/util/test_stream_change_cache.py | 5 +- tests/utils.py | 1 + 32 files changed, 620 insertions(+), 146 deletions(-) create mode 100644 changelog.d/6391.feature create mode 100644 synapse/config/cache.py create mode 100644 tests/config/test_cache.py (limited to 'synapse/storage') diff --git a/changelog.d/6391.feature b/changelog.d/6391.feature new file mode 100644 index 0000000000..f123426e23 --- /dev/null +++ b/changelog.d/6391.feature @@ -0,0 +1 @@ +Synapse's cache factor can now be configured in `homeserver.yaml` by the `caches.global_factor` setting. Additionally, `caches.per_cache_factors` controls the cache factors for individual caches. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 5abeaf519b..8a8415b9a2 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -603,6 +603,45 @@ acme: +## Caching ## + +# Caching can be configured through the following options. +# +# A cache 'factor' is a multiplier that can be applied to each of +# Synapse's caches in order to increase or decrease the maximum +# number of entries that can be stored. + +# The number of events to cache in memory. Not affected by +# caches.global_factor. +# +#event_cache_size: 10K + +caches: + # Controls the global cache factor, which is the default cache factor + # for all caches if a specific factor for that cache is not otherwise + # set. + # + # This can also be set by the "SYNAPSE_CACHE_FACTOR" environment + # variable. Setting by environment variable takes priority over + # setting through the config file. + # + # Defaults to 0.5, which will half the size of all caches. + # + #global_factor: 1.0 + + # A dictionary of cache name to cache factor for that individual + # cache. Overrides the global cache factor for a given cache. + # + # These can also be set through environment variables comprised + # of "SYNAPSE_CACHE_FACTOR_" + the name of the cache in capital + # letters and underscores. Setting by environment variable + # takes priority over setting through the config file. + # Ex. SYNAPSE_CACHE_FACTOR_GET_USERS_WHO_SHARE_ROOM_WITH_USER=2.0 + # + per_cache_factors: + #get_users_who_share_room_with_user: 2.0 + + ## Database ## # The 'database' setting defines the database that synapse uses to store all of @@ -646,10 +685,6 @@ database: args: database: DATADIR/homeserver.db -# Number of events to cache in memory. -# -#event_cache_size: 10K - ## Logging ## diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 1ad5ff9410..e009b1a760 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -37,7 +37,7 @@ from synapse.api.errors import ( from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase from synapse.types import StateMap, UserID -from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache +from synapse.util.caches import register_cache from synapse.util.caches.lrucache import LruCache from synapse.util.metrics import Measure @@ -73,7 +73,7 @@ class Auth(object): self.store = hs.get_datastore() self.state = hs.get_state_handler() - self.token_cache = LruCache(CACHE_SIZE_FACTOR * 10000) + self.token_cache = LruCache(10000) register_cache("cache", "token_cache", self.token_cache) self._auth_blocking = AuthBlocking(self.hs) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index bc8695d8dd..d7f337e586 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -69,7 +69,6 @@ from synapse.server import HomeServer from synapse.storage import DataStore from synapse.storage.engines import IncorrectDatabaseSetup from synapse.storage.prepare_database import UpgradeDatabaseException -from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole from synapse.util.module_loader import load_module @@ -516,8 +515,8 @@ def phone_stats_home(hs, stats, stats_process=_stats_process): daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages() stats["daily_sent_messages"] = daily_sent_messages - stats["cache_factor"] = CACHE_SIZE_FACTOR - stats["event_cache_size"] = hs.config.event_cache_size + stats["cache_factor"] = hs.config.caches.global_factor + stats["event_cache_size"] = hs.config.caches.event_cache_size # # Performance statistics diff --git a/synapse/config/cache.py b/synapse/config/cache.py new file mode 100644 index 0000000000..91036a012e --- /dev/null +++ b/synapse/config/cache.py @@ -0,0 +1,164 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Callable, Dict + +from ._base import Config, ConfigError + +# The prefix for all cache factor-related environment variables +_CACHES = {} +_CACHE_PREFIX = "SYNAPSE_CACHE_FACTOR" +_DEFAULT_FACTOR_SIZE = 0.5 +_DEFAULT_EVENT_CACHE_SIZE = "10K" + + +class CacheProperties(object): + def __init__(self): + # The default factor size for all caches + self.default_factor_size = float( + os.environ.get(_CACHE_PREFIX, _DEFAULT_FACTOR_SIZE) + ) + self.resize_all_caches_func = None + + +properties = CacheProperties() + + +def add_resizable_cache(cache_name: str, cache_resize_callback: Callable): + """Register a cache that's size can dynamically change + + Args: + cache_name: A reference to the cache + cache_resize_callback: A callback function that will be ran whenever + the cache needs to be resized + """ + _CACHES[cache_name.lower()] = cache_resize_callback + + # Ensure all loaded caches are sized appropriately + # + # This method should only run once the config has been read, + # as it uses values read from it + if properties.resize_all_caches_func: + properties.resize_all_caches_func() + + +class CacheConfig(Config): + section = "caches" + _environ = os.environ + + @staticmethod + def reset(): + """Resets the caches to their defaults. Used for tests.""" + properties.default_factor_size = float( + os.environ.get(_CACHE_PREFIX, _DEFAULT_FACTOR_SIZE) + ) + properties.resize_all_caches_func = None + _CACHES.clear() + + def generate_config_section(self, **kwargs): + return """\ + ## Caching ## + + # Caching can be configured through the following options. + # + # A cache 'factor' is a multiplier that can be applied to each of + # Synapse's caches in order to increase or decrease the maximum + # number of entries that can be stored. + + # The number of events to cache in memory. Not affected by + # caches.global_factor. + # + #event_cache_size: 10K + + caches: + # Controls the global cache factor, which is the default cache factor + # for all caches if a specific factor for that cache is not otherwise + # set. + # + # This can also be set by the "SYNAPSE_CACHE_FACTOR" environment + # variable. Setting by environment variable takes priority over + # setting through the config file. + # + # Defaults to 0.5, which will half the size of all caches. + # + #global_factor: 1.0 + + # A dictionary of cache name to cache factor for that individual + # cache. Overrides the global cache factor for a given cache. + # + # These can also be set through environment variables comprised + # of "SYNAPSE_CACHE_FACTOR_" + the name of the cache in capital + # letters and underscores. Setting by environment variable + # takes priority over setting through the config file. + # Ex. SYNAPSE_CACHE_FACTOR_GET_USERS_WHO_SHARE_ROOM_WITH_USER=2.0 + # + per_cache_factors: + #get_users_who_share_room_with_user: 2.0 + """ + + def read_config(self, config, **kwargs): + self.event_cache_size = self.parse_size( + config.get("event_cache_size", _DEFAULT_EVENT_CACHE_SIZE) + ) + self.cache_factors = {} # type: Dict[str, float] + + cache_config = config.get("caches") or {} + self.global_factor = cache_config.get( + "global_factor", properties.default_factor_size + ) + if not isinstance(self.global_factor, (int, float)): + raise ConfigError("caches.global_factor must be a number.") + + # Set the global one so that it's reflected in new caches + properties.default_factor_size = self.global_factor + + # Load cache factors from the config + individual_factors = cache_config.get("per_cache_factors") or {} + if not isinstance(individual_factors, dict): + raise ConfigError("caches.per_cache_factors must be a dictionary") + + # Override factors from environment if necessary + individual_factors.update( + { + key[len(_CACHE_PREFIX) + 1 :].lower(): float(val) + for key, val in self._environ.items() + if key.startswith(_CACHE_PREFIX + "_") + } + ) + + for cache, factor in individual_factors.items(): + if not isinstance(factor, (int, float)): + raise ConfigError( + "caches.per_cache_factors.%s must be a number" % (cache.lower(),) + ) + self.cache_factors[cache.lower()] = factor + + # Resize all caches (if necessary) with the new factors we've loaded + self.resize_all_caches() + + # Store this function so that it can be called from other classes without + # needing an instance of Config + properties.resize_all_caches_func = self.resize_all_caches + + def resize_all_caches(self): + """Ensure all cache sizes are up to date + + For each cache, run the mapped callback function with either + a specific cache factor or the default, global one. + """ + for cache_name, callback in _CACHES.items(): + new_factor = self.cache_factors.get(cache_name, self.global_factor) + callback(new_factor) diff --git a/synapse/config/database.py b/synapse/config/database.py index 5b662d1b01..1064c2697b 100644 --- a/synapse/config/database.py +++ b/synapse/config/database.py @@ -68,10 +68,6 @@ database: name: sqlite3 args: database: %(database_path)s - -# Number of events to cache in memory. -# -#event_cache_size: 10K """ @@ -116,8 +112,6 @@ class DatabaseConfig(Config): self.databases = [] def read_config(self, config, **kwargs): - self.event_cache_size = self.parse_size(config.get("event_cache_size", "10K")) - # We *experimentally* support specifying multiple databases via the # `databases` key. This is a map from a label to database config in the # same format as the `database` config option, plus an extra diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index 996d3e6bf7..2c7b3a699f 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -17,6 +17,7 @@ from ._base import RootConfig from .api import ApiConfig from .appservice import AppServiceConfig +from .cache import CacheConfig from .captcha import CaptchaConfig from .cas import CasConfig from .consent_config import ConsentConfig @@ -55,6 +56,7 @@ class HomeServerConfig(RootConfig): config_classes = [ ServerConfig, TlsConfig, + CacheConfig, DatabaseConfig, LoggingConfig, RatelimitConfig, diff --git a/synapse/http/client.py b/synapse/http/client.py index 58eb47c69c..3cef747a4d 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -49,7 +49,6 @@ from synapse.http.proxyagent import ProxyAgent from synapse.logging.context import make_deferred_yieldable from synapse.logging.opentracing import set_tag, start_active_span, tags from synapse.util.async_helpers import timeout_deferred -from synapse.util.caches import CACHE_SIZE_FACTOR logger = logging.getLogger(__name__) @@ -241,7 +240,10 @@ class SimpleHttpClient(object): # tends to do so in batches, so we need to allow the pool to keep # lots of idle connections around. pool = HTTPConnectionPool(self.reactor) - pool.maxPersistentPerHost = max((100 * CACHE_SIZE_FACTOR, 5)) + # XXX: The justification for using the cache factor here is that larger instances + # will need both more cache and more connections. + # Still, this should probably be a separate dial + pool.maxPersistentPerHost = max((100 * hs.config.caches.global_factor, 5)) pool.cachedConnectionTimeout = 2 * 60 self.agent = ProxyAgent( diff --git a/synapse/metrics/_exposition.py b/synapse/metrics/_exposition.py index a248103191..ab7f948ed4 100644 --- a/synapse/metrics/_exposition.py +++ b/synapse/metrics/_exposition.py @@ -33,6 +33,8 @@ from prometheus_client import REGISTRY from twisted.web.resource import Resource +from synapse.util import caches + try: from prometheus_client.samples import Sample except ImportError: @@ -103,13 +105,15 @@ def nameify_sample(sample): def generate_latest(registry, emit_help=False): - output = [] - for metric in registry.collect(): + # Trigger the cache metrics to be rescraped, which updates the common + # metrics but do not produce metrics themselves + for collector in caches.collectors_by_name.values(): + collector.collect() - if metric.name.startswith("__unused"): - continue + output = [] + for metric in registry.collect(): if not metric.samples: # No samples, don't bother. continue diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 433ca2f416..e75d964ac8 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -51,6 +51,7 @@ push_rules_delta_state_cache_metric = register_cache( "cache", "push_rules_delta_state_cache_metric", cache=[], # Meaningless size, as this isn't a cache that stores values + resizable=False, ) @@ -67,7 +68,8 @@ class BulkPushRuleEvaluator(object): self.room_push_rule_cache_metrics = register_cache( "cache", "room_push_rule_cache", - cache=[], # Meaningless size, as this isn't a cache that stores values + cache=[], # Meaningless size, as this isn't a cache that stores values, + resizable=False, ) @defer.inlineCallbacks diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index 4cd702b5fa..11032491af 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -22,7 +22,7 @@ from six import string_types from synapse.events import EventBase from synapse.types import UserID -from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache +from synapse.util.caches import register_cache from synapse.util.caches.lrucache import LruCache logger = logging.getLogger(__name__) @@ -165,7 +165,7 @@ class PushRuleEvaluatorForEvent(object): # Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches -regex_cache = LruCache(50000 * CACHE_SIZE_FACTOR) +regex_cache = LruCache(50000) register_cache("cache", "regex_push_cache", regex_cache) diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py index fbf996e33a..1a38f53dfb 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py @@ -15,7 +15,6 @@ from synapse.storage.data_stores.main.client_ips import LAST_SEEN_GRANULARITY from synapse.storage.database import Database -from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.caches.descriptors import Cache from ._base import BaseSlavedStore @@ -26,7 +25,7 @@ class SlavedClientIpStore(BaseSlavedStore): super(SlavedClientIpStore, self).__init__(database, db_conn, hs) self.client_ip_last_seen = Cache( - name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR + name="client_ip_last_seen", keylen=4, max_entries=50000 ) def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id): diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 4afefc6b1d..2fa529fcd0 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -35,7 +35,6 @@ from synapse.state import v1, v2 from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour from synapse.types import StateMap from synapse.util.async_helpers import Linearizer -from synapse.util.caches import get_cache_factor_for from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.metrics import Measure, measure_func @@ -53,7 +52,6 @@ state_groups_histogram = Histogram( KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key")) -SIZE_OF_CACHE = 100000 * get_cache_factor_for("state_cache") EVICTION_TIMEOUT_SECONDS = 60 * 60 @@ -447,7 +445,7 @@ class StateResolutionHandler(object): self._state_cache = ExpiringCache( cache_name="state_cache", clock=self.clock, - max_len=SIZE_OF_CACHE, + max_len=100000, expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000, iterable=True, reset_expiry_on_get=True, diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py index 92bc06919b..71f8d43a76 100644 --- a/synapse/storage/data_stores/main/client_ips.py +++ b/synapse/storage/data_stores/main/client_ips.py @@ -22,7 +22,6 @@ from twisted.internet import defer from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore from synapse.storage.database import Database, make_tuple_comparison_clause -from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.caches.descriptors import Cache logger = logging.getLogger(__name__) @@ -361,7 +360,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): def __init__(self, database: Database, db_conn, hs): self.client_ip_last_seen = Cache( - name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR + name="client_ip_last_seen", keylen=4, max_entries=50000 ) super(ClientIpStore, self).__init__(database, db_conn, hs) diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index 73df6b33ba..b8c1bbdf99 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -75,7 +75,10 @@ class EventsWorkerStore(SQLBaseStore): super(EventsWorkerStore, self).__init__(database, db_conn, hs) self._get_event_cache = Cache( - "*getEvent*", keylen=3, max_entries=hs.config.event_cache_size + "*getEvent*", + keylen=3, + max_entries=hs.config.caches.event_cache_size, + apply_cache_factor_from_config=False, ) self._event_fetch_lock = threading.Condition() diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/data_stores/state/store.py index 57a5267663..f3ad1e4369 100644 --- a/synapse/storage/data_stores/state/store.py +++ b/synapse/storage/data_stores/state/store.py @@ -28,7 +28,6 @@ from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateSt from synapse.storage.database import Database from synapse.storage.state import StateFilter from synapse.types import StateMap -from synapse.util.caches import get_cache_factor_for from synapse.util.caches.descriptors import cached from synapse.util.caches.dictionary_cache import DictionaryCache @@ -90,11 +89,10 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): self._state_group_cache = DictionaryCache( "*stateGroupCache*", # TODO: this hasn't been tuned yet - 50000 * get_cache_factor_for("stateGroupCache"), + 50000, ) self._state_group_members_cache = DictionaryCache( - "*stateGroupMembersCache*", - 500000 * get_cache_factor_for("stateGroupMembersCache"), + "*stateGroupMembersCache*", 500000, ) @cached(max_entries=10000, iterable=True) diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index da5077b471..4b8a0c7a8f 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2019, 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,27 +15,17 @@ # limitations under the License. import logging -import os -from typing import Dict +from typing import Callable, Dict, Optional import six from six.moves import intern -from prometheus_client.core import REGISTRY, Gauge, GaugeMetricFamily - -logger = logging.getLogger(__name__) - -CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.5)) +import attr +from prometheus_client.core import Gauge +from synapse.config.cache import add_resizable_cache -def get_cache_factor_for(cache_name): - env_var = "SYNAPSE_CACHE_FACTOR_" + cache_name.upper() - factor = os.environ.get(env_var) - if factor: - return float(factor) - - return CACHE_SIZE_FACTOR - +logger = logging.getLogger(__name__) caches_by_name = {} collectors_by_name = {} # type: Dict @@ -44,6 +34,7 @@ cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"]) cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"]) cache_evicted = Gauge("synapse_util_caches_cache:evicted_size", "", ["name"]) cache_total = Gauge("synapse_util_caches_cache:total", "", ["name"]) +cache_max_size = Gauge("synapse_util_caches_cache_max_size", "", ["name"]) response_cache_size = Gauge("synapse_util_caches_response_cache:size", "", ["name"]) response_cache_hits = Gauge("synapse_util_caches_response_cache:hits", "", ["name"]) @@ -53,67 +44,82 @@ response_cache_evicted = Gauge( response_cache_total = Gauge("synapse_util_caches_response_cache:total", "", ["name"]) -def register_cache(cache_type, cache_name, cache, collect_callback=None): - """Register a cache object for metric collection. +@attr.s +class CacheMetric(object): + + _cache = attr.ib() + _cache_type = attr.ib(type=str) + _cache_name = attr.ib(type=str) + _collect_callback = attr.ib(type=Optional[Callable]) + + hits = attr.ib(default=0) + misses = attr.ib(default=0) + evicted_size = attr.ib(default=0) + + def inc_hits(self): + self.hits += 1 + + def inc_misses(self): + self.misses += 1 + + def inc_evictions(self, size=1): + self.evicted_size += size + + def describe(self): + return [] + + def collect(self): + try: + if self._cache_type == "response_cache": + response_cache_size.labels(self._cache_name).set(len(self._cache)) + response_cache_hits.labels(self._cache_name).set(self.hits) + response_cache_evicted.labels(self._cache_name).set(self.evicted_size) + response_cache_total.labels(self._cache_name).set( + self.hits + self.misses + ) + else: + cache_size.labels(self._cache_name).set(len(self._cache)) + cache_hits.labels(self._cache_name).set(self.hits) + cache_evicted.labels(self._cache_name).set(self.evicted_size) + cache_total.labels(self._cache_name).set(self.hits + self.misses) + if getattr(self._cache, "max_size", None): + cache_max_size.labels(self._cache_name).set(self._cache.max_size) + if self._collect_callback: + self._collect_callback() + except Exception as e: + logger.warning("Error calculating metrics for %s: %s", self._cache_name, e) + raise + + +def register_cache( + cache_type: str, + cache_name: str, + cache, + collect_callback: Optional[Callable] = None, + resizable: bool = True, + resize_callback: Optional[Callable] = None, +) -> CacheMetric: + """Register a cache object for metric collection and resizing. Args: - cache_type (str): - cache_name (str): name of the cache - cache (object): cache itself - collect_callback (callable|None): if not None, a function which is called during - metric collection to update additional metrics. + cache_type + cache_name: name of the cache + cache: cache itself + collect_callback: If given, a function which is called during metric + collection to update additional metrics. + resizable: Whether this cache supports being resized. + resize_callback: A function which can be called to resize the cache. Returns: CacheMetric: an object which provides inc_{hits,misses,evictions} methods """ + if resizable: + if not resize_callback: + resize_callback = getattr(cache, "set_cache_factor") + add_resizable_cache(cache_name, resize_callback) - # Check if the metric is already registered. Unregister it, if so. - # This usually happens during tests, as at runtime these caches are - # effectively singletons. + metric = CacheMetric(cache, cache_type, cache_name, collect_callback) metric_name = "cache_%s_%s" % (cache_type, cache_name) - if metric_name in collectors_by_name.keys(): - REGISTRY.unregister(collectors_by_name[metric_name]) - - class CacheMetric(object): - - hits = 0 - misses = 0 - evicted_size = 0 - - def inc_hits(self): - self.hits += 1 - - def inc_misses(self): - self.misses += 1 - - def inc_evictions(self, size=1): - self.evicted_size += size - - def describe(self): - return [] - - def collect(self): - try: - if cache_type == "response_cache": - response_cache_size.labels(cache_name).set(len(cache)) - response_cache_hits.labels(cache_name).set(self.hits) - response_cache_evicted.labels(cache_name).set(self.evicted_size) - response_cache_total.labels(cache_name).set(self.hits + self.misses) - else: - cache_size.labels(cache_name).set(len(cache)) - cache_hits.labels(cache_name).set(self.hits) - cache_evicted.labels(cache_name).set(self.evicted_size) - cache_total.labels(cache_name).set(self.hits + self.misses) - if collect_callback: - collect_callback() - except Exception as e: - logger.warning("Error calculating metrics for %s: %s", cache_name, e) - raise - - yield GaugeMetricFamily("__unused", "") - - metric = CacheMetric() - REGISTRY.register(metric) caches_by_name[cache_name] = cache collectors_by_name[metric_name] = metric return metric diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 2e8f6543e5..cd48262420 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import functools import inspect import logging @@ -30,7 +31,6 @@ from twisted.internet import defer from synapse.logging.context import make_deferred_yieldable, preserve_fn from synapse.util import unwrapFirstError from synapse.util.async_helpers import ObservableDeferred -from synapse.util.caches import get_cache_factor_for from synapse.util.caches.lrucache import LruCache from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry @@ -81,7 +81,6 @@ class CacheEntry(object): class Cache(object): __slots__ = ( "cache", - "max_entries", "name", "keylen", "thread", @@ -89,7 +88,29 @@ class Cache(object): "_pending_deferred_cache", ) - def __init__(self, name, max_entries=1000, keylen=1, tree=False, iterable=False): + def __init__( + self, + name: str, + max_entries: int = 1000, + keylen: int = 1, + tree: bool = False, + iterable: bool = False, + apply_cache_factor_from_config: bool = True, + ): + """ + Args: + name: The name of the cache + max_entries: Maximum amount of entries that the cache will hold + keylen: The length of the tuple used as the cache key + tree: Use a TreeCache instead of a dict as the underlying cache type + iterable: If True, count each item in the cached object as an entry, + rather than each cached object + apply_cache_factor_from_config: Whether cache factors specified in the + config file affect `max_entries` + + Returns: + Cache + """ cache_type = TreeCache if tree else dict self._pending_deferred_cache = cache_type() @@ -99,6 +120,7 @@ class Cache(object): cache_type=cache_type, size_callback=(lambda d: len(d)) if iterable else None, evicted_callback=self._on_evicted, + apply_cache_factor_from_config=apply_cache_factor_from_config, ) self.name = name @@ -111,6 +133,10 @@ class Cache(object): collect_callback=self._metrics_collection_callback, ) + @property + def max_entries(self): + return self.cache.max_size + def _on_evicted(self, evicted_count): self.metrics.inc_evictions(evicted_count) @@ -370,13 +396,11 @@ class CacheDescriptor(_CacheDescriptorBase): cache_context=cache_context, ) - max_entries = int(max_entries * get_cache_factor_for(orig.__name__)) - self.max_entries = max_entries self.tree = tree self.iterable = iterable - def __get__(self, obj, objtype=None): + def __get__(self, obj, owner): cache = Cache( name=self.orig.__name__, max_entries=self.max_entries, diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index cddf1ed515..2726b67b6d 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -18,6 +18,7 @@ from collections import OrderedDict from six import iteritems, itervalues +from synapse.config import cache as cache_config from synapse.metrics.background_process_metrics import run_as_background_process from synapse.util.caches import register_cache @@ -51,15 +52,16 @@ class ExpiringCache(object): an item on access. Defaults to False. iterable (bool): If true, the size is calculated by summing the sizes of all entries, rather than the number of entries. - """ self._cache_name = cache_name + self._original_max_size = max_len + + self._max_size = int(max_len * cache_config.properties.default_factor_size) + self._clock = clock - self._max_len = max_len self._expiry_ms = expiry_ms - self._reset_expiry_on_get = reset_expiry_on_get self._cache = OrderedDict() @@ -82,9 +84,11 @@ class ExpiringCache(object): def __setitem__(self, key, value): now = self._clock.time_msec() self._cache[key] = _CacheEntry(now, value) + self.evict() + def evict(self): # Evict if there are now too many items - while self._max_len and len(self) > self._max_len: + while self._max_size and len(self) > self._max_size: _key, value = self._cache.popitem(last=False) if self.iterable: self.metrics.inc_evictions(len(value.value)) @@ -170,6 +174,23 @@ class ExpiringCache(object): else: return len(self._cache) + def set_cache_factor(self, factor: float) -> bool: + """ + Set the cache factor for this individual cache. + + This will trigger a resize if it changes, which may require evicting + items from the cache. + + Returns: + bool: Whether the cache changed size or not. + """ + new_size = int(self._original_max_size * factor) + if new_size != self._max_size: + self._max_size = new_size + self.evict() + return True + return False + class _CacheEntry(object): __slots__ = ["time", "value"] diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 1536cb64f3..29fabac3cd 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -13,10 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. - import threading from functools import wraps +from typing import Callable, Optional, Type, Union +from synapse.config import cache as cache_config from synapse.util.caches.treecache import TreeCache @@ -52,17 +53,18 @@ class LruCache(object): def __init__( self, - max_size, - keylen=1, - cache_type=dict, - size_callback=None, - evicted_callback=None, + max_size: int, + keylen: int = 1, + cache_type: Type[Union[dict, TreeCache]] = dict, + size_callback: Optional[Callable] = None, + evicted_callback: Optional[Callable] = None, + apply_cache_factor_from_config: bool = True, ): """ Args: - max_size (int): + max_size: The maximum amount of entries the cache can hold - keylen (int): + keylen: The length of the tuple used as the cache key cache_type (type): type of underlying cache to be used. Typically one of dict @@ -73,9 +75,23 @@ class LruCache(object): evicted_callback (func(int)|None): if not None, called on eviction with the size of the evicted entry + + apply_cache_factor_from_config (bool): If true, `max_size` will be + multiplied by a cache factor derived from the homeserver config """ cache = cache_type() self.cache = cache # Used for introspection. + + # Save the original max size, and apply the default size factor. + self._original_max_size = max_size + # We previously didn't apply the cache factor here, and as such some caches were + # not affected by the global cache factor. Add an option here to disable applying + # the cache factor when a cache is created + if apply_cache_factor_from_config: + self.max_size = int(max_size * cache_config.properties.default_factor_size) + else: + self.max_size = int(max_size) + list_root = _Node(None, None, None, None) list_root.next_node = list_root list_root.prev_node = list_root @@ -83,7 +99,7 @@ class LruCache(object): lock = threading.Lock() def evict(): - while cache_len() > max_size: + while cache_len() > self.max_size: todelete = list_root.prev_node evicted_len = delete_node(todelete) cache.pop(todelete.key, None) @@ -236,6 +252,7 @@ class LruCache(object): return key in cache self.sentinel = object() + self._on_resize = evict self.get = cache_get self.set = cache_set self.setdefault = cache_set_default @@ -266,3 +283,20 @@ class LruCache(object): def __contains__(self, key): return self.contains(key) + + def set_cache_factor(self, factor: float) -> bool: + """ + Set the cache factor for this individual cache. + + This will trigger a resize if it changes, which may require evicting + items from the cache. + + Returns: + bool: Whether the cache changed size or not. + """ + new_size = int(self._original_max_size * factor) + if new_size != self.max_size: + self.max_size = new_size + self._on_resize() + return True + return False diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py index b68f9fe0d4..a6c60888e5 100644 --- a/synapse/util/caches/response_cache.py +++ b/synapse/util/caches/response_cache.py @@ -38,7 +38,7 @@ class ResponseCache(object): self.timeout_sec = timeout_ms / 1000.0 self._name = name - self._metrics = register_cache("response_cache", name, self) + self._metrics = register_cache("response_cache", name, self, resizable=False) def size(self): return len(self.pending_result_cache) diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index e54f80d76e..2a161bf244 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +import math from typing import Dict, FrozenSet, List, Mapping, Optional, Set, Union from six import integer_types @@ -46,7 +47,8 @@ class StreamChangeCache: max_size=10000, prefilled_cache: Optional[Mapping[EntityType, int]] = None, ): - self._max_size = int(max_size * caches.CACHE_SIZE_FACTOR) + self._original_max_size = max_size + self._max_size = math.floor(max_size) self._entity_to_key = {} # type: Dict[EntityType, int] # map from stream id to the a set of entities which changed at that stream id. @@ -58,12 +60,31 @@ class StreamChangeCache: # self._earliest_known_stream_pos = current_stream_pos self.name = name - self.metrics = caches.register_cache("cache", self.name, self._cache) + self.metrics = caches.register_cache( + "cache", self.name, self._cache, resize_callback=self.set_cache_factor + ) if prefilled_cache: for entity, stream_pos in prefilled_cache.items(): self.entity_has_changed(entity, stream_pos) + def set_cache_factor(self, factor: float) -> bool: + """ + Set the cache factor for this individual cache. + + This will trigger a resize if it changes, which may require evicting + items from the cache. + + Returns: + bool: Whether the cache changed size or not. + """ + new_size = math.floor(self._original_max_size * factor) + if new_size != self._max_size: + self.max_size = new_size + self._evict() + return True + return False + def has_entity_changed(self, entity: EntityType, stream_pos: int) -> bool: """Returns True if the entity may have been updated since stream_pos """ @@ -171,6 +192,7 @@ class StreamChangeCache: e1 = self._cache[stream_pos] = set() e1.add(entity) self._entity_to_key[entity] = stream_pos + self._evict() # if the cache is too big, remove entries while len(self._cache) > self._max_size: @@ -179,6 +201,13 @@ class StreamChangeCache: for entity in r: del self._entity_to_key[entity] + def _evict(self): + while len(self._cache) > self._max_size: + k, r = self._cache.popitem(0) + self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos) + for entity in r: + self._entity_to_key.pop(entity, None) + def get_max_pos_of_last_change(self, entity: EntityType) -> int: """Returns an upper bound of the stream id of the last change to an diff --git a/synapse/util/caches/ttlcache.py b/synapse/util/caches/ttlcache.py index 99646c7cf0..6437aa907e 100644 --- a/synapse/util/caches/ttlcache.py +++ b/synapse/util/caches/ttlcache.py @@ -38,7 +38,7 @@ class TTLCache(object): self._timer = timer - self._metrics = register_cache("ttl", cache_name, self) + self._metrics = register_cache("ttl", cache_name, self, resizable=False) def set(self, key, value, ttl): """Add/update an entry in the cache diff --git a/tests/config/test_cache.py b/tests/config/test_cache.py new file mode 100644 index 0000000000..2920279125 --- /dev/null +++ b/tests/config/test_cache.py @@ -0,0 +1,127 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.config._base import Config, RootConfig +from synapse.config.cache import CacheConfig, add_resizable_cache +from synapse.util.caches.lrucache import LruCache + +from tests.unittest import TestCase + + +class FakeServer(Config): + section = "server" + + +class TestConfig(RootConfig): + config_classes = [FakeServer, CacheConfig] + + +class CacheConfigTests(TestCase): + def setUp(self): + # Reset caches before each test + TestConfig().caches.reset() + + def test_individual_caches_from_environ(self): + """ + Individual cache factors will be loaded from the environment. + """ + config = {} + t = TestConfig() + t.caches._environ = { + "SYNAPSE_CACHE_FACTOR_SOMETHING_OR_OTHER": "2", + "SYNAPSE_NOT_CACHE": "BLAH", + } + t.read_config(config, config_dir_path="", data_dir_path="") + + self.assertEqual(dict(t.caches.cache_factors), {"something_or_other": 2.0}) + + def test_config_overrides_environ(self): + """ + Individual cache factors defined in the environment will take precedence + over those in the config. + """ + config = {"caches": {"per_cache_factors": {"foo": 2, "bar": 3}}} + t = TestConfig() + t.caches._environ = { + "SYNAPSE_CACHE_FACTOR_SOMETHING_OR_OTHER": "2", + "SYNAPSE_CACHE_FACTOR_FOO": 1, + } + t.read_config(config, config_dir_path="", data_dir_path="") + + self.assertEqual( + dict(t.caches.cache_factors), + {"foo": 1.0, "bar": 3.0, "something_or_other": 2.0}, + ) + + def test_individual_instantiated_before_config_load(self): + """ + If a cache is instantiated before the config is read, it will be given + the default cache size in the interim, and then resized once the config + is loaded. + """ + cache = LruCache(100) + + add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) + self.assertEqual(cache.max_size, 50) + + config = {"caches": {"per_cache_factors": {"foo": 3}}} + t = TestConfig() + t.read_config(config, config_dir_path="", data_dir_path="") + + self.assertEqual(cache.max_size, 300) + + def test_individual_instantiated_after_config_load(self): + """ + If a cache is instantiated after the config is read, it will be + immediately resized to the correct size given the per_cache_factor if + there is one. + """ + config = {"caches": {"per_cache_factors": {"foo": 2}}} + t = TestConfig() + t.read_config(config, config_dir_path="", data_dir_path="") + + cache = LruCache(100) + add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) + self.assertEqual(cache.max_size, 200) + + def test_global_instantiated_before_config_load(self): + """ + If a cache is instantiated before the config is read, it will be given + the default cache size in the interim, and then resized to the new + default cache size once the config is loaded. + """ + cache = LruCache(100) + add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) + self.assertEqual(cache.max_size, 50) + + config = {"caches": {"global_factor": 4}} + t = TestConfig() + t.read_config(config, config_dir_path="", data_dir_path="") + + self.assertEqual(cache.max_size, 400) + + def test_global_instantiated_after_config_load(self): + """ + If a cache is instantiated after the config is read, it will be + immediately resized to the correct size given the global factor if there + is no per-cache factor. + """ + config = {"caches": {"global_factor": 1.5}} + t = TestConfig() + t.read_config(config, config_dir_path="", data_dir_path="") + + cache = LruCache(100) + add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) + self.assertEqual(cache.max_size, 150) diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index e37260a820..5a50e4fdd4 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -25,8 +25,8 @@ from synapse.util.caches.descriptors import Cache, cached from tests import unittest -class CacheTestCase(unittest.TestCase): - def setUp(self): +class CacheTestCase(unittest.HomeserverTestCase): + def prepare(self, reactor, clock, homeserver): self.cache = Cache("test") def test_empty(self): @@ -96,7 +96,7 @@ class CacheTestCase(unittest.TestCase): cache.get(3) -class CacheDecoratorTestCase(unittest.TestCase): +class CacheDecoratorTestCase(unittest.HomeserverTestCase): @defer.inlineCallbacks def test_passthrough(self): class A(object): @@ -239,7 +239,7 @@ class CacheDecoratorTestCase(unittest.TestCase): callcount2 = [0] class A(object): - @cached(max_entries=4) # HACK: This makes it 2 due to cache factor + @cached(max_entries=2) def func(self, key): callcount[0] += 1 return key diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 31710949a8..ef296e7dab 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -43,7 +43,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): ) hs.config.app_service_config_files = self.as_yaml_files - hs.config.event_cache_size = 1 + hs.config.caches.event_cache_size = 1 hs.config.password_providers = [] self.as_token = "token1" @@ -110,7 +110,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): ) hs.config.app_service_config_files = self.as_yaml_files - hs.config.event_cache_size = 1 + hs.config.caches.event_cache_size = 1 hs.config.password_providers = [] self.as_list = [ @@ -422,7 +422,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): ) hs.config.app_service_config_files = [f1, f2] - hs.config.event_cache_size = 1 + hs.config.caches.event_cache_size = 1 hs.config.password_providers = [] database = hs.get_datastores().databases[0] @@ -440,7 +440,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): ) hs.config.app_service_config_files = [f1, f2] - hs.config.event_cache_size = 1 + hs.config.caches.event_cache_size = 1 hs.config.password_providers = [] with self.assertRaises(ConfigError) as cm: @@ -464,7 +464,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): ) hs.config.app_service_config_files = [f1, f2] - hs.config.event_cache_size = 1 + hs.config.caches.event_cache_size = 1 hs.config.password_providers = [] with self.assertRaises(ConfigError) as cm: diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index cdee0a9e60..278961c331 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -51,7 +51,8 @@ class SQLBaseStoreTestCase(unittest.TestCase): config = Mock() config._disable_native_upserts = True - config.event_cache_size = 1 + config.caches = Mock() + config.caches.event_cache_size = 1 hs = TestHomeServer("test", config=config) sqlite_config = {"name": "sqlite3"} diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 270f853d60..f5f63d8ed6 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -15,6 +15,7 @@ # limitations under the License. from synapse.metrics import REGISTRY, InFlightGauge, generate_latest +from synapse.util.caches.descriptors import Cache from tests import unittest @@ -129,3 +130,36 @@ class BuildInfoTests(unittest.TestCase): self.assertTrue(b"osversion=" in items[0]) self.assertTrue(b"pythonversion=" in items[0]) self.assertTrue(b"version=" in items[0]) + + +class CacheMetricsTests(unittest.HomeserverTestCase): + def test_cache_metric(self): + """ + Caches produce metrics reflecting their state when scraped. + """ + CACHE_NAME = "cache_metrics_test_fgjkbdfg" + cache = Cache(CACHE_NAME, max_entries=777) + + items = { + x.split(b"{")[0].decode("ascii"): x.split(b" ")[1].decode("ascii") + for x in filter( + lambda x: b"cache_metrics_test_fgjkbdfg" in x, + generate_latest(REGISTRY).split(b"\n"), + ) + } + + self.assertEqual(items["synapse_util_caches_cache_size"], "0.0") + self.assertEqual(items["synapse_util_caches_cache_max_size"], "777.0") + + cache.prefill("1", "hi") + + items = { + x.split(b"{")[0].decode("ascii"): x.split(b" ")[1].decode("ascii") + for x in filter( + lambda x: b"cache_metrics_test_fgjkbdfg" in x, + generate_latest(REGISTRY).split(b"\n"), + ) + } + + self.assertEqual(items["synapse_util_caches_cache_size"], "1.0") + self.assertEqual(items["synapse_util_caches_cache_max_size"], "777.0") diff --git a/tests/util/test_expiring_cache.py b/tests/util/test_expiring_cache.py index 50bc7702d2..49ffeebd0e 100644 --- a/tests/util/test_expiring_cache.py +++ b/tests/util/test_expiring_cache.py @@ -21,7 +21,7 @@ from tests.utils import MockClock from .. import unittest -class ExpiringCacheTestCase(unittest.TestCase): +class ExpiringCacheTestCase(unittest.HomeserverTestCase): def test_get_set(self): clock = MockClock() cache = ExpiringCache("test", clock, max_len=1) diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py index 786947375d..0adb2174af 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py @@ -22,7 +22,7 @@ from synapse.util.caches.treecache import TreeCache from .. import unittest -class LruCacheTestCase(unittest.TestCase): +class LruCacheTestCase(unittest.HomeserverTestCase): def test_get_set(self): cache = LruCache(1) cache["key"] = "value" @@ -84,7 +84,7 @@ class LruCacheTestCase(unittest.TestCase): self.assertEquals(len(cache), 0) -class LruCacheCallbacksTestCase(unittest.TestCase): +class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): def test_get(self): m = Mock() cache = LruCache(1) @@ -233,7 +233,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase): self.assertEquals(m3.call_count, 1) -class LruCacheSizedTestCase(unittest.TestCase): +class LruCacheSizedTestCase(unittest.HomeserverTestCase): def test_evict(self): cache = LruCache(5, size_callback=len) cache["key1"] = [0] diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py index 6857933540..13b753e367 100644 --- a/tests/util/test_stream_change_cache.py +++ b/tests/util/test_stream_change_cache.py @@ -1,11 +1,9 @@ -from mock import patch - from synapse.util.caches.stream_change_cache import StreamChangeCache from tests import unittest -class StreamChangeCacheTests(unittest.TestCase): +class StreamChangeCacheTests(unittest.HomeserverTestCase): """ Tests for StreamChangeCache. """ @@ -54,7 +52,6 @@ class StreamChangeCacheTests(unittest.TestCase): self.assertTrue(cache.has_entity_changed("user@foo.com", 0)) self.assertTrue(cache.has_entity_changed("not@here.website", 0)) - @patch("synapse.util.caches.CACHE_SIZE_FACTOR", 1.0) def test_entity_has_changed_pops_off_start(self): """ StreamChangeCache.entity_has_changed will respect the max size and diff --git a/tests/utils.py b/tests/utils.py index f9be62b499..59c020a051 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -167,6 +167,7 @@ def default_config(name, parse=False): # disable user directory updates, because they get done in the # background, which upsets the test runner. "update_user_directory": False, + "caches": {"global_factor": 1}, } if parse: -- cgit 1.5.1 From 1a1da60ad2c9172fe487cd38a164b39df60f4cb5 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 12 May 2020 11:20:48 +0100 Subject: Fix new flake8 errors (#7470) --- changelog.d/7470.misc | 1 + synapse/app/_base.py | 5 +++-- synapse/config/server.py | 2 +- synapse/notifier.py | 10 ++++++---- synapse/push/mailer.py | 7 +++++-- synapse/storage/database.py | 4 ++-- tests/config/test_load.py | 2 +- 7 files changed, 19 insertions(+), 12 deletions(-) create mode 100644 changelog.d/7470.misc (limited to 'synapse/storage') diff --git a/changelog.d/7470.misc b/changelog.d/7470.misc new file mode 100644 index 0000000000..45e66ecf48 --- /dev/null +++ b/changelog.d/7470.misc @@ -0,0 +1 @@ +Fix linting errors in new version of Flake8. diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 628292b890..dedff81af3 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -22,6 +22,7 @@ import sys import traceback from daemonize import Daemonize +from typing_extensions import NoReturn from twisted.internet import defer, error, reactor from twisted.protocols.tls import TLSMemoryBIOFactory @@ -139,9 +140,9 @@ def start_reactor( run() -def quit_with_error(error_string): +def quit_with_error(error_string: str) -> NoReturn: message_lines = error_string.split("\n") - line_length = max(len(l) for l in message_lines if len(l) < 80) + 2 + line_length = max(len(line) for line in message_lines if len(line) < 80) + 2 sys.stderr.write("*" * line_length + "\n") for line in message_lines: sys.stderr.write(" %s\n" % (line.rstrip(),)) diff --git a/synapse/config/server.py b/synapse/config/server.py index 6d88231843..ed28da3deb 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -522,7 +522,7 @@ class ServerConfig(Config): ) def has_tls_listener(self) -> bool: - return any(l["tls"] for l in self.listeners) + return any(listener["tls"] for listener in self.listeners) def generate_config_section( self, server_name, data_dir_path, open_private_ports, listeners, **kwargs diff --git a/synapse/notifier.py b/synapse/notifier.py index 71d9ed62b0..87c120a59c 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -15,7 +15,7 @@ import logging from collections import namedtuple -from typing import Callable, List +from typing import Callable, Iterable, List, TypeVar from prometheus_client import Counter @@ -42,12 +42,14 @@ users_woken_by_stream_counter = Counter( "synapse_notifier_users_woken_by_stream", "", ["stream"] ) +T = TypeVar("T") + # TODO(paul): Should be shared somewhere -def count(func, l): - """Return the number of items in l for which func returns true.""" +def count(func: Callable[[T], bool], it: Iterable[T]) -> int: + """Return the number of items in it for which func returns true.""" n = 0 - for x in l: + for x in it: if func(x): n += 1 return n diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 73580c1c6c..ab33abbeed 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -19,6 +19,7 @@ import logging import time from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText +from typing import Iterable, List, TypeVar from six.moves import urllib @@ -41,6 +42,8 @@ from synapse.visibility import filter_events_for_client logger = logging.getLogger(__name__) +T = TypeVar("T") + MESSAGE_FROM_PERSON_IN_ROOM = ( "You have a message on %(app)s from %(person)s in the %(room)s room..." @@ -638,10 +641,10 @@ def safe_text(raw_text): ) -def deduped_ordered_list(l): +def deduped_ordered_list(it: Iterable[T]) -> List[T]: seen = set() ret = [] - for item in l: + for item in it: if item not in seen: seen.add(item) ret.append(item) diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 2b635d6ca0..c3d0863429 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -214,9 +214,9 @@ class LoggingTransaction: def executemany(self, sql: str, *args: Any): self._do_execute(self.txn.executemany, sql, *args) - def _make_sql_one_line(self, sql): + def _make_sql_one_line(self, sql: str) -> str: "Strip newlines out of SQL so that the loggers in the DB are on one line" - return " ".join(l.strip() for l in sql.splitlines() if l.strip()) + return " ".join(line.strip() for line in sql.splitlines() if line.strip()) def _do_execute(self, func, sql, *args): sql = self._make_sql_one_line(sql) diff --git a/tests/config/test_load.py b/tests/config/test_load.py index b3e557bd6a..734a9983e8 100644 --- a/tests/config/test_load.py +++ b/tests/config/test_load.py @@ -122,7 +122,7 @@ class ConfigLoadingTestCase(unittest.TestCase): with open(self.file, "r") as f: contents = f.readlines() - contents = [l for l in contents if needle not in l] + contents = [line for line in contents if needle not in line] with open(self.file, "w") as f: f.write("".join(contents)) -- cgit 1.5.1 From 782e4e64df58dd9e114e7e1e467ae55219336b79 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 13 May 2020 13:38:22 +0100 Subject: Shuffle persist event data store functions. (#7440) The aim here is to get to a stage where we have a `PersistEventStore` that holds all the write methods used during event persistence, so that we can take that class out of the `DataStore` mixin and instansiate it separately. This will allow us to instansiate it on processes other than master, while also ensuring it is only available on processes that are configured to write to events stream. This is a bit of an architectural change, where we end up with multiple classes per data store (rather than one per data store we have now). We end up having: 1. Storage classes that provide high level APIs that can talk to multiple data stores. 2. Data store modules that consist of classes that must point at the same database instance. 3. Classes in a data store that can be instantiated on processes depending on config. --- changelog.d/7440.misc | 1 + synapse/storage/data_stores/__init__.py | 9 + synapse/storage/data_stores/main/__init__.py | 8 +- synapse/storage/data_stores/main/censor_events.py | 213 ++++ .../storage/data_stores/main/event_federation.py | 83 -- .../storage/data_stores/main/event_push_actions.py | 74 -- synapse/storage/data_stores/main/events.py | 1216 ++++++++------------ synapse/storage/data_stores/main/events_worker.py | 152 ++- synapse/storage/data_stores/main/metrics.py | 128 +++ synapse/storage/data_stores/main/purge_events.py | 399 +++++++ synapse/storage/data_stores/main/rejections.py | 11 - synapse/storage/data_stores/main/relations.py | 60 +- synapse/storage/data_stores/main/room.py | 49 - synapse/storage/data_stores/main/roommember.py | 90 -- synapse/storage/data_stores/main/search.py | 23 - synapse/storage/data_stores/main/signatures.py | 30 - synapse/storage/data_stores/main/state.py | 35 - synapse/storage/persist_events.py | 27 +- tests/storage/test_event_push_actions.py | 3 +- 19 files changed, 1376 insertions(+), 1235 deletions(-) create mode 100644 changelog.d/7440.misc create mode 100644 synapse/storage/data_stores/main/censor_events.py create mode 100644 synapse/storage/data_stores/main/metrics.py create mode 100644 synapse/storage/data_stores/main/purge_events.py (limited to 'synapse/storage') diff --git a/changelog.d/7440.misc b/changelog.d/7440.misc new file mode 100644 index 0000000000..2505cdc66f --- /dev/null +++ b/changelog.d/7440.misc @@ -0,0 +1 @@ +Refactor event persistence database functions in preparation for allowing them to be run on non-master processes. diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/data_stores/__init__.py index e1d03429ca..7c14ba91c2 100644 --- a/synapse/storage/data_stores/__init__.py +++ b/synapse/storage/data_stores/__init__.py @@ -15,6 +15,7 @@ import logging +from synapse.storage.data_stores.main.events import PersistEventsStore from synapse.storage.data_stores.state import StateGroupDataStore from synapse.storage.database import Database, make_conn from synapse.storage.engines import create_engine @@ -39,6 +40,7 @@ class DataStores(object): self.databases = [] self.main = None self.state = None + self.persist_events = None for database_config in hs.config.database.databases: db_name = database_config.name @@ -64,6 +66,13 @@ class DataStores(object): self.main = main_store_class(database, db_conn, hs) + # If we're on a process that can persist events (currently + # master), also instansiate a `PersistEventsStore` + if hs.config.worker.worker_app is None: + self.persist_events = PersistEventsStore( + hs, database, self.main + ) + if "state" in database_config.data_stores: logger.info("Starting 'state' data store") diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py index cd2a1f0461..5df9dce79d 100644 --- a/synapse/storage/data_stores/main/__init__.py +++ b/synapse/storage/data_stores/main/__init__.py @@ -34,6 +34,7 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache from .account_data import AccountDataStore from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore from .cache import CacheInvalidationWorkerStore +from .censor_events import CensorEventsStore from .client_ips import ClientIpStore from .deviceinbox import DeviceInboxStore from .devices import DeviceStore @@ -42,16 +43,17 @@ from .e2e_room_keys import EndToEndRoomKeyStore from .end_to_end_keys import EndToEndKeyStore from .event_federation import EventFederationStore from .event_push_actions import EventPushActionsStore -from .events import EventsStore from .events_bg_updates import EventsBackgroundUpdatesStore from .filtering import FilteringStore from .group_server import GroupServerStore from .keys import KeyStore from .media_repository import MediaRepositoryStore +from .metrics import ServerMetricsStore from .monthly_active_users import MonthlyActiveUsersStore from .openid import OpenIdStore from .presence import PresenceStore, UserPresenceState from .profile import ProfileStore +from .purge_events import PurgeEventsStore from .push_rule import PushRuleStore from .pusher import PusherStore from .receipts import ReceiptsStore @@ -88,7 +90,7 @@ class DataStore( StateStore, SignatureStore, ApplicationServiceStore, - EventsStore, + PurgeEventsStore, EventFederationStore, MediaRepositoryStore, RejectionsStore, @@ -113,8 +115,10 @@ class DataStore( MonthlyActiveUsersStore, StatsStore, RelationsStore, + CensorEventsStore, UIAuthStore, CacheInvalidationWorkerStore, + ServerMetricsStore, ): def __init__(self, database: Database, db_conn, hs): self.hs = hs diff --git a/synapse/storage/data_stores/main/censor_events.py b/synapse/storage/data_stores/main/censor_events.py new file mode 100644 index 0000000000..49683b4d5a --- /dev/null +++ b/synapse/storage/data_stores/main/censor_events.py @@ -0,0 +1,213 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import TYPE_CHECKING + +from twisted.internet import defer + +from synapse.events.utils import prune_event_dict +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage._base import SQLBaseStore +from synapse.storage.data_stores.main.cache import CacheInvalidationWorkerStore +from synapse.storage.data_stores.main.events import encode_json +from synapse.storage.data_stores.main.events_worker import EventsWorkerStore +from synapse.storage.database import Database + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +logger = logging.getLogger(__name__) + + +class CensorEventsStore(CacheInvalidationWorkerStore, EventsWorkerStore, SQLBaseStore): + def __init__(self, database: Database, db_conn, hs: "HomeServer"): + super().__init__(database, db_conn, hs) + + # This should only exist on master for now + assert ( + hs.config.worker.worker_app is None + ), "Can only instantiate CensorEventsStore on master" + + def _censor_redactions(): + return run_as_background_process( + "_censor_redactions", self._censor_redactions + ) + + if self.hs.config.redaction_retention_period is not None: + hs.get_clock().looping_call(_censor_redactions, 5 * 60 * 1000) + + async def _censor_redactions(self): + """Censors all redactions older than the configured period that haven't + been censored yet. + + By censor we mean update the event_json table with the redacted event. + """ + + if self.hs.config.redaction_retention_period is None: + return + + if not ( + await self.db.updates.has_completed_background_update( + "redactions_have_censored_ts_idx" + ) + ): + # We don't want to run this until the appropriate index has been + # created. + return + + before_ts = self._clock.time_msec() - self.hs.config.redaction_retention_period + + # We fetch all redactions that: + # 1. point to an event we have, + # 2. has a received_ts from before the cut off, and + # 3. we haven't yet censored. + # + # This is limited to 100 events to ensure that we don't try and do too + # much at once. We'll get called again so this should eventually catch + # up. + sql = """ + SELECT redactions.event_id, redacts FROM redactions + LEFT JOIN events AS original_event ON ( + redacts = original_event.event_id + ) + WHERE NOT have_censored + AND redactions.received_ts <= ? + ORDER BY redactions.received_ts ASC + LIMIT ? + """ + + rows = await self.db.execute( + "_censor_redactions_fetch", None, sql, before_ts, 100 + ) + + updates = [] + + for redaction_id, event_id in rows: + redaction_event = await self.get_event(redaction_id, allow_none=True) + original_event = await self.get_event( + event_id, allow_rejected=True, allow_none=True + ) + + # The SQL above ensures that we have both the redaction and + # original event, so if the `get_event` calls return None it + # means that the redaction wasn't allowed. Either way we know that + # the result won't change so we mark the fact that we've checked. + if ( + redaction_event + and original_event + and original_event.internal_metadata.is_redacted() + ): + # Redaction was allowed + pruned_json = encode_json( + prune_event_dict( + original_event.room_version, original_event.get_dict() + ) + ) + else: + # Redaction wasn't allowed + pruned_json = None + + updates.append((redaction_id, event_id, pruned_json)) + + def _update_censor_txn(txn): + for redaction_id, event_id, pruned_json in updates: + if pruned_json: + self._censor_event_txn(txn, event_id, pruned_json) + + self.db.simple_update_one_txn( + txn, + table="redactions", + keyvalues={"event_id": redaction_id}, + updatevalues={"have_censored": True}, + ) + + await self.db.runInteraction("_update_censor_txn", _update_censor_txn) + + def _censor_event_txn(self, txn, event_id, pruned_json): + """Censor an event by replacing its JSON in the event_json table with the + provided pruned JSON. + + Args: + txn (LoggingTransaction): The database transaction. + event_id (str): The ID of the event to censor. + pruned_json (str): The pruned JSON + """ + self.db.simple_update_one_txn( + txn, + table="event_json", + keyvalues={"event_id": event_id}, + updatevalues={"json": pruned_json}, + ) + + @defer.inlineCallbacks + def expire_event(self, event_id): + """Retrieve and expire an event that has expired, and delete its associated + expiry timestamp. If the event can't be retrieved, delete its associated + timestamp so we don't try to expire it again in the future. + + Args: + event_id (str): The ID of the event to delete. + """ + # Try to retrieve the event's content from the database or the event cache. + event = yield self.get_event(event_id) + + def delete_expired_event_txn(txn): + # Delete the expiry timestamp associated with this event from the database. + self._delete_event_expiry_txn(txn, event_id) + + if not event: + # If we can't find the event, log a warning and delete the expiry date + # from the database so that we don't try to expire it again in the + # future. + logger.warning( + "Can't expire event %s because we don't have it.", event_id + ) + return + + # Prune the event's dict then convert it to JSON. + pruned_json = encode_json( + prune_event_dict(event.room_version, event.get_dict()) + ) + + # Update the event_json table to replace the event's JSON with the pruned + # JSON. + self._censor_event_txn(txn, event.event_id, pruned_json) + + # We need to invalidate the event cache entry for this event because we + # changed its content in the database. We can't call + # self._invalidate_cache_and_stream because self.get_event_cache isn't of the + # right type. + txn.call_after(self._get_event_cache.invalidate, (event.event_id,)) + # Send that invalidation to replication so that other workers also invalidate + # the event cache. + self._send_invalidation_to_replication( + txn, "_get_event_cache", (event.event_id,) + ) + + yield self.db.runInteraction("delete_expired_event", delete_expired_event_txn) + + def _delete_event_expiry_txn(self, txn, event_id): + """Delete the expiry timestamp associated with an event ID without deleting the + actual event. + + Args: + txn (LoggingTransaction): The transaction to use to perform the deletion. + event_id (str): The event ID to delete the associated expiry timestamp of. + """ + return self.db.simple_delete_txn( + txn=txn, table="event_expiry", keyvalues={"event_id": event_id} + ) diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py index b99439cc37..24ce8c4330 100644 --- a/synapse/storage/data_stores/main/event_federation.py +++ b/synapse/storage/data_stores/main/event_federation.py @@ -640,89 +640,6 @@ class EventFederationStore(EventFederationWorkerStore): self._delete_old_forward_extrem_cache, 60 * 60 * 1000 ) - def _update_min_depth_for_room_txn(self, txn, room_id, depth): - min_depth = self._get_min_depth_interaction(txn, room_id) - - if min_depth is not None and depth >= min_depth: - return - - self.db.simple_upsert_txn( - txn, - table="room_depth", - keyvalues={"room_id": room_id}, - values={"min_depth": depth}, - ) - - def _handle_mult_prev_events(self, txn, events): - """ - For the given event, update the event edges table and forward and - backward extremities tables. - """ - self.db.simple_insert_many_txn( - txn, - table="event_edges", - values=[ - { - "event_id": ev.event_id, - "prev_event_id": e_id, - "room_id": ev.room_id, - "is_state": False, - } - for ev in events - for e_id in ev.prev_event_ids() - ], - ) - - self._update_backward_extremeties(txn, events) - - def _update_backward_extremeties(self, txn, events): - """Updates the event_backward_extremities tables based on the new/updated - events being persisted. - - This is called for new events *and* for events that were outliers, but - are now being persisted as non-outliers. - - Forward extremities are handled when we first start persisting the events. - """ - events_by_room = {} - for ev in events: - events_by_room.setdefault(ev.room_id, []).append(ev) - - query = ( - "INSERT INTO event_backward_extremities (event_id, room_id)" - " SELECT ?, ? WHERE NOT EXISTS (" - " SELECT 1 FROM event_backward_extremities" - " WHERE event_id = ? AND room_id = ?" - " )" - " AND NOT EXISTS (" - " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? " - " AND outlier = ?" - " )" - ) - - txn.executemany( - query, - [ - (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False) - for ev in events - for e_id in ev.prev_event_ids() - if not ev.internal_metadata.is_outlier() - ], - ) - - query = ( - "DELETE FROM event_backward_extremities" - " WHERE event_id = ? AND room_id = ?" - ) - txn.executemany( - query, - [ - (ev.event_id, ev.room_id) - for ev in events - if not ev.internal_metadata.is_outlier() - ], - ) - def _delete_old_forward_extrem_cache(self): def _delete_old_forward_extrem_cache_txn(txn): # Delete entries older than a month, while making sure we don't delete diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py index 8eed590929..0321274de2 100644 --- a/synapse/storage/data_stores/main/event_push_actions.py +++ b/synapse/storage/data_stores/main/event_push_actions.py @@ -652,69 +652,6 @@ class EventPushActionsStore(EventPushActionsWorkerStore): self._start_rotate_notifs, 30 * 60 * 1000 ) - def _set_push_actions_for_event_and_users_txn( - self, txn, events_and_contexts, all_events_and_contexts - ): - """Handles moving push actions from staging table to main - event_push_actions table for all events in `events_and_contexts`. - - Also ensures that all events in `all_events_and_contexts` are removed - from the push action staging area. - - Args: - events_and_contexts (list[(EventBase, EventContext)]): events - we are persisting - all_events_and_contexts (list[(EventBase, EventContext)]): all - events that we were going to persist. This includes events - we've already persisted, etc, that wouldn't appear in - events_and_context. - """ - - sql = """ - INSERT INTO event_push_actions ( - room_id, event_id, user_id, actions, stream_ordering, - topological_ordering, notif, highlight - ) - SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight - FROM event_push_actions_staging - WHERE event_id = ? - """ - - if events_and_contexts: - txn.executemany( - sql, - ( - ( - event.room_id, - event.internal_metadata.stream_ordering, - event.depth, - event.event_id, - ) - for event, _ in events_and_contexts - ), - ) - - for event, _ in events_and_contexts: - user_ids = self.db.simple_select_onecol_txn( - txn, - table="event_push_actions_staging", - keyvalues={"event_id": event.event_id}, - retcol="user_id", - ) - - for uid in user_ids: - txn.call_after( - self.get_unread_event_push_actions_by_room_for_user.invalidate_many, - (event.room_id, uid), - ) - - # Now we delete the staging area for *all* events that were being - # persisted. - txn.executemany( - "DELETE FROM event_push_actions_staging WHERE event_id = ?", - ((event.event_id,) for event, _ in all_events_and_contexts), - ) - @defer.inlineCallbacks def get_push_actions_for_user( self, user_id, before=None, limit=50, only_highlight=False @@ -763,17 +700,6 @@ class EventPushActionsStore(EventPushActionsWorkerStore): ) return result[0] or 0 - def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id): - # Sad that we have to blow away the cache for the whole room here - txn.call_after( - self.get_unread_event_push_actions_by_room_for_user.invalidate_many, - (room_id,), - ) - txn.execute( - "DELETE FROM event_push_actions WHERE room_id = ? AND event_id = ?", - (room_id, event_id), - ) - def _remove_old_push_actions_before_txn( self, txn, room_id, user_id, stream_ordering ): diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index e71c23541d..092739962b 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -17,39 +17,44 @@ import itertools import logging -from collections import Counter as c_counter, OrderedDict, namedtuple +from collections import OrderedDict, namedtuple from functools import wraps -from typing import Dict, List, Tuple +from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple -from six import iteritems, text_type +from six import integer_types, iteritems, text_type from six.moves import range +import attr from canonicaljson import json from prometheus_client import Counter from twisted.internet import defer import synapse.metrics -from synapse.api.constants import EventContentFields, EventTypes -from synapse.api.errors import SynapseError +from synapse.api.constants import ( + EventContentFields, + EventTypes, + Membership, + RelationTypes, +) from synapse.api.room_versions import RoomVersions +from synapse.crypto.event_signing import compute_event_reference_hash from synapse.events import EventBase # noqa: F401 from synapse.events.snapshot import EventContext # noqa: F401 -from synapse.events.utils import prune_event_dict from synapse.logging.utils import log_function -from synapse.metrics import BucketCollector -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import make_in_list_sql_clause -from synapse.storage.data_stores.main.event_federation import EventFederationStore -from synapse.storage.data_stores.main.events_worker import EventsWorkerStore -from synapse.storage.data_stores.main.state import StateGroupWorkerStore +from synapse.storage.data_stores.main.search import SearchEntry from synapse.storage.database import Database, LoggingTransaction -from synapse.storage.persist_events import DeltaState -from synapse.types import RoomStreamToken, StateMap, get_domain_from_id -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks +from synapse.storage.util.id_generators import StreamIdGenerator +from synapse.types import StateMap, get_domain_from_id from synapse.util.frozenutils import frozendict_json_encoder from synapse.util.iterutils import batch_iter +if TYPE_CHECKING: + from synapse.storage.data_stores.main import DataStore + from synapse.server import HomeServer + + logger = logging.getLogger(__name__) persist_event_counter = Counter("synapse_storage_events_persisted_events", "") @@ -94,58 +99,49 @@ def _retry_on_integrity_error(func): return f -# inherits from EventFederationStore so that we can call _update_backward_extremities -# and _handle_mult_prev_events (though arguably those could both be moved in here) -class EventsStore( - StateGroupWorkerStore, EventFederationStore, EventsWorkerStore, -): - def __init__(self, database: Database, db_conn, hs): - super(EventsStore, self).__init__(database, db_conn, hs) +@attr.s(slots=True) +class DeltaState: + """Deltas to use to update the `current_state_events` table. - # Collect metrics on the number of forward extremities that exist. - # Counter of number of extremities to count - self._current_forward_extremities_amount = c_counter() + Attributes: + to_delete: List of type/state_keys to delete from current state + to_insert: Map of state to upsert into current state + no_longer_in_room: The server is not longer in the room, so the room + should e.g. be removed from `current_state_events` table. + """ - BucketCollector( - "synapse_forward_extremities", - lambda: self._current_forward_extremities_amount, - buckets=[1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"], - ) + to_delete = attr.ib(type=List[Tuple[str, str]]) + to_insert = attr.ib(type=StateMap[str]) + no_longer_in_room = attr.ib(type=bool, default=False) - # Read the extrems every 60 minutes - def read_forward_extremities(): - # run as a background process to make sure that the database transactions - # have a logcontext to report to - return run_as_background_process( - "read_forward_extremities", self._read_forward_extremities - ) - hs.get_clock().looping_call(read_forward_extremities, 60 * 60 * 1000) +class PersistEventsStore: + """Contains all the functions for writing events to the database. - def _censor_redactions(): - return run_as_background_process( - "_censor_redactions", self._censor_redactions - ) + Should only be instansiated on one process (when using a worker mode setup). + + Note: This is not part of the `DataStore` mixin. + """ - if self.hs.config.redaction_retention_period is not None: - hs.get_clock().looping_call(_censor_redactions, 5 * 60 * 1000) + def __init__(self, hs: "HomeServer", db: Database, main_data_store: "DataStore"): + self.hs = hs + self.db = db + self.store = main_data_store + self.database_engine = db.engine + self._clock = hs.get_clock() self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages self.is_mine_id = hs.is_mine_id - @defer.inlineCallbacks - def _read_forward_extremities(self): - def fetch(txn): - txn.execute( - """ - select count(*) c from event_forward_extremities - group by room_id - """ - ) - return txn.fetchall() + # Ideally we'd move these ID gens here, unfortunately some other ID + # generators are chained off them so doing so is a bit of a PITA. + self._backfill_id_gen = self.store._backfill_id_gen # type: StreamIdGenerator + self._stream_id_gen = self.store._stream_id_gen # type: StreamIdGenerator - res = yield self.db.runInteraction("read_forward_extremities", fetch) - self._current_forward_extremities_amount = c_counter([x[0] for x in res]) + # This should only exist on master for now + assert ( + hs.config.worker.worker_app is None + ), "Can only instantiate PersistEventsStore on master" @_retry_on_integrity_error @defer.inlineCallbacks @@ -237,10 +233,10 @@ class EventsStore( event_counter.labels(event.type, origin_type, origin_entity).inc() for room_id, new_state in iteritems(current_state_for_room): - self.get_current_state_ids.prefill((room_id,), new_state) + self.store.get_current_state_ids.prefill((room_id,), new_state) for room_id, latest_event_ids in iteritems(new_forward_extremeties): - self.get_latest_event_ids_in_room.prefill( + self.store.get_latest_event_ids_in_room.prefill( (room_id,), list(latest_event_ids) ) @@ -586,7 +582,7 @@ class EventsStore( ) txn.call_after( - self._curr_state_delta_stream_cache.entity_has_changed, + self.store._curr_state_delta_stream_cache.entity_has_changed, room_id, stream_id, ) @@ -606,10 +602,13 @@ class EventsStore( for member in members_changed: txn.call_after( - self.get_rooms_for_user_with_stream_ordering.invalidate, (member,) + self.store.get_rooms_for_user_with_stream_ordering.invalidate, + (member,), ) - self._invalidate_state_caches_and_stream(txn, room_id, members_changed) + self.store._invalidate_state_caches_and_stream( + txn, room_id, members_changed + ) def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str): """Update the room version in the database based off current state @@ -647,7 +646,9 @@ class EventsStore( self.db.simple_delete_txn( txn, table="event_forward_extremities", keyvalues={"room_id": room_id} ) - txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,)) + txn.call_after( + self.store.get_latest_event_ids_in_room.invalidate, (room_id,) + ) self.db.simple_insert_many_txn( txn, @@ -713,10 +714,10 @@ class EventsStore( depth_updates = {} for event, context in events_and_contexts: # Remove the any existing cache entries for the event_ids - txn.call_after(self._invalidate_get_event_cache, event.event_id) + txn.call_after(self.store._invalidate_get_event_cache, event.event_id) if not backfilled: txn.call_after( - self._events_stream_cache.entity_has_changed, + self.store._events_stream_cache.entity_has_changed, event.room_id, event.internal_metadata.stream_ordering, ) @@ -1088,13 +1089,15 @@ class EventsStore( def prefill(): for cache_entry in to_prefill: - self._get_event_cache.prefill((cache_entry[0].event_id,), cache_entry) + self.store._get_event_cache.prefill( + (cache_entry[0].event_id,), cache_entry + ) txn.call_after(prefill) def _store_redaction(self, txn, event): # invalidate the cache for the redacted event - txn.call_after(self._invalidate_get_event_cache, event.redacts) + txn.call_after(self.store._invalidate_get_event_cache, event.redacts) self.db.simple_insert_txn( txn, @@ -1106,783 +1109,484 @@ class EventsStore( }, ) - async def _censor_redactions(self): - """Censors all redactions older than the configured period that haven't - been censored yet. + def insert_labels_for_event_txn( + self, txn, event_id, labels, room_id, topological_ordering + ): + """Store the mapping between an event's ID and its labels, with one row per + (event_id, label) tuple. - By censor we mean update the event_json table with the redacted event. + Args: + txn (LoggingTransaction): The transaction to execute. + event_id (str): The event's ID. + labels (list[str]): A list of text labels. + room_id (str): The ID of the room the event was sent to. + topological_ordering (int): The position of the event in the room's topology. """ + return self.db.simple_insert_many_txn( + txn=txn, + table="event_labels", + values=[ + { + "event_id": event_id, + "label": label, + "room_id": room_id, + "topological_ordering": topological_ordering, + } + for label in labels + ], + ) - if self.hs.config.redaction_retention_period is None: - return - - if not ( - await self.db.updates.has_completed_background_update( - "redactions_have_censored_ts_idx" - ) - ): - # We don't want to run this until the appropriate index has been - # created. - return - - before_ts = self._clock.time_msec() - self.hs.config.redaction_retention_period + def _insert_event_expiry_txn(self, txn, event_id, expiry_ts): + """Save the expiry timestamp associated with a given event ID. - # We fetch all redactions that: - # 1. point to an event we have, - # 2. has a received_ts from before the cut off, and - # 3. we haven't yet censored. - # - # This is limited to 100 events to ensure that we don't try and do too - # much at once. We'll get called again so this should eventually catch - # up. - sql = """ - SELECT redactions.event_id, redacts FROM redactions - LEFT JOIN events AS original_event ON ( - redacts = original_event.event_id - ) - WHERE NOT have_censored - AND redactions.received_ts <= ? - ORDER BY redactions.received_ts ASC - LIMIT ? + Args: + txn (LoggingTransaction): The database transaction to use. + event_id (str): The event ID the expiry timestamp is associated with. + expiry_ts (int): The timestamp at which to expire (delete) the event. """ - - rows = await self.db.execute( - "_censor_redactions_fetch", None, sql, before_ts, 100 + return self.db.simple_insert_txn( + txn=txn, + table="event_expiry", + values={"event_id": event_id, "expiry_ts": expiry_ts}, ) - updates = [] + def _store_event_reference_hashes_txn(self, txn, events): + """Store a hash for a PDU + Args: + txn (cursor): + events (list): list of Events. + """ - for redaction_id, event_id in rows: - redaction_event = await self.get_event(redaction_id, allow_none=True) - original_event = await self.get_event( - event_id, allow_rejected=True, allow_none=True + vals = [] + for event in events: + ref_alg, ref_hash_bytes = compute_event_reference_hash(event) + vals.append( + { + "event_id": event.event_id, + "algorithm": ref_alg, + "hash": memoryview(ref_hash_bytes), + } ) - # The SQL above ensures that we have both the redaction and - # original event, so if the `get_event` calls return None it - # means that the redaction wasn't allowed. Either way we know that - # the result won't change so we mark the fact that we've checked. - if ( - redaction_event - and original_event - and original_event.internal_metadata.is_redacted() - ): - # Redaction was allowed - pruned_json = encode_json( - prune_event_dict( - original_event.room_version, original_event.get_dict() - ) - ) - else: - # Redaction wasn't allowed - pruned_json = None - - updates.append((redaction_id, event_id, pruned_json)) - - def _update_censor_txn(txn): - for redaction_id, event_id, pruned_json in updates: - if pruned_json: - self._censor_event_txn(txn, event_id, pruned_json) - - self.db.simple_update_one_txn( - txn, - table="redactions", - keyvalues={"event_id": redaction_id}, - updatevalues={"have_censored": True}, - ) - - await self.db.runInteraction("_update_censor_txn", _update_censor_txn) + self.db.simple_insert_many_txn(txn, table="event_reference_hashes", values=vals) - def _censor_event_txn(self, txn, event_id, pruned_json): - """Censor an event by replacing its JSON in the event_json table with the - provided pruned JSON. - - Args: - txn (LoggingTransaction): The database transaction. - event_id (str): The ID of the event to censor. - pruned_json (str): The pruned JSON + def _store_room_members_txn(self, txn, events, backfilled): + """Store a room member in the database. """ - self.db.simple_update_one_txn( + self.db.simple_insert_many_txn( txn, - table="event_json", - keyvalues={"event_id": event_id}, - updatevalues={"json": pruned_json}, + table="room_memberships", + values=[ + { + "event_id": event.event_id, + "user_id": event.state_key, + "sender": event.user_id, + "room_id": event.room_id, + "membership": event.membership, + "display_name": event.content.get("displayname", None), + "avatar_url": event.content.get("avatar_url", None), + } + for event in events + ], ) - @defer.inlineCallbacks - def count_daily_messages(self): - """ - Returns an estimate of the number of messages sent in the last day. - - If it has been significantly less or more than one day since the last - call to this function, it will return None. - """ - - def _count_messages(txn): - sql = """ - SELECT COALESCE(COUNT(*), 0) FROM events - WHERE type = 'm.room.message' - AND stream_ordering > ? - """ - txn.execute(sql, (self.stream_ordering_day_ago,)) - (count,) = txn.fetchone() - return count - - ret = yield self.db.runInteraction("count_messages", _count_messages) - return ret - - @defer.inlineCallbacks - def count_daily_sent_messages(self): - def _count_messages(txn): - # This is good enough as if you have silly characters in your own - # hostname then thats your own fault. - like_clause = "%:" + self.hs.hostname - - sql = """ - SELECT COALESCE(COUNT(*), 0) FROM events - WHERE type = 'm.room.message' - AND sender LIKE ? - AND stream_ordering > ? - """ - - txn.execute(sql, (like_clause, self.stream_ordering_day_ago)) - (count,) = txn.fetchone() - return count - - ret = yield self.db.runInteraction("count_daily_sent_messages", _count_messages) - return ret - - @defer.inlineCallbacks - def count_daily_active_rooms(self): - def _count(txn): - sql = """ - SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events - WHERE type = 'm.room.message' - AND stream_ordering > ? - """ - txn.execute(sql, (self.stream_ordering_day_ago,)) - (count,) = txn.fetchone() - return count - - ret = yield self.db.runInteraction("count_daily_active_rooms", _count) - return ret - - @cached(num_args=5, max_entries=10) - def get_all_new_events( - self, - last_backfill_id, - last_forward_id, - current_backfill_id, - current_forward_id, - limit, - ): - """Get all the new events that have arrived at the server either as - new events or as backfilled events""" - have_backfill_events = last_backfill_id != current_backfill_id - have_forward_events = last_forward_id != current_forward_id - - if not have_backfill_events and not have_forward_events: - return defer.succeed(AllNewEventsResult([], [], [], [], [])) - - def get_all_new_events_txn(txn): - sql = ( - "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts" - " FROM events AS e" - " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" - " WHERE ? < stream_ordering AND stream_ordering <= ?" - " ORDER BY stream_ordering ASC" - " LIMIT ?" + for event in events: + txn.call_after( + self.store._membership_stream_cache.entity_has_changed, + event.state_key, + event.internal_metadata.stream_ordering, ) - if have_forward_events: - txn.execute(sql, (last_forward_id, current_forward_id, limit)) - new_forward_events = txn.fetchall() - - if len(new_forward_events) == limit: - upper_bound = new_forward_events[-1][0] - else: - upper_bound = current_forward_id - - sql = ( - "SELECT event_stream_ordering, event_id, state_group" - " FROM ex_outlier_stream" - " WHERE ? > event_stream_ordering" - " AND event_stream_ordering >= ?" - " ORDER BY event_stream_ordering DESC" - ) - txn.execute(sql, (last_forward_id, upper_bound)) - forward_ex_outliers = txn.fetchall() - else: - new_forward_events = [] - forward_ex_outliers = [] - - sql = ( - "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts" - " FROM events AS e" - " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" - " WHERE ? > stream_ordering AND stream_ordering >= ?" - " ORDER BY stream_ordering DESC" - " LIMIT ?" + txn.call_after( + self.store.get_invited_rooms_for_local_user.invalidate, + (event.state_key,), ) - if have_backfill_events: - txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit)) - new_backfill_events = txn.fetchall() - if len(new_backfill_events) == limit: - upper_bound = new_backfill_events[-1][0] - else: - upper_bound = current_backfill_id - - sql = ( - "SELECT -event_stream_ordering, event_id, state_group" - " FROM ex_outlier_stream" - " WHERE ? > event_stream_ordering" - " AND event_stream_ordering >= ?" - " ORDER BY event_stream_ordering DESC" - ) - txn.execute(sql, (-last_backfill_id, -upper_bound)) - backward_ex_outliers = txn.fetchall() - else: - new_backfill_events = [] - backward_ex_outliers = [] - - return AllNewEventsResult( - new_forward_events, - new_backfill_events, - forward_ex_outliers, - backward_ex_outliers, + # We update the local_invites table only if the event is "current", + # i.e., its something that has just happened. If the event is an + # outlier it is only current if its an "out of band membership", + # like a remote invite or a rejection of a remote invite. + is_new_state = not backfilled and ( + not event.internal_metadata.is_outlier() + or event.internal_metadata.is_out_of_band_membership() ) + is_mine = self.is_mine_id(event.state_key) + if is_new_state and is_mine: + if event.membership == Membership.INVITE: + self.db.simple_insert_txn( + txn, + table="local_invites", + values={ + "event_id": event.event_id, + "invitee": event.state_key, + "inviter": event.sender, + "room_id": event.room_id, + "stream_id": event.internal_metadata.stream_ordering, + }, + ) + else: + sql = ( + "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE" + " room_id = ? AND invitee = ? AND locally_rejected is NULL" + " AND replaced_by is NULL" + ) - return self.db.runInteraction("get_all_new_events", get_all_new_events_txn) - - def purge_history(self, room_id, token, delete_local_events): - """Deletes room history before a certain point - - Args: - room_id (str): + txn.execute( + sql, + ( + event.internal_metadata.stream_ordering, + event.event_id, + event.room_id, + event.state_key, + ), + ) - token (str): A topological token to delete events before + # We also update the `local_current_membership` table with + # latest invite info. This will usually get updated by the + # `current_state_events` handling, unless its an outlier. + if event.internal_metadata.is_outlier(): + # This should only happen for out of band memberships, so + # we add a paranoia check. + assert event.internal_metadata.is_out_of_band_membership() + + self.db.simple_upsert_txn( + txn, + table="local_current_membership", + keyvalues={ + "room_id": event.room_id, + "user_id": event.state_key, + }, + values={ + "event_id": event.event_id, + "membership": event.membership, + }, + ) - delete_local_events (bool): - if True, we will delete local events as well as remote ones - (instead of just marking them as outliers and deleting their - state groups). + def _handle_event_relations(self, txn, event): + """Handles inserting relation data during peristence of events - Returns: - Deferred[set[int]]: The set of state groups that are referenced by - deleted events. + Args: + txn + event (EventBase) """ + relation = event.content.get("m.relates_to") + if not relation: + # No relations + return - return self.db.runInteraction( - "purge_history", - self._purge_history_txn, - room_id, - token, - delete_local_events, - ) + rel_type = relation.get("rel_type") + if rel_type not in ( + RelationTypes.ANNOTATION, + RelationTypes.REFERENCE, + RelationTypes.REPLACE, + ): + # Unknown relation type + return - def _purge_history_txn(self, txn, room_id, token_str, delete_local_events): - token = RoomStreamToken.parse(token_str) - - # Tables that should be pruned: - # event_auth - # event_backward_extremities - # event_edges - # event_forward_extremities - # event_json - # event_push_actions - # event_reference_hashes - # event_search - # event_to_state_groups - # events - # rejections - # room_depth - # state_groups - # state_groups_state - - # we will build a temporary table listing the events so that we don't - # have to keep shovelling the list back and forth across the - # connection. Annoyingly the python sqlite driver commits the - # transaction on CREATE, so let's do this first. - # - # furthermore, we might already have the table from a previous (failed) - # purge attempt, so let's drop the table first. + parent_id = relation.get("event_id") + if not parent_id: + # Invalid relation + return - txn.execute("DROP TABLE IF EXISTS events_to_purge") + aggregation_key = relation.get("key") - txn.execute( - "CREATE TEMPORARY TABLE events_to_purge (" - " event_id TEXT NOT NULL," - " should_delete BOOLEAN NOT NULL" - ")" + self.db.simple_insert_txn( + txn, + table="event_relations", + values={ + "event_id": event.event_id, + "relates_to_id": parent_id, + "relation_type": rel_type, + "aggregation_key": aggregation_key, + }, ) - # First ensure that we're not about to delete all the forward extremeties - txn.execute( - "SELECT e.event_id, e.depth FROM events as e " - "INNER JOIN event_forward_extremities as f " - "ON e.event_id = f.event_id " - "AND e.room_id = f.room_id " - "WHERE f.room_id = ?", - (room_id,), + txn.call_after(self.store.get_relations_for_event.invalidate_many, (parent_id,)) + txn.call_after( + self.store.get_aggregation_groups_for_event.invalidate_many, (parent_id,) ) - rows = txn.fetchall() - max_depth = max(row[1] for row in rows) - - if max_depth < token.topological: - # We need to ensure we don't delete all the events from the database - # otherwise we wouldn't be able to send any events (due to not - # having any backwards extremeties) - raise SynapseError( - 400, "topological_ordering is greater than forward extremeties" - ) - - logger.info("[purge] looking for events to delete") - - should_delete_expr = "state_key IS NULL" - should_delete_params = () - if not delete_local_events: - should_delete_expr += " AND event_id NOT LIKE ?" - - # We include the parameter twice since we use the expression twice - should_delete_params += ("%:" + self.hs.hostname, "%:" + self.hs.hostname) - should_delete_params += (room_id, token.topological) + if rel_type == RelationTypes.REPLACE: + txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,)) - # Note that we insert events that are outliers and aren't going to be - # deleted, as nothing will happen to them. - txn.execute( - "INSERT INTO events_to_purge" - " SELECT event_id, %s" - " FROM events AS e LEFT JOIN state_events USING (event_id)" - " WHERE (NOT outlier OR (%s)) AND e.room_id = ? AND topological_ordering < ?" - % (should_delete_expr, should_delete_expr), - should_delete_params, - ) + def _handle_redaction(self, txn, redacted_event_id): + """Handles receiving a redaction and checking whether we need to remove + any redacted relations from the database. - # We create the indices *after* insertion as that's a lot faster. + Args: + txn + redacted_event_id (str): The event that was redacted. + """ - # create an index on should_delete because later we'll be looking for - # the should_delete / shouldn't_delete subsets - txn.execute( - "CREATE INDEX events_to_purge_should_delete" - " ON events_to_purge(should_delete)" + self.db.simple_delete_txn( + txn, table="event_relations", keyvalues={"event_id": redacted_event_id} ) - # We do joins against events_to_purge for e.g. calculating state - # groups to purge, etc., so lets make an index. - txn.execute("CREATE INDEX events_to_purge_id ON events_to_purge(event_id)") - - txn.execute("SELECT event_id, should_delete FROM events_to_purge") - event_rows = txn.fetchall() - logger.info( - "[purge] found %i events before cutoff, of which %i can be deleted", - len(event_rows), - sum(1 for e in event_rows if e[1]), - ) + def _store_room_topic_txn(self, txn, event): + if hasattr(event, "content") and "topic" in event.content: + self.store_event_search_txn( + txn, event, "content.topic", event.content["topic"] + ) - logger.info("[purge] Finding new backward extremities") + def _store_room_name_txn(self, txn, event): + if hasattr(event, "content") and "name" in event.content: + self.store_event_search_txn( + txn, event, "content.name", event.content["name"] + ) - # We calculate the new entries for the backward extremeties by finding - # events to be purged that are pointed to by events we're not going to - # purge. - txn.execute( - "SELECT DISTINCT e.event_id FROM events_to_purge AS e" - " INNER JOIN event_edges AS ed ON e.event_id = ed.prev_event_id" - " LEFT JOIN events_to_purge AS ep2 ON ed.event_id = ep2.event_id" - " WHERE ep2.event_id IS NULL" - ) - new_backwards_extrems = txn.fetchall() + def _store_room_message_txn(self, txn, event): + if hasattr(event, "content") and "body" in event.content: + self.store_event_search_txn( + txn, event, "content.body", event.content["body"] + ) - logger.info("[purge] replacing backward extremities: %r", new_backwards_extrems) + def _store_retention_policy_for_room_txn(self, txn, event): + if hasattr(event, "content") and ( + "min_lifetime" in event.content or "max_lifetime" in event.content + ): + if ( + "min_lifetime" in event.content + and not isinstance(event.content.get("min_lifetime"), integer_types) + ) or ( + "max_lifetime" in event.content + and not isinstance(event.content.get("max_lifetime"), integer_types) + ): + # Ignore the event if one of the value isn't an integer. + return - txn.execute( - "DELETE FROM event_backward_extremities WHERE room_id = ?", (room_id,) - ) + self.db.simple_insert_txn( + txn=txn, + table="room_retention", + values={ + "room_id": event.room_id, + "event_id": event.event_id, + "min_lifetime": event.content.get("min_lifetime"), + "max_lifetime": event.content.get("max_lifetime"), + }, + ) - # Update backward extremeties - txn.executemany( - "INSERT INTO event_backward_extremities (room_id, event_id)" - " VALUES (?, ?)", - [(room_id, event_id) for event_id, in new_backwards_extrems], - ) + self.store._invalidate_cache_and_stream( + txn, self.store.get_retention_policy_for_room, (event.room_id,) + ) - logger.info("[purge] finding state groups referenced by deleted events") + def store_event_search_txn(self, txn, event, key, value): + """Add event to the search table - # Get all state groups that are referenced by events that are to be - # deleted. - txn.execute( - """ - SELECT DISTINCT state_group FROM events_to_purge - INNER JOIN event_to_state_groups USING (event_id) + Args: + txn (cursor): + event (EventBase): + key (str): + value (str): """ + self.store.store_search_entries_txn( + txn, + ( + SearchEntry( + key=key, + value=value, + event_id=event.event_id, + room_id=event.room_id, + stream_ordering=event.internal_metadata.stream_ordering, + origin_server_ts=event.origin_server_ts, + ), + ), ) - referenced_state_groups = {sg for sg, in txn} - logger.info( - "[purge] found %i referenced state groups", len(referenced_state_groups) - ) + def _set_push_actions_for_event_and_users_txn( + self, txn, events_and_contexts, all_events_and_contexts + ): + """Handles moving push actions from staging table to main + event_push_actions table for all events in `events_and_contexts`. - logger.info("[purge] removing events from event_to_state_groups") - txn.execute( - "DELETE FROM event_to_state_groups " - "WHERE event_id IN (SELECT event_id from events_to_purge)" - ) - for event_id, _ in event_rows: - txn.call_after(self._get_state_group_for_event.invalidate, (event_id,)) + Also ensures that all events in `all_events_and_contexts` are removed + from the push action staging area. - # Delete all remote non-state events - for table in ( - "events", - "event_json", - "event_auth", - "event_edges", - "event_forward_extremities", - "event_reference_hashes", - "event_search", - "rejections", - ): - logger.info("[purge] removing events from %s", table) + Args: + events_and_contexts (list[(EventBase, EventContext)]): events + we are persisting + all_events_and_contexts (list[(EventBase, EventContext)]): all + events that we were going to persist. This includes events + we've already persisted, etc, that wouldn't appear in + events_and_context. + """ - txn.execute( - "DELETE FROM %s WHERE event_id IN (" - " SELECT event_id FROM events_to_purge WHERE should_delete" - ")" % (table,) + sql = """ + INSERT INTO event_push_actions ( + room_id, event_id, user_id, actions, stream_ordering, + topological_ordering, notif, highlight ) + SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight + FROM event_push_actions_staging + WHERE event_id = ? + """ - # event_push_actions lacks an index on event_id, and has one on - # (room_id, event_id) instead. - for table in ("event_push_actions",): - logger.info("[purge] removing events from %s", table) + if events_and_contexts: + txn.executemany( + sql, + ( + ( + event.room_id, + event.internal_metadata.stream_ordering, + event.depth, + event.event_id, + ) + for event, _ in events_and_contexts + ), + ) - txn.execute( - "DELETE FROM %s WHERE room_id = ? AND event_id IN (" - " SELECT event_id FROM events_to_purge WHERE should_delete" - ")" % (table,), - (room_id,), + for event, _ in events_and_contexts: + user_ids = self.db.simple_select_onecol_txn( + txn, + table="event_push_actions_staging", + keyvalues={"event_id": event.event_id}, + retcol="user_id", ) - # Mark all state and own events as outliers - logger.info("[purge] marking remaining events as outliers") - txn.execute( - "UPDATE events SET outlier = ?" - " WHERE event_id IN (" - " SELECT event_id FROM events_to_purge " - " WHERE NOT should_delete" - ")", - (True,), + for uid in user_ids: + txn.call_after( + self.store.get_unread_event_push_actions_by_room_for_user.invalidate_many, + (event.room_id, uid), + ) + + # Now we delete the staging area for *all* events that were being + # persisted. + txn.executemany( + "DELETE FROM event_push_actions_staging WHERE event_id = ?", + ((event.event_id,) for event, _ in all_events_and_contexts), ) - # synapse tries to take out an exclusive lock on room_depth whenever it - # persists events (because upsert), and once we run this update, we - # will block that for the rest of our transaction. - # - # So, let's stick it at the end so that we don't block event - # persistence. - # - # We do this by calculating the minimum depth of the backwards - # extremities. However, the events in event_backward_extremities - # are ones we don't have yet so we need to look at the events that - # point to it via event_edges table. - txn.execute( - """ - SELECT COALESCE(MIN(depth), 0) - FROM event_backward_extremities AS eb - INNER JOIN event_edges AS eg ON eg.prev_event_id = eb.event_id - INNER JOIN events AS e ON e.event_id = eg.event_id - WHERE eb.room_id = ? - """, + def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id): + # Sad that we have to blow away the cache for the whole room here + txn.call_after( + self.store.get_unread_event_push_actions_by_room_for_user.invalidate_many, (room_id,), ) - (min_depth,) = txn.fetchone() - - logger.info("[purge] updating room_depth to %d", min_depth) - txn.execute( - "UPDATE room_depth SET min_depth = ? WHERE room_id = ?", - (min_depth, room_id), + "DELETE FROM event_push_actions WHERE room_id = ? AND event_id = ?", + (room_id, event_id), ) - # finally, drop the temp table. this will commit the txn in sqlite, - # so make sure to keep this actually last. - txn.execute("DROP TABLE events_to_purge") - - logger.info("[purge] done") - - return referenced_state_groups - - def purge_room(self, room_id): - """Deletes all record of a room + def _store_rejections_txn(self, txn, event_id, reason): + self.db.simple_insert_txn( + txn, + table="rejections", + values={ + "event_id": event_id, + "reason": reason, + "last_check": self._clock.time_msec(), + }, + ) - Args: - room_id (str) + def _store_event_state_mappings_txn( + self, txn, events_and_contexts: Iterable[Tuple[EventBase, EventContext]] + ): + state_groups = {} + for event, context in events_and_contexts: + if event.internal_metadata.is_outlier(): + continue - Returns: - Deferred[List[int]]: The list of state groups to delete. - """ + # if the event was rejected, just give it the same state as its + # predecessor. + if context.rejected: + state_groups[event.event_id] = context.state_group_before_event + continue - return self.db.runInteraction("purge_room", self._purge_room_txn, room_id) + state_groups[event.event_id] = context.state_group - def _purge_room_txn(self, txn, room_id): - # First we fetch all the state groups that should be deleted, before - # we delete that information. - txn.execute( - """ - SELECT DISTINCT state_group FROM events - INNER JOIN event_to_state_groups USING(event_id) - WHERE events.room_id = ? - """, - (room_id,), + self.db.simple_insert_many_txn( + txn, + table="event_to_state_groups", + values=[ + {"state_group": state_group_id, "event_id": event_id} + for event_id, state_group_id in iteritems(state_groups) + ], ) - state_groups = [row[0] for row in txn] - - # Now we delete tables which lack an index on room_id but have one on event_id - for table in ( - "event_auth", - "event_edges", - "event_push_actions_staging", - "event_reference_hashes", - "event_relations", - "event_to_state_groups", - "redactions", - "rejections", - "state_events", - ): - logger.info("[purge] removing %s from %s", room_id, table) - - txn.execute( - """ - DELETE FROM %s WHERE event_id IN ( - SELECT event_id FROM events WHERE room_id=? - ) - """ - % (table,), - (room_id,), + for event_id, state_group_id in iteritems(state_groups): + txn.call_after( + self.store._get_state_group_for_event.prefill, + (event_id,), + state_group_id, ) - # and finally, the tables with an index on room_id (or no useful index) - for table in ( - "current_state_events", - "event_backward_extremities", - "event_forward_extremities", - "event_json", - "event_push_actions", - "event_search", - "events", - "group_rooms", - "public_room_list_stream", - "receipts_graph", - "receipts_linearized", - "room_aliases", - "room_depth", - "room_memberships", - "room_stats_state", - "room_stats_current", - "room_stats_historical", - "room_stats_earliest_token", - "rooms", - "stream_ordering_to_exterm", - "users_in_public_rooms", - "users_who_share_private_rooms", - # no useful index, but let's clear them anyway - "appservice_room_list", - "e2e_room_keys", - "event_push_summary", - "pusher_throttle", - "group_summary_rooms", - "local_invites", - "room_account_data", - "room_tags", - "local_current_membership", - ): - logger.info("[purge] removing %s from %s", room_id, table) - txn.execute("DELETE FROM %s WHERE room_id=?" % (table,), (room_id,)) - - # Other tables we do NOT need to clear out: - # - # - blocked_rooms - # This is important, to make sure that we don't accidentally rejoin a blocked - # room after it was purged - # - # - user_directory - # This has a room_id column, but it is unused - # - - # Other tables that we might want to consider clearing out include: - # - # - event_reports - # Given that these are intended for abuse management my initial - # inclination is to leave them in place. - # - # - current_state_delta_stream - # - ex_outlier_stream - # - room_tags_revisions - # The problem with these is that they are largeish and there is no room_id - # index on them. In any case we should be clearing out 'stream' tables - # periodically anyway (#5888) - - # TODO: we could probably usefully do a bunch of cache invalidation here + def _update_min_depth_for_room_txn(self, txn, room_id, depth): + min_depth = self.store._get_min_depth_interaction(txn, room_id) - logger.info("[purge] done") - - return state_groups - - async def is_event_after(self, event_id1, event_id2): - """Returns True if event_id1 is after event_id2 in the stream - """ - to_1, so_1 = await self._get_event_ordering(event_id1) - to_2, so_2 = await self._get_event_ordering(event_id2) - return (to_1, so_1) > (to_2, so_2) + if min_depth is not None and depth >= min_depth: + return - @cachedInlineCallbacks(max_entries=5000) - def _get_event_ordering(self, event_id): - res = yield self.db.simple_select_one( - table="events", - retcols=["topological_ordering", "stream_ordering"], - keyvalues={"event_id": event_id}, - allow_none=True, + self.db.simple_upsert_txn( + txn, + table="room_depth", + keyvalues={"room_id": room_id}, + values={"min_depth": depth}, ) - if not res: - raise SynapseError(404, "Could not find event %s" % (event_id,)) - - return (int(res["topological_ordering"]), int(res["stream_ordering"])) - - def insert_labels_for_event_txn( - self, txn, event_id, labels, room_id, topological_ordering - ): - """Store the mapping between an event's ID and its labels, with one row per - (event_id, label) tuple. - - Args: - txn (LoggingTransaction): The transaction to execute. - event_id (str): The event's ID. - labels (list[str]): A list of text labels. - room_id (str): The ID of the room the event was sent to. - topological_ordering (int): The position of the event in the room's topology. + def _handle_mult_prev_events(self, txn, events): """ - return self.db.simple_insert_many_txn( - txn=txn, - table="event_labels", + For the given event, update the event edges table and forward and + backward extremities tables. + """ + self.db.simple_insert_many_txn( + txn, + table="event_edges", values=[ { - "event_id": event_id, - "label": label, - "room_id": room_id, - "topological_ordering": topological_ordering, + "event_id": ev.event_id, + "prev_event_id": e_id, + "room_id": ev.room_id, + "is_state": False, } - for label in labels + for ev in events + for e_id in ev.prev_event_ids() ], ) - def _insert_event_expiry_txn(self, txn, event_id, expiry_ts): - """Save the expiry timestamp associated with a given event ID. - - Args: - txn (LoggingTransaction): The database transaction to use. - event_id (str): The event ID the expiry timestamp is associated with. - expiry_ts (int): The timestamp at which to expire (delete) the event. - """ - return self.db.simple_insert_txn( - txn=txn, - table="event_expiry", - values={"event_id": event_id, "expiry_ts": expiry_ts}, - ) - - @defer.inlineCallbacks - def expire_event(self, event_id): - """Retrieve and expire an event that has expired, and delete its associated - expiry timestamp. If the event can't be retrieved, delete its associated - timestamp so we don't try to expire it again in the future. - - Args: - event_id (str): The ID of the event to delete. - """ - # Try to retrieve the event's content from the database or the event cache. - event = yield self.get_event(event_id) - - def delete_expired_event_txn(txn): - # Delete the expiry timestamp associated with this event from the database. - self._delete_event_expiry_txn(txn, event_id) - - if not event: - # If we can't find the event, log a warning and delete the expiry date - # from the database so that we don't try to expire it again in the - # future. - logger.warning( - "Can't expire event %s because we don't have it.", event_id - ) - return - - # Prune the event's dict then convert it to JSON. - pruned_json = encode_json( - prune_event_dict(event.room_version, event.get_dict()) - ) - - # Update the event_json table to replace the event's JSON with the pruned - # JSON. - self._censor_event_txn(txn, event.event_id, pruned_json) - - # We need to invalidate the event cache entry for this event because we - # changed its content in the database. We can't call - # self._invalidate_cache_and_stream because self.get_event_cache isn't of the - # right type. - txn.call_after(self._get_event_cache.invalidate, (event.event_id,)) - # Send that invalidation to replication so that other workers also invalidate - # the event cache. - self._send_invalidation_to_replication( - txn, "_get_event_cache", (event.event_id,) - ) + self._update_backward_extremeties(txn, events) - yield self.db.runInteraction("delete_expired_event", delete_expired_event_txn) + def _update_backward_extremeties(self, txn, events): + """Updates the event_backward_extremities tables based on the new/updated + events being persisted. - def _delete_event_expiry_txn(self, txn, event_id): - """Delete the expiry timestamp associated with an event ID without deleting the - actual event. + This is called for new events *and* for events that were outliers, but + are now being persisted as non-outliers. - Args: - txn (LoggingTransaction): The transaction to use to perform the deletion. - event_id (str): The event ID to delete the associated expiry timestamp of. + Forward extremities are handled when we first start persisting the events. """ - return self.db.simple_delete_txn( - txn=txn, table="event_expiry", keyvalues={"event_id": event_id} + events_by_room = {} + for ev in events: + events_by_room.setdefault(ev.room_id, []).append(ev) + + query = ( + "INSERT INTO event_backward_extremities (event_id, room_id)" + " SELECT ?, ? WHERE NOT EXISTS (" + " SELECT 1 FROM event_backward_extremities" + " WHERE event_id = ? AND room_id = ?" + " )" + " AND NOT EXISTS (" + " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? " + " AND outlier = ?" + " )" ) - def get_next_event_to_expire(self): - """Retrieve the entry with the lowest expiry timestamp in the event_expiry - table, or None if there's no more event to expire. - - Returns: Deferred[Optional[Tuple[str, int]]] - A tuple containing the event ID as its first element and an expiry timestamp - as its second one, if there's at least one row in the event_expiry table. - None otherwise. - """ - - def get_next_event_to_expire_txn(txn): - txn.execute( - """ - SELECT event_id, expiry_ts FROM event_expiry - ORDER BY expiry_ts ASC LIMIT 1 - """ - ) - - return txn.fetchone() - - return self.db.runInteraction( - desc="get_next_event_to_expire", func=get_next_event_to_expire_txn + txn.executemany( + query, + [ + (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False) + for ev in events + for e_id in ev.prev_event_ids() + if not ev.internal_metadata.is_outlier() + ], ) - -AllNewEventsResult = namedtuple( - "AllNewEventsResult", - [ - "new_forward_events", - "new_backfill_events", - "forward_ex_outliers", - "backward_ex_outliers", - ], -) + query = ( + "DELETE FROM event_backward_extremities" + " WHERE event_id = ? AND room_id = ?" + ) + txn.executemany( + query, + [ + (ev.event_id, ev.room_id) + for ev in events + if not ev.internal_metadata.is_outlier() + ], + ) diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index b8c1bbdf99..970c31bd05 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -27,7 +27,7 @@ from constantly import NamedConstant, Names from twisted.internet import defer from synapse.api.constants import EventTypes -from synapse.api.errors import NotFoundError +from synapse.api.errors import NotFoundError, SynapseError from synapse.api.room_versions import ( KNOWN_ROOM_VERSIONS, EventFormatVersions, @@ -40,7 +40,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause from synapse.storage.database import Database from synapse.types import get_domain_from_id -from synapse.util.caches.descriptors import Cache +from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure @@ -1157,4 +1157,152 @@ class EventsWorkerStore(SQLBaseStore): rows = await self.db.runInteraction( "get_deltas_for_stream_id", get_deltas_for_stream_id_txn, to_token ) + return rows, to_token, True + + @cached(num_args=5, max_entries=10) + def get_all_new_events( + self, + last_backfill_id, + last_forward_id, + current_backfill_id, + current_forward_id, + limit, + ): + """Get all the new events that have arrived at the server either as + new events or as backfilled events""" + have_backfill_events = last_backfill_id != current_backfill_id + have_forward_events = last_forward_id != current_forward_id + + if not have_backfill_events and not have_forward_events: + return defer.succeed(AllNewEventsResult([], [], [], [], [])) + + def get_all_new_events_txn(txn): + sql = ( + "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts" + " FROM events AS e" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " WHERE ? < stream_ordering AND stream_ordering <= ?" + " ORDER BY stream_ordering ASC" + " LIMIT ?" + ) + if have_forward_events: + txn.execute(sql, (last_forward_id, current_forward_id, limit)) + new_forward_events = txn.fetchall() + + if len(new_forward_events) == limit: + upper_bound = new_forward_events[-1][0] + else: + upper_bound = current_forward_id + + sql = ( + "SELECT event_stream_ordering, event_id, state_group" + " FROM ex_outlier_stream" + " WHERE ? > event_stream_ordering" + " AND event_stream_ordering >= ?" + " ORDER BY event_stream_ordering DESC" + ) + txn.execute(sql, (last_forward_id, upper_bound)) + forward_ex_outliers = txn.fetchall() + else: + new_forward_events = [] + forward_ex_outliers = [] + + sql = ( + "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts" + " FROM events AS e" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " WHERE ? > stream_ordering AND stream_ordering >= ?" + " ORDER BY stream_ordering DESC" + " LIMIT ?" + ) + if have_backfill_events: + txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit)) + new_backfill_events = txn.fetchall() + + if len(new_backfill_events) == limit: + upper_bound = new_backfill_events[-1][0] + else: + upper_bound = current_backfill_id + + sql = ( + "SELECT -event_stream_ordering, event_id, state_group" + " FROM ex_outlier_stream" + " WHERE ? > event_stream_ordering" + " AND event_stream_ordering >= ?" + " ORDER BY event_stream_ordering DESC" + ) + txn.execute(sql, (-last_backfill_id, -upper_bound)) + backward_ex_outliers = txn.fetchall() + else: + new_backfill_events = [] + backward_ex_outliers = [] + + return AllNewEventsResult( + new_forward_events, + new_backfill_events, + forward_ex_outliers, + backward_ex_outliers, + ) + + return self.db.runInteraction("get_all_new_events", get_all_new_events_txn) + + async def is_event_after(self, event_id1, event_id2): + """Returns True if event_id1 is after event_id2 in the stream + """ + to_1, so_1 = await self._get_event_ordering(event_id1) + to_2, so_2 = await self._get_event_ordering(event_id2) + return (to_1, so_1) > (to_2, so_2) + + @cachedInlineCallbacks(max_entries=5000) + def _get_event_ordering(self, event_id): + res = yield self.db.simple_select_one( + table="events", + retcols=["topological_ordering", "stream_ordering"], + keyvalues={"event_id": event_id}, + allow_none=True, + ) + + if not res: + raise SynapseError(404, "Could not find event %s" % (event_id,)) + + return (int(res["topological_ordering"]), int(res["stream_ordering"])) + + def get_next_event_to_expire(self): + """Retrieve the entry with the lowest expiry timestamp in the event_expiry + table, or None if there's no more event to expire. + + Returns: Deferred[Optional[Tuple[str, int]]] + A tuple containing the event ID as its first element and an expiry timestamp + as its second one, if there's at least one row in the event_expiry table. + None otherwise. + """ + + def get_next_event_to_expire_txn(txn): + txn.execute( + """ + SELECT event_id, expiry_ts FROM event_expiry + ORDER BY expiry_ts ASC LIMIT 1 + """ + ) + + return txn.fetchone() + + return self.db.runInteraction( + desc="get_next_event_to_expire", func=get_next_event_to_expire_txn + ) + + +AllNewEventsResult = namedtuple( + "AllNewEventsResult", + [ + "new_forward_events", + "new_backfill_events", + "forward_ex_outliers", + "backward_ex_outliers", + ], +) diff --git a/synapse/storage/data_stores/main/metrics.py b/synapse/storage/data_stores/main/metrics.py new file mode 100644 index 0000000000..dad5bbc602 --- /dev/null +++ b/synapse/storage/data_stores/main/metrics.py @@ -0,0 +1,128 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import typing +from collections import Counter + +from twisted.internet import defer + +from synapse.metrics import BucketCollector +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage._base import SQLBaseStore +from synapse.storage.data_stores.main.event_push_actions import ( + EventPushActionsWorkerStore, +) +from synapse.storage.database import Database + + +class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): + """Functions to pull various metrics from the DB, for e.g. phone home + stats and prometheus metrics. + """ + + def __init__(self, database: Database, db_conn, hs): + super().__init__(database, db_conn, hs) + + # Collect metrics on the number of forward extremities that exist. + # Counter of number of extremities to count + self._current_forward_extremities_amount = ( + Counter() + ) # type: typing.Counter[int] + + BucketCollector( + "synapse_forward_extremities", + lambda: self._current_forward_extremities_amount, + buckets=[1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"], + ) + + # Read the extrems every 60 minutes + def read_forward_extremities(): + # run as a background process to make sure that the database transactions + # have a logcontext to report to + return run_as_background_process( + "read_forward_extremities", self._read_forward_extremities + ) + + hs.get_clock().looping_call(read_forward_extremities, 60 * 60 * 1000) + + async def _read_forward_extremities(self): + def fetch(txn): + txn.execute( + """ + select count(*) c from event_forward_extremities + group by room_id + """ + ) + return txn.fetchall() + + res = await self.db.runInteraction("read_forward_extremities", fetch) + self._current_forward_extremities_amount = Counter([x[0] for x in res]) + + @defer.inlineCallbacks + def count_daily_messages(self): + """ + Returns an estimate of the number of messages sent in the last day. + + If it has been significantly less or more than one day since the last + call to this function, it will return None. + """ + + def _count_messages(txn): + sql = """ + SELECT COALESCE(COUNT(*), 0) FROM events + WHERE type = 'm.room.message' + AND stream_ordering > ? + """ + txn.execute(sql, (self.stream_ordering_day_ago,)) + (count,) = txn.fetchone() + return count + + ret = yield self.db.runInteraction("count_messages", _count_messages) + return ret + + @defer.inlineCallbacks + def count_daily_sent_messages(self): + def _count_messages(txn): + # This is good enough as if you have silly characters in your own + # hostname then thats your own fault. + like_clause = "%:" + self.hs.hostname + + sql = """ + SELECT COALESCE(COUNT(*), 0) FROM events + WHERE type = 'm.room.message' + AND sender LIKE ? + AND stream_ordering > ? + """ + + txn.execute(sql, (like_clause, self.stream_ordering_day_ago)) + (count,) = txn.fetchone() + return count + + ret = yield self.db.runInteraction("count_daily_sent_messages", _count_messages) + return ret + + @defer.inlineCallbacks + def count_daily_active_rooms(self): + def _count(txn): + sql = """ + SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events + WHERE type = 'm.room.message' + AND stream_ordering > ? + """ + txn.execute(sql, (self.stream_ordering_day_ago,)) + (count,) = txn.fetchone() + return count + + ret = yield self.db.runInteraction("count_daily_active_rooms", _count) + return ret diff --git a/synapse/storage/data_stores/main/purge_events.py b/synapse/storage/data_stores/main/purge_events.py new file mode 100644 index 0000000000..a93e1ef198 --- /dev/null +++ b/synapse/storage/data_stores/main/purge_events.py @@ -0,0 +1,399 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Any, Tuple + +from synapse.api.errors import SynapseError +from synapse.storage._base import SQLBaseStore +from synapse.storage.data_stores.main.state import StateGroupWorkerStore +from synapse.types import RoomStreamToken + +logger = logging.getLogger(__name__) + + +class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): + def purge_history(self, room_id, token, delete_local_events): + """Deletes room history before a certain point + + Args: + room_id (str): + + token (str): A topological token to delete events before + + delete_local_events (bool): + if True, we will delete local events as well as remote ones + (instead of just marking them as outliers and deleting their + state groups). + + Returns: + Deferred[set[int]]: The set of state groups that are referenced by + deleted events. + """ + + return self.db.runInteraction( + "purge_history", + self._purge_history_txn, + room_id, + token, + delete_local_events, + ) + + def _purge_history_txn(self, txn, room_id, token_str, delete_local_events): + token = RoomStreamToken.parse(token_str) + + # Tables that should be pruned: + # event_auth + # event_backward_extremities + # event_edges + # event_forward_extremities + # event_json + # event_push_actions + # event_reference_hashes + # event_search + # event_to_state_groups + # events + # rejections + # room_depth + # state_groups + # state_groups_state + + # we will build a temporary table listing the events so that we don't + # have to keep shovelling the list back and forth across the + # connection. Annoyingly the python sqlite driver commits the + # transaction on CREATE, so let's do this first. + # + # furthermore, we might already have the table from a previous (failed) + # purge attempt, so let's drop the table first. + + txn.execute("DROP TABLE IF EXISTS events_to_purge") + + txn.execute( + "CREATE TEMPORARY TABLE events_to_purge (" + " event_id TEXT NOT NULL," + " should_delete BOOLEAN NOT NULL" + ")" + ) + + # First ensure that we're not about to delete all the forward extremeties + txn.execute( + "SELECT e.event_id, e.depth FROM events as e " + "INNER JOIN event_forward_extremities as f " + "ON e.event_id = f.event_id " + "AND e.room_id = f.room_id " + "WHERE f.room_id = ?", + (room_id,), + ) + rows = txn.fetchall() + max_depth = max(row[1] for row in rows) + + if max_depth < token.topological: + # We need to ensure we don't delete all the events from the database + # otherwise we wouldn't be able to send any events (due to not + # having any backwards extremeties) + raise SynapseError( + 400, "topological_ordering is greater than forward extremeties" + ) + + logger.info("[purge] looking for events to delete") + + should_delete_expr = "state_key IS NULL" + should_delete_params = () # type: Tuple[Any, ...] + if not delete_local_events: + should_delete_expr += " AND event_id NOT LIKE ?" + + # We include the parameter twice since we use the expression twice + should_delete_params += ("%:" + self.hs.hostname, "%:" + self.hs.hostname) + + should_delete_params += (room_id, token.topological) + + # Note that we insert events that are outliers and aren't going to be + # deleted, as nothing will happen to them. + txn.execute( + "INSERT INTO events_to_purge" + " SELECT event_id, %s" + " FROM events AS e LEFT JOIN state_events USING (event_id)" + " WHERE (NOT outlier OR (%s)) AND e.room_id = ? AND topological_ordering < ?" + % (should_delete_expr, should_delete_expr), + should_delete_params, + ) + + # We create the indices *after* insertion as that's a lot faster. + + # create an index on should_delete because later we'll be looking for + # the should_delete / shouldn't_delete subsets + txn.execute( + "CREATE INDEX events_to_purge_should_delete" + " ON events_to_purge(should_delete)" + ) + + # We do joins against events_to_purge for e.g. calculating state + # groups to purge, etc., so lets make an index. + txn.execute("CREATE INDEX events_to_purge_id ON events_to_purge(event_id)") + + txn.execute("SELECT event_id, should_delete FROM events_to_purge") + event_rows = txn.fetchall() + logger.info( + "[purge] found %i events before cutoff, of which %i can be deleted", + len(event_rows), + sum(1 for e in event_rows if e[1]), + ) + + logger.info("[purge] Finding new backward extremities") + + # We calculate the new entries for the backward extremeties by finding + # events to be purged that are pointed to by events we're not going to + # purge. + txn.execute( + "SELECT DISTINCT e.event_id FROM events_to_purge AS e" + " INNER JOIN event_edges AS ed ON e.event_id = ed.prev_event_id" + " LEFT JOIN events_to_purge AS ep2 ON ed.event_id = ep2.event_id" + " WHERE ep2.event_id IS NULL" + ) + new_backwards_extrems = txn.fetchall() + + logger.info("[purge] replacing backward extremities: %r", new_backwards_extrems) + + txn.execute( + "DELETE FROM event_backward_extremities WHERE room_id = ?", (room_id,) + ) + + # Update backward extremeties + txn.executemany( + "INSERT INTO event_backward_extremities (room_id, event_id)" + " VALUES (?, ?)", + [(room_id, event_id) for event_id, in new_backwards_extrems], + ) + + logger.info("[purge] finding state groups referenced by deleted events") + + # Get all state groups that are referenced by events that are to be + # deleted. + txn.execute( + """ + SELECT DISTINCT state_group FROM events_to_purge + INNER JOIN event_to_state_groups USING (event_id) + """ + ) + + referenced_state_groups = {sg for sg, in txn} + logger.info( + "[purge] found %i referenced state groups", len(referenced_state_groups) + ) + + logger.info("[purge] removing events from event_to_state_groups") + txn.execute( + "DELETE FROM event_to_state_groups " + "WHERE event_id IN (SELECT event_id from events_to_purge)" + ) + for event_id, _ in event_rows: + txn.call_after(self._get_state_group_for_event.invalidate, (event_id,)) + + # Delete all remote non-state events + for table in ( + "events", + "event_json", + "event_auth", + "event_edges", + "event_forward_extremities", + "event_reference_hashes", + "event_search", + "rejections", + ): + logger.info("[purge] removing events from %s", table) + + txn.execute( + "DELETE FROM %s WHERE event_id IN (" + " SELECT event_id FROM events_to_purge WHERE should_delete" + ")" % (table,) + ) + + # event_push_actions lacks an index on event_id, and has one on + # (room_id, event_id) instead. + for table in ("event_push_actions",): + logger.info("[purge] removing events from %s", table) + + txn.execute( + "DELETE FROM %s WHERE room_id = ? AND event_id IN (" + " SELECT event_id FROM events_to_purge WHERE should_delete" + ")" % (table,), + (room_id,), + ) + + # Mark all state and own events as outliers + logger.info("[purge] marking remaining events as outliers") + txn.execute( + "UPDATE events SET outlier = ?" + " WHERE event_id IN (" + " SELECT event_id FROM events_to_purge " + " WHERE NOT should_delete" + ")", + (True,), + ) + + # synapse tries to take out an exclusive lock on room_depth whenever it + # persists events (because upsert), and once we run this update, we + # will block that for the rest of our transaction. + # + # So, let's stick it at the end so that we don't block event + # persistence. + # + # We do this by calculating the minimum depth of the backwards + # extremities. However, the events in event_backward_extremities + # are ones we don't have yet so we need to look at the events that + # point to it via event_edges table. + txn.execute( + """ + SELECT COALESCE(MIN(depth), 0) + FROM event_backward_extremities AS eb + INNER JOIN event_edges AS eg ON eg.prev_event_id = eb.event_id + INNER JOIN events AS e ON e.event_id = eg.event_id + WHERE eb.room_id = ? + """, + (room_id,), + ) + (min_depth,) = txn.fetchone() + + logger.info("[purge] updating room_depth to %d", min_depth) + + txn.execute( + "UPDATE room_depth SET min_depth = ? WHERE room_id = ?", + (min_depth, room_id), + ) + + # finally, drop the temp table. this will commit the txn in sqlite, + # so make sure to keep this actually last. + txn.execute("DROP TABLE events_to_purge") + + logger.info("[purge] done") + + return referenced_state_groups + + def purge_room(self, room_id): + """Deletes all record of a room + + Args: + room_id (str) + + Returns: + Deferred[List[int]]: The list of state groups to delete. + """ + + return self.db.runInteraction("purge_room", self._purge_room_txn, room_id) + + def _purge_room_txn(self, txn, room_id): + # First we fetch all the state groups that should be deleted, before + # we delete that information. + txn.execute( + """ + SELECT DISTINCT state_group FROM events + INNER JOIN event_to_state_groups USING(event_id) + WHERE events.room_id = ? + """, + (room_id,), + ) + + state_groups = [row[0] for row in txn] + + # Now we delete tables which lack an index on room_id but have one on event_id + for table in ( + "event_auth", + "event_edges", + "event_push_actions_staging", + "event_reference_hashes", + "event_relations", + "event_to_state_groups", + "redactions", + "rejections", + "state_events", + ): + logger.info("[purge] removing %s from %s", room_id, table) + + txn.execute( + """ + DELETE FROM %s WHERE event_id IN ( + SELECT event_id FROM events WHERE room_id=? + ) + """ + % (table,), + (room_id,), + ) + + # and finally, the tables with an index on room_id (or no useful index) + for table in ( + "current_state_events", + "event_backward_extremities", + "event_forward_extremities", + "event_json", + "event_push_actions", + "event_search", + "events", + "group_rooms", + "public_room_list_stream", + "receipts_graph", + "receipts_linearized", + "room_aliases", + "room_depth", + "room_memberships", + "room_stats_state", + "room_stats_current", + "room_stats_historical", + "room_stats_earliest_token", + "rooms", + "stream_ordering_to_exterm", + "users_in_public_rooms", + "users_who_share_private_rooms", + # no useful index, but let's clear them anyway + "appservice_room_list", + "e2e_room_keys", + "event_push_summary", + "pusher_throttle", + "group_summary_rooms", + "local_invites", + "room_account_data", + "room_tags", + "local_current_membership", + ): + logger.info("[purge] removing %s from %s", room_id, table) + txn.execute("DELETE FROM %s WHERE room_id=?" % (table,), (room_id,)) + + # Other tables we do NOT need to clear out: + # + # - blocked_rooms + # This is important, to make sure that we don't accidentally rejoin a blocked + # room after it was purged + # + # - user_directory + # This has a room_id column, but it is unused + # + + # Other tables that we might want to consider clearing out include: + # + # - event_reports + # Given that these are intended for abuse management my initial + # inclination is to leave them in place. + # + # - current_state_delta_stream + # - ex_outlier_stream + # - room_tags_revisions + # The problem with these is that they are largeish and there is no room_id + # index on them. In any case we should be clearing out 'stream' tables + # periodically anyway (#5888) + + # TODO: we could probably usefully do a bunch of cache invalidation here + + logger.info("[purge] done") + + return state_groups diff --git a/synapse/storage/data_stores/main/rejections.py b/synapse/storage/data_stores/main/rejections.py index 1c07c7a425..27e5a2084a 100644 --- a/synapse/storage/data_stores/main/rejections.py +++ b/synapse/storage/data_stores/main/rejections.py @@ -21,17 +21,6 @@ logger = logging.getLogger(__name__) class RejectionsStore(SQLBaseStore): - def _store_rejections_txn(self, txn, event_id, reason): - self.db.simple_insert_txn( - txn, - table="rejections", - values={ - "event_id": event_id, - "reason": reason, - "last_check": self._clock.time_msec(), - }, - ) - def get_rejection_reason(self, event_id): return self.db.simple_select_one_onecol( table="rejections", diff --git a/synapse/storage/data_stores/main/relations.py b/synapse/storage/data_stores/main/relations.py index 046c2b4845..7d477f8d01 100644 --- a/synapse/storage/data_stores/main/relations.py +++ b/synapse/storage/data_stores/main/relations.py @@ -324,62 +324,4 @@ class RelationsWorkerStore(SQLBaseStore): class RelationsStore(RelationsWorkerStore): - def _handle_event_relations(self, txn, event): - """Handles inserting relation data during peristence of events - - Args: - txn - event (EventBase) - """ - relation = event.content.get("m.relates_to") - if not relation: - # No relations - return - - rel_type = relation.get("rel_type") - if rel_type not in ( - RelationTypes.ANNOTATION, - RelationTypes.REFERENCE, - RelationTypes.REPLACE, - ): - # Unknown relation type - return - - parent_id = relation.get("event_id") - if not parent_id: - # Invalid relation - return - - aggregation_key = relation.get("key") - - self.db.simple_insert_txn( - txn, - table="event_relations", - values={ - "event_id": event.event_id, - "relates_to_id": parent_id, - "relation_type": rel_type, - "aggregation_key": aggregation_key, - }, - ) - - txn.call_after(self.get_relations_for_event.invalidate_many, (parent_id,)) - txn.call_after( - self.get_aggregation_groups_for_event.invalidate_many, (parent_id,) - ) - - if rel_type == RelationTypes.REPLACE: - txn.call_after(self.get_applicable_edit.invalidate, (parent_id,)) - - def _handle_redaction(self, txn, redacted_event_id): - """Handles receiving a redaction and checking whether we need to remove - any redacted relations from the database. - - Args: - txn - redacted_event_id (str): The event that was redacted. - """ - - self.db.simple_delete_txn( - txn, table="event_relations", keyvalues={"event_id": redacted_event_id} - ) + pass diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py index cafa664c16..46f643c6b9 100644 --- a/synapse/storage/data_stores/main/room.py +++ b/synapse/storage/data_stores/main/room.py @@ -21,8 +21,6 @@ from abc import abstractmethod from enum import Enum from typing import Any, Dict, List, Optional, Tuple -from six import integer_types - from canonicaljson import json from twisted.internet import defer @@ -1302,53 +1300,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): return self.db.runInteraction("get_rooms", f) - def _store_room_topic_txn(self, txn, event): - if hasattr(event, "content") and "topic" in event.content: - self.store_event_search_txn( - txn, event, "content.topic", event.content["topic"] - ) - - def _store_room_name_txn(self, txn, event): - if hasattr(event, "content") and "name" in event.content: - self.store_event_search_txn( - txn, event, "content.name", event.content["name"] - ) - - def _store_room_message_txn(self, txn, event): - if hasattr(event, "content") and "body" in event.content: - self.store_event_search_txn( - txn, event, "content.body", event.content["body"] - ) - - def _store_retention_policy_for_room_txn(self, txn, event): - if hasattr(event, "content") and ( - "min_lifetime" in event.content or "max_lifetime" in event.content - ): - if ( - "min_lifetime" in event.content - and not isinstance(event.content.get("min_lifetime"), integer_types) - ) or ( - "max_lifetime" in event.content - and not isinstance(event.content.get("max_lifetime"), integer_types) - ): - # Ignore the event if one of the value isn't an integer. - return - - self.db.simple_insert_txn( - txn=txn, - table="room_retention", - values={ - "room_id": event.room_id, - "event_id": event.event_id, - "min_lifetime": event.content.get("min_lifetime"), - "max_lifetime": event.content.get("max_lifetime"), - }, - ) - - self._invalidate_cache_and_stream( - txn, self.get_retention_policy_for_room, (event.room_id,) - ) - def add_event_report( self, room_id, event_id, user_id, reason, content, received_ts ): diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py index ae9a76713c..04ffdcf934 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py @@ -1050,96 +1050,6 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): def __init__(self, database: Database, db_conn, hs): super(RoomMemberStore, self).__init__(database, db_conn, hs) - def _store_room_members_txn(self, txn, events, backfilled): - """Store a room member in the database. - """ - self.db.simple_insert_many_txn( - txn, - table="room_memberships", - values=[ - { - "event_id": event.event_id, - "user_id": event.state_key, - "sender": event.user_id, - "room_id": event.room_id, - "membership": event.membership, - "display_name": event.content.get("displayname", None), - "avatar_url": event.content.get("avatar_url", None), - } - for event in events - ], - ) - - for event in events: - txn.call_after( - self._membership_stream_cache.entity_has_changed, - event.state_key, - event.internal_metadata.stream_ordering, - ) - txn.call_after( - self.get_invited_rooms_for_local_user.invalidate, (event.state_key,) - ) - - # We update the local_invites table only if the event is "current", - # i.e., its something that has just happened. If the event is an - # outlier it is only current if its an "out of band membership", - # like a remote invite or a rejection of a remote invite. - is_new_state = not backfilled and ( - not event.internal_metadata.is_outlier() - or event.internal_metadata.is_out_of_band_membership() - ) - is_mine = self.hs.is_mine_id(event.state_key) - if is_new_state and is_mine: - if event.membership == Membership.INVITE: - self.db.simple_insert_txn( - txn, - table="local_invites", - values={ - "event_id": event.event_id, - "invitee": event.state_key, - "inviter": event.sender, - "room_id": event.room_id, - "stream_id": event.internal_metadata.stream_ordering, - }, - ) - else: - sql = ( - "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE" - " room_id = ? AND invitee = ? AND locally_rejected is NULL" - " AND replaced_by is NULL" - ) - - txn.execute( - sql, - ( - event.internal_metadata.stream_ordering, - event.event_id, - event.room_id, - event.state_key, - ), - ) - - # We also update the `local_current_membership` table with - # latest invite info. This will usually get updated by the - # `current_state_events` handling, unless its an outlier. - if event.internal_metadata.is_outlier(): - # This should only happen for out of band memberships, so - # we add a paranoia check. - assert event.internal_metadata.is_out_of_band_membership() - - self.db.simple_upsert_txn( - txn, - table="local_current_membership", - keyvalues={ - "room_id": event.room_id, - "user_id": event.state_key, - }, - values={ - "event_id": event.event_id, - "membership": event.membership, - }, - ) - @defer.inlineCallbacks def locally_reject_invite(self, user_id, room_id): sql = ( diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py index 47ebb8a214..ee75b92344 100644 --- a/synapse/storage/data_stores/main/search.py +++ b/synapse/storage/data_stores/main/search.py @@ -347,29 +347,6 @@ class SearchStore(SearchBackgroundUpdateStore): def __init__(self, database: Database, db_conn, hs): super(SearchStore, self).__init__(database, db_conn, hs) - def store_event_search_txn(self, txn, event, key, value): - """Add event to the search table - - Args: - txn (cursor): - event (EventBase): - key (str): - value (str): - """ - self.store_search_entries_txn( - txn, - ( - SearchEntry( - key=key, - value=value, - event_id=event.event_id, - room_id=event.room_id, - stream_ordering=event.internal_metadata.stream_ordering, - origin_server_ts=event.origin_server_ts, - ), - ), - ) - @defer.inlineCallbacks def search_msgs(self, room_ids, search_term, keys): """Performs a full text search over events with given keys. diff --git a/synapse/storage/data_stores/main/signatures.py b/synapse/storage/data_stores/main/signatures.py index 563216b63c..36244d9f5d 100644 --- a/synapse/storage/data_stores/main/signatures.py +++ b/synapse/storage/data_stores/main/signatures.py @@ -13,23 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import six - from unpaddedbase64 import encode_base64 from twisted.internet import defer -from synapse.crypto.event_signing import compute_event_reference_hash from synapse.storage._base import SQLBaseStore from synapse.util.caches.descriptors import cached, cachedList -# py2 sqlite has buffer hardcoded as only binary type, so we must use it, -# despite being deprecated and removed in favor of memoryview -if six.PY2: - db_binary_type = six.moves.builtins.buffer -else: - db_binary_type = memoryview - class SignatureWorkerStore(SQLBaseStore): @cached() @@ -79,23 +69,3 @@ class SignatureWorkerStore(SQLBaseStore): class SignatureStore(SignatureWorkerStore): """Persistence for event signatures and hashes""" - - def _store_event_reference_hashes_txn(self, txn, events): - """Store a hash for a PDU - Args: - txn (cursor): - events (list): list of Events. - """ - - vals = [] - for event in events: - ref_alg, ref_hash_bytes = compute_event_reference_hash(event) - vals.append( - { - "event_id": event.event_id, - "algorithm": ref_alg, - "hash": db_binary_type(ref_hash_bytes), - } - ) - - self.db.simple_insert_many_txn(txn, table="event_reference_hashes", values=vals) diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py index 3a3b9a8e72..21052fcc7a 100644 --- a/synapse/storage/data_stores/main/state.py +++ b/synapse/storage/data_stores/main/state.py @@ -16,17 +16,12 @@ import collections.abc import logging from collections import namedtuple -from typing import Iterable, Tuple - -from six import iteritems from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion -from synapse.events import EventBase -from synapse.events.snapshot import EventContext from synapse.storage._base import SQLBaseStore from synapse.storage.data_stores.main.events_worker import EventsWorkerStore from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore @@ -473,33 +468,3 @@ class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore): def __init__(self, database: Database, db_conn, hs): super(StateStore, self).__init__(database, db_conn, hs) - - def _store_event_state_mappings_txn( - self, txn, events_and_contexts: Iterable[Tuple[EventBase, EventContext]] - ): - state_groups = {} - for event, context in events_and_contexts: - if event.internal_metadata.is_outlier(): - continue - - # if the event was rejected, just give it the same state as its - # predecessor. - if context.rejected: - state_groups[event.event_id] = context.state_group_before_event - continue - - state_groups[event.event_id] = context.state_group - - self.db.simple_insert_many_txn( - txn, - table="event_to_state_groups", - values=[ - {"state_group": state_group_id, "event_id": event_id} - for event_id, state_group_id in iteritems(state_groups) - ], - ) - - for event_id, state_group_id in iteritems(state_groups): - txn.call_after( - self._get_state_group_for_event.prefill, (event_id,), state_group_id - ) diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index 0f9ac1cf09..41881ea20b 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -23,7 +23,6 @@ from typing import Iterable, List, Optional, Set, Tuple from six import iteritems from six.moves import range -import attr from prometheus_client import Counter, Histogram from twisted.internet import defer @@ -35,6 +34,7 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda from synapse.metrics.background_process_metrics import run_as_background_process from synapse.state import StateResolutionStore from synapse.storage.data_stores import DataStores +from synapse.storage.data_stores.main.events import DeltaState from synapse.types import StateMap from synapse.util.async_helpers import ObservableDeferred from synapse.util.metrics import Measure @@ -73,22 +73,6 @@ stale_forward_extremities_counter = Histogram( ) -@attr.s(slots=True) -class DeltaState: - """Deltas to use to update the `current_state_events` table. - - Attributes: - to_delete: List of type/state_keys to delete from current state - to_insert: Map of state to upsert into current state - no_longer_in_room: The server is not longer in the room, so the room - should e.g. be removed from `current_state_events` table. - """ - - to_delete = attr.ib(type=List[Tuple[str, str]]) - to_insert = attr.ib(type=StateMap[str]) - no_longer_in_room = attr.ib(type=bool, default=False) - - class _EventPeristenceQueue(object): """Queues up events so that they can be persisted in bulk with only one concurrent transaction per room. @@ -205,6 +189,7 @@ class EventsPersistenceStorage(object): # store for now. self.main_store = stores.main self.state_store = stores.state + self.persist_events_store = stores.persist_events self._clock = hs.get_clock() self.is_mine_id = hs.is_mine_id @@ -445,7 +430,7 @@ class EventsPersistenceStorage(object): if current_state is not None: current_state_for_room[room_id] = current_state - await self.main_store._persist_events_and_state_updates( + await self.persist_events_store._persist_events_and_state_updates( chunk, current_state_for_room=current_state_for_room, state_delta_for_room=state_delta_for_room, @@ -491,13 +476,15 @@ class EventsPersistenceStorage(object): ) # Remove any events which are prev_events of any existing events. - existing_prevs = await self.main_store._get_events_which_are_prevs(result) + existing_prevs = await self.persist_events_store._get_events_which_are_prevs( + result + ) result.difference_update(existing_prevs) # Finally handle the case where the new events have soft-failed prev # events. If they do we need to remove them and their prev events, # otherwise we end up with dangling extremities. - existing_prevs = await self.main_store._get_prevs_before_rejected( + existing_prevs = await self.persist_events_store._get_prevs_before_rejected( e_id for event in new_events for e_id in event.prev_event_ids() ) result.difference_update(existing_prevs) diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index d4bcf1821e..b45bc9c115 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -35,6 +35,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): def setUp(self): hs = yield tests.utils.setup_test_homeserver(self.addCleanup) self.store = hs.get_datastore() + self.persist_events_store = hs.get_datastores().persist_events @defer.inlineCallbacks def test_get_unread_push_actions_for_user_in_range_for_http(self): @@ -76,7 +77,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): ) yield self.store.db.runInteraction( "", - self.store._set_push_actions_for_event_and_users_txn, + self.persist_events_store._set_push_actions_for_event_and_users_txn, [(event, None)], [(event, None)], ) -- cgit 1.5.1 From 00ba9c48bf1da6a3173fd77ed72d7d8295f4d051 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 13 May 2020 10:37:11 +0100 Subject: Spelling --- synapse/storage/data_stores/__init__.py | 2 +- synapse/storage/data_stores/main/events.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/data_stores/__init__.py index 7c14ba91c2..791961b296 100644 --- a/synapse/storage/data_stores/__init__.py +++ b/synapse/storage/data_stores/__init__.py @@ -67,7 +67,7 @@ class DataStores(object): self.main = main_store_class(database, db_conn, hs) # If we're on a process that can persist events (currently - # master), also instansiate a `PersistEventsStore` + # master), also instantiate a `PersistEventsStore` if hs.config.worker.worker_app is None: self.persist_events = PersistEventsStore( hs, database, self.main diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index 092739962b..a97f8b3934 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -118,7 +118,7 @@ class DeltaState: class PersistEventsStore: """Contains all the functions for writing events to the database. - Should only be instansiated on one process (when using a worker mode setup). + Should only be instantiated on one process (when using a worker mode setup). Note: This is not part of the `DataStore` mixin. """ -- cgit 1.5.1 From 1124111a12c3ab35f8b68d9031695aec8b2c7c50 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 13 May 2020 17:15:40 +0100 Subject: Allow censoring of events to happen on workers. (#7492) This is safe as we can now write to cache invalidation stream on workers, and is required for when we move event persistence off master. --- changelog.d/7492.misc | 1 + synapse/app/generic_worker.py | 2 ++ synapse/handlers/message.py | 2 -- synapse/storage/data_stores/main/censor_events.py | 7 +------ 4 files changed, 4 insertions(+), 8 deletions(-) create mode 100644 changelog.d/7492.misc (limited to 'synapse/storage') diff --git a/changelog.d/7492.misc b/changelog.d/7492.misc new file mode 100644 index 0000000000..5ad31819d6 --- /dev/null +++ b/changelog.d/7492.misc @@ -0,0 +1 @@ +Allow censoring of events to happen on workers. diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 667ad20428..bccb1140b2 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -122,6 +122,7 @@ from synapse.rest.client.v2_alpha.register import RegisterRestServlet from synapse.rest.client.versions import VersionsRestServlet from synapse.rest.key.v2 import KeyApiV2Resource from synapse.server import HomeServer +from synapse.storage.data_stores.main.censor_events import CensorEventsStore from synapse.storage.data_stores.main.media_repository import MediaRepositoryStore from synapse.storage.data_stores.main.monthly_active_users import ( MonthlyActiveUsersWorkerStore, @@ -442,6 +443,7 @@ class GenericWorkerSlavedStore( SlavedGroupServerStore, SlavedAccountDataStore, SlavedPusherStore, + CensorEventsStore, SlavedEventStore, SlavedKeyStore, RoomStore, diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index a622a600b4..0242521cc6 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -72,7 +72,6 @@ class MessageHandler(object): self.state_store = self.storage.state self._event_serializer = hs.get_event_client_serializer() self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages - self._is_worker_app = bool(hs.config.worker_app) # The scheduled call to self._expire_event. None if no call is currently # scheduled. @@ -260,7 +259,6 @@ class MessageHandler(object): Args: event (EventBase): The event to schedule the expiry of. """ - assert not self._is_worker_app expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER) if not isinstance(expiry_ts, int) or event.is_state(): diff --git a/synapse/storage/data_stores/main/censor_events.py b/synapse/storage/data_stores/main/censor_events.py index 49683b4d5a..2d48261724 100644 --- a/synapse/storage/data_stores/main/censor_events.py +++ b/synapse/storage/data_stores/main/censor_events.py @@ -33,15 +33,10 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class CensorEventsStore(CacheInvalidationWorkerStore, EventsWorkerStore, SQLBaseStore): +class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBaseStore): def __init__(self, database: Database, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) - # This should only exist on master for now - assert ( - hs.config.worker.worker_app is None - ), "Can only instantiate CensorEventsStore on master" - def _censor_redactions(): return run_as_background_process( "_censor_redactions", self._censor_redactions -- cgit 1.5.1 From 86614e251f25ef3f1741b4cd9c9d60ee9e754862 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Fri, 15 May 2020 16:17:12 +0100 Subject: Fix a small typo in the arguments of simple_update in update_remote_profile_cache (#7511) --- changelog.d/7511.bugfix | 1 + synapse/storage/data_stores/main/profile.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelog.d/7511.bugfix (limited to 'synapse/storage') diff --git a/changelog.d/7511.bugfix b/changelog.d/7511.bugfix new file mode 100644 index 0000000000..cf8bc69c6f --- /dev/null +++ b/changelog.d/7511.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug that broke the update remote profile background process. diff --git a/synapse/storage/data_stores/main/profile.py b/synapse/storage/data_stores/main/profile.py index 2b52cf9c1a..bfc9369f0b 100644 --- a/synapse/storage/data_stores/main/profile.py +++ b/synapse/storage/data_stores/main/profile.py @@ -110,7 +110,7 @@ class ProfileStore(ProfileWorkerStore): return self.db.simple_update( table="remote_profile_cache", keyvalues={"user_id": user_id}, - values={ + updatevalues={ "displayname": displayname, "avatar_url": avatar_url, "last_check": self._clock.time_msec(), -- cgit 1.5.1 From 1f36ff69e8c7b2853e906598898878742d39c42b Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 15 May 2020 16:43:59 +0100 Subject: Move event stream handling out of slave store. (#7491) This allows us to have the logic on both master and workers, which is necessary to move event persistence off master. We also combine the instantiation of ID generators from DataStore and slave stores to the base worker stores. This allows us to select which process writes events independently of the master/worker splits. --- changelog.d/7491.misc | 1 + synapse/replication/slave/storage/events.py | 89 ------------------- synapse/replication/slave/storage/push_rule.py | 8 -- synapse/storage/data_stores/main/__init__.py | 17 ---- synapse/storage/data_stores/main/cache.py | 102 +++++++++++++++++++++- synapse/storage/data_stores/main/events_worker.py | 35 ++++++++ synapse/storage/data_stores/main/push_rule.py | 14 +++ synapse/storage/util/id_generators.py | 11 +++ 8 files changed, 161 insertions(+), 116 deletions(-) create mode 100644 changelog.d/7491.misc (limited to 'synapse/storage') diff --git a/changelog.d/7491.misc b/changelog.d/7491.misc new file mode 100644 index 0000000000..50eb226db7 --- /dev/null +++ b/changelog.d/7491.misc @@ -0,0 +1 @@ +Move event stream handling out of slave store. diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index b313720a4b..1a1a50a24f 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -15,11 +15,6 @@ # limitations under the License. import logging -from synapse.api.constants import EventTypes -from synapse.replication.tcp.streams.events import ( - EventsStreamCurrentStateRow, - EventsStreamEventRow, -) from synapse.storage.data_stores.main.event_federation import EventFederationWorkerStore from synapse.storage.data_stores.main.event_push_actions import ( EventPushActionsWorkerStore, @@ -35,7 +30,6 @@ from synapse.storage.database import Database from synapse.util.caches.stream_change_cache import StreamChangeCache from ._base import BaseSlavedStore -from ._slaved_id_tracker import SlavedIdTracker logger = logging.getLogger(__name__) @@ -62,11 +56,6 @@ class SlavedEventStore( BaseSlavedStore, ): def __init__(self, database: Database, db_conn, hs): - self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering") - self._backfill_id_gen = SlavedIdTracker( - db_conn, "events", "stream_ordering", step=-1 - ) - super(SlavedEventStore, self).__init__(database, db_conn, hs) events_max = self._stream_id_gen.get_current_token() @@ -92,81 +81,3 @@ class SlavedEventStore( def get_room_min_stream_ordering(self): return self._backfill_id_gen.get_current_token() - - def process_replication_rows(self, stream_name, instance_name, token, rows): - if stream_name == "events": - self._stream_id_gen.advance(token) - for row in rows: - self._process_event_stream_row(token, row) - elif stream_name == "backfill": - self._backfill_id_gen.advance(-token) - for row in rows: - self.invalidate_caches_for_event( - -token, - row.event_id, - row.room_id, - row.type, - row.state_key, - row.redacts, - row.relates_to, - backfilled=True, - ) - return super().process_replication_rows(stream_name, instance_name, token, rows) - - def _process_event_stream_row(self, token, row): - data = row.data - - if row.type == EventsStreamEventRow.TypeId: - self.invalidate_caches_for_event( - token, - data.event_id, - data.room_id, - data.type, - data.state_key, - data.redacts, - data.relates_to, - backfilled=False, - ) - elif row.type == EventsStreamCurrentStateRow.TypeId: - self._curr_state_delta_stream_cache.entity_has_changed( - row.data.room_id, token - ) - - if data.type == EventTypes.Member: - self.get_rooms_for_user_with_stream_ordering.invalidate( - (data.state_key,) - ) - else: - raise Exception("Unknown events stream row type %s" % (row.type,)) - - def invalidate_caches_for_event( - self, - stream_ordering, - event_id, - room_id, - etype, - state_key, - redacts, - relates_to, - backfilled, - ): - self._invalidate_get_event_cache(event_id) - - self.get_latest_event_ids_in_room.invalidate((room_id,)) - - self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,)) - - if not backfilled: - self._events_stream_cache.entity_has_changed(room_id, stream_ordering) - - if redacts: - self._invalidate_get_event_cache(redacts) - - if etype == EventTypes.Member: - self._membership_stream_cache.entity_has_changed(state_key, stream_ordering) - self.get_invited_rooms_for_local_user.invalidate((state_key,)) - - if relates_to: - self.get_relations_for_event.invalidate_many((relates_to,)) - self.get_aggregation_groups_for_event.invalidate_many((relates_to,)) - self.get_applicable_edit.invalidate((relates_to,)) diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py index 5d5816d7eb..6adb19463a 100644 --- a/synapse/replication/slave/storage/push_rule.py +++ b/synapse/replication/slave/storage/push_rule.py @@ -15,19 +15,11 @@ # limitations under the License. from synapse.storage.data_stores.main.push_rule import PushRulesWorkerStore -from synapse.storage.database import Database -from ._slaved_id_tracker import SlavedIdTracker from .events import SlavedEventStore class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore): - def __init__(self, database: Database, db_conn, hs): - self._push_rules_stream_id_gen = SlavedIdTracker( - db_conn, "push_rules_stream", "stream_id" - ) - super(SlavedPushRuleStore, self).__init__(database, db_conn, hs) - def get_push_rules_stream_token(self): return ( self._push_rules_stream_id_gen.get_current_token(), diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py index 5df9dce79d..4b4763c701 100644 --- a/synapse/storage/data_stores/main/__init__.py +++ b/synapse/storage/data_stores/main/__init__.py @@ -24,7 +24,6 @@ from synapse.config.homeserver import HomeServerConfig from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import ( - ChainedIdGenerator, IdGenerator, MultiWriterIdGenerator, StreamIdGenerator, @@ -125,19 +124,6 @@ class DataStore( self._clock = hs.get_clock() self.database_engine = database.engine - self._stream_id_gen = StreamIdGenerator( - db_conn, - "events", - "stream_ordering", - extra_tables=[("local_invites", "stream_id")], - ) - self._backfill_id_gen = StreamIdGenerator( - db_conn, - "events", - "stream_ordering", - step=-1, - extra_tables=[("ex_outlier_stream", "event_stream_ordering")], - ) self._presence_id_gen = StreamIdGenerator( db_conn, "presence_stream", "stream_id" ) @@ -164,9 +150,6 @@ class DataStore( self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") - self._push_rules_stream_id_gen = ChainedIdGenerator( - self._stream_id_gen, db_conn, "push_rules_stream", "stream_id" - ) self._pushers_id_gen = StreamIdGenerator( db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] ) diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py index 342a87a46b..eac5a4e55b 100644 --- a/synapse/storage/data_stores/main/cache.py +++ b/synapse/storage/data_stores/main/cache.py @@ -16,8 +16,13 @@ import itertools import logging -from typing import Any, Iterable, Optional +from typing import Any, Iterable, Optional, Tuple +from synapse.api.constants import EventTypes +from synapse.replication.tcp.streams.events import ( + EventsStreamCurrentStateRow, + EventsStreamEventRow, +) from synapse.storage._base import SQLBaseStore from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine @@ -66,7 +71,22 @@ class CacheInvalidationWorkerStore(SQLBaseStore): ) def process_replication_rows(self, stream_name, instance_name, token, rows): - if stream_name == "caches": + if stream_name == "events": + for row in rows: + self._process_event_stream_row(token, row) + elif stream_name == "backfill": + for row in rows: + self._invalidate_caches_for_event( + -token, + row.event_id, + row.room_id, + row.type, + row.state_key, + row.redacts, + row.relates_to, + backfilled=True, + ) + elif stream_name == "caches": if self._cache_id_gen: self._cache_id_gen.advance(instance_name, token) @@ -85,6 +105,84 @@ class CacheInvalidationWorkerStore(SQLBaseStore): super().process_replication_rows(stream_name, instance_name, token, rows) + def _process_event_stream_row(self, token, row): + data = row.data + + if row.type == EventsStreamEventRow.TypeId: + self._invalidate_caches_for_event( + token, + data.event_id, + data.room_id, + data.type, + data.state_key, + data.redacts, + data.relates_to, + backfilled=False, + ) + elif row.type == EventsStreamCurrentStateRow.TypeId: + self._curr_state_delta_stream_cache.entity_has_changed( + row.data.room_id, token + ) + + if data.type == EventTypes.Member: + self.get_rooms_for_user_with_stream_ordering.invalidate( + (data.state_key,) + ) + else: + raise Exception("Unknown events stream row type %s" % (row.type,)) + + def _invalidate_caches_for_event( + self, + stream_ordering, + event_id, + room_id, + etype, + state_key, + redacts, + relates_to, + backfilled, + ): + self._invalidate_get_event_cache(event_id) + + self.get_latest_event_ids_in_room.invalidate((room_id,)) + + self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,)) + + if not backfilled: + self._events_stream_cache.entity_has_changed(room_id, stream_ordering) + + if redacts: + self._invalidate_get_event_cache(redacts) + + if etype == EventTypes.Member: + self._membership_stream_cache.entity_has_changed(state_key, stream_ordering) + self.get_invited_rooms_for_local_user.invalidate((state_key,)) + + if relates_to: + self.get_relations_for_event.invalidate_many((relates_to,)) + self.get_aggregation_groups_for_event.invalidate_many((relates_to,)) + self.get_applicable_edit.invalidate((relates_to,)) + + async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]): + """Invalidates the cache and adds it to the cache stream so slaves + will know to invalidate their caches. + + This should only be used to invalidate caches where slaves won't + otherwise know from other replication streams that the cache should + be invalidated. + """ + cache_func = getattr(self, cache_name, None) + if not cache_func: + return + + cache_func.invalidate(keys) + await self.db.runInteraction( + "invalidate_cache_and_stream", + self._send_invalidation_to_replication, + cache_func.__name__, + keys, + ) + def _invalidate_cache_and_stream(self, txn, cache_func, keys): """Invalidates the cache and adds it to the cache stream so slaves will know to invalidate their caches. diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index 970c31bd05..9130b74eb5 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -37,8 +37,10 @@ from synapse.events import make_event_from_dict from synapse.events.utils import prune_event from synapse.logging.context import PreserveLoggingContext, current_context from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause from synapse.storage.database import Database +from synapse.storage.util.id_generators import StreamIdGenerator from synapse.types import get_domain_from_id from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks from synapse.util.iterutils import batch_iter @@ -74,6 +76,31 @@ class EventsWorkerStore(SQLBaseStore): def __init__(self, database: Database, db_conn, hs): super(EventsWorkerStore, self).__init__(database, db_conn, hs) + if hs.config.worker_app is None: + # We are the process in charge of generating stream ids for events, + # so instantiate ID generators based on the database + self._stream_id_gen = StreamIdGenerator( + db_conn, + "events", + "stream_ordering", + extra_tables=[("local_invites", "stream_id")], + ) + self._backfill_id_gen = StreamIdGenerator( + db_conn, + "events", + "stream_ordering", + step=-1, + extra_tables=[("ex_outlier_stream", "event_stream_ordering")], + ) + else: + # Another process is in charge of persisting events and generating + # stream IDs: rely on the replication streams to let us know which + # IDs we can process. + self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering") + self._backfill_id_gen = SlavedIdTracker( + db_conn, "events", "stream_ordering", step=-1 + ) + self._get_event_cache = Cache( "*getEvent*", keylen=3, @@ -85,6 +112,14 @@ class EventsWorkerStore(SQLBaseStore): self._event_fetch_list = [] self._event_fetch_ongoing = 0 + def process_replication_rows(self, stream_name, instance_name, token, rows): + if stream_name == "events": + self._stream_id_gen.advance(token) + elif stream_name == "backfill": + self._backfill_id_gen.advance(-token) + + super().process_replication_rows(stream_name, instance_name, token, rows) + def get_received_ts(self, event_id): """Get received_ts (when it was persisted) for the event. diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py index b3faafa0a4..ef8f40959f 100644 --- a/synapse/storage/data_stores/main/push_rule.py +++ b/synapse/storage/data_stores/main/push_rule.py @@ -16,19 +16,23 @@ import abc import logging +from typing import Union from canonicaljson import json from twisted.internet import defer from synapse.push.baserules import list_with_base_rules +from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.storage._base import SQLBaseStore from synapse.storage.data_stores.main.appservice import ApplicationServiceWorkerStore +from synapse.storage.data_stores.main.events_worker import EventsWorkerStore from synapse.storage.data_stores.main.pusher import PusherWorkerStore from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore from synapse.storage.database import Database from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException +from synapse.storage.util.id_generators import ChainedIdGenerator from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -64,6 +68,7 @@ class PushRulesWorkerStore( ReceiptsWorkerStore, PusherWorkerStore, RoomMemberWorkerStore, + EventsWorkerStore, SQLBaseStore, ): """This is an abstract base class where subclasses must implement @@ -77,6 +82,15 @@ class PushRulesWorkerStore( def __init__(self, database: Database, db_conn, hs): super(PushRulesWorkerStore, self).__init__(database, db_conn, hs) + if hs.config.worker.worker_app is None: + self._push_rules_stream_id_gen = ChainedIdGenerator( + self._stream_id_gen, db_conn, "push_rules_stream", "stream_id" + ) # type: Union[ChainedIdGenerator, SlavedIdTracker] + else: + self._push_rules_stream_id_gen = SlavedIdTracker( + db_conn, "push_rules_stream", "stream_id" + ) + push_rules_prefill, push_rules_id = self.db.get_cache_dict( db_conn, "push_rules_stream", diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 86d04ea9ac..f89ce0bed2 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -166,6 +166,7 @@ class ChainedIdGenerator(object): def __init__(self, chained_generator, db_conn, table, column): self.chained_generator = chained_generator + self._table = table self._lock = threading.Lock() self._current_max = _load_current_id(db_conn, table, column) self._unfinished_ids = deque() # type: Deque[Tuple[int, int]] @@ -204,6 +205,16 @@ class ChainedIdGenerator(object): return self._current_max, self.chained_generator.get_current_token() + def advance(self, token: int): + """Stub implementation for advancing the token when receiving updates + over replication; raises an exception as this instance should be the + only source of updates. + """ + + raise Exception( + "Attempted to advance token on source for table %r", self._table + ) + class MultiWriterIdGenerator: """An ID generator that tracks a stream that can have multiple writers. -- cgit 1.5.1 From 16090a077f9f387dcd42edabda58063e3df6b771 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Fri, 15 May 2020 17:17:42 +0100 Subject: Prevent 0-member/null room_version rooms from appearing in group room queries (#7465) --- changelog.d/7465.bugfix | 1 + synapse/storage/data_stores/main/group_server.py | 92 ++++++++++++++++++++---- 2 files changed, 79 insertions(+), 14 deletions(-) create mode 100644 changelog.d/7465.bugfix (limited to 'synapse/storage') diff --git a/changelog.d/7465.bugfix b/changelog.d/7465.bugfix new file mode 100644 index 0000000000..1cbe50caa5 --- /dev/null +++ b/changelog.d/7465.bugfix @@ -0,0 +1 @@ +Prevent rooms with 0 members or with invalid version strings from breaking group queries. \ No newline at end of file diff --git a/synapse/storage/data_stores/main/group_server.py b/synapse/storage/data_stores/main/group_server.py index 0963e6c250..fb1361f1c1 100644 --- a/synapse/storage/data_stores/main/group_server.py +++ b/synapse/storage/data_stores/main/group_server.py @@ -68,24 +68,78 @@ class GroupServerWorkerStore(SQLBaseStore): desc="get_invited_users_in_group", ) - def get_rooms_in_group(self, group_id, include_private=False): + def get_rooms_in_group(self, group_id: str, include_private: bool = False): + """Retrieve the rooms that belong to a given group. Does not return rooms that + lack members. + + Args: + group_id: The ID of the group to query for rooms + include_private: Whether to return private rooms in results + + Returns: + Deferred[List[Dict[str, str|bool]]]: A list of dictionaries, each in the + form of: + + { + "room_id": "!a_room_id:example.com", # The ID of the room + "is_public": False # Whether this is a public room or not + } + """ # TODO: Pagination - keyvalues = {"group_id": group_id} - if not include_private: - keyvalues["is_public"] = True + def _get_rooms_in_group_txn(txn): + sql = """ + SELECT room_id, is_public FROM group_rooms + WHERE group_id = ? + AND room_id IN ( + SELECT group_rooms.room_id FROM group_rooms + LEFT JOIN room_stats_current ON + group_rooms.room_id = room_stats_current.room_id + AND joined_members > 0 + AND local_users_in_room > 0 + LEFT JOIN rooms ON + group_rooms.room_id = rooms.room_id + AND (room_version <> '') = ? + ) + """ + args = [group_id, False] - return self.db.simple_select_list( - table="group_rooms", - keyvalues=keyvalues, - retcols=("room_id", "is_public"), - desc="get_rooms_in_group", - ) + if not include_private: + sql += " AND is_public = ?" + args += [True] + + txn.execute(sql, args) + + return [ + {"room_id": room_id, "is_public": is_public} + for room_id, is_public in txn + ] - def get_rooms_for_summary_by_category(self, group_id, include_private=False): + return self.db.runInteraction("get_rooms_in_group", _get_rooms_in_group_txn) + + def get_rooms_for_summary_by_category( + self, group_id: str, include_private: bool = False, + ): """Get the rooms and categories that should be included in a summary request - Returns ([rooms], [categories]) + Args: + group_id: The ID of the group to query the summary for + include_private: Whether to return private rooms in results + + Returns: + Deferred[Tuple[List, Dict]]: A tuple containing: + + * A list of dictionaries with the keys: + * "room_id": str, the room ID + * "is_public": bool, whether the room is public + * "category_id": str|None, the category ID if set, else None + * "order": int, the sort order of rooms + + * A dictionary with the key: + * category_id (str): a dictionary with the keys: + * "is_public": bool, whether the category is public + * "profile": str, the category profile + * "order": int, the sort order of rooms in this category """ def _get_rooms_for_summary_txn(txn): @@ -97,13 +151,23 @@ class GroupServerWorkerStore(SQLBaseStore): SELECT room_id, is_public, category_id, room_order FROM group_summary_rooms WHERE group_id = ? + AND room_id IN ( + SELECT group_rooms.room_id FROM group_rooms + LEFT JOIN room_stats_current ON + group_rooms.room_id = room_stats_current.room_id + AND joined_members > 0 + AND local_users_in_room > 0 + LEFT JOIN rooms ON + group_rooms.room_id = rooms.room_id + AND (room_version <> '') = ? + ) """ if not include_private: sql += " AND is_public = ?" - txn.execute(sql, (group_id, True)) + txn.execute(sql, (group_id, False, True)) else: - txn.execute(sql, (group_id,)) + txn.execute(sql, (group_id, False)) rooms = [ { -- cgit 1.5.1 From 03aff4c75ed3b0b106ed1395b3d03b1ab9b013a6 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 15 May 2020 17:22:47 +0100 Subject: Add a worker store for search insertion. (#7516) This is required as both event persistence and the background update needs access to this function. It should be perfectly safe for two workers to write to that table at the same time. --- changelog.d/7516.misc | 1 + synapse/app/generic_worker.py | 2 + synapse/storage/data_stores/main/search.py | 96 +++++++++++++++--------------- 3 files changed, 52 insertions(+), 47 deletions(-) create mode 100644 changelog.d/7516.misc (limited to 'synapse/storage') diff --git a/changelog.d/7516.misc b/changelog.d/7516.misc new file mode 100644 index 0000000000..94b0fd49b2 --- /dev/null +++ b/changelog.d/7516.misc @@ -0,0 +1 @@ +Add a worker store for search insertion, required for moving event persistence off master. diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 2e3add7ac5..ab801108ca 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -122,6 +122,7 @@ from synapse.storage.data_stores.main.monthly_active_users import ( MonthlyActiveUsersWorkerStore, ) from synapse.storage.data_stores.main.presence import UserPresenceState +from synapse.storage.data_stores.main.search import SearchWorkerStore from synapse.storage.data_stores.main.ui_auth import UIAuthWorkerStore from synapse.storage.data_stores.main.user_directory import UserDirectoryStore from synapse.types import ReadReceipt @@ -451,6 +452,7 @@ class GenericWorkerSlavedStore( SlavedFilteringStore, MonthlyActiveUsersWorkerStore, MediaRepositoryStore, + SearchWorkerStore, BaseSlavedStore, ): def __init__(self, database, db_conn, hs): diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py index ee75b92344..13f49d8060 100644 --- a/synapse/storage/data_stores/main/search.py +++ b/synapse/storage/data_stores/main/search.py @@ -37,7 +37,55 @@ SearchEntry = namedtuple( ) -class SearchBackgroundUpdateStore(SQLBaseStore): +class SearchWorkerStore(SQLBaseStore): + def store_search_entries_txn(self, txn, entries): + """Add entries to the search table + + Args: + txn (cursor): + entries (iterable[SearchEntry]): + entries to be added to the table + """ + if not self.hs.config.enable_search: + return + if isinstance(self.database_engine, PostgresEngine): + sql = ( + "INSERT INTO event_search" + " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)" + " VALUES (?,?,?,to_tsvector('english', ?),?,?)" + ) + + args = ( + ( + entry.event_id, + entry.room_id, + entry.key, + entry.value, + entry.stream_ordering, + entry.origin_server_ts, + ) + for entry in entries + ) + + txn.executemany(sql, args) + + elif isinstance(self.database_engine, Sqlite3Engine): + sql = ( + "INSERT INTO event_search (event_id, room_id, key, value)" + " VALUES (?,?,?,?)" + ) + args = ( + (entry.event_id, entry.room_id, entry.key, entry.value) + for entry in entries + ) + + txn.executemany(sql, args) + else: + # This should be unreachable. + raise Exception("Unrecognized database engine") + + +class SearchBackgroundUpdateStore(SearchWorkerStore): EVENT_SEARCH_UPDATE_NAME = "event_search" EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order" @@ -296,52 +344,6 @@ class SearchBackgroundUpdateStore(SQLBaseStore): return num_rows - def store_search_entries_txn(self, txn, entries): - """Add entries to the search table - - Args: - txn (cursor): - entries (iterable[SearchEntry]): - entries to be added to the table - """ - if not self.hs.config.enable_search: - return - if isinstance(self.database_engine, PostgresEngine): - sql = ( - "INSERT INTO event_search" - " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)" - " VALUES (?,?,?,to_tsvector('english', ?),?,?)" - ) - - args = ( - ( - entry.event_id, - entry.room_id, - entry.key, - entry.value, - entry.stream_ordering, - entry.origin_server_ts, - ) - for entry in entries - ) - - txn.executemany(sql, args) - - elif isinstance(self.database_engine, Sqlite3Engine): - sql = ( - "INSERT INTO event_search (event_id, room_id, key, value)" - " VALUES (?,?,?,?)" - ) - args = ( - (entry.event_id, entry.room_id, entry.key, entry.value) - for entry in entries - ) - - txn.executemany(sql, args) - else: - # This should be unreachable. - raise Exception("Unrecognized database engine") - class SearchStore(SearchBackgroundUpdateStore): def __init__(self, database: Database, db_conn, hs): -- cgit 1.5.1 From 34a43f0084ae2e16eb6e8f6eb2ecf4bb37bd4cf0 Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Fri, 15 May 2020 18:53:31 +0100 Subject: Fix a couple of small typos --- synapse/appservice/__init__.py | 2 +- synapse/storage/data_stores/main/appservice.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index aea3985a5f..1b13e84425 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -270,7 +270,7 @@ class ApplicationService(object): def is_exclusive_room(self, room_id): return self._is_exclusive(ApplicationService.NS_ROOMS, room_id) - def get_exlusive_user_regexes(self): + def get_exclusive_user_regexes(self): """Get the list of regexes used to determine if a user is exclusively registered by the AS """ diff --git a/synapse/storage/data_stores/main/appservice.py b/synapse/storage/data_stores/main/appservice.py index efbc06c796..7a1fe8cdd2 100644 --- a/synapse/storage/data_stores/main/appservice.py +++ b/synapse/storage/data_stores/main/appservice.py @@ -30,12 +30,12 @@ logger = logging.getLogger(__name__) def _make_exclusive_regex(services_cache): - # We precompie a regex constructed from all the regexes that the AS's + # We precompile a regex constructed from all the regexes that the AS's # have registered for exclusive users. exclusive_user_regexes = [ regex.pattern for service in services_cache - for regex in service.get_exlusive_user_regexes() + for regex in service.get_exclusive_user_regexes() ] if exclusive_user_regexes: exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes) -- cgit 1.5.1 From 6c1f7c722f0baade9aecf41f600fcced670c4fcb Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 15 May 2020 19:03:25 +0100 Subject: Fix limit logic for AccountDataStream (#7384) Make sure that the AccountDataStream presents complete updates, in the right order. This is much the same fix as #7337 and #7358, but applied to a different stream. --- changelog.d/7384.bugfix | 1 + synapse/replication/tcp/streams/_base.py | 68 +++++++++--- synapse/storage/data_stores/main/account_data.py | 62 +++++++---- tests/replication/tcp/streams/test_account_data.py | 117 +++++++++++++++++++++ 4 files changed, 217 insertions(+), 31 deletions(-) create mode 100644 changelog.d/7384.bugfix create mode 100644 tests/replication/tcp/streams/test_account_data.py (limited to 'synapse/storage') diff --git a/changelog.d/7384.bugfix b/changelog.d/7384.bugfix new file mode 100644 index 0000000000..f49c600173 --- /dev/null +++ b/changelog.d/7384.bugfix @@ -0,0 +1 @@ +Fix a bug where event updates might not be sent over replication to worker processes after the stream falls behind. diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index b48a6a3e91..d42aaff055 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -14,14 +14,27 @@ # See the License for the specific language governing permissions and # limitations under the License. +import heapq import logging from collections import namedtuple -from typing import Any, Awaitable, Callable, List, Optional, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + List, + Optional, + Tuple, + TypeVar, +) import attr from synapse.replication.http.streams import ReplicationGetStreamUpdates +if TYPE_CHECKING: + import synapse.server + logger = logging.getLogger(__name__) # the number of rows to request from an update_function. @@ -37,7 +50,7 @@ Token = int # parsing with Stream.parse_row (which turns it into a `ROW_TYPE`). Normally it's # just a row from a database query, though this is dependent on the stream in question. # -StreamRow = Tuple +StreamRow = TypeVar("StreamRow", bound=Tuple) # The type returned by the update_function of a stream, as well as get_updates(), # get_updates_since, etc. @@ -533,32 +546,63 @@ class AccountDataStream(Stream): """ AccountDataStreamRow = namedtuple( - "AccountDataStream", ("user_id", "room_id", "data_type") # str # str # str + "AccountDataStream", + ("user_id", "room_id", "data_type"), # str # Optional[str] # str ) NAME = "account_data" ROW_TYPE = AccountDataStreamRow - def __init__(self, hs): + def __init__(self, hs: "synapse.server.HomeServer"): self.store = hs.get_datastore() super().__init__( hs.get_instance_name(), current_token_without_instance(self.store.get_max_account_data_stream_id), - db_query_to_update_function(self._update_function), + self._update_function, + ) + + async def _update_function( + self, instance_name: str, from_token: int, to_token: int, limit: int + ) -> StreamUpdateResult: + limited = False + global_results = await self.store.get_updated_global_account_data( + from_token, to_token, limit ) - async def _update_function(self, from_token, to_token, limit): - global_results, room_results = await self.store.get_all_updated_account_data( - from_token, from_token, to_token, limit + # if the global results hit the limit, we'll need to limit the room results to + # the same stream token. + if len(global_results) >= limit: + to_token = global_results[-1][0] + limited = True + + room_results = await self.store.get_updated_room_account_data( + from_token, to_token, limit ) - results = list(room_results) - results.extend( - (stream_id, user_id, None, account_data_type) + # likewise, if the room results hit the limit, limit the global results to + # the same stream token. + if len(room_results) >= limit: + to_token = room_results[-1][0] + limited = True + + # convert the global results to the right format, and limit them to the to_token + # at the same time + global_rows = ( + (stream_id, (user_id, None, account_data_type)) for stream_id, user_id, account_data_type in global_results + if stream_id <= to_token + ) + + # we know that the room_results are already limited to `to_token` so no need + # for a check on `stream_id` here. + room_rows = ( + (stream_id, (user_id, room_id, account_data_type)) + for stream_id, user_id, room_id, account_data_type in room_results ) - return results + # we need to return a sorted list, so merge them together. + updates = list(heapq.merge(room_rows, global_rows)) + return updates, to_token, limited class GroupServerStream(Stream): diff --git a/synapse/storage/data_stores/main/account_data.py b/synapse/storage/data_stores/main/account_data.py index 46b494b334..f9eef1b78e 100644 --- a/synapse/storage/data_stores/main/account_data.py +++ b/synapse/storage/data_stores/main/account_data.py @@ -16,6 +16,7 @@ import abc import logging +from typing import List, Tuple from canonicaljson import json @@ -175,41 +176,64 @@ class AccountDataWorkerStore(SQLBaseStore): "get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn ) - def get_all_updated_account_data( - self, last_global_id, last_room_id, current_id, limit - ): - """Get all the client account_data that has changed on the server + async def get_updated_global_account_data( + self, last_id: int, current_id: int, limit: int + ) -> List[Tuple[int, str, str]]: + """Get the global account_data that has changed, for the account_data stream + Args: - last_global_id(int): The position to fetch from for top level data - last_room_id(int): The position to fetch from for per room data - current_id(int): The position to fetch up to. + last_id: the last stream_id from the previous batch. + current_id: the maximum stream_id to return up to + limit: the maximum number of rows to return + Returns: - A deferred pair of lists of tuples of stream_id int, user_id string, - room_id string, and type string. + A list of tuples of stream_id int, user_id string, + and type string. """ - if last_room_id == current_id and last_global_id == current_id: - return defer.succeed(([], [])) + if last_id == current_id: + return [] - def get_updated_account_data_txn(txn): + def get_updated_global_account_data_txn(txn): sql = ( "SELECT stream_id, user_id, account_data_type" " FROM account_data WHERE ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC LIMIT ?" ) - txn.execute(sql, (last_global_id, current_id, limit)) - global_results = txn.fetchall() + txn.execute(sql, (last_id, current_id, limit)) + return txn.fetchall() + + return await self.db.runInteraction( + "get_updated_global_account_data", get_updated_global_account_data_txn + ) + + async def get_updated_room_account_data( + self, last_id: int, current_id: int, limit: int + ) -> List[Tuple[int, str, str, str]]: + """Get the global account_data that has changed, for the account_data stream + Args: + last_id: the last stream_id from the previous batch. + current_id: the maximum stream_id to return up to + limit: the maximum number of rows to return + + Returns: + A list of tuples of stream_id int, user_id string, + room_id string and type string. + """ + if last_id == current_id: + return [] + + def get_updated_room_account_data_txn(txn): sql = ( "SELECT stream_id, user_id, room_id, account_data_type" " FROM room_account_data WHERE ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC LIMIT ?" ) - txn.execute(sql, (last_room_id, current_id, limit)) - room_results = txn.fetchall() - return global_results, room_results + txn.execute(sql, (last_id, current_id, limit)) + return txn.fetchall() - return self.db.runInteraction( - "get_all_updated_account_data_txn", get_updated_account_data_txn + return await self.db.runInteraction( + "get_updated_room_account_data", get_updated_room_account_data_txn ) def get_updated_account_data_for_user(self, user_id, stream_id): diff --git a/tests/replication/tcp/streams/test_account_data.py b/tests/replication/tcp/streams/test_account_data.py new file mode 100644 index 0000000000..6a5116dd2a --- /dev/null +++ b/tests/replication/tcp/streams/test_account_data.py @@ -0,0 +1,117 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.replication.tcp.streams._base import ( + _STREAM_UPDATE_TARGET_ROW_COUNT, + AccountDataStream, +) + +from tests.replication._base import BaseStreamTestCase + + +class AccountDataStreamTestCase(BaseStreamTestCase): + def test_update_function_room_account_data_limit(self): + """Test replication with many room account data updates + """ + store = self.hs.get_datastore() + + # generate lots of account data updates + updates = [] + for i in range(_STREAM_UPDATE_TARGET_ROW_COUNT + 5): + update = "m.test_type.%i" % (i,) + self.get_success( + store.add_account_data_to_room("test_user", "test_room", update, {}) + ) + updates.append(update) + + # also one global update + self.get_success(store.add_account_data_for_user("test_user", "m.global", {})) + + # tell the notifier to catch up to avoid duplicate rows. + # workaround for https://github.com/matrix-org/synapse/issues/7360 + # FIXME remove this when the above is fixed + self.replicate() + + # check we're testing what we think we are: no rows should yet have been + # received + self.assertEqual([], self.test_handler.received_rdata_rows) + + # now reconnect to pull the updates + self.reconnect() + self.replicate() + + # we should have received all the expected rows in the right order + received_rows = self.test_handler.received_rdata_rows + + for t in updates: + (stream_name, token, row) = received_rows.pop(0) + self.assertEqual(stream_name, AccountDataStream.NAME) + self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow) + self.assertEqual(row.data_type, t) + self.assertEqual(row.room_id, "test_room") + + (stream_name, token, row) = received_rows.pop(0) + self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow) + self.assertEqual(row.data_type, "m.global") + self.assertIsNone(row.room_id) + + self.assertEqual([], received_rows) + + def test_update_function_global_account_data_limit(self): + """Test replication with many global account data updates + """ + store = self.hs.get_datastore() + + # generate lots of account data updates + updates = [] + for i in range(_STREAM_UPDATE_TARGET_ROW_COUNT + 5): + update = "m.test_type.%i" % (i,) + self.get_success(store.add_account_data_for_user("test_user", update, {})) + updates.append(update) + + # also one per-room update + self.get_success( + store.add_account_data_to_room("test_user", "test_room", "m.per_room", {}) + ) + + # tell the notifier to catch up to avoid duplicate rows. + # workaround for https://github.com/matrix-org/synapse/issues/7360 + # FIXME remove this when the above is fixed + self.replicate() + + # check we're testing what we think we are: no rows should yet have been + # received + self.assertEqual([], self.test_handler.received_rdata_rows) + + # now reconnect to pull the updates + self.reconnect() + self.replicate() + + # we should have received all the expected rows in the right order + received_rows = self.test_handler.received_rdata_rows + + for t in updates: + (stream_name, token, row) = received_rows.pop(0) + self.assertEqual(stream_name, AccountDataStream.NAME) + self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow) + self.assertEqual(row.data_type, t) + self.assertIsNone(row.room_id) + + (stream_name, token, row) = received_rows.pop(0) + self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow) + self.assertEqual(row.data_type, "m.per_room") + self.assertEqual(row.room_id, "test_room") + + self.assertEqual([], received_rows) -- cgit 1.5.1 From 08fa96f03037178620f5f0dd609fac52fbf7f2d1 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 15 May 2020 19:07:24 +0100 Subject: Remove `exception_to_unicode` this is a no-op on python 3. --- synapse/storage/database.py | 15 +++------------ synapse/util/stringutils.py | 36 ------------------------------------ 2 files changed, 3 insertions(+), 48 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/database.py b/synapse/storage/database.py index c3d0863429..9947dbce77 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -50,7 +50,6 @@ from synapse.storage.background_updates import BackgroundUpdater from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.storage.types import Connection, Cursor from synapse.types import Collection -from synapse.util.stringutils import exception_to_unicode logger = logging.getLogger(__name__) @@ -424,20 +423,14 @@ class Database(object): # This can happen if the database disappears mid # transaction. logger.warning( - "[TXN OPERROR] {%s} %s %d/%d", - name, - exception_to_unicode(e), - i, - N, + "[TXN OPERROR] {%s} %s %d/%d", name, e, i, N, ) if i < N: i += 1 try: conn.rollback() except self.engine.module.Error as e1: - logger.warning( - "[TXN EROLL] {%s} %s", name, exception_to_unicode(e1) - ) + logger.warning("[TXN EROLL] {%s} %s", name, e1) continue raise except self.engine.module.DatabaseError as e: @@ -449,9 +442,7 @@ class Database(object): conn.rollback() except self.engine.module.Error as e1: logger.warning( - "[TXN EROLL] {%s} %s", - name, - exception_to_unicode(e1), + "[TXN EROLL] {%s} %s", name, e1, ) continue raise diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index 6899bcb788..2cfa5cf721 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -85,42 +85,6 @@ def to_ascii(s): return s -def exception_to_unicode(e): - """Helper function to extract the text of an exception as a unicode string - - Args: - e (Exception): exception to be stringified - - Returns: - unicode - """ - # urgh, this is a mess. The basic problem here is that psycopg2 constructs its - # exceptions with PyErr_SetString, with a (possibly non-ascii) argument. str() will - # then produce the raw byte sequence. Under Python 2, this will then cause another - # error if it gets mixed with a `unicode` object, as per - # https://github.com/matrix-org/synapse/issues/4252 - - # First of all, if we're under python3, everything is fine because it will sort this - # nonsense out for us. - if not PY2: - return str(e) - - # otherwise let's have a stab at decoding the exception message. We'll circumvent - # Exception.__str__(), which would explode if someone raised Exception(u'non-ascii') - # and instead look at what is in the args member. - - if len(e.args) == 0: - return "" - elif len(e.args) > 1: - return six.text_type(repr(e.args)) - - msg = e.args[0] - if isinstance(msg, bytes): - return msg.decode("utf-8", errors="replace") - else: - return msg - - def assert_valid_client_secret(client_secret): """Validate that a given string matches the client_secret regex defined by the spec""" if client_secret_regex.match(client_secret) is None: -- cgit 1.5.1 From 65902e08c3f4449de9baa4e6466f126585f688b3 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 15 May 2020 19:12:03 +0100 Subject: remove to_ascii this is a no-op on python 3. --- synapse/storage/data_stores/main/roommember.py | 25 ++++++++++--------------- synapse/storage/data_stores/main/state.py | 5 +---- synapse/util/stringutils.py | 20 +------------------- 3 files changed, 12 insertions(+), 38 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py index 48810a3e91..1e9c850152 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py @@ -45,7 +45,6 @@ from synapse.util.async_helpers import Linearizer from synapse.util.caches import intern_string from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList from synapse.util.metrics import Measure -from synapse.util.stringutils import to_ascii logger = logging.getLogger(__name__) @@ -179,7 +178,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): """ txn.execute(sql, (room_id, Membership.JOIN)) - return [to_ascii(r[0]) for r in txn] + return [r[0] for r in txn] @cached(max_entries=100000) def get_room_summary(self, room_id): @@ -223,7 +222,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): txn.execute(sql, (room_id,)) res = {} for count, membership in txn: - summary = res.setdefault(to_ascii(membership), MemberSummary([], count)) + summary = res.setdefault(membership, MemberSummary([], count)) # we order by membership and then fairly arbitrarily by event_id so # heroes are consistent @@ -255,11 +254,11 @@ class RoomMemberWorkerStore(EventsWorkerStore): # 6 is 5 (number of heroes) plus 1, in case one of them is the calling user. txn.execute(sql, (room_id, Membership.JOIN, Membership.INVITE, 6)) for user_id, membership, event_id in txn: - summary = res[to_ascii(membership)] + summary = res[membership] # we will always have a summary for this membership type at this # point given the summary currently contains the counts. members = summary.members - members.append((to_ascii(user_id), to_ascii(event_id))) + members.append((user_id, event_id)) return res @@ -584,13 +583,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): ev_entry = event_map.get(event_id) if ev_entry: if ev_entry.event.membership == Membership.JOIN: - users_in_room[to_ascii(ev_entry.event.state_key)] = ProfileInfo( - display_name=to_ascii( - ev_entry.event.content.get("displayname", None) - ), - avatar_url=to_ascii( - ev_entry.event.content.get("avatar_url", None) - ), + users_in_room[ev_entry.event.state_key] = ProfileInfo( + display_name=ev_entry.event.content.get("displayname", None), + avatar_url=ev_entry.event.content.get("avatar_url", None), ) else: missing_member_event_ids.append(event_id) @@ -604,9 +599,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): if event is not None and event.type == EventTypes.Member: if event.membership == Membership.JOIN: if event.event_id in member_event_ids: - users_in_room[to_ascii(event.state_key)] = ProfileInfo( - display_name=to_ascii(event.content.get("displayname", None)), - avatar_url=to_ascii(event.content.get("avatar_url", None)), + users_in_room[event.state_key] = ProfileInfo( + display_name=event.content.get("displayname", None), + avatar_url=event.content.get("avatar_url", None), ) return users_in_room diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py index 21052fcc7a..347cc50778 100644 --- a/synapse/storage/data_stores/main/state.py +++ b/synapse/storage/data_stores/main/state.py @@ -29,7 +29,6 @@ from synapse.storage.database import Database from synapse.storage.state import StateFilter from synapse.util.caches import intern_string from synapse.util.caches.descriptors import cached, cachedList -from synapse.util.stringutils import to_ascii logger = logging.getLogger(__name__) @@ -185,9 +184,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): (room_id,), ) - return { - (intern_string(r[0]), intern_string(r[1])): to_ascii(r[2]) for r in txn - } + return {(intern_string(r[0]), intern_string(r[1])): r[2] for r in txn} return self.db.runInteraction( "get_current_state_ids", _get_current_state_ids_txn diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index 2cfa5cf721..81a44184ca 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -19,8 +19,7 @@ import re import string from collections import Iterable -import six -from six import PY2, PY3 +from six import PY3 from six.moves import range from synapse.api.errors import Codes, SynapseError @@ -68,23 +67,6 @@ def is_ascii(s): return True -def to_ascii(s): - """Converts a string to ascii if it is ascii, otherwise leave it alone. - - If given None then will return None. - """ - if PY3: - return s - - if s is None: - return None - - try: - return s.encode("ascii") - except UnicodeEncodeError: - return s - - def assert_valid_client_secret(client_secret): """Validate that a given string matches the client_secret regex defined by the spec""" if client_secret_regex.match(client_secret) is None: -- cgit 1.5.1 From e6027562e2a1964bcaa0163f1615ab72bfc6630b Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 15 May 2020 19:26:54 +0100 Subject: remove `builtins.buffer` code from storage code this is no longer needed on python 3 --- scripts-dev/convert_server_keys.py | 9 ++------- synapse/storage/_base.py | 8 -------- synapse/storage/data_stores/main/keys.py | 10 ++-------- synapse/storage/data_stores/main/transactions.py | 9 +-------- 4 files changed, 5 insertions(+), 31 deletions(-) (limited to 'synapse/storage') diff --git a/scripts-dev/convert_server_keys.py b/scripts-dev/convert_server_keys.py index 06b4c1e2ff..961dc59f11 100644 --- a/scripts-dev/convert_server_keys.py +++ b/scripts-dev/convert_server_keys.py @@ -3,8 +3,6 @@ import json import sys import time -import six - import psycopg2 import yaml from canonicaljson import encode_canonical_json @@ -12,10 +10,7 @@ from signedjson.key import read_signing_keys from signedjson.sign import sign_json from unpaddedbase64 import encode_base64 -if six.PY2: - db_type = six.moves.builtins.buffer -else: - db_type = memoryview +db_binary_type = memoryview def select_v1_keys(connection): @@ -72,7 +67,7 @@ def rows_v2(server, json): valid_until = json["valid_until_ts"] key_json = encode_canonical_json(json) for key_id in json["verify_keys"]: - yield (server, key_id, "-", valid_until, valid_until, db_type(key_json)) + yield (server, key_id, "-", valid_until, valid_until, db_binary_type(key_json)) def main(): diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 59073c0a42..bfce541ca7 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -19,9 +19,6 @@ import random from abc import ABCMeta from typing import Any, Optional -from six import PY2 -from six.moves import builtins - from canonicaljson import json from synapse.storage.database import LoggingTransaction # noqa: F401 @@ -103,11 +100,6 @@ def db_to_json(db_content): if isinstance(db_content, memoryview): db_content = db_content.tobytes() - # psycopg2 on Python 2 returns buffer objects, which we need to cast to - # bytes to decode - if PY2 and isinstance(db_content, builtins.buffer): - db_content = bytes(db_content) - # Decode it to a Unicode string before feeding it to json.loads, so we # consistenty get a Unicode-containing object out. if isinstance(db_content, (bytes, bytearray)): diff --git a/synapse/storage/data_stores/main/keys.py b/synapse/storage/data_stores/main/keys.py index ba89c68c9f..4e1642a27a 100644 --- a/synapse/storage/data_stores/main/keys.py +++ b/synapse/storage/data_stores/main/keys.py @@ -17,8 +17,6 @@ import itertools import logging -import six - from signedjson.key import decode_verify_key_bytes from synapse.storage._base import SQLBaseStore @@ -28,12 +26,8 @@ from synapse.util.iterutils import batch_iter logger = logging.getLogger(__name__) -# py2 sqlite has buffer hardcoded as only binary type, so we must use it, -# despite being deprecated and removed in favor of memoryview -if six.PY2: - db_binary_type = six.moves.builtins.buffer -else: - db_binary_type = memoryview + +db_binary_type = memoryview class KeyStore(SQLBaseStore): diff --git a/synapse/storage/data_stores/main/transactions.py b/synapse/storage/data_stores/main/transactions.py index 5b07c2fbc0..a9bf457939 100644 --- a/synapse/storage/data_stores/main/transactions.py +++ b/synapse/storage/data_stores/main/transactions.py @@ -16,8 +16,6 @@ import logging from collections import namedtuple -import six - from canonicaljson import encode_canonical_json from twisted.internet import defer @@ -27,12 +25,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import Database from synapse.util.caches.expiringcache import ExpiringCache -# py2 sqlite has buffer hardcoded as only binary type, so we must use it, -# despite being deprecated and removed in favor of memoryview -if six.PY2: - db_binary_type = six.moves.builtins.buffer -else: - db_binary_type = memoryview +db_binary_type = memoryview logger = logging.getLogger(__name__) -- cgit 1.5.1 From f6f92845f84301619f4c48eed9364b25db0b1445 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 21 May 2020 13:20:10 +0100 Subject: Fix bug in persist events when dealing with non member types. (#7548) `_is_server_still_joined` will throw if it is given state updates with non-user ID state keys with local user leaves. This is actually rarely a problem since local leaves almost always get persisted by themselves. (I discovered this on a branch that was otherwise broken, so I haven't seen this in the wild) --- changelog.d/7548.bugfix | 1 + synapse/storage/persist_events.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) create mode 100644 changelog.d/7548.bugfix (limited to 'synapse/storage') diff --git a/changelog.d/7548.bugfix b/changelog.d/7548.bugfix new file mode 100644 index 0000000000..1233b3b31a --- /dev/null +++ b/changelog.d/7548.bugfix @@ -0,0 +1 @@ +Fix bug where a local user leaving a room could fail under rare circumstances. diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index 41881ea20b..12e1ffb9a2 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -740,8 +740,8 @@ class EventsPersistenceStorage(object): # whose state has changed as we've already their new state above. users_to_ignore = [ state_key - for _, state_key in itertools.chain(delta.to_insert, delta.to_delete) - if self.is_mine_id(state_key) + for typ, state_key in itertools.chain(delta.to_insert, delta.to_delete) + if typ == EventTypes.Member and self.is_mine_id(state_key) ] if await self.main_store.is_local_host_in_room_ignoring_users( -- cgit 1.5.1 From d1ae1015ecb7dcb36d7cb39d5d41733e1ced2a52 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Thu, 21 May 2020 17:41:12 +0200 Subject: Retry to sync out of sync device lists (#7453) When a call to `user_device_resync` fails, we don't currently mark the remote user's device list as out of sync, nor do we retry to sync it. https://github.com/matrix-org/synapse/pull/6776 introduced some code infrastructure to mark device lists as stale/out of sync. This commit uses that code infrastructure to mark device lists as out of sync if processing an incoming device list update makes the device handler realise that the device list is out of sync, but we can't resync right now. It also adds a looping call to retry all failed resync every 30s. This shouldn't cause too much spam in the logs as this commit also removes the "Failed to handle device list update for..." warning logs when catching `NotRetryingDestination`. Fixes #7418 --- changelog.d/7453.bugfix | 1 + synapse/handlers/device.py | 80 ++++++++++++++++++++++++++--- synapse/storage/data_stores/main/devices.py | 34 +++++++----- tests/test_federation.py | 63 ++++++++++++++++++++++- 4 files changed, 158 insertions(+), 20 deletions(-) create mode 100644 changelog.d/7453.bugfix (limited to 'synapse/storage') diff --git a/changelog.d/7453.bugfix b/changelog.d/7453.bugfix new file mode 100644 index 0000000000..629a3d61e2 --- /dev/null +++ b/changelog.d/7453.bugfix @@ -0,0 +1 @@ +Fix a bug that would cause Synapse not to resync out-of-sync device lists. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 9bd941b5a0..29a19b4572 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -29,6 +29,7 @@ from synapse.api.errors import ( SynapseError, ) from synapse.logging.opentracing import log_kv, set_tag, trace +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import RoomStreamToken, get_domain_from_id from synapse.util import stringutils from synapse.util.async_helpers import Linearizer @@ -535,6 +536,15 @@ class DeviceListUpdater(object): iterable=True, ) + # Attempt to resync out of sync device lists every 30s. + self._resync_retry_in_progress = False + self.clock.looping_call( + run_as_background_process, + 30 * 1000, + func=self._maybe_retry_device_resync, + desc="_maybe_retry_device_resync", + ) + @trace @defer.inlineCallbacks def incoming_device_list_update(self, origin, edu_content): @@ -679,11 +689,50 @@ class DeviceListUpdater(object): return False @defer.inlineCallbacks - def user_device_resync(self, user_id): + def _maybe_retry_device_resync(self): + """Retry to resync device lists that are out of sync, except if another retry is + in progress. + """ + if self._resync_retry_in_progress: + return + + try: + # Prevent another call of this function to retry resyncing device lists so + # we don't send too many requests. + self._resync_retry_in_progress = True + # Get all of the users that need resyncing. + need_resync = yield self.store.get_user_ids_requiring_device_list_resync() + # Iterate over the set of user IDs. + for user_id in need_resync: + # Try to resync the current user's devices list. Exception handling + # isn't necessary here, since user_device_resync catches all instances + # of "Exception" that might be raised from the federation request. This + # means that if an exception is raised by this function, it must be + # because of a database issue, which means _maybe_retry_device_resync + # probably won't be able to go much further anyway. + result = yield self.user_device_resync( + user_id=user_id, mark_failed_as_stale=False, + ) + # user_device_resync only returns a result if it managed to successfully + # resync and update the database. Updating the table of users requiring + # resync isn't necessary here as user_device_resync already does it + # (through self.store.update_remote_device_list_cache). + if result: + logger.debug( + "Successfully resynced the device list for %s" % user_id, + ) + finally: + # Allow future calls to retry resyncinc out of sync device lists. + self._resync_retry_in_progress = False + + @defer.inlineCallbacks + def user_device_resync(self, user_id, mark_failed_as_stale=True): """Fetches all devices for a user and updates the device cache with them. Args: user_id (str): The user's id whose device_list will be updated. + mark_failed_as_stale (bool): Whether to mark the user's device list as stale + if the attempt to resync failed. Returns: Deferred[dict]: a dict with device info as under the "devices" in the result of this request: @@ -694,10 +743,23 @@ class DeviceListUpdater(object): origin = get_domain_from_id(user_id) try: result = yield self.federation.query_user_devices(origin, user_id) - except (NotRetryingDestination, RequestSendFailed, HttpResponseException): - # TODO: Remember that we are now out of sync and try again - # later - logger.warning("Failed to handle device list update for %s", user_id) + except NotRetryingDestination: + if mark_failed_as_stale: + # Mark the remote user's device list as stale so we know we need to retry + # it later. + yield self.store.mark_remote_user_device_cache_as_stale(user_id) + + return + except (RequestSendFailed, HttpResponseException) as e: + logger.warning( + "Failed to handle device list update for %s: %s", user_id, e, + ) + + if mark_failed_as_stale: + # Mark the remote user's device list as stale so we know we need to retry + # it later. + yield self.store.mark_remote_user_device_cache_as_stale(user_id) + # We abort on exceptions rather than accepting the update # as otherwise synapse will 'forget' that its device list # is out of date. If we bail then we will retry the resync @@ -711,13 +773,17 @@ class DeviceListUpdater(object): logger.info(e) return except Exception as e: - # TODO: Remember that we are now out of sync and try again - # later set_tag("error", True) log_kv( {"message": "Exception raised by federation request", "exception": e} ) logger.exception("Failed to handle device list update for %s", user_id) + + if mark_failed_as_stale: + # Mark the remote user's device list as stale so we know we need to retry + # it later. + yield self.store.mark_remote_user_device_cache_as_stale(user_id) + return log_kv({"result": result}) stream_id = result["stream_id"] diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index fe6d6ecfe0..0e8378714a 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import List, Tuple +from typing import List, Optional, Set, Tuple from six import iteritems @@ -649,21 +649,31 @@ class DeviceWorkerStore(SQLBaseStore): return results @defer.inlineCallbacks - def get_user_ids_requiring_device_list_resync(self, user_ids: Collection[str]): + def get_user_ids_requiring_device_list_resync( + self, user_ids: Optional[Collection[str]] = None, + ) -> Set[str]: """Given a list of remote users return the list of users that we - should resync the device lists for. + should resync the device lists for. If None is given instead of a list, + return every user that we should resync the device lists for. Returns: - Deferred[Set[str]] + The IDs of users whose device lists need resync. """ - - rows = yield self.db.simple_select_many_batch( - table="device_lists_remote_resync", - column="user_id", - iterable=user_ids, - retcols=("user_id",), - desc="get_user_ids_requiring_device_list_resync", - ) + if user_ids: + rows = yield self.db.simple_select_many_batch( + table="device_lists_remote_resync", + column="user_id", + iterable=user_ids, + retcols=("user_id",), + desc="get_user_ids_requiring_device_list_resync_with_iterable", + ) + else: + rows = yield self.db.simple_select_list( + table="device_lists_remote_resync", + keyvalues=None, + retcols=("user_id",), + desc="get_user_ids_requiring_device_list_resync", + ) return {row["user_id"] for row in rows} diff --git a/tests/test_federation.py b/tests/test_federation.py index f297de95f1..13ff14863e 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -6,12 +6,13 @@ from synapse.events import make_event_from_dict from synapse.logging.context import LoggingContext from synapse.types import Requester, UserID from synapse.util import Clock +from synapse.util.retryutils import NotRetryingDestination from tests import unittest from tests.server import ThreadedMemoryReactorClock, setup_test_homeserver -class MessageAcceptTests(unittest.TestCase): +class MessageAcceptTests(unittest.HomeserverTestCase): def setUp(self): self.http_client = Mock() @@ -145,3 +146,63 @@ class MessageAcceptTests(unittest.TestCase): # Make sure the invalid event isn't there extrem = maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id) self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv") + + def test_retry_device_list_resync(self): + """Tests that device lists are marked as stale if they couldn't be synced, and + that stale device lists are retried periodically. + """ + remote_user_id = "@john:test_remote" + remote_origin = "test_remote" + + # Track the number of attempts to resync the user's device list. + self.resync_attempts = 0 + + # When this function is called, increment the number of resync attempts (only if + # we're querying devices for the right user ID), then raise a + # NotRetryingDestination error to fail the resync gracefully. + def query_user_devices(destination, user_id): + if user_id == remote_user_id: + self.resync_attempts += 1 + + raise NotRetryingDestination(0, 0, destination) + + # Register the mock on the federation client. + federation_client = self.homeserver.get_federation_client() + federation_client.query_user_devices = Mock(side_effect=query_user_devices) + + # Register a mock on the store so that the incoming update doesn't fail because + # we don't share a room with the user. + store = self.homeserver.get_datastore() + store.get_rooms_for_user = Mock(return_value=["!someroom:test"]) + + # Manually inject a fake device list update. We need this update to include at + # least one prev_id so that the user's device list will need to be retried. + device_list_updater = self.homeserver.get_device_handler().device_list_updater + self.get_success( + device_list_updater.incoming_device_list_update( + origin=remote_origin, + edu_content={ + "deleted": False, + "device_display_name": "Mobile", + "device_id": "QBUAZIFURK", + "prev_id": [5], + "stream_id": 6, + "user_id": remote_user_id, + }, + ) + ) + + # Check that there was one resync attempt. + self.assertEqual(self.resync_attempts, 1) + + # Check that the resync attempt failed and caused the user's device list to be + # marked as stale. + need_resync = self.get_success( + store.get_user_ids_requiring_device_list_resync() + ) + self.assertIn(remote_user_id, need_resync) + + # Check that waiting for 30 seconds caused Synapse to retry resyncing the device + # list. + self.reactor.advance(30) + self.assertEqual(self.resync_attempts, 2) -- cgit 1.5.1 From 1531b214fc57714c14046a8f66c7b5fe5ec5dcdd Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 22 May 2020 14:21:54 +0100 Subject: Add ability to wait for replication streams (#7542) The idea here is that if an instance persists an event via the replication HTTP API it can return before we receive that event over replication, which can lead to races where code assumes that persisting an event immediately updates various caches (e.g. current state of the room). Most of Synapse doesn't hit such races, so we don't do the waiting automagically, instead we do so where necessary to avoid unnecessary delays. We may decide to change our minds here if it turns out there are a lot of subtle races going on. People probably want to look at this commit by commit. --- changelog.d/7542.misc | 1 + synapse/handlers/federation.py | 33 ++++++--- synapse/handlers/message.py | 36 ++++++--- synapse/handlers/room.py | 65 +++++++++++----- synapse/handlers/room_member.py | 65 ++++++++++------ synapse/handlers/room_member_worker.py | 11 +-- synapse/replication/http/federation.py | 13 +++- synapse/replication/http/membership.py | 14 ++-- synapse/replication/http/send_event.py | 4 +- synapse/replication/http/streams.py | 5 +- synapse/replication/tcp/client.py | 90 ++++++++++++++++++++++- synapse/rest/admin/rooms.py | 10 ++- synapse/rest/client/v1/room.py | 20 +++-- synapse/rest/client/v2_alpha/relations.py | 2 +- synapse/server.py | 5 ++ synapse/server.pyi | 5 ++ synapse/server_notices/server_notices_manager.py | 6 +- synapse/storage/data_stores/main/events_worker.py | 6 +- synapse/storage/data_stores/main/roommember.py | 2 + tests/federation/test_complexity.py | 8 +- tests/handlers/test_typing.py | 5 +- tests/storage/test_cleanup_extrems.py | 4 +- tests/storage/test_event_metrics.py | 2 +- tests/test_federation.py | 4 +- 24 files changed, 304 insertions(+), 112 deletions(-) create mode 100644 changelog.d/7542.misc (limited to 'synapse/storage') diff --git a/changelog.d/7542.misc b/changelog.d/7542.misc new file mode 100644 index 0000000000..7dd9b4823b --- /dev/null +++ b/changelog.d/7542.misc @@ -0,0 +1 @@ +Add ability to wait for replication streams. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index bb03cc9add..e354c803db 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -126,6 +126,7 @@ class FederationHandler(BaseHandler): self._server_notices_mxid = hs.config.server_notices_mxid self.config = hs.config self.http_client = hs.get_simple_http_client() + self._replication = hs.get_replication_data_handler() self._send_events_to_master = ReplicationFederationSendEventsRestServlet.make_client( hs @@ -1221,7 +1222,7 @@ class FederationHandler(BaseHandler): async def do_invite_join( self, target_hosts: Iterable[str], room_id: str, joinee: str, content: JsonDict - ) -> None: + ) -> Tuple[str, int]: """ Attempts to join the `joinee` to the room `room_id` via the servers contained in `target_hosts`. @@ -1304,15 +1305,23 @@ class FederationHandler(BaseHandler): room_id=room_id, room_version=room_version_obj, ) - await self._persist_auth_tree( + max_stream_id = await self._persist_auth_tree( origin, auth_chain, state, event, room_version_obj ) + # We wait here until this instance has seen the events come down + # replication (if we're using replication) as the below uses caches. + # + # TODO: Currently the events stream is written to from master + await self._replication.wait_for_stream_position( + "master", "events", max_stream_id + ) + # Check whether this room is the result of an upgrade of a room we already know # about. If so, migrate over user information predecessor = await self.store.get_room_predecessor(room_id) if not predecessor or not isinstance(predecessor.get("room_id"), str): - return + return event.event_id, max_stream_id old_room_id = predecessor["room_id"] logger.debug( "Found predecessor for %s during remote join: %s", room_id, old_room_id @@ -1325,6 +1334,7 @@ class FederationHandler(BaseHandler): ) logger.debug("Finished joining %s to %s", joinee, room_id) + return event.event_id, max_stream_id finally: room_queue = self.room_queues[room_id] del self.room_queues[room_id] @@ -1554,7 +1564,7 @@ class FederationHandler(BaseHandler): async def do_remotely_reject_invite( self, target_hosts: Iterable[str], room_id: str, user_id: str, content: JsonDict - ) -> EventBase: + ) -> Tuple[EventBase, int]: origin, event, room_version = await self._make_and_verify_event( target_hosts, room_id, user_id, "leave", content=content ) @@ -1574,9 +1584,9 @@ class FederationHandler(BaseHandler): await self.federation_client.send_leave(target_hosts, event) context = await self.state_handler.compute_event_context(event) - await self.persist_events_and_notify([(event, context)]) + stream_id = await self.persist_events_and_notify([(event, context)]) - return event + return event, stream_id async def _make_and_verify_event( self, @@ -1888,7 +1898,7 @@ class FederationHandler(BaseHandler): state: List[EventBase], event: EventBase, room_version: RoomVersion, - ) -> None: + ) -> int: """Checks the auth chain is valid (and passes auth checks) for the state and event. Then persists the auth chain and state atomically. Persists the event separately. Notifies about the persisted events @@ -1982,7 +1992,7 @@ class FederationHandler(BaseHandler): event, old_state=state ) - await self.persist_events_and_notify([(event, new_event_context)]) + return await self.persist_events_and_notify([(event, new_event_context)]) async def _prep_event( self, @@ -2835,7 +2845,7 @@ class FederationHandler(BaseHandler): self, event_and_contexts: Sequence[Tuple[EventBase, EventContext]], backfilled: bool = False, - ) -> None: + ) -> int: """Persists events and tells the notifier/pushers about them, if necessary. @@ -2845,11 +2855,12 @@ class FederationHandler(BaseHandler): backfilling or not """ if self.config.worker_app: - await self._send_events_to_master( + result = await self._send_events_to_master( store=self.store, event_and_contexts=event_and_contexts, backfilled=backfilled, ) + return result["max_stream_id"] else: max_stream_id = await self.storage.persistence.persist_events( event_and_contexts, backfilled=backfilled @@ -2864,6 +2875,8 @@ class FederationHandler(BaseHandler): for event, _ in event_and_contexts: await self._notify_persisted_event(event, max_stream_id) + return max_stream_id + async def _notify_persisted_event( self, event: EventBase, max_stream_id: int ) -> None: diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 8f362896a2..f445e2aa2a 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Optional +from typing import Optional, Tuple from six import iteritems, itervalues, string_types @@ -42,6 +42,7 @@ from synapse.api.errors import ( ) from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions from synapse.api.urls import ConsentURIBuilder +from synapse.events import EventBase from synapse.events.validator import EventValidator from synapse.logging.context import run_in_background from synapse.metrics.background_process_metrics import run_as_background_process @@ -630,7 +631,9 @@ class EventCreationHandler(object): msg = self._block_events_without_consent_error % {"consent_uri": consent_uri} raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri) - async def send_nonmember_event(self, requester, event, context, ratelimit=True): + async def send_nonmember_event( + self, requester, event, context, ratelimit=True + ) -> int: """ Persists and notifies local clients and federation of an event. @@ -639,6 +642,9 @@ class EventCreationHandler(object): context (Context) the context of the event. ratelimit (bool): Whether to rate limit this send. is_guest (bool): Whether the sender is a guest. + + Return: + The stream_id of the persisted event. """ if event.type == EventTypes.Member: raise SynapseError( @@ -659,7 +665,7 @@ class EventCreationHandler(object): ) return prev_state - await self.handle_new_client_event( + return await self.handle_new_client_event( requester=requester, event=event, context=context, ratelimit=ratelimit ) @@ -688,7 +694,7 @@ class EventCreationHandler(object): async def create_and_send_nonmember_event( self, requester, event_dict, ratelimit=True, txn_id=None - ): + ) -> Tuple[EventBase, int]: """ Creates an event, then sends it. @@ -711,10 +717,10 @@ class EventCreationHandler(object): spam_error = "Spam is not permitted here" raise SynapseError(403, spam_error, Codes.FORBIDDEN) - await self.send_nonmember_event( + stream_id = await self.send_nonmember_event( requester, event, context, ratelimit=ratelimit ) - return event + return event, stream_id @measure_func("create_new_client_event") @defer.inlineCallbacks @@ -774,7 +780,7 @@ class EventCreationHandler(object): @measure_func("handle_new_client_event") async def handle_new_client_event( self, requester, event, context, ratelimit=True, extra_users=[] - ): + ) -> int: """Processes a new event. This includes checking auth, persisting it, notifying users, sending to remote servers, etc. @@ -787,6 +793,9 @@ class EventCreationHandler(object): context (EventContext) ratelimit (bool) extra_users (list(UserID)): Any extra users to notify about event + + Return: + The stream_id of the persisted event. """ if event.is_state() and (event.type, event.state_key) == ( @@ -827,7 +836,7 @@ class EventCreationHandler(object): try: # If we're a worker we need to hit out to the master. if self.config.worker_app: - await self.send_event_to_master( + result = await self.send_event_to_master( event_id=event.event_id, store=self.store, requester=requester, @@ -836,14 +845,17 @@ class EventCreationHandler(object): ratelimit=ratelimit, extra_users=extra_users, ) + stream_id = result["stream_id"] + event.internal_metadata.stream_ordering = stream_id success = True - return + return stream_id - await self.persist_and_notify_client_event( + stream_id = await self.persist_and_notify_client_event( requester, event, context, ratelimit=ratelimit, extra_users=extra_users ) success = True + return stream_id finally: if not success: # Ensure that we actually remove the entries in the push actions @@ -886,7 +898,7 @@ class EventCreationHandler(object): async def persist_and_notify_client_event( self, requester, event, context, ratelimit=True, extra_users=[] - ): + ) -> int: """Called when we have fully built the event, have already calculated the push actions for the event, and checked auth. @@ -1076,6 +1088,8 @@ class EventCreationHandler(object): # matters as sometimes presence code can take a while. run_in_background(self._bump_active_time, requester.user) + return event_stream_id + async def _bump_active_time(self, user): try: presence = self.hs.get_presence_handler() diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 13850ba672..2698a129ca 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -22,6 +22,7 @@ import logging import math import string from collections import OrderedDict +from typing import Tuple from six import iteritems, string_types @@ -518,7 +519,7 @@ class RoomCreationHandler(BaseHandler): async def create_room( self, requester, config, ratelimit=True, creator_join_profile=None - ): + ) -> Tuple[dict, int]: """ Creates a new room. Args: @@ -535,9 +536,9 @@ class RoomCreationHandler(BaseHandler): `avatar_url` and/or `displayname`. Returns: - Deferred[dict]: - a dict containing the keys `room_id` and, if an alias was - requested, `room_alias`. + First, a dict containing the keys `room_id` and, if an alias + was, requested, `room_alias`. Secondly, the stream_id of the + last persisted event. Raises: SynapseError if the room ID couldn't be stored, or something went horribly wrong. @@ -669,7 +670,7 @@ class RoomCreationHandler(BaseHandler): # override any attempt to set room versions via the creation_content creation_content["room_version"] = room_version.identifier - await self._send_events_for_new_room( + last_stream_id = await self._send_events_for_new_room( requester, room_id, preset_config=preset_config, @@ -683,7 +684,10 @@ class RoomCreationHandler(BaseHandler): if "name" in config: name = config["name"] - await self.event_creation_handler.create_and_send_nonmember_event( + ( + _, + last_stream_id, + ) = await self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.Name, @@ -697,7 +701,10 @@ class RoomCreationHandler(BaseHandler): if "topic" in config: topic = config["topic"] - await self.event_creation_handler.create_and_send_nonmember_event( + ( + _, + last_stream_id, + ) = await self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.Topic, @@ -715,7 +722,7 @@ class RoomCreationHandler(BaseHandler): if is_direct: content["is_direct"] = is_direct - await self.room_member_handler.update_membership( + _, last_stream_id = await self.room_member_handler.update_membership( requester, UserID.from_string(invitee), room_id, @@ -729,7 +736,7 @@ class RoomCreationHandler(BaseHandler): id_access_token = invite_3pid.get("id_access_token") # optional address = invite_3pid["address"] medium = invite_3pid["medium"] - await self.hs.get_room_member_handler().do_3pid_invite( + last_stream_id = await self.hs.get_room_member_handler().do_3pid_invite( room_id, requester.user, medium, @@ -745,7 +752,7 @@ class RoomCreationHandler(BaseHandler): if room_alias: result["room_alias"] = room_alias.to_string() - return result + return result, last_stream_id async def _send_events_for_new_room( self, @@ -758,7 +765,13 @@ class RoomCreationHandler(BaseHandler): room_alias=None, power_level_content_override=None, # Doesn't apply when initial state has power level state event content creator_join_profile=None, - ): + ) -> int: + """Sends the initial events into a new room. + + Returns: + The stream_id of the last event persisted. + """ + def create(etype, content, **kwargs): e = {"type": etype, "content": content} @@ -767,12 +780,16 @@ class RoomCreationHandler(BaseHandler): return e - async def send(etype, content, **kwargs): + async def send(etype, content, **kwargs) -> int: event = create(etype, content, **kwargs) logger.debug("Sending %s in new room", etype) - await self.event_creation_handler.create_and_send_nonmember_event( + ( + _, + last_stream_id, + ) = await self.event_creation_handler.create_and_send_nonmember_event( creator, event, ratelimit=False ) + return last_stream_id config = RoomCreationHandler.PRESETS_DICT[preset_config] @@ -797,7 +814,9 @@ class RoomCreationHandler(BaseHandler): # of the first events that get sent into a room. pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None) if pl_content is not None: - await send(etype=EventTypes.PowerLevels, content=pl_content) + last_sent_stream_id = await send( + etype=EventTypes.PowerLevels, content=pl_content + ) else: power_level_content = { "users": {creator_id: 100}, @@ -830,33 +849,39 @@ class RoomCreationHandler(BaseHandler): if power_level_content_override: power_level_content.update(power_level_content_override) - await send(etype=EventTypes.PowerLevels, content=power_level_content) + last_sent_stream_id = await send( + etype=EventTypes.PowerLevels, content=power_level_content + ) if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state: - await send( + last_sent_stream_id = await send( etype=EventTypes.CanonicalAlias, content={"alias": room_alias.to_string()}, ) if (EventTypes.JoinRules, "") not in initial_state: - await send( + last_sent_stream_id = await send( etype=EventTypes.JoinRules, content={"join_rule": config["join_rules"]} ) if (EventTypes.RoomHistoryVisibility, "") not in initial_state: - await send( + last_sent_stream_id = await send( etype=EventTypes.RoomHistoryVisibility, content={"history_visibility": config["history_visibility"]}, ) if config["guest_can_join"]: if (EventTypes.GuestAccess, "") not in initial_state: - await send( + last_sent_stream_id = await send( etype=EventTypes.GuestAccess, content={"guest_access": "can_join"} ) for (etype, state_key), content in initial_state.items(): - await send(etype=etype, state_key=state_key, content=content) + last_sent_stream_id = await send( + etype=etype, state_key=state_key, content=content + ) + + return last_sent_stream_id async def _generate_room_id( self, creator_id: str, is_public: str, room_version: RoomVersion, diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index e51e1c32fe..691b6705b2 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -17,7 +17,7 @@ import abc import logging -from typing import Dict, Iterable, List, Optional, Tuple, Union +from typing import Dict, Iterable, List, Optional, Tuple from six.moves import http_client @@ -84,7 +84,7 @@ class RoomMemberHandler(object): room_id: str, user: UserID, content: dict, - ) -> Optional[dict]: + ) -> Tuple[str, int]: """Try and join a room that this server is not in Args: @@ -104,7 +104,7 @@ class RoomMemberHandler(object): room_id: str, target: UserID, content: dict, - ) -> dict: + ) -> Tuple[Optional[str], int]: """Attempt to reject an invite for a room this server is not in. If we fail to do so we locally mark the invite as rejected. @@ -154,7 +154,7 @@ class RoomMemberHandler(object): ratelimit: bool = True, content: Optional[dict] = None, require_consent: bool = True, - ) -> EventBase: + ) -> Tuple[str, int]: user_id = target.to_string() if content is None: @@ -187,9 +187,10 @@ class RoomMemberHandler(object): ) if duplicate is not None: # Discard the new event since this membership change is a no-op. - return duplicate + _, stream_id = await self.store.get_event_ordering(duplicate.event_id) + return duplicate.event_id, stream_id - await self.event_creation_handler.handle_new_client_event( + stream_id = await self.event_creation_handler.handle_new_client_event( requester, event, context, extra_users=[target], ratelimit=ratelimit ) @@ -213,7 +214,7 @@ class RoomMemberHandler(object): if prev_member_event.membership == Membership.JOIN: await self._user_left_room(target, room_id) - return event + return event.event_id, stream_id async def copy_room_tags_and_direct_to_room( self, old_room_id, new_room_id, user_id @@ -263,7 +264,7 @@ class RoomMemberHandler(object): ratelimit: bool = True, content: Optional[dict] = None, require_consent: bool = True, - ) -> Union[EventBase, Optional[dict]]: + ) -> Tuple[Optional[str], int]: key = (room_id,) with (await self.member_linearizer.queue(key)): @@ -294,7 +295,7 @@ class RoomMemberHandler(object): ratelimit: bool = True, content: Optional[dict] = None, require_consent: bool = True, - ) -> Union[EventBase, Optional[dict]]: + ) -> Tuple[Optional[str], int]: content_specified = bool(content) if content is None: content = {} @@ -398,7 +399,13 @@ class RoomMemberHandler(object): same_membership = old_membership == effective_membership_state same_sender = requester.user.to_string() == old_state.sender if same_sender and same_membership and same_content: - return old_state + _, stream_id = await self.store.get_event_ordering( + old_state.event_id + ) + return ( + old_state.event_id, + stream_id, + ) if old_membership in ["ban", "leave"] and action == "kick": raise AuthError(403, "The target user is not in the room") @@ -705,7 +712,7 @@ class RoomMemberHandler(object): requester: Requester, txn_id: Optional[str], id_access_token: Optional[str] = None, - ) -> None: + ) -> int: if self.config.block_non_admin_invites: is_requester_admin = await self.auth.is_server_admin(requester.user) if not is_requester_admin: @@ -737,11 +744,11 @@ class RoomMemberHandler(object): ) if invitee: - await self.update_membership( + _, stream_id = await self.update_membership( requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id ) else: - await self._make_and_store_3pid_invite( + stream_id = await self._make_and_store_3pid_invite( requester, id_server, medium, @@ -752,6 +759,8 @@ class RoomMemberHandler(object): id_access_token=id_access_token, ) + return stream_id + async def _make_and_store_3pid_invite( self, requester: Requester, @@ -762,7 +771,7 @@ class RoomMemberHandler(object): user: UserID, txn_id: Optional[str], id_access_token: Optional[str] = None, - ) -> None: + ) -> int: room_state = await self.state_handler.get_current_state(room_id) inviter_display_name = "" @@ -817,7 +826,10 @@ class RoomMemberHandler(object): id_access_token=id_access_token, ) - await self.event_creation_handler.create_and_send_nonmember_event( + ( + event, + stream_id, + ) = await self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.ThirdPartyInvite, @@ -835,6 +847,7 @@ class RoomMemberHandler(object): ratelimit=False, txn_id=txn_id, ) + return stream_id async def _is_host_in_room( self, current_state_ids: Dict[Tuple[str, str], str] @@ -916,7 +929,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): room_id: str, user: UserID, content: dict, - ) -> None: + ) -> Tuple[str, int]: """Implements RoomMemberHandler._remote_join """ # filter ourselves out of remote_room_hosts: do_invite_join ignores it @@ -945,7 +958,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): # join dance for now, since we're kinda implicitly checking # that we are allowed to join when we decide whether or not we # need to do the invite/join dance. - await self.federation_handler.do_invite_join( + event_id, stream_id = await self.federation_handler.do_invite_join( remote_room_hosts, room_id, user.to_string(), content ) await self._user_joined_room(user, room_id) @@ -955,14 +968,14 @@ class RoomMemberMasterHandler(RoomMemberHandler): if self.hs.config.limit_remote_rooms.enabled: if too_complex is False: # We checked, and we're under the limit. - return + return event_id, stream_id # Check again, but with the local state events too_complex = await self._is_local_room_too_complex(room_id) if too_complex is False: # We're under the limit. - return + return event_id, stream_id # The room is too large. Leave. requester = types.create_requester(user, None, False, None) @@ -975,6 +988,8 @@ class RoomMemberMasterHandler(RoomMemberHandler): errcode=Codes.RESOURCE_LIMIT_EXCEEDED, ) + return event_id, stream_id + async def _remote_reject_invite( self, requester: Requester, @@ -982,15 +997,15 @@ class RoomMemberMasterHandler(RoomMemberHandler): room_id: str, target: UserID, content: dict, - ) -> dict: + ) -> Tuple[Optional[str], int]: """Implements RoomMemberHandler._remote_reject_invite """ fed_handler = self.federation_handler try: - ret = await fed_handler.do_remotely_reject_invite( + event, stream_id = await fed_handler.do_remotely_reject_invite( remote_room_hosts, room_id, target.to_string(), content=content, ) - return ret + return event.event_id, stream_id except Exception as e: # if we were unable to reject the exception, just mark # it as rejected on our end and plough ahead. @@ -1000,8 +1015,10 @@ class RoomMemberMasterHandler(RoomMemberHandler): # logger.warning("Failed to reject invite: %s", e) - await self.store.locally_reject_invite(target.to_string(), room_id) - return {} + stream_id = await self.store.locally_reject_invite( + target.to_string(), room_id + ) + return None, stream_id async def _user_joined_room(self, target: UserID, room_id: str) -> None: """Implements RoomMemberHandler._user_joined_room diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py index 5c776cc0be..02e0c4103d 100644 --- a/synapse/handlers/room_member_worker.py +++ b/synapse/handlers/room_member_worker.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import List, Optional +from typing import List, Optional, Tuple from synapse.api.errors import SynapseError from synapse.handlers.room_member import RoomMemberHandler @@ -43,7 +43,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler): room_id: str, user: UserID, content: dict, - ) -> Optional[dict]: + ) -> Tuple[str, int]: """Implements RoomMemberHandler._remote_join """ if len(remote_room_hosts) == 0: @@ -59,7 +59,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler): await self._user_joined_room(user, room_id) - return ret + return ret["event_id"], ret["stream_id"] async def _remote_reject_invite( self, @@ -68,16 +68,17 @@ class RoomMemberWorkerHandler(RoomMemberHandler): room_id: str, target: UserID, content: dict, - ) -> dict: + ) -> Tuple[Optional[str], int]: """Implements RoomMemberHandler._remote_reject_invite """ - return await self._remote_reject_client( + ret = await self._remote_reject_client( requester=requester, remote_room_hosts=remote_room_hosts, room_id=room_id, user_id=target.to_string(), content=content, ) + return ret["event_id"], ret["stream_id"] async def _user_joined_room(self, target: UserID, room_id: str) -> None: """Implements RoomMemberHandler._user_joined_room diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index 7e23b565b9..c287c4e269 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -29,7 +29,7 @@ logger = logging.getLogger(__name__) class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): """Handles events newly received from federation, including persisting and - notifying. + notifying. Returns the maximum stream ID of the persisted events. The API looks like: @@ -46,6 +46,13 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): "context": { .. serialized event context .. }, }], "backfilled": false + } + + 200 OK + + { + "max_stream_id": 32443, + } """ NAME = "fed_send_events" @@ -115,11 +122,11 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): logger.info("Got %d events from federation", len(event_and_contexts)) - await self.federation_handler.persist_events_and_notify( + max_stream_id = await self.federation_handler.persist_events_and_notify( event_and_contexts, backfilled ) - return 200, {} + return 200, {"max_stream_id": max_stream_id} class ReplicationFederationSendEduRestServlet(ReplicationEndpoint): diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py index 3577611fd7..050fd34562 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py @@ -76,11 +76,11 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): logger.info("remote_join: %s into room: %s", user_id, room_id) - await self.federation_handler.do_invite_join( + event_id, stream_id = await self.federation_handler.do_invite_join( remote_room_hosts, room_id, user_id, event_content ) - return 200, {} + return 200, {"event_id": event_id, "stream_id": stream_id} class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): @@ -136,10 +136,10 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): logger.info("remote_reject_invite: %s out of room: %s", user_id, room_id) try: - event = await self.federation_handler.do_remotely_reject_invite( + event, stream_id = await self.federation_handler.do_remotely_reject_invite( remote_room_hosts, room_id, user_id, event_content, ) - ret = event.get_pdu_json() + event_id = event.event_id except Exception as e: # if we were unable to reject the exception, just mark # it as rejected on our end and plough ahead. @@ -149,10 +149,10 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): # logger.warning("Failed to reject invite: %s", e) - await self.store.locally_reject_invite(user_id, room_id) - ret = {} + stream_id = await self.store.locally_reject_invite(user_id, room_id) + event_id = None - return 200, ret + return 200, {"event_id": event_id, "stream_id": stream_id} class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint): diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index b74b088ff4..c981723c1a 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -119,11 +119,11 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): "Got event to send with ID: %s into room: %s", event.event_id, event.room_id ) - await self.event_creation_handler.persist_and_notify_client_event( + stream_id = await self.event_creation_handler.persist_and_notify_client_event( requester, event, context, ratelimit=ratelimit, extra_users=extra_users ) - return 200, {} + return 200, {"stream_id": stream_id} def register_servlets(hs, http_server): diff --git a/synapse/replication/http/streams.py b/synapse/replication/http/streams.py index b705a8e16c..bde97eef32 100644 --- a/synapse/replication/http/streams.py +++ b/synapse/replication/http/streams.py @@ -51,10 +51,7 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint): super().__init__(hs) self._instance_name = hs.get_instance_name() - - # We pull the streams from the replication handler (if we try and make - # them ourselves we end up in an import loop). - self.streams = hs.get_tcp_replication().get_streams() + self.streams = hs.get_replication_streams() @staticmethod def _serialize_payload(stream_name, from_token, upto_token): diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 28826302f5..508ad1b720 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -14,19 +14,23 @@ # limitations under the License. """A replication client for use by synapse workers. """ - +import heapq import logging -from typing import TYPE_CHECKING, Tuple +from typing import TYPE_CHECKING, Dict, List, Tuple +from twisted.internet.defer import Deferred from twisted.internet.protocol import ReconnectingClientFactory from synapse.api.constants import EventTypes +from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol from synapse.replication.tcp.streams.events import ( EventsStream, EventsStreamEventRow, EventsStreamRow, ) +from synapse.util.async_helpers import timeout_deferred +from synapse.util.metrics import Measure if TYPE_CHECKING: from synapse.server import HomeServer @@ -35,6 +39,10 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +# How long we allow callers to wait for replication updates before timing out. +_WAIT_FOR_REPLICATION_TIMEOUT_SECONDS = 30 + + class DirectTcpReplicationClientFactory(ReconnectingClientFactory): """Factory for building connections to the master. Will reconnect if the connection is lost. @@ -92,6 +100,16 @@ class ReplicationDataHandler: self.store = hs.get_datastore() self.pusher_pool = hs.get_pusherpool() self.notifier = hs.get_notifier() + self._reactor = hs.get_reactor() + self._clock = hs.get_clock() + self._streams = hs.get_replication_streams() + self._instance_name = hs.get_instance_name() + + # Map from stream to list of deferreds waiting for the stream to + # arrive at a particular position. The lists are sorted by stream position. + self._streams_to_waiters = ( + {} + ) # type: Dict[str, List[Tuple[int, Deferred[None]]]] async def on_rdata( self, stream_name: str, instance_name: str, token: int, rows: list @@ -131,8 +149,76 @@ class ReplicationDataHandler: await self.pusher_pool.on_new_notifications(token, token) + # Notify any waiting deferreds. The list is ordered by position so we + # just iterate through the list until we reach a position that is + # greater than the received row position. + waiting_list = self._streams_to_waiters.get(stream_name, []) + + # Index of first item with a position after the current token, i.e we + # have called all deferreds before this index. If not overwritten by + # loop below means either a) no items in list so no-op or b) all items + # in list were called and so the list should be cleared. Setting it to + # `len(list)` works for both cases. + index_of_first_deferred_not_called = len(waiting_list) + + for idx, (position, deferred) in enumerate(waiting_list): + if position <= token: + try: + with PreserveLoggingContext(): + deferred.callback(None) + except Exception: + # The deferred has been cancelled or timed out. + pass + else: + # The list is sorted by position so we don't need to continue + # checking any futher entries in the list. + index_of_first_deferred_not_called = idx + break + + # Drop all entries in the waiting list that were called in the above + # loop. (This maintains the order so no need to resort) + waiting_list[:] = waiting_list[index_of_first_deferred_not_called:] + async def on_position(self, stream_name: str, instance_name: str, token: int): self.store.process_replication_rows(stream_name, instance_name, token, []) def on_remote_server_up(self, server: str): """Called when get a new REMOTE_SERVER_UP command.""" + + async def wait_for_stream_position( + self, instance_name: str, stream_name: str, position: int + ): + """Wait until this instance has received updates up to and including + the given stream position. + """ + + if instance_name == self._instance_name: + # We don't get told about updates written by this process, and + # anyway in that case we don't need to wait. + return + + current_position = self._streams[stream_name].current_token(self._instance_name) + if position <= current_position: + # We're already past the position + return + + # Create a new deferred that times out after N seconds, as we don't want + # to wedge here forever. + deferred = Deferred() + deferred = timeout_deferred( + deferred, _WAIT_FOR_REPLICATION_TIMEOUT_SECONDS, self._reactor + ) + + waiting_list = self._streams_to_waiters.setdefault(stream_name, []) + + # We insert into the list using heapq as it is more efficient than + # pushing then resorting each time. + heapq.heappush(waiting_list, (position, deferred)) + + # We measure here to get in flight counts and average waiting time. + with Measure(self._clock, "repl.wait_for_stream_position"): + logger.info("Waiting for repl stream %r to reach %s", stream_name, position) + await make_deferred_yieldable(deferred) + logger.info( + "Finished waiting for repl stream %r to reach %s", stream_name, position + ) diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 7d40001988..0a13e1ed34 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -59,6 +59,7 @@ class ShutdownRoomRestServlet(RestServlet): self.event_creation_handler = hs.get_event_creation_handler() self.room_member_handler = hs.get_room_member_handler() self.auth = hs.get_auth() + self._replication = hs.get_replication_data_handler() async def on_POST(self, request, room_id): requester = await self.auth.get_user_by_req(request) @@ -73,7 +74,7 @@ class ShutdownRoomRestServlet(RestServlet): message = content.get("message", self.DEFAULT_MESSAGE) room_name = content.get("room_name", "Content Violation Notification") - info = await self._room_creation_handler.create_room( + info, stream_id = await self._room_creation_handler.create_room( room_creator_requester, config={ "preset": "public_chat", @@ -94,6 +95,13 @@ class ShutdownRoomRestServlet(RestServlet): # desirable in case the first attempt at blocking the room failed below. await self.store.block_room(room_id, requester_user_id) + # We now wait for the create room to come back in via replication so + # that we can assume that all the joins/invites have propogated before + # we try and auto join below. + # + # TODO: Currently the events stream is written to from master + await self._replication.wait_for_stream_position("master", "events", stream_id) + users = await self.state.get_current_users_in_room(room_id) kicked_users = [] failed_to_kick_users = [] diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 6b5830cc3f..105e0cf4d2 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -93,7 +93,7 @@ class RoomCreateRestServlet(TransactionRestServlet): async def on_POST(self, request): requester = await self.auth.get_user_by_req(request) - info = await self._room_creation_handler.create_room( + info, _ = await self._room_creation_handler.create_room( requester, self.get_room_config(request) ) @@ -202,7 +202,7 @@ class RoomStateEventRestServlet(TransactionRestServlet): if event_type == EventTypes.Member: membership = content.get("membership", None) - event = await self.room_member_handler.update_membership( + event_id, _ = await self.room_member_handler.update_membership( requester, target=UserID.from_string(state_key), room_id=room_id, @@ -210,14 +210,18 @@ class RoomStateEventRestServlet(TransactionRestServlet): content=content, ) else: - event = await self.event_creation_handler.create_and_send_nonmember_event( + ( + event, + _, + ) = await self.event_creation_handler.create_and_send_nonmember_event( requester, event_dict, txn_id=txn_id ) + event_id = event.event_id ret = {} # type: dict - if event: - set_tag("event_id", event.event_id) - ret = {"event_id": event.event_id} + if event_id: + set_tag("event_id", event_id) + ret = {"event_id": event_id} return 200, ret @@ -247,7 +251,7 @@ class RoomSendEventRestServlet(TransactionRestServlet): if b"ts" in request.args and requester.app_service: event_dict["origin_server_ts"] = parse_integer(request, "ts", 0) - event = await self.event_creation_handler.create_and_send_nonmember_event( + event, _ = await self.event_creation_handler.create_and_send_nonmember_event( requester, event_dict, txn_id=txn_id ) @@ -781,7 +785,7 @@ class RoomRedactEventRestServlet(TransactionRestServlet): requester = await self.auth.get_user_by_req(request) content = parse_json_object_from_request(request) - event = await self.event_creation_handler.create_and_send_nonmember_event( + event, _ = await self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.Redaction, diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py index 63f07b63da..89002ffbff 100644 --- a/synapse/rest/client/v2_alpha/relations.py +++ b/synapse/rest/client/v2_alpha/relations.py @@ -111,7 +111,7 @@ class RelationSendServlet(RestServlet): "sender": requester.user.to_string(), } - event = await self.event_creation_handler.create_and_send_nonmember_event( + event, _ = await self.event_creation_handler.create_and_send_nonmember_event( requester, event_dict=event_dict, txn_id=txn_id ) diff --git a/synapse/server.py b/synapse/server.py index c530f1aa1a..ca2deb49bb 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -90,6 +90,7 @@ from synapse.push.pusherpool import PusherPool from synapse.replication.tcp.client import ReplicationDataHandler from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.resource import ReplicationStreamer +from synapse.replication.tcp.streams import STREAMS_MAP from synapse.rest.media.v1.media_repository import ( MediaRepository, MediaRepositoryResource, @@ -210,6 +211,7 @@ class HomeServer(object): "storage", "replication_streamer", "replication_data_handler", + "replication_streams", ] REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"] @@ -583,6 +585,9 @@ class HomeServer(object): def build_replication_data_handler(self): return ReplicationDataHandler(self) + def build_replication_streams(self): + return {stream.NAME: stream(self) for stream in STREAMS_MAP.values()} + def remove_pusher(self, app_id, push_key, user_id): return self.get_pusherpool().remove_pusher(app_id, push_key, user_id) diff --git a/synapse/server.pyi b/synapse/server.pyi index 9e7fad7e6e..fe8024d2d4 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -1,3 +1,5 @@ +from typing import Dict + import twisted.internet import synapse.api.auth @@ -28,6 +30,7 @@ import synapse.server_notices.server_notices_sender import synapse.state import synapse.storage from synapse.events.builder import EventBuilderFactory +from synapse.replication.tcp.streams import Stream class HomeServer(object): @property @@ -136,3 +139,5 @@ class HomeServer(object): pass def get_pusherpool(self) -> synapse.push.pusherpool.PusherPool: pass + def get_replication_streams(self) -> Dict[str, Stream]: + pass diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py index 999c621b92..bf2454c01c 100644 --- a/synapse/server_notices/server_notices_manager.py +++ b/synapse/server_notices/server_notices_manager.py @@ -83,10 +83,10 @@ class ServerNoticesManager(object): if state_key is not None: event_dict["state_key"] = state_key - res = await self._event_creation_handler.create_and_send_nonmember_event( + event, _ = await self._event_creation_handler.create_and_send_nonmember_event( requester, event_dict, ratelimit=False ) - return res + return event @cached() async def get_or_create_notice_room_for_user(self, user_id): @@ -143,7 +143,7 @@ class ServerNoticesManager(object): } requester = create_requester(self.server_notices_mxid) - info = await self._room_creation_handler.create_room( + info, _ = await self._room_creation_handler.create_room( requester, config={ "preset": RoomCreationPreset.PRIVATE_CHAT, diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index 9130b74eb5..b880a71782 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -1289,12 +1289,12 @@ class EventsWorkerStore(SQLBaseStore): async def is_event_after(self, event_id1, event_id2): """Returns True if event_id1 is after event_id2 in the stream """ - to_1, so_1 = await self._get_event_ordering(event_id1) - to_2, so_2 = await self._get_event_ordering(event_id2) + to_1, so_1 = await self.get_event_ordering(event_id1) + to_2, so_2 = await self.get_event_ordering(event_id2) return (to_1, so_1) > (to_2, so_2) @cachedInlineCallbacks(max_entries=5000) - def _get_event_ordering(self, event_id): + def get_event_ordering(self, event_id): res = yield self.db.simple_select_one( table="events", retcols=["topological_ordering", "stream_ordering"], diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py index 1e9c850152..7c5ca81ae0 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py @@ -1069,6 +1069,8 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): with self._stream_id_gen.get_next() as stream_ordering: yield self.db.runInteraction("locally_reject_invite", f, stream_ordering) + return stream_ordering + def forget(self, user_id, room_id): """Indicate that user_id wishes to discard history for room_id.""" diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index 94980733c4..0c9987be54 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -79,7 +79,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): # Mock out some things, because we don't want to test the whole join fed_transport.client.get_json = Mock(return_value=defer.succeed({"v1": 9999})) - handler.federation_handler.do_invite_join = Mock(return_value=defer.succeed(1)) + handler.federation_handler.do_invite_join = Mock( + return_value=defer.succeed(("", 1)) + ) d = handler._remote_join( None, @@ -115,7 +117,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): # Mock out some things, because we don't want to test the whole join fed_transport.client.get_json = Mock(return_value=defer.succeed(None)) - handler.federation_handler.do_invite_join = Mock(return_value=defer.succeed(1)) + handler.federation_handler.do_invite_join = Mock( + return_value=defer.succeed(("", 1)) + ) # Artificially raise the complexity self.hs.get_datastore().get_current_state_event_counts = lambda x: defer.succeed( diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 51e2b37218..2fa8d4739b 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -86,7 +86,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): reactor.pump((1000,)) hs = self.setup_test_homeserver( - notifier=Mock(), http_client=mock_federation_client, keyring=mock_keyring + notifier=Mock(), + http_client=mock_federation_client, + keyring=mock_keyring, + replication_streams={}, ) hs.datastores = datastores diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py index 0e04b2cf92..43425c969a 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py @@ -39,7 +39,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): # Create a test user and room self.user = UserID("alice", "test") self.requester = Requester(self.user, None, False, None, None) - info = self.get_success(self.room_creator.create_room(self.requester, {})) + info, _ = self.get_success(self.room_creator.create_room(self.requester, {})) self.room_id = info["room_id"] def run_background_update(self): @@ -261,7 +261,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): self.user = UserID.from_string(self.register_user("user1", "password")) self.token1 = self.login("user1", "password") self.requester = Requester(self.user, None, False, None, None) - info = self.get_success(self.room_creator.create_room(self.requester, {})) + info, _ = self.get_success(self.room_creator.create_room(self.requester, {})) self.room_id = info["room_id"] self.event_creator = homeserver.get_event_creation_handler() homeserver.config.user_consent_version = self.CONSENT_VERSION diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py index a7b7fd36d3..a7b85004e5 100644 --- a/tests/storage/test_event_metrics.py +++ b/tests/storage/test_event_metrics.py @@ -33,7 +33,7 @@ class ExtremStatisticsTestCase(HomeserverTestCase): events = [(3, 2), (6, 2), (4, 6)] for event_count, extrems in events: - info = self.get_success(room_creator.create_room(requester, {})) + info, _ = self.get_success(room_creator.create_room(requester, {})) room_id = info["room_id"] last_event = None diff --git a/tests/test_federation.py b/tests/test_federation.py index 13ff14863e..c5099dd039 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -28,13 +28,13 @@ class MessageAcceptTests(unittest.HomeserverTestCase): user_id = UserID("us", "test") our_user = Requester(user_id, None, False, None, None) room_creator = self.homeserver.get_room_creation_handler() - room = ensureDeferred( + room_deferred = ensureDeferred( room_creator.create_room( our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False ) ) self.reactor.advance(0.1) - self.room_id = self.successResultOf(room)["room_id"] + self.room_id = self.successResultOf(room_deferred)[0]["room_id"] self.store = self.homeserver.get_datastore() -- cgit 1.5.1 From e5c67d04dbe5ed45d659e826a5dfcd5044a4e374 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 22 May 2020 16:11:35 +0100 Subject: Add option to move event persistence off master (#7517) --- changelog.d/7517.feature | 1 + scripts/synapse_port_db | 3 + synapse/app/generic_worker.py | 53 +++++++++- synapse/config/logger.py | 1 + synapse/config/workers.py | 30 +++++- synapse/handlers/federation.py | 16 +-- synapse/handlers/message.py | 12 ++- synapse/handlers/presence.py | 6 ++ synapse/handlers/room.py | 7 ++ synapse/handlers/room_member.py | 39 ++++++-- synapse/replication/http/__init__.py | 4 +- synapse/replication/http/_base.py | 3 + synapse/replication/http/membership.py | 40 +++++++- synapse/replication/http/presence.py | 116 ++++++++++++++++++++++ synapse/replication/tcp/handler.py | 10 ++ synapse/rest/admin/rooms.py | 11 +- synapse/storage/data_stores/__init__.py | 6 +- synapse/storage/data_stores/main/devices.py | 30 ++++-- synapse/storage/data_stores/main/events.py | 34 ++++++- synapse/storage/data_stores/main/events_worker.py | 2 +- synapse/storage/data_stores/main/roommember.py | 25 ----- synapse/storage/persist_events.py | 6 ++ 22 files changed, 382 insertions(+), 73 deletions(-) create mode 100644 changelog.d/7517.feature create mode 100644 synapse/replication/http/presence.py (limited to 'synapse/storage') diff --git a/changelog.d/7517.feature b/changelog.d/7517.feature new file mode 100644 index 0000000000..6062f0061c --- /dev/null +++ b/changelog.d/7517.feature @@ -0,0 +1 @@ +Add option to move event persistence off master. diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index acd9ac4b75..9a0fbc61d8 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -196,6 +196,9 @@ class MockHomeserver: def get_reactor(self): return reactor + def get_instance_name(self): + return "master" + class Porter(object): def __init__(self, **kwargs): diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index a37520000a..2906b93f6a 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -39,7 +39,11 @@ from synapse.config.homeserver import HomeServerConfig from synapse.config.logger import setup_logging from synapse.federation import send_queue from synapse.federation.transport.server import TransportLayerServer -from synapse.handlers.presence import BasePresenceHandler, get_interested_parties +from synapse.handlers.presence import ( + BasePresenceHandler, + PresenceState, + get_interested_parties, +) from synapse.http.server import JsonResource, OptionsResource from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.site import SynapseSite @@ -47,6 +51,10 @@ from synapse.logging.context import LoggingContext from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource +from synapse.replication.http.presence import ( + ReplicationBumpPresenceActiveTime, + ReplicationPresenceSetState, +) from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage.account_data import SlavedAccountDataStore from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore @@ -247,6 +255,9 @@ class GenericWorkerPresence(BasePresenceHandler): # but we haven't notified the master of that yet self.users_going_offline = {} + self._bump_active_client = ReplicationBumpPresenceActiveTime.make_client(hs) + self._set_state_client = ReplicationPresenceSetState.make_client(hs) + self._send_stop_syncing_loop = self.clock.looping_call( self.send_stop_syncing, UPDATE_SYNCING_USERS_MS ) @@ -304,10 +315,6 @@ class GenericWorkerPresence(BasePresenceHandler): self.users_going_offline.pop(user_id, None) self.send_user_sync(user_id, False, last_sync_ms) - def set_state(self, user, state, ignore_status_msg=False): - # TODO Hows this supposed to work? - return defer.succeed(None) - async def user_syncing( self, user_id: str, affect_presence: bool ) -> ContextManager[None]: @@ -386,6 +393,42 @@ class GenericWorkerPresence(BasePresenceHandler): if count > 0 ] + async def set_state(self, target_user, state, ignore_status_msg=False): + """Set the presence state of the user. + """ + presence = state["presence"] + + valid_presence = ( + PresenceState.ONLINE, + PresenceState.UNAVAILABLE, + PresenceState.OFFLINE, + ) + if presence not in valid_presence: + raise SynapseError(400, "Invalid presence state") + + user_id = target_user.to_string() + + # If presence is disabled, no-op + if not self.hs.config.use_presence: + return + + # Proxy request to master + await self._set_state_client( + user_id=user_id, state=state, ignore_status_msg=ignore_status_msg + ) + + async def bump_presence_active_time(self, user): + """We've seen the user do something that indicates they're interacting + with the app. + """ + # If presence is disabled, no-op + if not self.hs.config.use_presence: + return + + # Proxy request to master + user_id = user.to_string() + await self._bump_active_client(user_id=user_id) + class GenericWorkerTyping(object): def __init__(self, hs): diff --git a/synapse/config/logger.py b/synapse/config/logger.py index a25c70e928..49f6c32beb 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -257,5 +257,6 @@ def setup_logging( logging.warning("***** STARTING SERVER *****") logging.warning("Server %s version %s", sys.argv[0], get_version_string(synapse)) logging.info("Server hostname: %s", config.server_name) + logging.info("Instance name: %s", hs.get_instance_name()) return logger diff --git a/synapse/config/workers.py b/synapse/config/workers.py index c80c338584..ed06b91a54 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -15,7 +15,7 @@ import attr -from ._base import Config +from ._base import Config, ConfigError @attr.s @@ -27,6 +27,17 @@ class InstanceLocationConfig: port = attr.ib(type=int) +@attr.s +class WriterLocations: + """Specifies the instances that write various streams. + + Attributes: + events: The instance that writes to the event and backfill streams. + """ + + events = attr.ib(default="master", type=str) + + class WorkerConfig(Config): """The workers are processes run separately to the main synapse process. They have their own pid_file and listener configuration. They use the @@ -83,11 +94,26 @@ class WorkerConfig(Config): bind_addresses.append("") # A map from instance name to host/port of their HTTP replication endpoint. - instance_map = config.get("instance_map", {}) or {} + instance_map = config.get("instance_map") or {} self.instance_map = { name: InstanceLocationConfig(**c) for name, c in instance_map.items() } + # Map from type of streams to source, c.f. WriterLocations. + writers = config.get("stream_writers") or {} + self.writers = WriterLocations(**writers) + + # Check that the configured writer for events also appears in + # `instance_map`. + if ( + self.writers.events != "master" + and self.writers.events not in self.instance_map + ): + raise ConfigError( + "Instance %r is configured to write events but does not appear in `instance_map` config." + % (self.writers.events,) + ) + def read_arguments(self, args): # We support a bunch of command line arguments that override options in # the config. A lot of these options have a worker_* prefix when running diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index e354c803db..75ec90d267 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -126,11 +126,10 @@ class FederationHandler(BaseHandler): self._server_notices_mxid = hs.config.server_notices_mxid self.config = hs.config self.http_client = hs.get_simple_http_client() + self._instance_name = hs.get_instance_name() self._replication = hs.get_replication_data_handler() - self._send_events_to_master = ReplicationFederationSendEventsRestServlet.make_client( - hs - ) + self._send_events = ReplicationFederationSendEventsRestServlet.make_client(hs) self._notify_user_membership_change = ReplicationUserJoinedLeftRoomRestServlet.make_client( hs ) @@ -1243,6 +1242,10 @@ class FederationHandler(BaseHandler): content: The event content to use for the join event. """ + # TODO: We should be able to call this on workers, but the upgrading of + # room stuff after join currently doesn't work on workers. + assert self.config.worker.worker_app is None + logger.debug("Joining %s to %s", joinee, room_id) origin, event, room_version_obj = await self._make_and_verify_event( @@ -1314,7 +1317,7 @@ class FederationHandler(BaseHandler): # # TODO: Currently the events stream is written to from master await self._replication.wait_for_stream_position( - "master", "events", max_stream_id + self.config.worker.writers.events, "events", max_stream_id ) # Check whether this room is the result of an upgrade of a room we already know @@ -2854,8 +2857,9 @@ class FederationHandler(BaseHandler): backfilled: Whether these events are a result of backfilling or not """ - if self.config.worker_app: - result = await self._send_events_to_master( + if self.config.worker.writers.events != self._instance_name: + result = await self._send_events( + instance_name=self.config.worker.writers.events, store=self.store, event_and_contexts=event_and_contexts, backfilled=backfilled, diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index f445e2aa2a..ea25f0515a 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -366,10 +366,11 @@ class EventCreationHandler(object): self.notifier = hs.get_notifier() self.config = hs.config self.require_membership_for_aliases = hs.config.require_membership_for_aliases + self._instance_name = hs.get_instance_name() self.room_invite_state_types = self.hs.config.room_invite_state_types - self.send_event_to_master = ReplicationSendEventRestServlet.make_client(hs) + self.send_event = ReplicationSendEventRestServlet.make_client(hs) # This is only used to get at ratelimit function, and maybe_kick_guest_users self.base_handler = BaseHandler(hs) @@ -835,8 +836,9 @@ class EventCreationHandler(object): success = False try: # If we're a worker we need to hit out to the master. - if self.config.worker_app: - result = await self.send_event_to_master( + if self.config.worker.writers.events != self._instance_name: + result = await self.send_event( + instance_name=self.config.worker.writers.events, event_id=event.event_id, store=self.store, requester=requester, @@ -902,9 +904,9 @@ class EventCreationHandler(object): """Called when we have fully built the event, have already calculated the push actions for the event, and checked auth. - This should only be run on master. + This should only be run on the instance in charge of persisting events. """ - assert not self.config.worker_app + assert self.config.worker.writers.events == self._instance_name if ratelimit: # We check if this is a room admin redacting an event so that we diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 9ea11c0754..3594f3b00f 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -193,6 +193,12 @@ class BasePresenceHandler(abc.ABC): ) -> None: """Set the presence state of the user. """ + @abc.abstractmethod + async def bump_presence_active_time(self, user: UserID): + """We've seen the user do something that indicates they're interacting + with the app. + """ + class PresenceHandler(BasePresenceHandler): def __init__(self, hs: "synapse.server.HomeServer"): diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 2698a129ca..61db3ccc43 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -89,6 +89,8 @@ class RoomCreationHandler(BaseHandler): self.room_member_handler = hs.get_room_member_handler() self.config = hs.config + self._replication = hs.get_replication_data_handler() + # linearizer to stop two upgrades happening at once self._upgrade_linearizer = Linearizer("room_upgrade_linearizer") @@ -752,6 +754,11 @@ class RoomCreationHandler(BaseHandler): if room_alias: result["room_alias"] = room_alias.to_string() + # Always wait for room creation to progate before returning + await self._replication.wait_for_stream_position( + self.hs.config.worker.writers.events, "events", last_stream_id + ) + return result, last_stream_id async def _send_events_for_new_room( diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 691b6705b2..0f7af982f0 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -26,6 +26,9 @@ from synapse.api.constants import EventTypes, Membership from synapse.api.errors import AuthError, Codes, SynapseError from synapse.events import EventBase from synapse.events.snapshot import EventContext +from synapse.replication.http.membership import ( + ReplicationLocallyRejectInviteRestServlet, +) from synapse.types import Collection, Requester, RoomAlias, RoomID, UserID from synapse.util.async_helpers import Linearizer from synapse.util.distributor import user_joined_room, user_left_room @@ -44,11 +47,6 @@ class RoomMemberHandler(object): __metaclass__ = abc.ABCMeta def __init__(self, hs): - """ - - Args: - hs (synapse.server.HomeServer): - """ self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() @@ -71,6 +69,17 @@ class RoomMemberHandler(object): self._enable_lookup = hs.config.enable_3pid_lookup self.allow_per_room_profiles = self.config.allow_per_room_profiles + self._event_stream_writer_instance = hs.config.worker.writers.events + self._is_on_event_persistence_instance = ( + self._event_stream_writer_instance == hs.get_instance_name() + ) + if self._is_on_event_persistence_instance: + self.persist_event_storage = hs.get_storage().persistence + else: + self._locally_reject_client = ReplicationLocallyRejectInviteRestServlet.make_client( + hs + ) + # This is only used to get at ratelimit function, and # maybe_kick_guest_users. It's fine there are multiple of these as # it doesn't store state. @@ -121,6 +130,22 @@ class RoomMemberHandler(object): """ raise NotImplementedError() + async def locally_reject_invite(self, user_id: str, room_id: str) -> int: + """Mark the invite has having been rejected even though we failed to + create a leave event for it. + """ + if self._is_on_event_persistence_instance: + return await self.persist_event_storage.locally_reject_invite( + user_id, room_id + ) + else: + result = await self._locally_reject_client( + instance_name=self._event_stream_writer_instance, + user_id=user_id, + room_id=room_id, + ) + return result["stream_id"] + @abc.abstractmethod async def _user_joined_room(self, target: UserID, room_id: str) -> None: """Notifies distributor on master process that the user has joined the @@ -1015,9 +1040,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): # logger.warning("Failed to reject invite: %s", e) - stream_id = await self.store.locally_reject_invite( - target.to_string(), room_id - ) + stream_id = await self.locally_reject_invite(target.to_string(), room_id) return None, stream_id async def _user_joined_room(self, target: UserID, room_id: str) -> None: diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py index a909744e93..19b69e0e11 100644 --- a/synapse/replication/http/__init__.py +++ b/synapse/replication/http/__init__.py @@ -19,6 +19,7 @@ from synapse.replication.http import ( federation, login, membership, + presence, register, send_event, streams, @@ -35,10 +36,11 @@ class ReplicationRestResource(JsonResource): def register_servlets(self, hs): send_event.register_servlets(hs, self) federation.register_servlets(hs, self) + presence.register_servlets(hs, self) + membership.register_servlets(hs, self) # The following can't currently be instantiated on workers. if hs.config.worker.worker_app is None: - membership.register_servlets(hs, self) login.register_servlets(hs, self) register.register_servlets(hs, self) devices.register_servlets(hs, self) diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index c3136a4eb9..793cef6c26 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -142,6 +142,7 @@ class ReplicationEndpoint(object): """ clock = hs.get_clock() client = hs.get_simple_http_client() + local_instance_name = hs.get_instance_name() master_host = hs.config.worker_replication_host master_port = hs.config.worker_replication_http_port @@ -151,6 +152,8 @@ class ReplicationEndpoint(object): @trace(opname="outgoing_replication_request") @defer.inlineCallbacks def send_request(instance_name="master", **kwargs): + if instance_name == local_instance_name: + raise Exception("Trying to send HTTP request to self") if instance_name == "master": host = master_host port = master_port diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py index 050fd34562..a7174c4a8f 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py @@ -14,12 +14,16 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint from synapse.types import Requester, UserID from synapse.util.distributor import user_joined_room, user_left_room +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -106,6 +110,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): self.federation_handler = hs.get_handlers().federation_handler self.store = hs.get_datastore() self.clock = hs.get_clock() + self.member_handler = hs.get_room_member_handler() @staticmethod def _serialize_payload(requester, room_id, user_id, remote_room_hosts, content): @@ -149,12 +154,44 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): # logger.warning("Failed to reject invite: %s", e) - stream_id = await self.store.locally_reject_invite(user_id, room_id) + stream_id = await self.member_handler.locally_reject_invite( + user_id, room_id + ) event_id = None return 200, {"event_id": event_id, "stream_id": stream_id} +class ReplicationLocallyRejectInviteRestServlet(ReplicationEndpoint): + """Rejects the invite for the user and room locally. + + Request format: + + POST /_synapse/replication/locally_reject_invite/:room_id/:user_id + + {} + """ + + NAME = "locally_reject_invite" + PATH_ARGS = ("room_id", "user_id") + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + + self.member_handler = hs.get_room_member_handler() + + @staticmethod + def _serialize_payload(room_id, user_id): + return {} + + async def _handle_request(self, request, room_id, user_id): + logger.info("locally_reject_invite: %s out of room: %s", user_id, room_id) + + stream_id = await self.member_handler.locally_reject_invite(user_id, room_id) + + return 200, {"stream_id": stream_id} + + class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint): """Notifies that a user has joined or left the room @@ -208,3 +245,4 @@ def register_servlets(hs, http_server): ReplicationRemoteJoinRestServlet(hs).register(http_server) ReplicationRemoteRejectInviteRestServlet(hs).register(http_server) ReplicationUserJoinedLeftRoomRestServlet(hs).register(http_server) + ReplicationLocallyRejectInviteRestServlet(hs).register(http_server) diff --git a/synapse/replication/http/presence.py b/synapse/replication/http/presence.py new file mode 100644 index 0000000000..ea1b33331b --- /dev/null +++ b/synapse/replication/http/presence.py @@ -0,0 +1,116 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import TYPE_CHECKING + +from synapse.http.servlet import parse_json_object_from_request +from synapse.replication.http._base import ReplicationEndpoint +from synapse.types import UserID + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class ReplicationBumpPresenceActiveTime(ReplicationEndpoint): + """We've seen the user do something that indicates they're interacting + with the app. + + The POST looks like: + + POST /_synapse/replication/bump_presence_active_time/ + + 200 OK + + {} + """ + + NAME = "bump_presence_active_time" + PATH_ARGS = ("user_id",) + METHOD = "POST" + CACHE = False + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + + self._presence_handler = hs.get_presence_handler() + + @staticmethod + def _serialize_payload(user_id): + return {} + + async def _handle_request(self, request, user_id): + await self._presence_handler.bump_presence_active_time( + UserID.from_string(user_id) + ) + + return ( + 200, + {}, + ) + + +class ReplicationPresenceSetState(ReplicationEndpoint): + """Set the presence state for a user. + + The POST looks like: + + POST /_synapse/replication/presence_set_state/ + + { + "state": { ... }, + "ignore_status_msg": false, + } + + 200 OK + + {} + """ + + NAME = "presence_set_state" + PATH_ARGS = ("user_id",) + METHOD = "POST" + CACHE = False + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + + self._presence_handler = hs.get_presence_handler() + + @staticmethod + def _serialize_payload(user_id, state, ignore_status_msg=False): + return { + "state": state, + "ignore_status_msg": ignore_status_msg, + } + + async def _handle_request(self, request, user_id): + content = parse_json_object_from_request(request) + + await self._presence_handler.set_state( + UserID.from_string(user_id), content["state"], content["ignore_status_msg"] + ) + + return ( + 200, + {}, + ) + + +def register_servlets(hs, http_server): + ReplicationBumpPresenceActiveTime(hs).register(http_server) + ReplicationPresenceSetState(hs).register(http_server) diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index acfa66a7a8..03300e5336 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -38,7 +38,9 @@ from synapse.replication.tcp.commands import ( from synapse.replication.tcp.protocol import AbstractConnection from synapse.replication.tcp.streams import ( STREAMS_MAP, + BackfillStream, CachesStream, + EventsStream, FederationStream, Stream, ) @@ -87,6 +89,14 @@ class ReplicationCommandHandler: self._streams_to_replicate.append(stream) continue + if isinstance(stream, (EventsStream, BackfillStream)): + # Only add EventStream and BackfillStream as a source on the + # instance in charge of event persistence. + if hs.config.worker.writers.events == hs.get_instance_name(): + self._streams_to_replicate.append(stream) + + continue + # Only add any other streams if we're on master. if hs.config.worker_app is not None: continue diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 0a13e1ed34..8173baef8f 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -100,7 +100,9 @@ class ShutdownRoomRestServlet(RestServlet): # we try and auto join below. # # TODO: Currently the events stream is written to from master - await self._replication.wait_for_stream_position("master", "events", stream_id) + await self._replication.wait_for_stream_position( + self.hs.config.worker.writers.events, "events", stream_id + ) users = await self.state.get_current_users_in_room(room_id) kicked_users = [] @@ -113,7 +115,7 @@ class ShutdownRoomRestServlet(RestServlet): try: target_requester = create_requester(user_id) - await self.room_member_handler.update_membership( + _, stream_id = await self.room_member_handler.update_membership( requester=target_requester, target=target_requester.user, room_id=room_id, @@ -123,6 +125,11 @@ class ShutdownRoomRestServlet(RestServlet): require_consent=False, ) + # Wait for leave to come in over replication before trying to forget. + await self._replication.wait_for_stream_position( + self.hs.config.worker.writers.events, "events", stream_id + ) + await self.room_member_handler.forget(target_requester.user, room_id) await self.room_member_handler.update_membership( diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/data_stores/__init__.py index 791961b296..599ee470d4 100644 --- a/synapse/storage/data_stores/__init__.py +++ b/synapse/storage/data_stores/__init__.py @@ -66,9 +66,9 @@ class DataStores(object): self.main = main_store_class(database, db_conn, hs) - # If we're on a process that can persist events (currently - # master), also instantiate a `PersistEventsStore` - if hs.config.worker.worker_app is None: + # If we're on a process that can persist events also + # instantiate a `PersistEventsStore` + if hs.config.worker.writers.events == hs.get_instance_name(): self.persist_events = PersistEventsStore( hs, database, self.main ) diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index 0e8378714a..417ac8dc7c 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -689,6 +689,25 @@ class DeviceWorkerStore(SQLBaseStore): desc="make_remote_user_device_cache_as_stale", ) + def mark_remote_user_device_list_as_unsubscribed(self, user_id): + """Mark that we no longer track device lists for remote user. + """ + + def _mark_remote_user_device_list_as_unsubscribed_txn(txn): + self.db.simple_delete_txn( + txn, + table="device_lists_remote_extremeties", + keyvalues={"user_id": user_id}, + ) + self._invalidate_cache_and_stream( + txn, self.get_device_list_last_stream_id_for_remote, (user_id,) + ) + + return self.db.runInteraction( + "mark_remote_user_device_list_as_unsubscribed", + _mark_remote_user_device_list_as_unsubscribed_txn, + ) + class DeviceBackgroundUpdateStore(SQLBaseStore): def __init__(self, database: Database, db_conn, hs): @@ -969,17 +988,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): desc="update_device", ) - @defer.inlineCallbacks - def mark_remote_user_device_list_as_unsubscribed(self, user_id): - """Mark that we no longer track device lists for remote user. - """ - yield self.db.simple_delete( - table="device_lists_remote_extremeties", - keyvalues={"user_id": user_id}, - desc="mark_remote_user_device_list_as_unsubscribed", - ) - self.get_device_list_last_stream_id_for_remote.invalidate((user_id,)) - def update_remote_device_list_cache_entry( self, user_id, device_id, content, stream_id ): diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index a97f8b3934..a6572571b4 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -138,10 +138,10 @@ class PersistEventsStore: self._backfill_id_gen = self.store._backfill_id_gen # type: StreamIdGenerator self._stream_id_gen = self.store._stream_id_gen # type: StreamIdGenerator - # This should only exist on master for now + # This should only exist on instances that are configured to write assert ( - hs.config.worker.worker_app is None - ), "Can only instantiate PersistEventsStore on master" + hs.config.worker.writers.events == hs.get_instance_name() + ), "Can only instantiate EventsStore on master" @_retry_on_integrity_error @defer.inlineCallbacks @@ -1590,3 +1590,31 @@ class PersistEventsStore: if not ev.internal_metadata.is_outlier() ], ) + + async def locally_reject_invite(self, user_id: str, room_id: str) -> int: + """Mark the invite has having been rejected even though we failed to + create a leave event for it. + """ + + sql = ( + "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE" + " room_id = ? AND invitee = ? AND locally_rejected is NULL" + " AND replaced_by is NULL" + ) + + def f(txn, stream_ordering): + txn.execute(sql, (stream_ordering, True, room_id, user_id)) + + # We also clear this entry from `local_current_membership`. + # Ideally we'd point to a leave event, but we don't have one, so + # nevermind. + self.db.simple_delete_txn( + txn, + table="local_current_membership", + keyvalues={"room_id": room_id, "user_id": user_id}, + ) + + with self._stream_id_gen.get_next() as stream_ordering: + await self.db.runInteraction("locally_reject_invite", f, stream_ordering) + + return stream_ordering diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index b880a71782..213d69100a 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -76,7 +76,7 @@ class EventsWorkerStore(SQLBaseStore): def __init__(self, database: Database, db_conn, hs): super(EventsWorkerStore, self).__init__(database, db_conn, hs) - if hs.config.worker_app is None: + if hs.config.worker.writers.events == hs.get_instance_name(): # We are the process in charge of generating stream ids for events, # so instantiate ID generators based on the database self._stream_id_gen = StreamIdGenerator( diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py index 7c5ca81ae0..137ebac833 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py @@ -1046,31 +1046,6 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): def __init__(self, database: Database, db_conn, hs): super(RoomMemberStore, self).__init__(database, db_conn, hs) - @defer.inlineCallbacks - def locally_reject_invite(self, user_id, room_id): - sql = ( - "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE" - " room_id = ? AND invitee = ? AND locally_rejected is NULL" - " AND replaced_by is NULL" - ) - - def f(txn, stream_ordering): - txn.execute(sql, (stream_ordering, True, room_id, user_id)) - - # We also clear this entry from `local_current_membership`. - # Ideally we'd point to a leave event, but we don't have one, so - # nevermind. - self.db.simple_delete_txn( - txn, - table="local_current_membership", - keyvalues={"room_id": room_id, "user_id": user_id}, - ) - - with self._stream_id_gen.get_next() as stream_ordering: - yield self.db.runInteraction("locally_reject_invite", f, stream_ordering) - - return stream_ordering - def forget(self, user_id, room_id): """Indicate that user_id wishes to discard history for room_id.""" diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index 12e1ffb9a2..f159400a87 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -786,3 +786,9 @@ class EventsPersistenceStorage(object): for user_id in left_users: await self.main_store.mark_remote_user_device_list_as_unsubscribed(user_id) + + async def locally_reject_invite(self, user_id: str, room_id: str) -> int: + """Mark the invite has having been rejected even though we failed to + create a leave event for it. + """ + return await self.persist_events_store.locally_reject_invite(user_id, room_id) -- cgit 1.5.1 From f4269694ce51a04761e43fc5897d6f8fe0ea18cc Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 22 May 2020 21:47:07 +0100 Subject: Optimise some references to hs.config (#7546) These are surprisingly expensive, and we only really need to do them at startup. --- changelog.d/7546.misc | 1 + synapse/handlers/message.py | 8 +- .../resource_limits_server_notices.py | 15 ++- .../data_stores/main/monthly_active_users.py | 28 ++++-- .../test_resource_limits_server_notices.py | 60 ++++++----- tests/storage/test_client_ips.py | 13 +-- tests/storage/test_monthly_active_users.py | 110 +++++++++++---------- tests/test_mau.py | 63 ++++++------ 8 files changed, 162 insertions(+), 136 deletions(-) create mode 100644 changelog.d/7546.misc (limited to 'synapse/storage') diff --git a/changelog.d/7546.misc b/changelog.d/7546.misc new file mode 100644 index 0000000000..0a9704789b --- /dev/null +++ b/changelog.d/7546.misc @@ -0,0 +1 @@ +Optimise some references to `hs.config`. diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index ea25f0515a..681f92cafd 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -366,7 +366,9 @@ class EventCreationHandler(object): self.notifier = hs.get_notifier() self.config = hs.config self.require_membership_for_aliases = hs.config.require_membership_for_aliases - self._instance_name = hs.get_instance_name() + self._is_event_writer = ( + self.config.worker.writers.events == hs.get_instance_name() + ) self.room_invite_state_types = self.hs.config.room_invite_state_types @@ -836,7 +838,7 @@ class EventCreationHandler(object): success = False try: # If we're a worker we need to hit out to the master. - if self.config.worker.writers.events != self._instance_name: + if not self._is_event_writer: result = await self.send_event( instance_name=self.config.worker.writers.events, event_id=event.event_id, @@ -906,7 +908,7 @@ class EventCreationHandler(object): This should only be run on the instance in charge of persisting events. """ - assert self.config.worker.writers.events == self._instance_name + assert self._is_event_writer if ratelimit: # We check if this is a room admin redacting an event so that we diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py index d97166351e..73f2cedb5c 100644 --- a/synapse/server_notices/resource_limits_server_notices.py +++ b/synapse/server_notices/resource_limits_server_notices.py @@ -48,6 +48,12 @@ class ResourceLimitsServerNotices(object): self._notifier = hs.get_notifier() + self._enabled = ( + hs.config.limit_usage_by_mau + and self._server_notices_manager.is_enabled() + and not hs.config.hs_disabled + ) + async def maybe_send_server_notice_to_user(self, user_id): """Check if we need to send a notice to this user, this will be true in two cases. @@ -61,14 +67,7 @@ class ResourceLimitsServerNotices(object): Returns: Deferred """ - if self._config.hs_disabled is True: - return - - if self._config.limit_usage_by_mau is False: - return - - if not self._server_notices_manager.is_enabled(): - # Don't try and send server notices unless they've been enabled + if not self._enabled: return timestamp = await self._store.user_last_seen_monthly_active(user_id) diff --git a/synapse/storage/data_stores/main/monthly_active_users.py b/synapse/storage/data_stores/main/monthly_active_users.py index 925bc5691b..a42c52f323 100644 --- a/synapse/storage/data_stores/main/monthly_active_users.py +++ b/synapse/storage/data_stores/main/monthly_active_users.py @@ -122,6 +122,10 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): def __init__(self, database: Database, db_conn, hs): super(MonthlyActiveUsersStore, self).__init__(database, db_conn, hs) + self._limit_usage_by_mau = hs.config.limit_usage_by_mau + self._mau_stats_only = hs.config.mau_stats_only + self._max_mau_value = hs.config.max_mau_value + # Do not add more reserved users than the total allowable number # cur = LoggingTransaction( self.db.new_transaction( @@ -130,7 +134,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): [], [], self._initialise_reserved_users, - hs.config.mau_limits_reserved_threepids[: self.hs.config.max_mau_value], + hs.config.mau_limits_reserved_threepids[: self._max_mau_value], ) def _initialise_reserved_users(self, txn, threepids): @@ -142,6 +146,15 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): threepids (list[dict]): List of threepid dicts to reserve """ + # XXX what is this function trying to achieve? It upserts into + # monthly_active_users for each *registered* reserved mau user, but why? + # + # - shouldn't there already be an entry for each reserved user (at least + # if they have been active recently)? + # + # - if it's important that the timestamp is kept up to date, why do we only + # run this at startup? + for tp in threepids: user_id = self.get_user_id_by_threepid_txn(txn, tp["medium"], tp["address"]) @@ -191,8 +204,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): txn.execute(sql, query_args) - max_mau_value = self.hs.config.max_mau_value - if self.hs.config.limit_usage_by_mau: + if self._limit_usage_by_mau: # If MAU user count still exceeds the MAU threshold, then delete on # a least recently active basis. # Note it is not possible to write this query using OFFSET due to @@ -210,13 +222,13 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): LIMIT ? ) """ - txn.execute(sql, (max_mau_value,)) + txn.execute(sql, ((self._max_mau_value),)) # Need if/else since 'AND user_id NOT IN ({})' fails on Postgres # when len(reserved_users) == 0. Works fine on sqlite. else: # Must be >= 0 for postgres num_of_non_reserved_users_to_remove = max( - max_mau_value - len(reserved_users), 0 + self._max_mau_value - len(reserved_users), 0 ) # It is important to filter reserved users twice to guard @@ -335,7 +347,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): Args: user_id(str): the user_id to query """ - if self.hs.config.limit_usage_by_mau or self.hs.config.mau_stats_only: + if self._limit_usage_by_mau or self._mau_stats_only: # Trial users and guests should not be included as part of MAU group is_guest = yield self.is_guest(user_id) if is_guest: @@ -356,11 +368,11 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): # In the case where mau_stats_only is True and limit_usage_by_mau is # False, there is no point in checking get_monthly_active_count - it # adds no value and will break the logic if max_mau_value is exceeded. - if not self.hs.config.limit_usage_by_mau: + if not self._limit_usage_by_mau: yield self.upsert_monthly_active_user(user_id) else: count = yield self.get_monthly_active_count() - if count < self.hs.config.max_mau_value: + if count < self._max_mau_value: yield self.upsert_monthly_active_user(user_id) elif now - last_seen_timestamp > LAST_SEEN_GRANULARITY: yield self.upsert_monthly_active_user(user_id) diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 406f29a7c0..99908edba3 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -27,20 +27,33 @@ from synapse.server_notices.resource_limits_server_notices import ( ) from tests import unittest +from tests.unittest import override_config +from tests.utils import default_config class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): - def make_homeserver(self, reactor, clock): - hs_config = self.default_config() - hs_config["server_notices"] = { - "system_mxid_localpart": "server", - "system_mxid_display_name": "test display name", - "system_mxid_avatar_url": None, - "room_name": "Server Notices", - } + def default_config(self): + config = default_config("test") + + config.update( + { + "admin_contact": "mailto:user@test.com", + "limit_usage_by_mau": True, + "server_notices": { + "system_mxid_localpart": "server", + "system_mxid_display_name": "test display name", + "system_mxid_avatar_url": None, + "room_name": "Server Notices", + }, + } + ) + + # apply any additional config which was specified via the override_config + # decorator. + if self._extra_config is not None: + config.update(self._extra_config) - hs = self.setup_test_homeserver(config=hs_config) - return hs + return config def prepare(self, reactor, clock, hs): self.server_notices_sender = self.hs.get_server_notices_sender() @@ -60,7 +73,6 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): ) self._send_notice = self._rlsn._server_notices_manager.send_notice - self.hs.config.limit_usage_by_mau = True self.user_id = "@user_id:test" self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = Mock( @@ -68,21 +80,17 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): ) self._rlsn._store.add_tag_to_room = Mock(return_value=defer.succeed(None)) self._rlsn._store.get_tags_for_room = Mock(return_value=defer.succeed({})) - self.hs.config.admin_contact = "mailto:user@test.com" - - def test_maybe_send_server_notice_to_user_flag_off(self): - """Tests cases where the flags indicate nothing to do""" - # test hs disabled case - self.hs.config.hs_disabled = True + @override_config({"hs_disabled": True}) + def test_maybe_send_server_notice_disabled_hs(self): + """If the HS is disabled, we should not send notices""" self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) - self._send_notice.assert_not_called() - # Test when mau limiting disabled - self.hs.config.hs_disabled = False - self.hs.config.limit_usage_by_mau = False - self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) + @override_config({"limit_usage_by_mau": False}) + def test_maybe_send_server_notice_to_user_flag_off(self): + """If mau limiting is disabled, we should not send notices""" + self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self._send_notice.assert_not_called() def test_maybe_send_server_notice_to_user_remove_blocked_notice(self): @@ -153,13 +161,12 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): self._send_notice.assert_not_called() + @override_config({"mau_limit_alerting": False}) def test_maybe_send_server_notice_when_alerting_suppressed_room_unblocked(self): """ Test that when server is over MAU limit and alerting is suppressed, then an alert message is not sent into the room """ - self.hs.config.mau_limit_alerting = False - self._rlsn._auth.check_auth_blocking = Mock( return_value=defer.succeed(None), side_effect=ResourceLimitError( @@ -170,12 +177,11 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): self.assertEqual(self._send_notice.call_count, 0) + @override_config({"mau_limit_alerting": False}) def test_check_hs_disabled_unaffected_by_mau_alert_suppression(self): """ Test that when a server is disabled, that MAU limit alerting is ignored. """ - self.hs.config.mau_limit_alerting = False - self._rlsn._auth.check_auth_blocking = Mock( return_value=defer.succeed(None), side_effect=ResourceLimitError( @@ -187,12 +193,12 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): # Would be better to check contents, but 2 calls == set blocking event self.assertEqual(self._send_notice.call_count, 2) + @override_config({"mau_limit_alerting": False}) def test_maybe_send_server_notice_when_alerting_suppressed_room_blocked(self): """ When the room is already in a blocked state, test that when alerting is suppressed that the room is returned to an unblocked state. """ - self.hs.config.mau_limit_alerting = False self._rlsn._auth.check_auth_blocking = Mock( return_value=defer.succeed(None), side_effect=ResourceLimitError( diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index bf674dd184..3b483bc7f0 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -23,6 +23,7 @@ from synapse.http.site import XForwardedForRequest from synapse.rest.client.v1 import login from tests import unittest +from tests.unittest import override_config class ClientIpStoreTestCase(unittest.HomeserverTestCase): @@ -137,9 +138,8 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): ], ) + @override_config({"limit_usage_by_mau": False, "max_mau_value": 50}) def test_disabled_monthly_active_user(self): - self.hs.config.limit_usage_by_mau = False - self.hs.config.max_mau_value = 50 user_id = "@user:server" self.get_success( self.store.insert_client_ip( @@ -149,9 +149,8 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) self.assertFalse(active) + @override_config({"limit_usage_by_mau": True, "max_mau_value": 50}) def test_adding_monthly_active_user_when_full(self): - self.hs.config.limit_usage_by_mau = True - self.hs.config.max_mau_value = 50 lots_of_users = 100 user_id = "@user:server" @@ -166,9 +165,8 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) self.assertFalse(active) + @override_config({"limit_usage_by_mau": True, "max_mau_value": 50}) def test_adding_monthly_active_user_when_space(self): - self.hs.config.limit_usage_by_mau = True - self.hs.config.max_mau_value = 50 user_id = "@user:server" active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) self.assertFalse(active) @@ -184,9 +182,8 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): active = self.get_success(self.store.user_last_seen_monthly_active(user_id)) self.assertTrue(active) + @override_config({"limit_usage_by_mau": True, "max_mau_value": 50}) def test_updating_monthly_active_user_when_space(self): - self.hs.config.limit_usage_by_mau = True - self.hs.config.max_mau_value = 50 user_id = "@user:server" self.get_success(self.store.register_user(user_id=user_id, password_hash=None)) diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index bc53bf0951..447fcb3a1c 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -19,94 +19,106 @@ from twisted.internet import defer from synapse.api.constants import UserTypes from tests import unittest +from tests.unittest import default_config, override_config FORTY_DAYS = 40 * 24 * 60 * 60 +def gen_3pids(count): + """Generate `count` threepids as a list.""" + return [ + {"medium": "email", "address": "user%i@matrix.org" % i} for i in range(count) + ] + + class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): - def make_homeserver(self, reactor, clock): + def default_config(self): + config = default_config("test") + + config.update({"limit_usage_by_mau": True, "max_mau_value": 50}) - hs = self.setup_test_homeserver() - self.store = hs.get_datastore() - hs.config.limit_usage_by_mau = True - hs.config.max_mau_value = 50 + # apply any additional config which was specified via the override_config + # decorator. + if self._extra_config is not None: + config.update(self._extra_config) + return config + + def prepare(self, reactor, clock, homeserver): + self.store = homeserver.get_datastore() # Advance the clock a bit reactor.advance(FORTY_DAYS) - return hs - + @override_config({"max_mau_value": 3, "mau_limit_reserved_threepids": gen_3pids(3)}) def test_initialise_reserved_users(self): - self.hs.config.max_mau_value = 5 + threepids = self.hs.config.mau_limits_reserved_threepids + + # register three users, of which two have reserved 3pids, and a third + # which is a support user. user1 = "@user1:server" - user1_email = "user1@matrix.org" + user1_email = threepids[0]["address"] user2 = "@user2:server" - user2_email = "user2@matrix.org" + user2_email = threepids[1]["address"] user3 = "@user3:server" - user3_email = "user3@matrix.org" - threepids = [ - {"medium": "email", "address": user1_email}, - {"medium": "email", "address": user2_email}, - {"medium": "email", "address": user3_email}, - ] - self.hs.config.mau_limits_reserved_threepids = threepids - # -1 because user3 is a support user and does not count - user_num = len(threepids) - 1 - - self.store.register_user(user_id=user1, password_hash=None) - self.store.register_user(user_id=user2, password_hash=None) - self.store.register_user( - user_id=user3, password_hash=None, user_type=UserTypes.SUPPORT - ) + self.store.register_user(user_id=user1) + self.store.register_user(user_id=user2) + self.store.register_user(user_id=user3, user_type=UserTypes.SUPPORT) self.pump() now = int(self.hs.get_clock().time_msec()) self.store.user_add_threepid(user1, "email", user1_email, now, now) self.store.user_add_threepid(user2, "email", user2_email, now, now) + # XXX why are we doing this here? this function is only run at startup + # so it is odd to re-run it here. self.store.db.runInteraction( "initialise", self.store._initialise_reserved_users, threepids ) self.pump() - active_count = self.store.get_monthly_active_count() + # the number of users we expect will be counted against the mau limit + # -1 because user3 is a support user and does not count + user_num = len(threepids) - 1 - # Test total counts, ensure user3 (support user) is not counted - self.assertEquals(self.get_success(active_count), user_num) + # Check the number of active users. Ensure user3 (support user) is not counted + active_count = self.get_success(self.store.get_monthly_active_count()) + self.assertEquals(active_count, user_num) - # Test user is marked as active + # Test each of the registered users is marked as active timestamp = self.store.user_last_seen_monthly_active(user1) self.assertTrue(self.get_success(timestamp)) timestamp = self.store.user_last_seen_monthly_active(user2) self.assertTrue(self.get_success(timestamp)) - # Test that users are never removed from the db. + # Test that users with reserved 3pids are not removed from the MAU table + # XXX some of this is redundant. poking things into the config shouldn't + # work, and in any case it's not obvious what we expect to happen when + # we advance the reactor. self.hs.config.max_mau_value = 0 - self.reactor.advance(FORTY_DAYS) self.hs.config.max_mau_value = 5 - self.store.reap_monthly_active_users() self.pump() active_count = self.store.get_monthly_active_count() self.assertEquals(self.get_success(active_count), user_num) - # Test that regular users are removed from the db + # Add some more users and check they are counted as active ru_count = 2 self.store.upsert_monthly_active_user("@ru1:server") self.store.upsert_monthly_active_user("@ru2:server") self.pump() - active_count = self.store.get_monthly_active_count() self.assertEqual(self.get_success(active_count), user_num + ru_count) - self.hs.config.max_mau_value = user_num + + # now run the reaper and check that the number of active users is reduced + # to max_mau_value self.store.reap_monthly_active_users() self.pump() active_count = self.store.get_monthly_active_count() - self.assertEquals(self.get_success(active_count), user_num) + self.assertEquals(self.get_success(active_count), 3) def test_can_insert_and_count_mau(self): count = self.store.get_monthly_active_count() @@ -136,8 +148,8 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): result = self.store.user_last_seen_monthly_active(user_id3) self.assertNotEqual(self.get_success(result), 0) + @override_config({"max_mau_value": 5}) def test_reap_monthly_active_users(self): - self.hs.config.max_mau_value = 5 initial_users = 10 for i in range(initial_users): self.store.upsert_monthly_active_user("@user%d:server" % i) @@ -158,19 +170,19 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): count = self.store.get_monthly_active_count() self.assertEquals(self.get_success(count), 0) + # Note that below says mau_limit (no s), this is the name of the config + # value, although it gets stored on the config object as mau_limits. + @override_config({"max_mau_value": 5, "mau_limit_reserved_threepids": gen_3pids(5)}) def test_reap_monthly_active_users_reserved_users(self): """ Tests that reaping correctly handles reaping where reserved users are present""" - - self.hs.config.max_mau_value = 5 - initial_users = 5 + threepids = self.hs.config.mau_limits_reserved_threepids + initial_users = len(threepids) reserved_user_number = initial_users - 1 - threepids = [] for i in range(initial_users): user = "@user%d:server" % i - email = "user%d@example.com" % i + email = "user%d@matrix.org" % i self.get_success(self.store.upsert_monthly_active_user(user)) - threepids.append({"medium": "email", "address": email}) # Need to ensure that the most recent entries in the # monthly_active_users table are reserved now = int(self.hs.get_clock().time_msec()) @@ -182,7 +194,6 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.store.user_add_threepid(user, "email", email, now, now) ) - self.hs.config.mau_limits_reserved_threepids = threepids self.store.db.runInteraction( "initialise", self.store._initialise_reserved_users, threepids ) @@ -279,11 +290,11 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.pump() self.assertEqual(self.get_success(count), 0) + # Note that the max_mau_value setting should not matter. + @override_config( + {"limit_usage_by_mau": False, "mau_stats_only": True, "max_mau_value": 1} + ) def test_track_monthly_users_without_cap(self): - self.hs.config.limit_usage_by_mau = False - self.hs.config.mau_stats_only = True - self.hs.config.max_mau_value = 1 # should not matter - count = self.store.get_monthly_active_count() self.assertEqual(0, self.get_success(count)) @@ -294,9 +305,8 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): count = self.store.get_monthly_active_count() self.assertEqual(2, self.get_success(count)) + @override_config({"limit_usage_by_mau": False, "mau_stats_only": False}) def test_no_users_when_not_tracking(self): - self.hs.config.limit_usage_by_mau = False - self.hs.config.mau_stats_only = False self.store.upsert_monthly_active_user = Mock() self.store.populate_monthly_active_users("@user:sever") diff --git a/tests/test_mau.py b/tests/test_mau.py index eb159e3ba5..8a97f0998d 100644 --- a/tests/test_mau.py +++ b/tests/test_mau.py @@ -17,47 +17,44 @@ import json -from mock import Mock - -from synapse.api.auth_blocking import AuthBlocking from synapse.api.constants import LoginType from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.rest.client.v2_alpha import register, sync from tests import unittest +from tests.unittest import override_config +from tests.utils import default_config class TestMauLimit(unittest.HomeserverTestCase): servlets = [register.register_servlets, sync.register_servlets] - def make_homeserver(self, reactor, clock): + def default_config(self): + config = default_config("test") - self.hs = self.setup_test_homeserver( - "red", http_client=None, federation_client=Mock() + config.update( + { + "registrations_require_3pid": [], + "limit_usage_by_mau": True, + "max_mau_value": 2, + "mau_trial_days": 0, + "server_notices": { + "system_mxid_localpart": "server", + "room_name": "Test Server Notice Room", + }, + } ) - self.store = self.hs.get_datastore() - - self.hs.config.registrations_require_3pid = [] - self.hs.config.enable_registration_captcha = False - self.hs.config.recaptcha_public_key = [] - - self.hs.config.limit_usage_by_mau = True - self.hs.config.hs_disabled = False - self.hs.config.max_mau_value = 2 - self.hs.config.server_notices_mxid = "@server:red" - self.hs.config.server_notices_mxid_display_name = None - self.hs.config.server_notices_mxid_avatar_url = None - self.hs.config.server_notices_room_name = "Test Server Notice Room" - self.hs.config.mau_trial_days = 0 + # apply any additional config which was specified via the override_config + # decorator. + if self._extra_config is not None: + config.update(self._extra_config) - # AuthBlocking reads config options during hs creation. Recreate the - # hs' copy of AuthBlocking after we've updated config values above - self.auth_blocking = AuthBlocking(self.hs) - self.hs.get_auth()._auth_blocking = self.auth_blocking + return config - return self.hs + def prepare(self, reactor, clock, homeserver): + self.store = homeserver.get_datastore() def test_simple_deny_mau(self): # Create and sync so that the MAU counts get updated @@ -66,6 +63,9 @@ class TestMauLimit(unittest.HomeserverTestCase): token2 = self.create_user("kermit2") self.do_sync_for_user(token2) + # check we're testing what we think we are: there should be two active users + self.assertEqual(self.get_success(self.store.get_monthly_active_count()), 2) + # We've created and activated two users, we shouldn't be able to # register new users with self.assertRaises(SynapseError) as cm: @@ -93,9 +93,8 @@ class TestMauLimit(unittest.HomeserverTestCase): token3 = self.create_user("kermit3") self.do_sync_for_user(token3) + @override_config({"mau_trial_days": 1}) def test_trial_delay(self): - self.hs.config.mau_trial_days = 1 - # We should be able to register more than the limit initially token1 = self.create_user("kermit1") self.do_sync_for_user(token1) @@ -127,8 +126,8 @@ class TestMauLimit(unittest.HomeserverTestCase): self.assertEqual(e.code, 403) self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + @override_config({"mau_trial_days": 1}) def test_trial_users_cant_come_back(self): - self.auth_blocking._mau_trial_days = 1 self.hs.config.mau_trial_days = 1 # We should be able to register more than the limit initially @@ -176,11 +175,11 @@ class TestMauLimit(unittest.HomeserverTestCase): self.assertEqual(e.code, 403) self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + @override_config( + # max_mau_value should not matter + {"max_mau_value": 1, "limit_usage_by_mau": False, "mau_stats_only": True} + ) def test_tracked_but_not_limited(self): - self.auth_blocking._max_mau_value = 1 # should not matter - self.auth_blocking._limit_usage_by_mau = False - self.hs.config.mau_stats_only = True - # Simply being able to create 2 users indicates that the # limit was not reached. token1 = self.create_user("kermit1") -- cgit 1.5.1 From d14c4d6b6d02e8c55a6647ce0f0b3b44356d6173 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Sat, 23 May 2020 01:20:10 +0100 Subject: Simplify reap_monthly_active_users (#7558) we can use `make_in_list_sql_clause` rather than doing our own half-baked equivalent, which has the benefit of working just fine with empty lists. (This has quite a lot of tests, so I think it's pretty safe) --- changelog.d/7558.misc | 1 + .../data_stores/main/monthly_active_users.py | 100 +++++++++------------ 2 files changed, 42 insertions(+), 59 deletions(-) create mode 100644 changelog.d/7558.misc (limited to 'synapse/storage') diff --git a/changelog.d/7558.misc b/changelog.d/7558.misc new file mode 100644 index 0000000000..f8f4111136 --- /dev/null +++ b/changelog.d/7558.misc @@ -0,0 +1 @@ +Simplify `reap_monthly_active_users`. diff --git a/synapse/storage/data_stores/main/monthly_active_users.py b/synapse/storage/data_stores/main/monthly_active_users.py index a42c52f323..1310d39069 100644 --- a/synapse/storage/data_stores/main/monthly_active_users.py +++ b/synapse/storage/data_stores/main/monthly_active_users.py @@ -17,7 +17,7 @@ import logging from twisted.internet import defer from synapse.storage._base import SQLBaseStore -from synapse.storage.database import Database +from synapse.storage.database import Database, make_in_list_sql_clause from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -187,75 +187,57 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): """ thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) - query_args = [thirty_days_ago] - base_sql = "DELETE FROM monthly_active_users WHERE timestamp < ?" - - # Need if/else since 'AND user_id NOT IN ({})' fails on Postgres - # when len(reserved_users) == 0. Works fine on sqlite. - if len(reserved_users) > 0: - # questionmarks is a hack to overcome sqlite not supporting - # tuples in 'WHERE IN %s' - question_marks = ",".join("?" * len(reserved_users)) - - query_args.extend(reserved_users) - sql = base_sql + " AND user_id NOT IN ({})".format(question_marks) - else: - sql = base_sql - txn.execute(sql, query_args) + in_clause, in_clause_args = make_in_list_sql_clause( + self.database_engine, "user_id", reserved_users + ) + + txn.execute( + "DELETE FROM monthly_active_users WHERE timestamp < ? AND NOT %s" + % (in_clause,), + [thirty_days_ago] + in_clause_args, + ) if self._limit_usage_by_mau: # If MAU user count still exceeds the MAU threshold, then delete on # a least recently active basis. # Note it is not possible to write this query using OFFSET due to # incompatibilities in how sqlite and postgres support the feature. - # sqlite requires 'LIMIT -1 OFFSET ?', the LIMIT must be present - # While Postgres does not require 'LIMIT', but also does not support + # Sqlite requires 'LIMIT -1 OFFSET ?', the LIMIT must be present, + # while Postgres does not require 'LIMIT', but also does not support # negative LIMIT values. So there is no way to write it that both can # support - if len(reserved_users) == 0: - sql = """ - DELETE FROM monthly_active_users - WHERE user_id NOT IN ( - SELECT user_id FROM monthly_active_users - ORDER BY timestamp DESC - LIMIT ? - ) - """ - txn.execute(sql, ((self._max_mau_value),)) - # Need if/else since 'AND user_id NOT IN ({})' fails on Postgres - # when len(reserved_users) == 0. Works fine on sqlite. - else: - # Must be >= 0 for postgres - num_of_non_reserved_users_to_remove = max( - self._max_mau_value - len(reserved_users), 0 - ) - # It is important to filter reserved users twice to guard - # against the case where the reserved user is present in the - # SELECT, meaning that a legitmate mau is deleted. - sql = """ - DELETE FROM monthly_active_users - WHERE user_id NOT IN ( - SELECT user_id FROM monthly_active_users - WHERE user_id NOT IN ({}) - ORDER BY timestamp DESC - LIMIT ? - ) - AND user_id NOT IN ({}) - """.format( - question_marks, question_marks + # Limit must be >= 0 for postgres + num_of_non_reserved_users_to_remove = max( + self._max_mau_value - len(reserved_users), 0 + ) + + # It is important to filter reserved users twice to guard + # against the case where the reserved user is present in the + # SELECT, meaning that a legitimate mau is deleted. + sql = """ + DELETE FROM monthly_active_users + WHERE user_id NOT IN ( + SELECT user_id FROM monthly_active_users + WHERE NOT %s + ORDER BY timestamp DESC + LIMIT ? ) - - query_args = [ - *reserved_users, - num_of_non_reserved_users_to_remove, - *reserved_users, - ] - - txn.execute(sql, query_args) - - # It seems poor to invalidate the whole cache, Postgres supports + AND NOT %s + """ % ( + in_clause, + in_clause, + ) + + query_args = ( + in_clause_args + + [num_of_non_reserved_users_to_remove] + + in_clause_args + ) + txn.execute(sql, query_args) + + # It seems poor to invalidate the whole cache. Postgres supports # 'Returning' which would allow me to invalidate only the # specific users, but sqlite has no way to do this and instead # I would need to SELECT and the DELETE which without locking -- cgit 1.5.1 From edd9a7214c467e96f5a694598b2fbcfae3ac2912 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 26 May 2020 11:43:17 +0100 Subject: Replace device_27_unique_idx bg update with a fg one (#7562) The bg update never managed to complete, because it kept being interrupted by transactions which want to take a lock. Just doing it in the foreground isn't that bad, and is a good deal simpler. --- UPGRADE.rst | 12 +++- changelog.d/7562.misc | 1 + synapse/storage/data_stores/main/devices.py | 34 ++------- ...vice_lists_outbound_last_success_unique_idx.sql | 28 -------- .../main/schema/delta/58/06dlols_unique_idx.py | 80 ++++++++++++++++++++++ synapse/storage/database.py | 1 - synapse/storage/prepare_database.py | 13 ++-- 7 files changed, 104 insertions(+), 65 deletions(-) create mode 100644 changelog.d/7562.misc delete mode 100644 synapse/storage/data_stores/main/schema/delta/58/04device_lists_outbound_last_success_unique_idx.sql create mode 100644 synapse/storage/data_stores/main/schema/delta/58/06dlols_unique_idx.py (limited to 'synapse/storage') diff --git a/UPGRADE.rst b/UPGRADE.rst index 41c47e964d..3b5627e852 100644 --- a/UPGRADE.rst +++ b/UPGRADE.rst @@ -75,9 +75,15 @@ for example: wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb -Upgrading to v1.13.0 +Upgrading to v1.14.0 ==================== +This version includes a database update which is run as part of the upgrade, +and which may take a couple of minutes in the case of a large server. Synapse +will not respond to HTTP requests while this update is taking place. + +Upgrading to v1.13.0 +==================== Incorrect database migration in old synapse versions ---------------------------------------------------- @@ -136,12 +142,12 @@ back to v1.12.4 you need to: 2. Decrease the schema version in the database: .. code:: sql - + UPDATE schema_version SET version = 57; 3. Downgrade Synapse by following the instructions for your installation method in the "Rolling back to older versions" section above. - + Upgrading to v1.12.0 ==================== diff --git a/changelog.d/7562.misc b/changelog.d/7562.misc new file mode 100644 index 0000000000..3c25cd9917 --- /dev/null +++ b/changelog.d/7562.misc @@ -0,0 +1 @@ +Improve performance of `mark_as_sent_devices_by_remote`. diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index 417ac8dc7c..fb9f798e29 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -55,10 +55,6 @@ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = ( BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes" -BG_UPDATE_DROP_DEVICE_LISTS_OUTBOUND_LAST_SUCCESS_NON_UNIQUE_IDX = ( - "drop_device_lists_outbound_last_success_non_unique_idx" -) - class DeviceWorkerStore(SQLBaseStore): def get_device(self, user_id, device_id): @@ -749,19 +745,13 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, self._remove_duplicate_outbound_pokes, ) - # create a unique index on device_lists_outbound_last_success - self.db.updates.register_background_index_update( + # a pair of background updates that were added during the 1.14 release cycle, + # but replaced with 58/06dlols_unique_idx.py + self.db.updates.register_noop_background_update( "device_lists_outbound_last_success_unique_idx", - index_name="device_lists_outbound_last_success_unique_idx", - table="device_lists_outbound_last_success", - columns=["destination", "user_id"], - unique=True, ) - - # once that completes, we can remove the old non-unique index. - self.db.updates.register_background_update_handler( - BG_UPDATE_DROP_DEVICE_LISTS_OUTBOUND_LAST_SUCCESS_NON_UNIQUE_IDX, - self._drop_device_lists_outbound_last_success_non_unique_idx, + self.db.updates.register_noop_background_update( + "drop_device_lists_outbound_last_success_non_unique_idx", ) @defer.inlineCallbacks @@ -838,20 +828,6 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): return rows - async def _drop_device_lists_outbound_last_success_non_unique_idx( - self, progress, batch_size - ): - def f(txn): - txn.execute("DROP INDEX IF EXISTS device_lists_outbound_last_success_idx") - - await self.db.runInteraction( - "drop_device_lists_outbound_last_success_non_unique_idx", f, - ) - await self.db.updates._end_background_update( - BG_UPDATE_DROP_DEVICE_LISTS_OUTBOUND_LAST_SUCCESS_NON_UNIQUE_IDX - ) - return 1 - class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): def __init__(self, database: Database, db_conn, hs): diff --git a/synapse/storage/data_stores/main/schema/delta/58/04device_lists_outbound_last_success_unique_idx.sql b/synapse/storage/data_stores/main/schema/delta/58/04device_lists_outbound_last_success_unique_idx.sql deleted file mode 100644 index d5e6deb878..0000000000 --- a/synapse/storage/data_stores/main/schema/delta/58/04device_lists_outbound_last_success_unique_idx.sql +++ /dev/null @@ -1,28 +0,0 @@ -/* Copyright 2020 The Matrix.org Foundation C.I.C - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- register a background update which will create a unique index on --- device_lists_outbound_last_success -INSERT into background_updates (ordering, update_name, progress_json) - VALUES (5804, 'device_lists_outbound_last_success_unique_idx', '{}'); - --- once that completes, we can drop the old index. -INSERT into background_updates (ordering, update_name, progress_json, depends_on) - VALUES ( - 5804, - 'drop_device_lists_outbound_last_success_non_unique_idx', - '{}', - 'device_lists_outbound_last_success_unique_idx' - ); diff --git a/synapse/storage/data_stores/main/schema/delta/58/06dlols_unique_idx.py b/synapse/storage/data_stores/main/schema/delta/58/06dlols_unique_idx.py new file mode 100644 index 0000000000..d353f2bcb3 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/58/06dlols_unique_idx.py @@ -0,0 +1,80 @@ +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This migration rebuilds the device_lists_outbound_last_success table without duplicate +entries, and with a UNIQUE index. +""" + +import logging +from io import StringIO + +from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine +from synapse.storage.prepare_database import execute_statements_from_stream +from synapse.storage.types import Cursor + +logger = logging.getLogger(__name__) + + +def run_upgrade(*args, **kwargs): + pass + + +def run_create(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs): + # some instances might already have this index, in which case we can skip this + if isinstance(database_engine, PostgresEngine): + cur.execute( + """ + SELECT 1 FROM pg_class WHERE relkind = 'i' + AND relname = 'device_lists_outbound_last_success_unique_idx' + """ + ) + + if cur.rowcount: + logger.info( + "Unique index exists on device_lists_outbound_last_success: " + "skipping rebuild" + ) + return + + logger.info("Rebuilding device_lists_outbound_last_success with unique index") + execute_statements_from_stream(cur, StringIO(_rebuild_commands)) + + +# there might be duplicates, so the easiest way to achieve this is to create a new +# table with the right data, and renaming it into place + +_rebuild_commands = """ +DROP TABLE IF EXISTS device_lists_outbound_last_success_new; + +CREATE TABLE device_lists_outbound_last_success_new ( + destination TEXT NOT NULL, + user_id TEXT NOT NULL, + stream_id BIGINT NOT NULL +); + +-- this took about 30 seconds on matrix.org's 16 million rows. +INSERT INTO device_lists_outbound_last_success_new + SELECT destination, user_id, MAX(stream_id) FROM device_lists_outbound_last_success + GROUP BY destination, user_id; + +-- and this another 30 seconds. +CREATE UNIQUE INDEX device_lists_outbound_last_success_unique_idx + ON device_lists_outbound_last_success_new (destination, user_id); + +DROP TABLE device_lists_outbound_last_success; + +ALTER TABLE device_lists_outbound_last_success_new + RENAME TO device_lists_outbound_last_success; +""" diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 9947dbce77..b112ff3df2 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -78,7 +78,6 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = { "device_lists_remote_extremeties": "device_lists_remote_extremeties_unique_idx", "device_lists_remote_cache": "device_lists_remote_cache_unique_idx", "event_search": "event_search_event_id_idx", - "device_lists_outbound_last_success": "device_lists_outbound_last_success_unique_idx", } diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 640f242584..9afc145340 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -19,10 +19,12 @@ import logging import os import re from collections import Counter +from typing import TextIO import attr from synapse.storage.engines.postgres import PostgresEngine +from synapse.storage.types import Cursor logger = logging.getLogger(__name__) @@ -479,8 +481,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams) ) logger.info("applying schema %s for %s", name, modname) - for statement in get_statements(stream): - cur.execute(statement) + execute_statements_from_stream(cur, stream) # Mark as done. cur.execute( @@ -538,8 +539,12 @@ def get_statements(f): def executescript(txn, schema_path): with open(schema_path, "r") as f: - for statement in get_statements(f): - txn.execute(statement) + execute_statements_from_stream(txn, f) + + +def execute_statements_from_stream(cur: Cursor, f: TextIO): + for statement in get_statements(f): + cur.execute(statement) def _get_or_create_schema_state(txn, database_engine): -- cgit 1.5.1