From 8123b2f90945e6bb3d65cf48730c2133b46d2b9a Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 4 May 2020 17:17:45 +0100 Subject: Add MultiWriterIdGenerator. (#7281) This will be used to coordinate stream IDs across multiple writers. Functions as the equivalent of both `StreamIdGenerator` and `SlavedIdTracker`. --- synapse/storage/util/id_generators.py | 169 +++++++++++++++++++++++++++++++++- 1 file changed, 167 insertions(+), 2 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 9d851beaa5..86d04ea9ac 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -16,6 +16,11 @@ import contextlib import threading from collections import deque +from typing import Dict, Set, Tuple + +from typing_extensions import Deque + +from synapse.storage.database import Database, LoggingTransaction class IdGenerator(object): @@ -87,7 +92,7 @@ class StreamIdGenerator(object): self._current = (max if step > 0 else min)( self._current, _load_current_id(db_conn, table, column, step) ) - self._unfinished_ids = deque() + self._unfinished_ids = deque() # type: Deque[int] def get_next(self): """ @@ -163,7 +168,7 @@ class ChainedIdGenerator(object): self.chained_generator = chained_generator self._lock = threading.Lock() self._current_max = _load_current_id(db_conn, table, column) - self._unfinished_ids = deque() + self._unfinished_ids = deque() # type: Deque[Tuple[int, int]] def get_next(self): """ @@ -198,3 +203,163 @@ class ChainedIdGenerator(object): return stream_id - 1, chained_id return self._current_max, self.chained_generator.get_current_token() + + +class MultiWriterIdGenerator: + """An ID generator that tracks a stream that can have multiple writers. + + Uses a Postgres sequence to coordinate ID assignment, but positions of other + writers will only get updated when `advance` is called (by replication). + + Note: Only works with Postgres. + + Args: + db_conn + db + instance_name: The name of this instance. + table: Database table associated with stream. + instance_column: Column that stores the row's writer's instance name + id_column: Column that stores the stream ID. + sequence_name: The name of the postgres sequence used to generate new + IDs. + """ + + def __init__( + self, + db_conn, + db: Database, + instance_name: str, + table: str, + instance_column: str, + id_column: str, + sequence_name: str, + ): + self._db = db + self._instance_name = instance_name + self._sequence_name = sequence_name + + # We lock as some functions may be called from DB threads. + self._lock = threading.Lock() + + self._current_positions = self._load_current_ids( + db_conn, table, instance_column, id_column + ) + + # Set of local IDs that we're still processing. The current position + # should be less than the minimum of this set (if not empty). + self._unfinished_ids = set() # type: Set[int] + + def _load_current_ids( + self, db_conn, table: str, instance_column: str, id_column: str + ) -> Dict[str, int]: + sql = """ + SELECT %(instance)s, MAX(%(id)s) FROM %(table)s + GROUP BY %(instance)s + """ % { + "instance": instance_column, + "id": id_column, + "table": table, + } + + cur = db_conn.cursor() + cur.execute(sql) + + # `cur` is an iterable over returned rows, which are 2-tuples. + current_positions = dict(cur) + + cur.close() + + return current_positions + + def _load_next_id_txn(self, txn): + txn.execute("SELECT nextval(?)", (self._sequence_name,)) + (next_id,) = txn.fetchone() + return next_id + + async def get_next(self): + """ + Usage: + with await stream_id_gen.get_next() as stream_id: + # ... persist event ... + """ + next_id = await self._db.runInteraction("_load_next_id", self._load_next_id_txn) + + # Assert the fetched ID is actually greater than what we currently + # believe the ID to be. If not, then the sequence and table have got + # out of sync somehow. + assert self.get_current_token() < next_id + + with self._lock: + self._unfinished_ids.add(next_id) + + @contextlib.contextmanager + def manager(): + try: + yield next_id + finally: + self._mark_id_as_finished(next_id) + + return manager() + + def get_next_txn(self, txn: LoggingTransaction): + """ + Usage: + + stream_id = stream_id_gen.get_next(txn) + # ... persist event ... + """ + + next_id = self._load_next_id_txn(txn) + + with self._lock: + self._unfinished_ids.add(next_id) + + txn.call_after(self._mark_id_as_finished, next_id) + txn.call_on_exception(self._mark_id_as_finished, next_id) + + return next_id + + def _mark_id_as_finished(self, next_id: int): + """The ID has finished being processed so we should advance the + current poistion if possible. + """ + + with self._lock: + self._unfinished_ids.discard(next_id) + + # Figure out if its safe to advance the position by checking there + # aren't any lower allocated IDs that are yet to finish. + if all(c > next_id for c in self._unfinished_ids): + curr = self._current_positions.get(self._instance_name, 0) + self._current_positions[self._instance_name] = max(curr, next_id) + + def get_current_token(self, instance_name: str = None) -> int: + """Gets the current position of a named writer (defaults to current + instance). + + Returns 0 if we don't have a position for the named writer (likely due + to it being a new writer). + """ + + if instance_name is None: + instance_name = self._instance_name + + with self._lock: + return self._current_positions.get(instance_name, 0) + + def get_positions(self) -> Dict[str, int]: + """Get a copy of the current positon map. + """ + + with self._lock: + return dict(self._current_positions) + + def advance(self, instance_name: str, new_id: int): + """Advance the postion of the named writer to the given ID, if greater + than existing entry. + """ + + with self._lock: + self._current_positions[instance_name] = max( + new_id, self._current_positions.get(instance_name, 0) + ) -- cgit 1.5.1 From d7983b63a6746d92225295f1e9d521f847cf8ba7 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 7 May 2020 13:51:08 +0100 Subject: Support any process writing to cache invalidation stream. (#7436) --- changelog.d/7436.misc | 1 + docs/tcp_replication.md | 4 - scripts/synapse_port_db | 4 +- synapse/replication/slave/storage/_base.py | 50 +++---------- synapse/replication/slave/storage/account_data.py | 6 +- synapse/replication/slave/storage/deviceinbox.py | 6 +- synapse/replication/slave/storage/devices.py | 6 +- synapse/replication/slave/storage/events.py | 6 +- synapse/replication/slave/storage/groups.py | 6 +- synapse/replication/slave/storage/presence.py | 6 +- synapse/replication/slave/storage/push_rule.py | 6 +- synapse/replication/slave/storage/pushers.py | 6 +- synapse/replication/slave/storage/receipts.py | 6 +- synapse/replication/slave/storage/room.py | 4 +- synapse/replication/tcp/client.py | 6 +- synapse/replication/tcp/commands.py | 33 -------- synapse/replication/tcp/handler.py | 42 ++--------- synapse/replication/tcp/resource.py | 22 +++++- synapse/replication/tcp/streams/_base.py | 87 +++++++++++++++------- synapse/replication/tcp/streams/events.py | 4 +- synapse/replication/tcp/streams/federation.py | 12 ++- synapse/storage/_base.py | 3 + synapse/storage/data_stores/main/__init__.py | 15 +++- synapse/storage/data_stores/main/cache.py | 84 +++++++++++---------- .../schema/delta/58/05cache_instance.sql.postgres | 30 ++++++++ synapse/storage/prepare_database.py | 2 + 26 files changed, 226 insertions(+), 231 deletions(-) create mode 100644 changelog.d/7436.misc create mode 100644 synapse/storage/data_stores/main/schema/delta/58/05cache_instance.sql.postgres (limited to 'synapse/storage') diff --git a/changelog.d/7436.misc b/changelog.d/7436.misc new file mode 100644 index 0000000000..f7c4514950 --- /dev/null +++ b/changelog.d/7436.misc @@ -0,0 +1 @@ +Support any process writing to cache invalidation stream. diff --git a/docs/tcp_replication.md b/docs/tcp_replication.md index ab2fffbfe4..db318baa9d 100644 --- a/docs/tcp_replication.md +++ b/docs/tcp_replication.md @@ -219,10 +219,6 @@ Asks the server for the current position of all streams. Inform the server a pusher should be removed -#### INVALIDATE_CACHE (C) - - Inform the server a cache should be invalidated - ### REMOTE_SERVER_UP (S, C) Inform other processes that a remote server may have come back online. diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index e8b698f3ff..acd9ac4b75 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -122,7 +122,7 @@ APPEND_ONLY_TABLES = [ "presence_stream", "push_rules_stream", "ex_outlier_stream", - "cache_invalidation_stream", + "cache_invalidation_stream_by_instance", "public_room_list_stream", "state_group_edges", "stream_ordering_to_exterm", @@ -188,7 +188,7 @@ class MockHomeserver: self.clock = Clock(reactor) self.config = config self.hostname = config.server_name - self.version_string = "Synapse/"+get_version_string(synapse) + self.version_string = "Synapse/" + get_version_string(synapse) def get_clock(self): return self.clock diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index 5d7c8871a4..2904bd0235 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -18,14 +18,10 @@ from typing import Optional import six -from synapse.storage.data_stores.main.cache import ( - CURRENT_STATE_CACHE_NAME, - CacheInvalidationWorkerStore, -) +from synapse.storage.data_stores.main.cache import CacheInvalidationWorkerStore from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine - -from ._slaved_id_tracker import SlavedIdTracker +from synapse.storage.util.id_generators import MultiWriterIdGenerator logger = logging.getLogger(__name__) @@ -41,40 +37,16 @@ class BaseSlavedStore(CacheInvalidationWorkerStore): def __init__(self, database: Database, db_conn, hs): super(BaseSlavedStore, self).__init__(database, db_conn, hs) if isinstance(self.database_engine, PostgresEngine): - self._cache_id_gen = SlavedIdTracker( - db_conn, "cache_invalidation_stream", "stream_id" - ) # type: Optional[SlavedIdTracker] + self._cache_id_gen = MultiWriterIdGenerator( + db_conn, + database, + instance_name=hs.get_instance_name(), + table="cache_invalidation_stream_by_instance", + instance_column="instance_name", + id_column="stream_id", + sequence_name="cache_invalidation_stream_seq", + ) # type: Optional[MultiWriterIdGenerator] else: self._cache_id_gen = None self.hs = hs - - def get_cache_stream_token(self): - if self._cache_id_gen: - return self._cache_id_gen.get_current_token() - else: - return 0 - - def process_replication_rows(self, stream_name, token, rows): - if stream_name == "caches": - if self._cache_id_gen: - self._cache_id_gen.advance(token) - for row in rows: - if row.cache_func == CURRENT_STATE_CACHE_NAME: - if row.keys is None: - raise Exception( - "Can't send an 'invalidate all' for current state cache" - ) - - room_id = row.keys[0] - members_changed = set(row.keys[1:]) - self._invalidate_state_caches(room_id, members_changed) - else: - self._attempt_to_invalidate_cache(row.cache_func, row.keys) - - def _invalidate_cache_and_stream(self, txn, cache_func, keys): - txn.call_after(cache_func.invalidate, keys) - txn.call_after(self._send_invalidation_poke, cache_func, keys) - - def _send_invalidation_poke(self, cache_func, keys): - self.hs.get_tcp_replication().send_invalidate_cache(cache_func, keys) diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py index 65e54b1c71..2a4f5c7cfd 100644 --- a/synapse/replication/slave/storage/account_data.py +++ b/synapse/replication/slave/storage/account_data.py @@ -32,7 +32,7 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved def get_max_account_data_stream_id(self): return self._account_data_id_gen.get_current_token() - def process_replication_rows(self, stream_name, token, rows): + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == "tag_account_data": self._account_data_id_gen.advance(token) for row in rows: @@ -51,6 +51,4 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved (row.user_id, row.room_id, row.data_type) ) self._account_data_stream_cache.entity_has_changed(row.user_id, token) - return super(SlavedAccountDataStore, self).process_replication_rows( - stream_name, token, rows - ) + return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py index c923751e50..6e7fd259d4 100644 --- a/synapse/replication/slave/storage/deviceinbox.py +++ b/synapse/replication/slave/storage/deviceinbox.py @@ -43,7 +43,7 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore): expiry_ms=30 * 60 * 1000, ) - def process_replication_rows(self, stream_name, token, rows): + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == "to_device": self._device_inbox_id_gen.advance(token) for row in rows: @@ -55,6 +55,4 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore): self._device_federation_outbox_stream_cache.entity_has_changed( row.entity, token ) - return super(SlavedDeviceInboxStore, self).process_replication_rows( - stream_name, token, rows - ) + return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index 58fb0eaae3..9d8067342f 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -48,7 +48,7 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto "DeviceListFederationStreamChangeCache", device_list_max ) - def process_replication_rows(self, stream_name, token, rows): + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == DeviceListsStream.NAME: self._device_list_id_gen.advance(token) self._invalidate_caches_for_devices(token, rows) @@ -56,9 +56,7 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto self._device_list_id_gen.advance(token) for row in rows: self._user_signature_stream_cache.entity_has_changed(row.user_id, token) - return super(SlavedDeviceStore, self).process_replication_rows( - stream_name, token, rows - ) + return super().process_replication_rows(stream_name, instance_name, token, rows) def _invalidate_caches_for_devices(self, token, rows): for row in rows: diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index 15011259df..b313720a4b 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -93,7 +93,7 @@ class SlavedEventStore( def get_room_min_stream_ordering(self): return self._backfill_id_gen.get_current_token() - def process_replication_rows(self, stream_name, token, rows): + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == "events": self._stream_id_gen.advance(token) for row in rows: @@ -111,9 +111,7 @@ class SlavedEventStore( row.relates_to, backfilled=True, ) - return super(SlavedEventStore, self).process_replication_rows( - stream_name, token, rows - ) + return super().process_replication_rows(stream_name, instance_name, token, rows) def _process_event_stream_row(self, token, row): data = row.data diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py index 01bcf0e882..1851e7d525 100644 --- a/synapse/replication/slave/storage/groups.py +++ b/synapse/replication/slave/storage/groups.py @@ -37,12 +37,10 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore): def get_group_stream_token(self): return self._group_updates_id_gen.get_current_token() - def process_replication_rows(self, stream_name, token, rows): + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == "groups": self._group_updates_id_gen.advance(token) for row in rows: self._group_updates_stream_cache.entity_has_changed(row.user_id, token) - return super(SlavedGroupServerStore, self).process_replication_rows( - stream_name, token, rows - ) + return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py index fae3125072..bd79ba99be 100644 --- a/synapse/replication/slave/storage/presence.py +++ b/synapse/replication/slave/storage/presence.py @@ -41,12 +41,10 @@ class SlavedPresenceStore(BaseSlavedStore): def get_current_presence_token(self): return self._presence_id_gen.get_current_token() - def process_replication_rows(self, stream_name, token, rows): + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == "presence": self._presence_id_gen.advance(token) for row in rows: self.presence_stream_cache.entity_has_changed(row.user_id, token) self._get_presence_for_user.invalidate((row.user_id,)) - return super(SlavedPresenceStore, self).process_replication_rows( - stream_name, token, rows - ) + return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py index 6138796da4..5d5816d7eb 100644 --- a/synapse/replication/slave/storage/push_rule.py +++ b/synapse/replication/slave/storage/push_rule.py @@ -37,13 +37,11 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore): def get_max_push_rules_stream_id(self): return self._push_rules_stream_id_gen.get_current_token() - def process_replication_rows(self, stream_name, token, rows): + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == "push_rules": self._push_rules_stream_id_gen.advance(token) for row in rows: self.get_push_rules_for_user.invalidate((row.user_id,)) self.get_push_rules_enabled_for_user.invalidate((row.user_id,)) self.push_rules_stream_cache.entity_has_changed(row.user_id, token) - return super(SlavedPushRuleStore, self).process_replication_rows( - stream_name, token, rows - ) + return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py index 67be337945..cb78b49acb 100644 --- a/synapse/replication/slave/storage/pushers.py +++ b/synapse/replication/slave/storage/pushers.py @@ -31,9 +31,7 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore): def get_pushers_stream_token(self): return self._pushers_id_gen.get_current_token() - def process_replication_rows(self, stream_name, token, rows): + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == "pushers": self._pushers_id_gen.advance(token) - return super(SlavedPusherStore, self).process_replication_rows( - stream_name, token, rows - ) + return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py index 993432edcb..be716cc558 100644 --- a/synapse/replication/slave/storage/receipts.py +++ b/synapse/replication/slave/storage/receipts.py @@ -51,7 +51,7 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id) self.get_receipts_for_room.invalidate((room_id, receipt_type)) - def process_replication_rows(self, stream_name, token, rows): + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == "receipts": self._receipts_id_gen.advance(token) for row in rows: @@ -60,6 +60,4 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): ) self._receipts_stream_cache.entity_has_changed(row.room_id, token) - return super(SlavedReceiptsStore, self).process_replication_rows( - stream_name, token, rows - ) + return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py index 10dda8708f..8873bf37e5 100644 --- a/synapse/replication/slave/storage/room.py +++ b/synapse/replication/slave/storage/room.py @@ -30,8 +30,8 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore): def get_current_public_room_stream_id(self): return self._public_room_id_gen.get_current_token() - def process_replication_rows(self, stream_name, token, rows): + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == "public_rooms": self._public_room_id_gen.advance(token) - return super(RoomStore, self).process_replication_rows(stream_name, token, rows) + return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 3bbf3c3569..20cb8a654f 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -100,10 +100,10 @@ class ReplicationDataHandler: token: stream token for this batch of rows rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. """ - self.store.process_replication_rows(stream_name, token, rows) + self.store.process_replication_rows(stream_name, instance_name, token, rows) - async def on_position(self, stream_name: str, token: int): - self.store.process_replication_rows(stream_name, token, []) + async def on_position(self, stream_name: str, instance_name: str, token: int): + self.store.process_replication_rows(stream_name, instance_name, token, []) def on_remote_server_up(self, server: str): """Called when get a new REMOTE_SERVER_UP command.""" diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index f58e384d17..c04f622816 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -341,37 +341,6 @@ class RemovePusherCommand(Command): return " ".join((self.app_id, self.push_key, self.user_id)) -class InvalidateCacheCommand(Command): - """Sent by the client to invalidate an upstream cache. - - THIS IS NOT RELIABLE, AND SHOULD *NOT* BE USED ACCEPT FOR THINGS THAT ARE - NOT DISASTROUS IF WE DROP ON THE FLOOR. - - Mainly used to invalidate destination retry timing caches. - - Format:: - - INVALIDATE_CACHE - - Where is a json list. - """ - - NAME = "INVALIDATE_CACHE" - - def __init__(self, cache_func, keys): - self.cache_func = cache_func - self.keys = keys - - @classmethod - def from_line(cls, line): - cache_func, keys_json = line.split(" ", 1) - - return cls(cache_func, json.loads(keys_json)) - - def to_line(self): - return " ".join((self.cache_func, _json_encoder.encode(self.keys))) - - class UserIpCommand(Command): """Sent periodically when a worker sees activity from a client. @@ -439,7 +408,6 @@ _COMMANDS = ( UserSyncCommand, FederationAckCommand, RemovePusherCommand, - InvalidateCacheCommand, UserIpCommand, RemoteServerUpCommand, ClearUserSyncsCommand, @@ -467,7 +435,6 @@ VALID_CLIENT_COMMANDS = ( ClearUserSyncsCommand.NAME, FederationAckCommand.NAME, RemovePusherCommand.NAME, - InvalidateCacheCommand.NAME, UserIpCommand.NAME, ErrorCommand.NAME, RemoteServerUpCommand.NAME, diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index b14a3d9fca..7c5d6c76e7 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -15,18 +15,7 @@ # limitations under the License. import logging -from typing import ( - Any, - Callable, - Dict, - Iterable, - Iterator, - List, - Optional, - Set, - Tuple, - TypeVar, -) +from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar from prometheus_client import Counter @@ -38,7 +27,6 @@ from synapse.replication.tcp.commands import ( ClearUserSyncsCommand, Command, FederationAckCommand, - InvalidateCacheCommand, PositionCommand, RdataCommand, RemoteServerUpCommand, @@ -171,7 +159,7 @@ class ReplicationCommandHandler: return for stream_name, stream in self._streams.items(): - current_token = stream.current_token() + current_token = stream.current_token(self._instance_name) self.send_command( PositionCommand(stream_name, self._instance_name, current_token) ) @@ -210,18 +198,6 @@ class ReplicationCommandHandler: self._notifier.on_new_replication_data() - async def on_INVALIDATE_CACHE( - self, conn: AbstractConnection, cmd: InvalidateCacheCommand - ): - invalidate_cache_counter.inc() - - if self._is_master: - # We invalidate the cache locally, but then also stream that to other - # workers. - await self._store.invalidate_cache_and_stream( - cmd.cache_func, tuple(cmd.keys) - ) - async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand): user_ip_cache_counter.inc() @@ -295,7 +271,7 @@ class ReplicationCommandHandler: rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. """ - logger.debug("Received rdata %s -> %s", stream_name, token) + logger.debug("Received rdata %s (%s) -> %s", stream_name, instance_name, token) await self._replication_data_handler.on_rdata( stream_name, instance_name, token, rows ) @@ -326,7 +302,7 @@ class ReplicationCommandHandler: self._pending_batches.pop(stream_name, []) # Find where we previously streamed up to. - current_token = stream.current_token() + current_token = stream.current_token(cmd.instance_name) # If the position token matches our current token then we're up to # date and there's nothing to do. Otherwise, fetch all updates @@ -363,7 +339,9 @@ class ReplicationCommandHandler: logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token) # We've now caught up to position sent to us, notify handler. - await self._replication_data_handler.on_position(stream_name, cmd.token) + await self._replication_data_handler.on_position( + cmd.stream_name, cmd.instance_name, cmd.token + ) self._streams_by_connection.setdefault(conn, set()).add(stream_name) @@ -491,12 +469,6 @@ class ReplicationCommandHandler: cmd = RemovePusherCommand(app_id, push_key, user_id) self.send_command(cmd) - def send_invalidate_cache(self, cache_func: Callable, keys: tuple): - """Poke the master to invalidate a cache. - """ - cmd = InvalidateCacheCommand(cache_func.__name__, keys) - self.send_command(cmd) - def send_user_ip( self, user_id: str, diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index b690abedad..002171ce7c 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -25,7 +25,12 @@ from twisted.internet.protocol import Factory from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol -from synapse.replication.tcp.streams import STREAMS_MAP, FederationStream, Stream +from synapse.replication.tcp.streams import ( + STREAMS_MAP, + CachesStream, + FederationStream, + Stream, +) from synapse.util.metrics import Measure stream_updates_counter = Counter( @@ -71,11 +76,16 @@ class ReplicationStreamer(object): self.store = hs.get_datastore() self.clock = hs.get_clock() self.notifier = hs.get_notifier() + self._instance_name = hs.get_instance_name() self._replication_torture_level = hs.config.replication_torture_level # Work out list of streams that this instance is the source of. self.streams = [] # type: List[Stream] + + # All workers can write to the cache invalidation stream. + self.streams.append(CachesStream(hs)) + if hs.config.worker_app is None: for stream in STREAMS_MAP.values(): if stream == FederationStream and hs.config.send_federation: @@ -83,6 +93,10 @@ class ReplicationStreamer(object): # has been disabled on the master. continue + if stream == CachesStream: + # We've already added it above. + continue + self.streams.append(stream(hs)) self.streams_by_name = {stream.NAME: stream for stream in self.streams} @@ -145,7 +159,9 @@ class ReplicationStreamer(object): random.shuffle(all_streams) for stream in all_streams: - if stream.last_token == stream.current_token(): + if stream.last_token == stream.current_token( + self._instance_name + ): continue if self._replication_torture_level: @@ -157,7 +173,7 @@ class ReplicationStreamer(object): "Getting stream: %s: %s -> %s", stream.NAME, stream.last_token, - stream.current_token(), + stream.current_token(self._instance_name), ) try: updates, current_token, limited = await stream.get_updates() diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 084604e8b0..b48a6a3e91 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -95,20 +95,25 @@ class Stream(object): def __init__( self, local_instance_name: str, - current_token_function: Callable[[], Token], + current_token_function: Callable[[str], Token], update_function: UpdateFunction, ): """Instantiate a Stream - current_token_function and update_function are callbacks which should be - implemented by subclasses. + `current_token_function` and `update_function` are callbacks which + should be implemented by subclasses. - current_token_function is called to get the current token of the underlying - stream. It is only meaningful on the process that is the source of the - replication stream (ie, usually the master). + `current_token_function` takes an instance name, which is a writer to + the stream, and returns the position in the stream of the writer (as + viewed from the current process). On the writer process this is where + the writer has successfully written up to, whereas on other processes + this is the position which we have received updates up to over + replication. (Note that most streams have a single writer and so their + implementations ignore the instance name passed in). - update_function is called to get updates for this stream between a pair of - stream tokens. See the UpdateFunction type definition for more info. + `update_function` is called to get updates for this stream between a + pair of stream tokens. See the `UpdateFunction` type definition for more + info. Args: local_instance_name: The instance name of the current process @@ -120,13 +125,13 @@ class Stream(object): self.update_function = update_function # The token from which we last asked for updates - self.last_token = self.current_token() + self.last_token = self.current_token(self.local_instance_name) def discard_updates_and_advance(self): """Called when the stream should advance but the updates would be discarded, e.g. when there are no currently connected workers. """ - self.last_token = self.current_token() + self.last_token = self.current_token(self.local_instance_name) async def get_updates(self) -> StreamUpdateResult: """Gets all updates since the last time this function was called (or @@ -138,7 +143,7 @@ class Stream(object): position in stream, and `limited` is whether there are more updates to fetch. """ - current_token = self.current_token() + current_token = self.current_token(self.local_instance_name) updates, current_token, limited = await self.get_updates_since( self.local_instance_name, self.last_token, current_token ) @@ -170,6 +175,16 @@ class Stream(object): return updates, upto_token, limited +def current_token_without_instance( + current_token: Callable[[], int] +) -> Callable[[str], int]: + """Takes a current token callback function for a single writer stream + that doesn't take an instance name parameter and wraps it in a function that + does accept an instance name parameter but ignores it. + """ + return lambda instance_name: current_token() + + def db_query_to_update_function( query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]] ) -> UpdateFunction: @@ -235,7 +250,7 @@ class BackfillStream(Stream): store = hs.get_datastore() super().__init__( hs.get_instance_name(), - store.get_current_backfill_token, + current_token_without_instance(store.get_current_backfill_token), db_query_to_update_function(store.get_all_new_backfill_event_rows), ) @@ -271,7 +286,9 @@ class PresenceStream(Stream): update_function = make_http_update_function(hs, self.NAME) super().__init__( - hs.get_instance_name(), store.get_current_presence_token, update_function + hs.get_instance_name(), + current_token_without_instance(store.get_current_presence_token), + update_function, ) @@ -296,7 +313,9 @@ class TypingStream(Stream): update_function = make_http_update_function(hs, self.NAME) super().__init__( - hs.get_instance_name(), typing_handler.get_current_token, update_function + hs.get_instance_name(), + current_token_without_instance(typing_handler.get_current_token), + update_function, ) @@ -319,7 +338,7 @@ class ReceiptsStream(Stream): store = hs.get_datastore() super().__init__( hs.get_instance_name(), - store.get_max_receipt_stream_id, + current_token_without_instance(store.get_max_receipt_stream_id), db_query_to_update_function(store.get_all_updated_receipts), ) @@ -339,7 +358,7 @@ class PushRulesStream(Stream): hs.get_instance_name(), self._current_token, self._update_function ) - def _current_token(self) -> int: + def _current_token(self, instance_name: str) -> int: push_rules_token, _ = self.store.get_push_rules_stream_token() return push_rules_token @@ -373,7 +392,7 @@ class PushersStream(Stream): super().__init__( hs.get_instance_name(), - store.get_pushers_stream_token, + current_token_without_instance(store.get_pushers_stream_token), db_query_to_update_function(store.get_all_updated_pushers_rows), ) @@ -402,12 +421,26 @@ class CachesStream(Stream): ROW_TYPE = CachesStreamRow def __init__(self, hs): - store = hs.get_datastore() + self.store = hs.get_datastore() super().__init__( hs.get_instance_name(), - store.get_cache_stream_token, - db_query_to_update_function(store.get_all_updated_caches), + self.store.get_cache_stream_token, + self._update_function, + ) + + async def _update_function( + self, instance_name: str, from_token: int, upto_token: int, limit: int + ): + rows = await self.store.get_all_updated_caches( + instance_name, from_token, upto_token, limit ) + updates = [(row[0], row[1:]) for row in rows] + limited = False + if len(updates) >= limit: + upto_token = updates[-1][0] + limited = True + + return updates, upto_token, limited class PublicRoomsStream(Stream): @@ -431,7 +464,7 @@ class PublicRoomsStream(Stream): store = hs.get_datastore() super().__init__( hs.get_instance_name(), - store.get_current_public_room_stream_id, + current_token_without_instance(store.get_current_public_room_stream_id), db_query_to_update_function(store.get_all_new_public_rooms), ) @@ -452,7 +485,7 @@ class DeviceListsStream(Stream): store = hs.get_datastore() super().__init__( hs.get_instance_name(), - store.get_device_stream_token, + current_token_without_instance(store.get_device_stream_token), db_query_to_update_function(store.get_all_device_list_changes_for_remotes), ) @@ -470,7 +503,7 @@ class ToDeviceStream(Stream): store = hs.get_datastore() super().__init__( hs.get_instance_name(), - store.get_to_device_stream_token, + current_token_without_instance(store.get_to_device_stream_token), db_query_to_update_function(store.get_all_new_device_messages), ) @@ -490,7 +523,7 @@ class TagAccountDataStream(Stream): store = hs.get_datastore() super().__init__( hs.get_instance_name(), - store.get_max_account_data_stream_id, + current_token_without_instance(store.get_max_account_data_stream_id), db_query_to_update_function(store.get_all_updated_tags), ) @@ -510,7 +543,7 @@ class AccountDataStream(Stream): self.store = hs.get_datastore() super().__init__( hs.get_instance_name(), - self.store.get_max_account_data_stream_id, + current_token_without_instance(self.store.get_max_account_data_stream_id), db_query_to_update_function(self._update_function), ) @@ -541,7 +574,7 @@ class GroupServerStream(Stream): store = hs.get_datastore() super().__init__( hs.get_instance_name(), - store.get_group_stream_token, + current_token_without_instance(store.get_group_stream_token), db_query_to_update_function(store.get_all_groups_changes), ) @@ -559,7 +592,7 @@ class UserSignatureStream(Stream): store = hs.get_datastore() super().__init__( hs.get_instance_name(), - store.get_device_stream_token, + current_token_without_instance(store.get_device_stream_token), db_query_to_update_function( store.get_all_user_signature_changes_for_remotes ), diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index 890e75d827..f370390331 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -20,7 +20,7 @@ from typing import List, Tuple, Type import attr -from ._base import Stream, StreamUpdateResult, Token +from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance """Handling of the 'events' replication stream @@ -119,7 +119,7 @@ class EventsStream(Stream): self._store = hs.get_datastore() super().__init__( hs.get_instance_name(), - self._store.get_current_events_token, + current_token_without_instance(self._store.get_current_events_token), self._update_function, ) diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py index b0505b8a2c..9bcd13b009 100644 --- a/synapse/replication/tcp/streams/federation.py +++ b/synapse/replication/tcp/streams/federation.py @@ -15,7 +15,11 @@ # limitations under the License. from collections import namedtuple -from synapse.replication.tcp.streams._base import Stream, make_http_update_function +from synapse.replication.tcp.streams._base import ( + Stream, + current_token_without_instance, + make_http_update_function, +) class FederationStream(Stream): @@ -41,7 +45,9 @@ class FederationStream(Stream): # will be a real FederationSender, which has stubs for current_token and # get_replication_rows.) federation_sender = hs.get_federation_sender() - current_token = federation_sender.get_current_token + current_token = current_token_without_instance( + federation_sender.get_current_token + ) update_function = federation_sender.get_replication_rows elif hs.should_send_federation(): @@ -58,7 +64,7 @@ class FederationStream(Stream): super().__init__(hs.get_instance_name(), current_token, update_function) @staticmethod - def _stub_current_token(): + def _stub_current_token(instance_name: str) -> int: # dummy current-token method for use on workers return 0 diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 13de5f1f62..59073c0a42 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -47,6 +47,9 @@ class SQLBaseStore(metaclass=ABCMeta): self.db = database self.rand = random.SystemRandom() + def process_replication_rows(self, stream_name, instance_name, token, rows): + pass + def _invalidate_state_caches(self, room_id, members_changed): """Invalidates caches that are based on the current state, but does not stream invalidations down replication. diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py index ceba10882c..cd2a1f0461 100644 --- a/synapse/storage/data_stores/main/__init__.py +++ b/synapse/storage/data_stores/main/__init__.py @@ -26,13 +26,14 @@ from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import ( ChainedIdGenerator, IdGenerator, + MultiWriterIdGenerator, StreamIdGenerator, ) from synapse.util.caches.stream_change_cache import StreamChangeCache from .account_data import AccountDataStore from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore -from .cache import CacheInvalidationStore +from .cache import CacheInvalidationWorkerStore from .client_ips import ClientIpStore from .deviceinbox import DeviceInboxStore from .devices import DeviceStore @@ -112,8 +113,8 @@ class DataStore( MonthlyActiveUsersStore, StatsStore, RelationsStore, - CacheInvalidationStore, UIAuthStore, + CacheInvalidationWorkerStore, ): def __init__(self, database: Database, db_conn, hs): self.hs = hs @@ -170,8 +171,14 @@ class DataStore( ) if isinstance(self.database_engine, PostgresEngine): - self._cache_id_gen = StreamIdGenerator( - db_conn, "cache_invalidation_stream", "stream_id" + self._cache_id_gen = MultiWriterIdGenerator( + db_conn, + database, + instance_name="master", + table="cache_invalidation_stream_by_instance", + instance_column="instance_name", + id_column="stream_id", + sequence_name="cache_invalidation_stream_seq", ) else: self._cache_id_gen = None diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py index 4dc5da3fe8..342a87a46b 100644 --- a/synapse/storage/data_stores/main/cache.py +++ b/synapse/storage/data_stores/main/cache.py @@ -16,11 +16,10 @@ import itertools import logging -from typing import Any, Iterable, Optional, Tuple - -from twisted.internet import defer +from typing import Any, Iterable, Optional from synapse.storage._base import SQLBaseStore +from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine from synapse.util.iterutils import batch_iter @@ -33,47 +32,58 @@ CURRENT_STATE_CACHE_NAME = "cs_cache_fake" class CacheInvalidationWorkerStore(SQLBaseStore): - def get_all_updated_caches(self, last_id, current_id, limit): + def __init__(self, database: Database, db_conn, hs): + super().__init__(database, db_conn, hs) + + self._instance_name = hs.get_instance_name() + + async def get_all_updated_caches( + self, instance_name: str, last_id: int, current_id: int, limit: int + ): + """Fetches cache invalidation rows between the two given IDs written + by the given instance. Returns at most `limit` rows. + """ + if last_id == current_id: - return defer.succeed([]) + return [] def get_all_updated_caches_txn(txn): # We purposefully don't bound by the current token, as we want to # send across cache invalidations as quickly as possible. Cache # invalidations are idempotent, so duplicates are fine. - sql = ( - "SELECT stream_id, cache_func, keys, invalidation_ts" - " FROM cache_invalidation_stream" - " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?" - ) - txn.execute(sql, (last_id, limit)) + sql = """ + SELECT stream_id, cache_func, keys, invalidation_ts + FROM cache_invalidation_stream_by_instance + WHERE stream_id > ? AND instance_name = ? + ORDER BY stream_id ASC + LIMIT ? + """ + txn.execute(sql, (last_id, instance_name, limit)) return txn.fetchall() - return self.db.runInteraction( + return await self.db.runInteraction( "get_all_updated_caches", get_all_updated_caches_txn ) + def process_replication_rows(self, stream_name, instance_name, token, rows): + if stream_name == "caches": + if self._cache_id_gen: + self._cache_id_gen.advance(instance_name, token) -class CacheInvalidationStore(CacheInvalidationWorkerStore): - async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]): - """Invalidates the cache and adds it to the cache stream so slaves - will know to invalidate their caches. + for row in rows: + if row.cache_func == CURRENT_STATE_CACHE_NAME: + if row.keys is None: + raise Exception( + "Can't send an 'invalidate all' for current state cache" + ) - This should only be used to invalidate caches where slaves won't - otherwise know from other replication streams that the cache should - be invalidated. - """ - cache_func = getattr(self, cache_name, None) - if not cache_func: - return - - cache_func.invalidate(keys) - await self.runInteraction( - "invalidate_cache_and_stream", - self._send_invalidation_to_replication, - cache_func.__name__, - keys, - ) + room_id = row.keys[0] + members_changed = set(row.keys[1:]) + self._invalidate_state_caches(room_id, members_changed) + else: + self._attempt_to_invalidate_cache(row.cache_func, row.keys) + + super().process_replication_rows(stream_name, instance_name, token, rows) def _invalidate_cache_and_stream(self, txn, cache_func, keys): """Invalidates the cache and adds it to the cache stream so slaves @@ -147,10 +157,7 @@ class CacheInvalidationStore(CacheInvalidationWorkerStore): # the transaction. However, we want to only get an ID when we want # to use it, here, so we need to call __enter__ manually, and have # __exit__ called after the transaction finishes. - ctx = self._cache_id_gen.get_next() - stream_id = ctx.__enter__() - txn.call_on_exception(ctx.__exit__, None, None, None) - txn.call_after(ctx.__exit__, None, None, None) + stream_id = self._cache_id_gen.get_next_txn(txn) txn.call_after(self.hs.get_notifier().on_new_replication_data) if keys is not None: @@ -158,17 +165,18 @@ class CacheInvalidationStore(CacheInvalidationWorkerStore): self.db.simple_insert_txn( txn, - table="cache_invalidation_stream", + table="cache_invalidation_stream_by_instance", values={ "stream_id": stream_id, + "instance_name": self._instance_name, "cache_func": cache_name, "keys": keys, "invalidation_ts": self.clock.time_msec(), }, ) - def get_cache_stream_token(self): + def get_cache_stream_token(self, instance_name): if self._cache_id_gen: - return self._cache_id_gen.get_current_token() + return self._cache_id_gen.get_current_token(instance_name) else: return 0 diff --git a/synapse/storage/data_stores/main/schema/delta/58/05cache_instance.sql.postgres b/synapse/storage/data_stores/main/schema/delta/58/05cache_instance.sql.postgres new file mode 100644 index 0000000000..aa46eb0e10 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/58/05cache_instance.sql.postgres @@ -0,0 +1,30 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- We keep the old table here to enable us to roll back. It doesn't matter +-- that we have dropped all the data here. +TRUNCATE cache_invalidation_stream; + +CREATE TABLE cache_invalidation_stream_by_instance ( + stream_id BIGINT NOT NULL, + instance_name TEXT NOT NULL, + cache_func TEXT NOT NULL, + keys TEXT[], + invalidation_ts BIGINT +); + +CREATE UNIQUE INDEX cache_invalidation_stream_by_instance_id ON cache_invalidation_stream_by_instance(stream_id); + +CREATE SEQUENCE cache_invalidation_stream_seq; diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 1712932f31..640f242584 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -29,6 +29,8 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. +# XXX: If you're about to bump this to 59 (or higher) please create an update +# that drops the unused `cache_invalidation_stream` table, as per #7436! SCHEMA_VERSION = 58 dir_path = os.path.abspath(os.path.dirname(__file__)) -- cgit 1.5.1 From a4a5ec4096f8de938f4a6e4264aeaaa0e0b26463 Mon Sep 17 00:00:00 2001 From: Manuel Stahl <37705355+awesome-manuel@users.noreply.github.com> Date: Thu, 7 May 2020 21:33:07 +0200 Subject: Add room details admin endpoint (#7317) --- changelog.d/7317.feature | 1 + docs/admin_api/rooms.md | 54 ++++++++++++++++++++++++++++++++ synapse/rest/admin/__init__.py | 2 ++ synapse/rest/admin/rooms.py | 26 ++++++++++++++- synapse/storage/data_stores/main/room.py | 31 ++++++++++++++++++ tests/rest/admin/test_room.py | 41 ++++++++++++++++++++++++ tests/storage/test_room.py | 11 +++++++ 7 files changed, 165 insertions(+), 1 deletion(-) create mode 100644 changelog.d/7317.feature (limited to 'synapse/storage') diff --git a/changelog.d/7317.feature b/changelog.d/7317.feature new file mode 100644 index 0000000000..23c063f280 --- /dev/null +++ b/changelog.d/7317.feature @@ -0,0 +1 @@ +Add room details admin endpoint. Contributed by Awesome Technologies Innovationslabor GmbH. diff --git a/docs/admin_api/rooms.md b/docs/admin_api/rooms.md index 26fe8b8679..624e7745ba 100644 --- a/docs/admin_api/rooms.md +++ b/docs/admin_api/rooms.md @@ -264,3 +264,57 @@ Response: Once the `next_token` parameter is no longer present, we know we've reached the end of the list. + +# DRAFT: Room Details API + +The Room Details admin API allows server admins to get all details of a room. + +This API is still a draft and details might change! + +The following fields are possible in the JSON response body: + +* `room_id` - The ID of the room. +* `name` - The name of the room. +* `canonical_alias` - The canonical (main) alias address of the room. +* `joined_members` - How many users are currently in the room. +* `joined_local_members` - How many local users are currently in the room. +* `version` - The version of the room as a string. +* `creator` - The `user_id` of the room creator. +* `encryption` - Algorithm of end-to-end encryption of messages. Is `null` if encryption is not active. +* `federatable` - Whether users on other servers can join this room. +* `public` - Whether the room is visible in room directory. +* `join_rules` - The type of rules used for users wishing to join this room. One of: ["public", "knock", "invite", "private"]. +* `guest_access` - Whether guests can join the room. One of: ["can_join", "forbidden"]. +* `history_visibility` - Who can see the room history. One of: ["invited", "joined", "shared", "world_readable"]. +* `state_events` - Total number of state_events of a room. Complexity of the room. + +## Usage + +A standard request: + +``` +GET /_synapse/admin/v1/rooms/ + +{} +``` + +Response: + +``` +{ + "room_id": "!mscvqgqpHYjBGDxNym:matrix.org", + "name": "Music Theory", + "canonical_alias": "#musictheory:matrix.org", + "joined_members": 127 + "joined_local_members": 2, + "version": "1", + "creator": "@foo:matrix.org", + "encryption": null, + "federatable": true, + "public": true, + "join_rules": "invite", + "guest_access": null, + "history_visibility": "shared", + "state_events": 93534 +} +``` diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index ed70d448a1..6b85148a32 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -32,6 +32,7 @@ from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet from synapse.rest.admin.rooms import ( JoinRoomAliasServlet, ListRoomRestServlet, + RoomRestServlet, ShutdownRoomRestServlet, ) from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet @@ -193,6 +194,7 @@ def register_servlets(hs, http_server): """ register_servlets_for_client_rest_resource(hs, http_server) ListRoomRestServlet(hs).register(http_server) + RoomRestServlet(hs).register(http_server) JoinRoomAliasServlet(hs).register(http_server) PurgeRoomServlet(hs).register(http_server) SendServerNoticeServlet(hs).register(http_server) diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index d1bdb64111..7d40001988 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -26,6 +26,7 @@ from synapse.http.servlet import ( ) from synapse.rest.admin._base import ( admin_patterns, + assert_requester_is_admin, assert_user_is_admin, historical_admin_path_patterns, ) @@ -169,7 +170,7 @@ class ListRoomRestServlet(RestServlet): in a dictionary containing room information. Supports pagination. """ - PATTERNS = admin_patterns("/rooms") + PATTERNS = admin_patterns("/rooms$") def __init__(self, hs): self.store = hs.get_datastore() @@ -253,6 +254,29 @@ class ListRoomRestServlet(RestServlet): return 200, response +class RoomRestServlet(RestServlet): + """Get room details. + + TODO: Add on_POST to allow room creation without joining the room + """ + + PATTERNS = admin_patterns("/rooms/(?P[^/]+)$") + + def __init__(self, hs): + self.hs = hs + self.auth = hs.get_auth() + self.store = hs.get_datastore() + + async def on_GET(self, request, room_id): + await assert_requester_is_admin(self.auth, request) + + ret = await self.store.get_room_with_stats(room_id) + if not ret: + raise NotFoundError("Room not found") + + return 200, ret + + class JoinRoomAliasServlet(RestServlet): PATTERNS = admin_patterns("/join/(?P[^/]*)") diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py index 147eba1df7..cafa664c16 100644 --- a/synapse/storage/data_stores/main/room.py +++ b/synapse/storage/data_stores/main/room.py @@ -98,6 +98,37 @@ class RoomWorkerStore(SQLBaseStore): allow_none=True, ) + def get_room_with_stats(self, room_id: str): + """Retrieve room with statistics. + + Args: + room_id: The ID of the room to retrieve. + Returns: + A dict containing the room information, or None if the room is unknown. + """ + + def get_room_with_stats_txn(txn, room_id): + sql = """ + SELECT room_id, state.name, state.canonical_alias, curr.joined_members, + curr.local_users_in_room AS joined_local_members, rooms.room_version AS version, + rooms.creator, state.encryption, state.is_federatable AS federatable, + rooms.is_public AS public, state.join_rules, state.guest_access, + state.history_visibility, curr.current_state_events AS state_events + FROM rooms + LEFT JOIN room_stats_state state USING (room_id) + LEFT JOIN room_stats_current curr USING (room_id) + WHERE room_id = ? + """ + txn.execute(sql, [room_id]) + res = self.db.cursor_to_dict(txn)[0] + res["federatable"] = bool(res["federatable"]) + res["public"] = bool(res["public"]) + return res + + return self.db.runInteraction( + "get_room_with_stats", get_room_with_stats_txn, room_id + ) + def get_public_room_ids(self): return self.db.simple_select_onecol( table="rooms", diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 249c93722f..54cd24bf64 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -701,6 +701,47 @@ class RoomTestCase(unittest.HomeserverTestCase): _search_test(None, "bar") _search_test(None, "", expected_http_code=400) + def test_single_room(self): + """Test that a single room can be requested correctly""" + # Create two test rooms + room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + + room_name_1 = "something" + room_name_2 = "else" + + # Set the name for each room + self.helper.send_state( + room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok, + ) + self.helper.send_state( + room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok, + ) + + url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,) + request, channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + + self.assertIn("room_id", channel.json_body) + self.assertIn("name", channel.json_body) + self.assertIn("canonical_alias", channel.json_body) + self.assertIn("joined_members", channel.json_body) + self.assertIn("joined_local_members", channel.json_body) + self.assertIn("version", channel.json_body) + self.assertIn("creator", channel.json_body) + self.assertIn("encryption", channel.json_body) + self.assertIn("federatable", channel.json_body) + self.assertIn("public", channel.json_body) + self.assertIn("join_rules", channel.json_body) + self.assertIn("guest_access", channel.json_body) + self.assertIn("history_visibility", channel.json_body) + self.assertIn("state_events", channel.json_body) + + self.assertEqual(room_id_1, channel.json_body["room_id"]) + class JoinAliasRoomTestCase(unittest.HomeserverTestCase): diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index 086adeb8fd..3b78d48896 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -55,6 +55,17 @@ class RoomStoreTestCase(unittest.TestCase): (yield self.store.get_room(self.room.to_string())), ) + @defer.inlineCallbacks + def test_get_room_with_stats(self): + self.assertDictContainsSubset( + { + "room_id": self.room.to_string(), + "creator": self.u_creator.to_string(), + "public": True, + }, + (yield self.store.get_room_with_stats(self.room.to_string())), + ) + class RoomEventsStoreTestCase(unittest.TestCase): @defer.inlineCallbacks -- cgit 1.5.1 From a8580c5f1986e8d2f4bf9cf64190ac5e3d84aaf6 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Mon, 11 May 2020 16:55:57 +0100 Subject: Remove unused store method get_hosts_in_room (#7448) --- changelog.d/7448.misc | 1 + synapse/storage/data_stores/main/roommember.py | 10 ---------- 2 files changed, 1 insertion(+), 10 deletions(-) create mode 100644 changelog.d/7448.misc (limited to 'synapse/storage') diff --git a/changelog.d/7448.misc b/changelog.d/7448.misc new file mode 100644 index 0000000000..2163cc9b97 --- /dev/null +++ b/changelog.d/7448.misc @@ -0,0 +1 @@ +Remove storage method `get_hosts_in_room` that is no longer called anywhere. \ No newline at end of file diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py index d5bd0cb5cf..ae9a76713c 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py @@ -153,16 +153,6 @@ class RoomMemberWorkerStore(EventsWorkerStore): self._check_safe_current_state_events_membership_updated_txn, ) - @cachedInlineCallbacks(max_entries=100000, iterable=True, cache_context=True) - def get_hosts_in_room(self, room_id, cache_context): - """Returns the set of all hosts currently in the room - """ - user_ids = yield self.get_users_in_room( - room_id, on_invalidate=cache_context.invalidate - ) - hosts = frozenset(get_domain_from_id(user_id) for user_id in user_ids) - return hosts - @cached(max_entries=100000, iterable=True) def get_users_in_room(self, room_id): return self.db.runInteraction( -- cgit 1.5.1 From 7cb8b4bc67042a39bd1b0e05df46089a2fce1955 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Tue, 12 May 2020 03:45:23 +1000 Subject: Allow configuration of Synapse's cache without using synctl or environment variables (#6391) --- changelog.d/6391.feature | 1 + docs/sample_config.yaml | 43 +++++- synapse/api/auth.py | 4 +- synapse/app/homeserver.py | 5 +- synapse/config/cache.py | 164 ++++++++++++++++++++++ synapse/config/database.py | 6 - synapse/config/homeserver.py | 2 + synapse/http/client.py | 6 +- synapse/metrics/_exposition.py | 12 +- synapse/push/bulk_push_rule_evaluator.py | 4 +- synapse/push/push_rule_evaluator.py | 4 +- synapse/replication/slave/storage/client_ips.py | 3 +- synapse/state/__init__.py | 4 +- synapse/storage/data_stores/main/client_ips.py | 3 +- synapse/storage/data_stores/main/events_worker.py | 5 +- synapse/storage/data_stores/state/store.py | 6 +- synapse/util/caches/__init__.py | 144 ++++++++++--------- synapse/util/caches/descriptors.py | 36 ++++- synapse/util/caches/expiringcache.py | 29 +++- synapse/util/caches/lrucache.py | 52 +++++-- synapse/util/caches/response_cache.py | 2 +- synapse/util/caches/stream_change_cache.py | 33 ++++- synapse/util/caches/ttlcache.py | 2 +- tests/config/test_cache.py | 127 +++++++++++++++++ tests/storage/test__base.py | 8 +- tests/storage/test_appservice.py | 10 +- tests/storage/test_base.py | 3 +- tests/test_metrics.py | 34 +++++ tests/util/test_expiring_cache.py | 2 +- tests/util/test_lrucache.py | 6 +- tests/util/test_stream_change_cache.py | 5 +- tests/utils.py | 1 + 32 files changed, 620 insertions(+), 146 deletions(-) create mode 100644 changelog.d/6391.feature create mode 100644 synapse/config/cache.py create mode 100644 tests/config/test_cache.py (limited to 'synapse/storage') diff --git a/changelog.d/6391.feature b/changelog.d/6391.feature new file mode 100644 index 0000000000..f123426e23 --- /dev/null +++ b/changelog.d/6391.feature @@ -0,0 +1 @@ +Synapse's cache factor can now be configured in `homeserver.yaml` by the `caches.global_factor` setting. Additionally, `caches.per_cache_factors` controls the cache factors for individual caches. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 5abeaf519b..8a8415b9a2 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -603,6 +603,45 @@ acme: +## Caching ## + +# Caching can be configured through the following options. +# +# A cache 'factor' is a multiplier that can be applied to each of +# Synapse's caches in order to increase or decrease the maximum +# number of entries that can be stored. + +# The number of events to cache in memory. Not affected by +# caches.global_factor. +# +#event_cache_size: 10K + +caches: + # Controls the global cache factor, which is the default cache factor + # for all caches if a specific factor for that cache is not otherwise + # set. + # + # This can also be set by the "SYNAPSE_CACHE_FACTOR" environment + # variable. Setting by environment variable takes priority over + # setting through the config file. + # + # Defaults to 0.5, which will half the size of all caches. + # + #global_factor: 1.0 + + # A dictionary of cache name to cache factor for that individual + # cache. Overrides the global cache factor for a given cache. + # + # These can also be set through environment variables comprised + # of "SYNAPSE_CACHE_FACTOR_" + the name of the cache in capital + # letters and underscores. Setting by environment variable + # takes priority over setting through the config file. + # Ex. SYNAPSE_CACHE_FACTOR_GET_USERS_WHO_SHARE_ROOM_WITH_USER=2.0 + # + per_cache_factors: + #get_users_who_share_room_with_user: 2.0 + + ## Database ## # The 'database' setting defines the database that synapse uses to store all of @@ -646,10 +685,6 @@ database: args: database: DATADIR/homeserver.db -# Number of events to cache in memory. -# -#event_cache_size: 10K - ## Logging ## diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 1ad5ff9410..e009b1a760 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -37,7 +37,7 @@ from synapse.api.errors import ( from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase from synapse.types import StateMap, UserID -from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache +from synapse.util.caches import register_cache from synapse.util.caches.lrucache import LruCache from synapse.util.metrics import Measure @@ -73,7 +73,7 @@ class Auth(object): self.store = hs.get_datastore() self.state = hs.get_state_handler() - self.token_cache = LruCache(CACHE_SIZE_FACTOR * 10000) + self.token_cache = LruCache(10000) register_cache("cache", "token_cache", self.token_cache) self._auth_blocking = AuthBlocking(self.hs) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index bc8695d8dd..d7f337e586 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -69,7 +69,6 @@ from synapse.server import HomeServer from synapse.storage import DataStore from synapse.storage.engines import IncorrectDatabaseSetup from synapse.storage.prepare_database import UpgradeDatabaseException -from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole from synapse.util.module_loader import load_module @@ -516,8 +515,8 @@ def phone_stats_home(hs, stats, stats_process=_stats_process): daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages() stats["daily_sent_messages"] = daily_sent_messages - stats["cache_factor"] = CACHE_SIZE_FACTOR - stats["event_cache_size"] = hs.config.event_cache_size + stats["cache_factor"] = hs.config.caches.global_factor + stats["event_cache_size"] = hs.config.caches.event_cache_size # # Performance statistics diff --git a/synapse/config/cache.py b/synapse/config/cache.py new file mode 100644 index 0000000000..91036a012e --- /dev/null +++ b/synapse/config/cache.py @@ -0,0 +1,164 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Callable, Dict + +from ._base import Config, ConfigError + +# The prefix for all cache factor-related environment variables +_CACHES = {} +_CACHE_PREFIX = "SYNAPSE_CACHE_FACTOR" +_DEFAULT_FACTOR_SIZE = 0.5 +_DEFAULT_EVENT_CACHE_SIZE = "10K" + + +class CacheProperties(object): + def __init__(self): + # The default factor size for all caches + self.default_factor_size = float( + os.environ.get(_CACHE_PREFIX, _DEFAULT_FACTOR_SIZE) + ) + self.resize_all_caches_func = None + + +properties = CacheProperties() + + +def add_resizable_cache(cache_name: str, cache_resize_callback: Callable): + """Register a cache that's size can dynamically change + + Args: + cache_name: A reference to the cache + cache_resize_callback: A callback function that will be ran whenever + the cache needs to be resized + """ + _CACHES[cache_name.lower()] = cache_resize_callback + + # Ensure all loaded caches are sized appropriately + # + # This method should only run once the config has been read, + # as it uses values read from it + if properties.resize_all_caches_func: + properties.resize_all_caches_func() + + +class CacheConfig(Config): + section = "caches" + _environ = os.environ + + @staticmethod + def reset(): + """Resets the caches to their defaults. Used for tests.""" + properties.default_factor_size = float( + os.environ.get(_CACHE_PREFIX, _DEFAULT_FACTOR_SIZE) + ) + properties.resize_all_caches_func = None + _CACHES.clear() + + def generate_config_section(self, **kwargs): + return """\ + ## Caching ## + + # Caching can be configured through the following options. + # + # A cache 'factor' is a multiplier that can be applied to each of + # Synapse's caches in order to increase or decrease the maximum + # number of entries that can be stored. + + # The number of events to cache in memory. Not affected by + # caches.global_factor. + # + #event_cache_size: 10K + + caches: + # Controls the global cache factor, which is the default cache factor + # for all caches if a specific factor for that cache is not otherwise + # set. + # + # This can also be set by the "SYNAPSE_CACHE_FACTOR" environment + # variable. Setting by environment variable takes priority over + # setting through the config file. + # + # Defaults to 0.5, which will half the size of all caches. + # + #global_factor: 1.0 + + # A dictionary of cache name to cache factor for that individual + # cache. Overrides the global cache factor for a given cache. + # + # These can also be set through environment variables comprised + # of "SYNAPSE_CACHE_FACTOR_" + the name of the cache in capital + # letters and underscores. Setting by environment variable + # takes priority over setting through the config file. + # Ex. SYNAPSE_CACHE_FACTOR_GET_USERS_WHO_SHARE_ROOM_WITH_USER=2.0 + # + per_cache_factors: + #get_users_who_share_room_with_user: 2.0 + """ + + def read_config(self, config, **kwargs): + self.event_cache_size = self.parse_size( + config.get("event_cache_size", _DEFAULT_EVENT_CACHE_SIZE) + ) + self.cache_factors = {} # type: Dict[str, float] + + cache_config = config.get("caches") or {} + self.global_factor = cache_config.get( + "global_factor", properties.default_factor_size + ) + if not isinstance(self.global_factor, (int, float)): + raise ConfigError("caches.global_factor must be a number.") + + # Set the global one so that it's reflected in new caches + properties.default_factor_size = self.global_factor + + # Load cache factors from the config + individual_factors = cache_config.get("per_cache_factors") or {} + if not isinstance(individual_factors, dict): + raise ConfigError("caches.per_cache_factors must be a dictionary") + + # Override factors from environment if necessary + individual_factors.update( + { + key[len(_CACHE_PREFIX) + 1 :].lower(): float(val) + for key, val in self._environ.items() + if key.startswith(_CACHE_PREFIX + "_") + } + ) + + for cache, factor in individual_factors.items(): + if not isinstance(factor, (int, float)): + raise ConfigError( + "caches.per_cache_factors.%s must be a number" % (cache.lower(),) + ) + self.cache_factors[cache.lower()] = factor + + # Resize all caches (if necessary) with the new factors we've loaded + self.resize_all_caches() + + # Store this function so that it can be called from other classes without + # needing an instance of Config + properties.resize_all_caches_func = self.resize_all_caches + + def resize_all_caches(self): + """Ensure all cache sizes are up to date + + For each cache, run the mapped callback function with either + a specific cache factor or the default, global one. + """ + for cache_name, callback in _CACHES.items(): + new_factor = self.cache_factors.get(cache_name, self.global_factor) + callback(new_factor) diff --git a/synapse/config/database.py b/synapse/config/database.py index 5b662d1b01..1064c2697b 100644 --- a/synapse/config/database.py +++ b/synapse/config/database.py @@ -68,10 +68,6 @@ database: name: sqlite3 args: database: %(database_path)s - -# Number of events to cache in memory. -# -#event_cache_size: 10K """ @@ -116,8 +112,6 @@ class DatabaseConfig(Config): self.databases = [] def read_config(self, config, **kwargs): - self.event_cache_size = self.parse_size(config.get("event_cache_size", "10K")) - # We *experimentally* support specifying multiple databases via the # `databases` key. This is a map from a label to database config in the # same format as the `database` config option, plus an extra diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index 996d3e6bf7..2c7b3a699f 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -17,6 +17,7 @@ from ._base import RootConfig from .api import ApiConfig from .appservice import AppServiceConfig +from .cache import CacheConfig from .captcha import CaptchaConfig from .cas import CasConfig from .consent_config import ConsentConfig @@ -55,6 +56,7 @@ class HomeServerConfig(RootConfig): config_classes = [ ServerConfig, TlsConfig, + CacheConfig, DatabaseConfig, LoggingConfig, RatelimitConfig, diff --git a/synapse/http/client.py b/synapse/http/client.py index 58eb47c69c..3cef747a4d 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -49,7 +49,6 @@ from synapse.http.proxyagent import ProxyAgent from synapse.logging.context import make_deferred_yieldable from synapse.logging.opentracing import set_tag, start_active_span, tags from synapse.util.async_helpers import timeout_deferred -from synapse.util.caches import CACHE_SIZE_FACTOR logger = logging.getLogger(__name__) @@ -241,7 +240,10 @@ class SimpleHttpClient(object): # tends to do so in batches, so we need to allow the pool to keep # lots of idle connections around. pool = HTTPConnectionPool(self.reactor) - pool.maxPersistentPerHost = max((100 * CACHE_SIZE_FACTOR, 5)) + # XXX: The justification for using the cache factor here is that larger instances + # will need both more cache and more connections. + # Still, this should probably be a separate dial + pool.maxPersistentPerHost = max((100 * hs.config.caches.global_factor, 5)) pool.cachedConnectionTimeout = 2 * 60 self.agent = ProxyAgent( diff --git a/synapse/metrics/_exposition.py b/synapse/metrics/_exposition.py index a248103191..ab7f948ed4 100644 --- a/synapse/metrics/_exposition.py +++ b/synapse/metrics/_exposition.py @@ -33,6 +33,8 @@ from prometheus_client import REGISTRY from twisted.web.resource import Resource +from synapse.util import caches + try: from prometheus_client.samples import Sample except ImportError: @@ -103,13 +105,15 @@ def nameify_sample(sample): def generate_latest(registry, emit_help=False): - output = [] - for metric in registry.collect(): + # Trigger the cache metrics to be rescraped, which updates the common + # metrics but do not produce metrics themselves + for collector in caches.collectors_by_name.values(): + collector.collect() - if metric.name.startswith("__unused"): - continue + output = [] + for metric in registry.collect(): if not metric.samples: # No samples, don't bother. continue diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 433ca2f416..e75d964ac8 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -51,6 +51,7 @@ push_rules_delta_state_cache_metric = register_cache( "cache", "push_rules_delta_state_cache_metric", cache=[], # Meaningless size, as this isn't a cache that stores values + resizable=False, ) @@ -67,7 +68,8 @@ class BulkPushRuleEvaluator(object): self.room_push_rule_cache_metrics = register_cache( "cache", "room_push_rule_cache", - cache=[], # Meaningless size, as this isn't a cache that stores values + cache=[], # Meaningless size, as this isn't a cache that stores values, + resizable=False, ) @defer.inlineCallbacks diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index 4cd702b5fa..11032491af 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -22,7 +22,7 @@ from six import string_types from synapse.events import EventBase from synapse.types import UserID -from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache +from synapse.util.caches import register_cache from synapse.util.caches.lrucache import LruCache logger = logging.getLogger(__name__) @@ -165,7 +165,7 @@ class PushRuleEvaluatorForEvent(object): # Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches -regex_cache = LruCache(50000 * CACHE_SIZE_FACTOR) +regex_cache = LruCache(50000) register_cache("cache", "regex_push_cache", regex_cache) diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py index fbf996e33a..1a38f53dfb 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py @@ -15,7 +15,6 @@ from synapse.storage.data_stores.main.client_ips import LAST_SEEN_GRANULARITY from synapse.storage.database import Database -from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.caches.descriptors import Cache from ._base import BaseSlavedStore @@ -26,7 +25,7 @@ class SlavedClientIpStore(BaseSlavedStore): super(SlavedClientIpStore, self).__init__(database, db_conn, hs) self.client_ip_last_seen = Cache( - name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR + name="client_ip_last_seen", keylen=4, max_entries=50000 ) def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id): diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 4afefc6b1d..2fa529fcd0 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -35,7 +35,6 @@ from synapse.state import v1, v2 from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour from synapse.types import StateMap from synapse.util.async_helpers import Linearizer -from synapse.util.caches import get_cache_factor_for from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.metrics import Measure, measure_func @@ -53,7 +52,6 @@ state_groups_histogram = Histogram( KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key")) -SIZE_OF_CACHE = 100000 * get_cache_factor_for("state_cache") EVICTION_TIMEOUT_SECONDS = 60 * 60 @@ -447,7 +445,7 @@ class StateResolutionHandler(object): self._state_cache = ExpiringCache( cache_name="state_cache", clock=self.clock, - max_len=SIZE_OF_CACHE, + max_len=100000, expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000, iterable=True, reset_expiry_on_get=True, diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py index 92bc06919b..71f8d43a76 100644 --- a/synapse/storage/data_stores/main/client_ips.py +++ b/synapse/storage/data_stores/main/client_ips.py @@ -22,7 +22,6 @@ from twisted.internet import defer from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore from synapse.storage.database import Database, make_tuple_comparison_clause -from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.caches.descriptors import Cache logger = logging.getLogger(__name__) @@ -361,7 +360,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): def __init__(self, database: Database, db_conn, hs): self.client_ip_last_seen = Cache( - name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR + name="client_ip_last_seen", keylen=4, max_entries=50000 ) super(ClientIpStore, self).__init__(database, db_conn, hs) diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index 73df6b33ba..b8c1bbdf99 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -75,7 +75,10 @@ class EventsWorkerStore(SQLBaseStore): super(EventsWorkerStore, self).__init__(database, db_conn, hs) self._get_event_cache = Cache( - "*getEvent*", keylen=3, max_entries=hs.config.event_cache_size + "*getEvent*", + keylen=3, + max_entries=hs.config.caches.event_cache_size, + apply_cache_factor_from_config=False, ) self._event_fetch_lock = threading.Condition() diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/data_stores/state/store.py index 57a5267663..f3ad1e4369 100644 --- a/synapse/storage/data_stores/state/store.py +++ b/synapse/storage/data_stores/state/store.py @@ -28,7 +28,6 @@ from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateSt from synapse.storage.database import Database from synapse.storage.state import StateFilter from synapse.types import StateMap -from synapse.util.caches import get_cache_factor_for from synapse.util.caches.descriptors import cached from synapse.util.caches.dictionary_cache import DictionaryCache @@ -90,11 +89,10 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): self._state_group_cache = DictionaryCache( "*stateGroupCache*", # TODO: this hasn't been tuned yet - 50000 * get_cache_factor_for("stateGroupCache"), + 50000, ) self._state_group_members_cache = DictionaryCache( - "*stateGroupMembersCache*", - 500000 * get_cache_factor_for("stateGroupMembersCache"), + "*stateGroupMembersCache*", 500000, ) @cached(max_entries=10000, iterable=True) diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index da5077b471..4b8a0c7a8f 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2019, 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,27 +15,17 @@ # limitations under the License. import logging -import os -from typing import Dict +from typing import Callable, Dict, Optional import six from six.moves import intern -from prometheus_client.core import REGISTRY, Gauge, GaugeMetricFamily - -logger = logging.getLogger(__name__) - -CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.5)) +import attr +from prometheus_client.core import Gauge +from synapse.config.cache import add_resizable_cache -def get_cache_factor_for(cache_name): - env_var = "SYNAPSE_CACHE_FACTOR_" + cache_name.upper() - factor = os.environ.get(env_var) - if factor: - return float(factor) - - return CACHE_SIZE_FACTOR - +logger = logging.getLogger(__name__) caches_by_name = {} collectors_by_name = {} # type: Dict @@ -44,6 +34,7 @@ cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"]) cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"]) cache_evicted = Gauge("synapse_util_caches_cache:evicted_size", "", ["name"]) cache_total = Gauge("synapse_util_caches_cache:total", "", ["name"]) +cache_max_size = Gauge("synapse_util_caches_cache_max_size", "", ["name"]) response_cache_size = Gauge("synapse_util_caches_response_cache:size", "", ["name"]) response_cache_hits = Gauge("synapse_util_caches_response_cache:hits", "", ["name"]) @@ -53,67 +44,82 @@ response_cache_evicted = Gauge( response_cache_total = Gauge("synapse_util_caches_response_cache:total", "", ["name"]) -def register_cache(cache_type, cache_name, cache, collect_callback=None): - """Register a cache object for metric collection. +@attr.s +class CacheMetric(object): + + _cache = attr.ib() + _cache_type = attr.ib(type=str) + _cache_name = attr.ib(type=str) + _collect_callback = attr.ib(type=Optional[Callable]) + + hits = attr.ib(default=0) + misses = attr.ib(default=0) + evicted_size = attr.ib(default=0) + + def inc_hits(self): + self.hits += 1 + + def inc_misses(self): + self.misses += 1 + + def inc_evictions(self, size=1): + self.evicted_size += size + + def describe(self): + return [] + + def collect(self): + try: + if self._cache_type == "response_cache": + response_cache_size.labels(self._cache_name).set(len(self._cache)) + response_cache_hits.labels(self._cache_name).set(self.hits) + response_cache_evicted.labels(self._cache_name).set(self.evicted_size) + response_cache_total.labels(self._cache_name).set( + self.hits + self.misses + ) + else: + cache_size.labels(self._cache_name).set(len(self._cache)) + cache_hits.labels(self._cache_name).set(self.hits) + cache_evicted.labels(self._cache_name).set(self.evicted_size) + cache_total.labels(self._cache_name).set(self.hits + self.misses) + if getattr(self._cache, "max_size", None): + cache_max_size.labels(self._cache_name).set(self._cache.max_size) + if self._collect_callback: + self._collect_callback() + except Exception as e: + logger.warning("Error calculating metrics for %s: %s", self._cache_name, e) + raise + + +def register_cache( + cache_type: str, + cache_name: str, + cache, + collect_callback: Optional[Callable] = None, + resizable: bool = True, + resize_callback: Optional[Callable] = None, +) -> CacheMetric: + """Register a cache object for metric collection and resizing. Args: - cache_type (str): - cache_name (str): name of the cache - cache (object): cache itself - collect_callback (callable|None): if not None, a function which is called during - metric collection to update additional metrics. + cache_type + cache_name: name of the cache + cache: cache itself + collect_callback: If given, a function which is called during metric + collection to update additional metrics. + resizable: Whether this cache supports being resized. + resize_callback: A function which can be called to resize the cache. Returns: CacheMetric: an object which provides inc_{hits,misses,evictions} methods """ + if resizable: + if not resize_callback: + resize_callback = getattr(cache, "set_cache_factor") + add_resizable_cache(cache_name, resize_callback) - # Check if the metric is already registered. Unregister it, if so. - # This usually happens during tests, as at runtime these caches are - # effectively singletons. + metric = CacheMetric(cache, cache_type, cache_name, collect_callback) metric_name = "cache_%s_%s" % (cache_type, cache_name) - if metric_name in collectors_by_name.keys(): - REGISTRY.unregister(collectors_by_name[metric_name]) - - class CacheMetric(object): - - hits = 0 - misses = 0 - evicted_size = 0 - - def inc_hits(self): - self.hits += 1 - - def inc_misses(self): - self.misses += 1 - - def inc_evictions(self, size=1): - self.evicted_size += size - - def describe(self): - return [] - - def collect(self): - try: - if cache_type == "response_cache": - response_cache_size.labels(cache_name).set(len(cache)) - response_cache_hits.labels(cache_name).set(self.hits) - response_cache_evicted.labels(cache_name).set(self.evicted_size) - response_cache_total.labels(cache_name).set(self.hits + self.misses) - else: - cache_size.labels(cache_name).set(len(cache)) - cache_hits.labels(cache_name).set(self.hits) - cache_evicted.labels(cache_name).set(self.evicted_size) - cache_total.labels(cache_name).set(self.hits + self.misses) - if collect_callback: - collect_callback() - except Exception as e: - logger.warning("Error calculating metrics for %s: %s", cache_name, e) - raise - - yield GaugeMetricFamily("__unused", "") - - metric = CacheMetric() - REGISTRY.register(metric) caches_by_name[cache_name] = cache collectors_by_name[metric_name] = metric return metric diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 2e8f6543e5..cd48262420 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import functools import inspect import logging @@ -30,7 +31,6 @@ from twisted.internet import defer from synapse.logging.context import make_deferred_yieldable, preserve_fn from synapse.util import unwrapFirstError from synapse.util.async_helpers import ObservableDeferred -from synapse.util.caches import get_cache_factor_for from synapse.util.caches.lrucache import LruCache from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry @@ -81,7 +81,6 @@ class CacheEntry(object): class Cache(object): __slots__ = ( "cache", - "max_entries", "name", "keylen", "thread", @@ -89,7 +88,29 @@ class Cache(object): "_pending_deferred_cache", ) - def __init__(self, name, max_entries=1000, keylen=1, tree=False, iterable=False): + def __init__( + self, + name: str, + max_entries: int = 1000, + keylen: int = 1, + tree: bool = False, + iterable: bool = False, + apply_cache_factor_from_config: bool = True, + ): + """ + Args: + name: The name of the cache + max_entries: Maximum amount of entries that the cache will hold + keylen: The length of the tuple used as the cache key + tree: Use a TreeCache instead of a dict as the underlying cache type + iterable: If True, count each item in the cached object as an entry, + rather than each cached object + apply_cache_factor_from_config: Whether cache factors specified in the + config file affect `max_entries` + + Returns: + Cache + """ cache_type = TreeCache if tree else dict self._pending_deferred_cache = cache_type() @@ -99,6 +120,7 @@ class Cache(object): cache_type=cache_type, size_callback=(lambda d: len(d)) if iterable else None, evicted_callback=self._on_evicted, + apply_cache_factor_from_config=apply_cache_factor_from_config, ) self.name = name @@ -111,6 +133,10 @@ class Cache(object): collect_callback=self._metrics_collection_callback, ) + @property + def max_entries(self): + return self.cache.max_size + def _on_evicted(self, evicted_count): self.metrics.inc_evictions(evicted_count) @@ -370,13 +396,11 @@ class CacheDescriptor(_CacheDescriptorBase): cache_context=cache_context, ) - max_entries = int(max_entries * get_cache_factor_for(orig.__name__)) - self.max_entries = max_entries self.tree = tree self.iterable = iterable - def __get__(self, obj, objtype=None): + def __get__(self, obj, owner): cache = Cache( name=self.orig.__name__, max_entries=self.max_entries, diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index cddf1ed515..2726b67b6d 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -18,6 +18,7 @@ from collections import OrderedDict from six import iteritems, itervalues +from synapse.config import cache as cache_config from synapse.metrics.background_process_metrics import run_as_background_process from synapse.util.caches import register_cache @@ -51,15 +52,16 @@ class ExpiringCache(object): an item on access. Defaults to False. iterable (bool): If true, the size is calculated by summing the sizes of all entries, rather than the number of entries. - """ self._cache_name = cache_name + self._original_max_size = max_len + + self._max_size = int(max_len * cache_config.properties.default_factor_size) + self._clock = clock - self._max_len = max_len self._expiry_ms = expiry_ms - self._reset_expiry_on_get = reset_expiry_on_get self._cache = OrderedDict() @@ -82,9 +84,11 @@ class ExpiringCache(object): def __setitem__(self, key, value): now = self._clock.time_msec() self._cache[key] = _CacheEntry(now, value) + self.evict() + def evict(self): # Evict if there are now too many items - while self._max_len and len(self) > self._max_len: + while self._max_size and len(self) > self._max_size: _key, value = self._cache.popitem(last=False) if self.iterable: self.metrics.inc_evictions(len(value.value)) @@ -170,6 +174,23 @@ class ExpiringCache(object): else: return len(self._cache) + def set_cache_factor(self, factor: float) -> bool: + """ + Set the cache factor for this individual cache. + + This will trigger a resize if it changes, which may require evicting + items from the cache. + + Returns: + bool: Whether the cache changed size or not. + """ + new_size = int(self._original_max_size * factor) + if new_size != self._max_size: + self._max_size = new_size + self.evict() + return True + return False + class _CacheEntry(object): __slots__ = ["time", "value"] diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 1536cb64f3..29fabac3cd 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -13,10 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. - import threading from functools import wraps +from typing import Callable, Optional, Type, Union +from synapse.config import cache as cache_config from synapse.util.caches.treecache import TreeCache @@ -52,17 +53,18 @@ class LruCache(object): def __init__( self, - max_size, - keylen=1, - cache_type=dict, - size_callback=None, - evicted_callback=None, + max_size: int, + keylen: int = 1, + cache_type: Type[Union[dict, TreeCache]] = dict, + size_callback: Optional[Callable] = None, + evicted_callback: Optional[Callable] = None, + apply_cache_factor_from_config: bool = True, ): """ Args: - max_size (int): + max_size: The maximum amount of entries the cache can hold - keylen (int): + keylen: The length of the tuple used as the cache key cache_type (type): type of underlying cache to be used. Typically one of dict @@ -73,9 +75,23 @@ class LruCache(object): evicted_callback (func(int)|None): if not None, called on eviction with the size of the evicted entry + + apply_cache_factor_from_config (bool): If true, `max_size` will be + multiplied by a cache factor derived from the homeserver config """ cache = cache_type() self.cache = cache # Used for introspection. + + # Save the original max size, and apply the default size factor. + self._original_max_size = max_size + # We previously didn't apply the cache factor here, and as such some caches were + # not affected by the global cache factor. Add an option here to disable applying + # the cache factor when a cache is created + if apply_cache_factor_from_config: + self.max_size = int(max_size * cache_config.properties.default_factor_size) + else: + self.max_size = int(max_size) + list_root = _Node(None, None, None, None) list_root.next_node = list_root list_root.prev_node = list_root @@ -83,7 +99,7 @@ class LruCache(object): lock = threading.Lock() def evict(): - while cache_len() > max_size: + while cache_len() > self.max_size: todelete = list_root.prev_node evicted_len = delete_node(todelete) cache.pop(todelete.key, None) @@ -236,6 +252,7 @@ class LruCache(object): return key in cache self.sentinel = object() + self._on_resize = evict self.get = cache_get self.set = cache_set self.setdefault = cache_set_default @@ -266,3 +283,20 @@ class LruCache(object): def __contains__(self, key): return self.contains(key) + + def set_cache_factor(self, factor: float) -> bool: + """ + Set the cache factor for this individual cache. + + This will trigger a resize if it changes, which may require evicting + items from the cache. + + Returns: + bool: Whether the cache changed size or not. + """ + new_size = int(self._original_max_size * factor) + if new_size != self.max_size: + self.max_size = new_size + self._on_resize() + return True + return False diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py index b68f9fe0d4..a6c60888e5 100644 --- a/synapse/util/caches/response_cache.py +++ b/synapse/util/caches/response_cache.py @@ -38,7 +38,7 @@ class ResponseCache(object): self.timeout_sec = timeout_ms / 1000.0 self._name = name - self._metrics = register_cache("response_cache", name, self) + self._metrics = register_cache("response_cache", name, self, resizable=False) def size(self): return len(self.pending_result_cache) diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index e54f80d76e..2a161bf244 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +import math from typing import Dict, FrozenSet, List, Mapping, Optional, Set, Union from six import integer_types @@ -46,7 +47,8 @@ class StreamChangeCache: max_size=10000, prefilled_cache: Optional[Mapping[EntityType, int]] = None, ): - self._max_size = int(max_size * caches.CACHE_SIZE_FACTOR) + self._original_max_size = max_size + self._max_size = math.floor(max_size) self._entity_to_key = {} # type: Dict[EntityType, int] # map from stream id to the a set of entities which changed at that stream id. @@ -58,12 +60,31 @@ class StreamChangeCache: # self._earliest_known_stream_pos = current_stream_pos self.name = name - self.metrics = caches.register_cache("cache", self.name, self._cache) + self.metrics = caches.register_cache( + "cache", self.name, self._cache, resize_callback=self.set_cache_factor + ) if prefilled_cache: for entity, stream_pos in prefilled_cache.items(): self.entity_has_changed(entity, stream_pos) + def set_cache_factor(self, factor: float) -> bool: + """ + Set the cache factor for this individual cache. + + This will trigger a resize if it changes, which may require evicting + items from the cache. + + Returns: + bool: Whether the cache changed size or not. + """ + new_size = math.floor(self._original_max_size * factor) + if new_size != self._max_size: + self.max_size = new_size + self._evict() + return True + return False + def has_entity_changed(self, entity: EntityType, stream_pos: int) -> bool: """Returns True if the entity may have been updated since stream_pos """ @@ -171,6 +192,7 @@ class StreamChangeCache: e1 = self._cache[stream_pos] = set() e1.add(entity) self._entity_to_key[entity] = stream_pos + self._evict() # if the cache is too big, remove entries while len(self._cache) > self._max_size: @@ -179,6 +201,13 @@ class StreamChangeCache: for entity in r: del self._entity_to_key[entity] + def _evict(self): + while len(self._cache) > self._max_size: + k, r = self._cache.popitem(0) + self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos) + for entity in r: + self._entity_to_key.pop(entity, None) + def get_max_pos_of_last_change(self, entity: EntityType) -> int: """Returns an upper bound of the stream id of the last change to an diff --git a/synapse/util/caches/ttlcache.py b/synapse/util/caches/ttlcache.py index 99646c7cf0..6437aa907e 100644 --- a/synapse/util/caches/ttlcache.py +++ b/synapse/util/caches/ttlcache.py @@ -38,7 +38,7 @@ class TTLCache(object): self._timer = timer - self._metrics = register_cache("ttl", cache_name, self) + self._metrics = register_cache("ttl", cache_name, self, resizable=False) def set(self, key, value, ttl): """Add/update an entry in the cache diff --git a/tests/config/test_cache.py b/tests/config/test_cache.py new file mode 100644 index 0000000000..2920279125 --- /dev/null +++ b/tests/config/test_cache.py @@ -0,0 +1,127 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.config._base import Config, RootConfig +from synapse.config.cache import CacheConfig, add_resizable_cache +from synapse.util.caches.lrucache import LruCache + +from tests.unittest import TestCase + + +class FakeServer(Config): + section = "server" + + +class TestConfig(RootConfig): + config_classes = [FakeServer, CacheConfig] + + +class CacheConfigTests(TestCase): + def setUp(self): + # Reset caches before each test + TestConfig().caches.reset() + + def test_individual_caches_from_environ(self): + """ + Individual cache factors will be loaded from the environment. + """ + config = {} + t = TestConfig() + t.caches._environ = { + "SYNAPSE_CACHE_FACTOR_SOMETHING_OR_OTHER": "2", + "SYNAPSE_NOT_CACHE": "BLAH", + } + t.read_config(config, config_dir_path="", data_dir_path="") + + self.assertEqual(dict(t.caches.cache_factors), {"something_or_other": 2.0}) + + def test_config_overrides_environ(self): + """ + Individual cache factors defined in the environment will take precedence + over those in the config. + """ + config = {"caches": {"per_cache_factors": {"foo": 2, "bar": 3}}} + t = TestConfig() + t.caches._environ = { + "SYNAPSE_CACHE_FACTOR_SOMETHING_OR_OTHER": "2", + "SYNAPSE_CACHE_FACTOR_FOO": 1, + } + t.read_config(config, config_dir_path="", data_dir_path="") + + self.assertEqual( + dict(t.caches.cache_factors), + {"foo": 1.0, "bar": 3.0, "something_or_other": 2.0}, + ) + + def test_individual_instantiated_before_config_load(self): + """ + If a cache is instantiated before the config is read, it will be given + the default cache size in the interim, and then resized once the config + is loaded. + """ + cache = LruCache(100) + + add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) + self.assertEqual(cache.max_size, 50) + + config = {"caches": {"per_cache_factors": {"foo": 3}}} + t = TestConfig() + t.read_config(config, config_dir_path="", data_dir_path="") + + self.assertEqual(cache.max_size, 300) + + def test_individual_instantiated_after_config_load(self): + """ + If a cache is instantiated after the config is read, it will be + immediately resized to the correct size given the per_cache_factor if + there is one. + """ + config = {"caches": {"per_cache_factors": {"foo": 2}}} + t = TestConfig() + t.read_config(config, config_dir_path="", data_dir_path="") + + cache = LruCache(100) + add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) + self.assertEqual(cache.max_size, 200) + + def test_global_instantiated_before_config_load(self): + """ + If a cache is instantiated before the config is read, it will be given + the default cache size in the interim, and then resized to the new + default cache size once the config is loaded. + """ + cache = LruCache(100) + add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) + self.assertEqual(cache.max_size, 50) + + config = {"caches": {"global_factor": 4}} + t = TestConfig() + t.read_config(config, config_dir_path="", data_dir_path="") + + self.assertEqual(cache.max_size, 400) + + def test_global_instantiated_after_config_load(self): + """ + If a cache is instantiated after the config is read, it will be + immediately resized to the correct size given the global factor if there + is no per-cache factor. + """ + config = {"caches": {"global_factor": 1.5}} + t = TestConfig() + t.read_config(config, config_dir_path="", data_dir_path="") + + cache = LruCache(100) + add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) + self.assertEqual(cache.max_size, 150) diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index e37260a820..5a50e4fdd4 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -25,8 +25,8 @@ from synapse.util.caches.descriptors import Cache, cached from tests import unittest -class CacheTestCase(unittest.TestCase): - def setUp(self): +class CacheTestCase(unittest.HomeserverTestCase): + def prepare(self, reactor, clock, homeserver): self.cache = Cache("test") def test_empty(self): @@ -96,7 +96,7 @@ class CacheTestCase(unittest.TestCase): cache.get(3) -class CacheDecoratorTestCase(unittest.TestCase): +class CacheDecoratorTestCase(unittest.HomeserverTestCase): @defer.inlineCallbacks def test_passthrough(self): class A(object): @@ -239,7 +239,7 @@ class CacheDecoratorTestCase(unittest.TestCase): callcount2 = [0] class A(object): - @cached(max_entries=4) # HACK: This makes it 2 due to cache factor + @cached(max_entries=2) def func(self, key): callcount[0] += 1 return key diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 31710949a8..ef296e7dab 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -43,7 +43,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): ) hs.config.app_service_config_files = self.as_yaml_files - hs.config.event_cache_size = 1 + hs.config.caches.event_cache_size = 1 hs.config.password_providers = [] self.as_token = "token1" @@ -110,7 +110,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): ) hs.config.app_service_config_files = self.as_yaml_files - hs.config.event_cache_size = 1 + hs.config.caches.event_cache_size = 1 hs.config.password_providers = [] self.as_list = [ @@ -422,7 +422,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): ) hs.config.app_service_config_files = [f1, f2] - hs.config.event_cache_size = 1 + hs.config.caches.event_cache_size = 1 hs.config.password_providers = [] database = hs.get_datastores().databases[0] @@ -440,7 +440,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): ) hs.config.app_service_config_files = [f1, f2] - hs.config.event_cache_size = 1 + hs.config.caches.event_cache_size = 1 hs.config.password_providers = [] with self.assertRaises(ConfigError) as cm: @@ -464,7 +464,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): ) hs.config.app_service_config_files = [f1, f2] - hs.config.event_cache_size = 1 + hs.config.caches.event_cache_size = 1 hs.config.password_providers = [] with self.assertRaises(ConfigError) as cm: diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index cdee0a9e60..278961c331 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -51,7 +51,8 @@ class SQLBaseStoreTestCase(unittest.TestCase): config = Mock() config._disable_native_upserts = True - config.event_cache_size = 1 + config.caches = Mock() + config.caches.event_cache_size = 1 hs = TestHomeServer("test", config=config) sqlite_config = {"name": "sqlite3"} diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 270f853d60..f5f63d8ed6 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -15,6 +15,7 @@ # limitations under the License. from synapse.metrics import REGISTRY, InFlightGauge, generate_latest +from synapse.util.caches.descriptors import Cache from tests import unittest @@ -129,3 +130,36 @@ class BuildInfoTests(unittest.TestCase): self.assertTrue(b"osversion=" in items[0]) self.assertTrue(b"pythonversion=" in items[0]) self.assertTrue(b"version=" in items[0]) + + +class CacheMetricsTests(unittest.HomeserverTestCase): + def test_cache_metric(self): + """ + Caches produce metrics reflecting their state when scraped. + """ + CACHE_NAME = "cache_metrics_test_fgjkbdfg" + cache = Cache(CACHE_NAME, max_entries=777) + + items = { + x.split(b"{")[0].decode("ascii"): x.split(b" ")[1].decode("ascii") + for x in filter( + lambda x: b"cache_metrics_test_fgjkbdfg" in x, + generate_latest(REGISTRY).split(b"\n"), + ) + } + + self.assertEqual(items["synapse_util_caches_cache_size"], "0.0") + self.assertEqual(items["synapse_util_caches_cache_max_size"], "777.0") + + cache.prefill("1", "hi") + + items = { + x.split(b"{")[0].decode("ascii"): x.split(b" ")[1].decode("ascii") + for x in filter( + lambda x: b"cache_metrics_test_fgjkbdfg" in x, + generate_latest(REGISTRY).split(b"\n"), + ) + } + + self.assertEqual(items["synapse_util_caches_cache_size"], "1.0") + self.assertEqual(items["synapse_util_caches_cache_max_size"], "777.0") diff --git a/tests/util/test_expiring_cache.py b/tests/util/test_expiring_cache.py index 50bc7702d2..49ffeebd0e 100644 --- a/tests/util/test_expiring_cache.py +++ b/tests/util/test_expiring_cache.py @@ -21,7 +21,7 @@ from tests.utils import MockClock from .. import unittest -class ExpiringCacheTestCase(unittest.TestCase): +class ExpiringCacheTestCase(unittest.HomeserverTestCase): def test_get_set(self): clock = MockClock() cache = ExpiringCache("test", clock, max_len=1) diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py index 786947375d..0adb2174af 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py @@ -22,7 +22,7 @@ from synapse.util.caches.treecache import TreeCache from .. import unittest -class LruCacheTestCase(unittest.TestCase): +class LruCacheTestCase(unittest.HomeserverTestCase): def test_get_set(self): cache = LruCache(1) cache["key"] = "value" @@ -84,7 +84,7 @@ class LruCacheTestCase(unittest.TestCase): self.assertEquals(len(cache), 0) -class LruCacheCallbacksTestCase(unittest.TestCase): +class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): def test_get(self): m = Mock() cache = LruCache(1) @@ -233,7 +233,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase): self.assertEquals(m3.call_count, 1) -class LruCacheSizedTestCase(unittest.TestCase): +class LruCacheSizedTestCase(unittest.HomeserverTestCase): def test_evict(self): cache = LruCache(5, size_callback=len) cache["key1"] = [0] diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py index 6857933540..13b753e367 100644 --- a/tests/util/test_stream_change_cache.py +++ b/tests/util/test_stream_change_cache.py @@ -1,11 +1,9 @@ -from mock import patch - from synapse.util.caches.stream_change_cache import StreamChangeCache from tests import unittest -class StreamChangeCacheTests(unittest.TestCase): +class StreamChangeCacheTests(unittest.HomeserverTestCase): """ Tests for StreamChangeCache. """ @@ -54,7 +52,6 @@ class StreamChangeCacheTests(unittest.TestCase): self.assertTrue(cache.has_entity_changed("user@foo.com", 0)) self.assertTrue(cache.has_entity_changed("not@here.website", 0)) - @patch("synapse.util.caches.CACHE_SIZE_FACTOR", 1.0) def test_entity_has_changed_pops_off_start(self): """ StreamChangeCache.entity_has_changed will respect the max size and diff --git a/tests/utils.py b/tests/utils.py index f9be62b499..59c020a051 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -167,6 +167,7 @@ def default_config(name, parse=False): # disable user directory updates, because they get done in the # background, which upsets the test runner. "update_user_directory": False, + "caches": {"global_factor": 1}, } if parse: -- cgit 1.5.1 From 1a1da60ad2c9172fe487cd38a164b39df60f4cb5 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 12 May 2020 11:20:48 +0100 Subject: Fix new flake8 errors (#7470) --- changelog.d/7470.misc | 1 + synapse/app/_base.py | 5 +++-- synapse/config/server.py | 2 +- synapse/notifier.py | 10 ++++++---- synapse/push/mailer.py | 7 +++++-- synapse/storage/database.py | 4 ++-- tests/config/test_load.py | 2 +- 7 files changed, 19 insertions(+), 12 deletions(-) create mode 100644 changelog.d/7470.misc (limited to 'synapse/storage') diff --git a/changelog.d/7470.misc b/changelog.d/7470.misc new file mode 100644 index 0000000000..45e66ecf48 --- /dev/null +++ b/changelog.d/7470.misc @@ -0,0 +1 @@ +Fix linting errors in new version of Flake8. diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 628292b890..dedff81af3 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -22,6 +22,7 @@ import sys import traceback from daemonize import Daemonize +from typing_extensions import NoReturn from twisted.internet import defer, error, reactor from twisted.protocols.tls import TLSMemoryBIOFactory @@ -139,9 +140,9 @@ def start_reactor( run() -def quit_with_error(error_string): +def quit_with_error(error_string: str) -> NoReturn: message_lines = error_string.split("\n") - line_length = max(len(l) for l in message_lines if len(l) < 80) + 2 + line_length = max(len(line) for line in message_lines if len(line) < 80) + 2 sys.stderr.write("*" * line_length + "\n") for line in message_lines: sys.stderr.write(" %s\n" % (line.rstrip(),)) diff --git a/synapse/config/server.py b/synapse/config/server.py index 6d88231843..ed28da3deb 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -522,7 +522,7 @@ class ServerConfig(Config): ) def has_tls_listener(self) -> bool: - return any(l["tls"] for l in self.listeners) + return any(listener["tls"] for listener in self.listeners) def generate_config_section( self, server_name, data_dir_path, open_private_ports, listeners, **kwargs diff --git a/synapse/notifier.py b/synapse/notifier.py index 71d9ed62b0..87c120a59c 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -15,7 +15,7 @@ import logging from collections import namedtuple -from typing import Callable, List +from typing import Callable, Iterable, List, TypeVar from prometheus_client import Counter @@ -42,12 +42,14 @@ users_woken_by_stream_counter = Counter( "synapse_notifier_users_woken_by_stream", "", ["stream"] ) +T = TypeVar("T") + # TODO(paul): Should be shared somewhere -def count(func, l): - """Return the number of items in l for which func returns true.""" +def count(func: Callable[[T], bool], it: Iterable[T]) -> int: + """Return the number of items in it for which func returns true.""" n = 0 - for x in l: + for x in it: if func(x): n += 1 return n diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 73580c1c6c..ab33abbeed 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -19,6 +19,7 @@ import logging import time from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText +from typing import Iterable, List, TypeVar from six.moves import urllib @@ -41,6 +42,8 @@ from synapse.visibility import filter_events_for_client logger = logging.getLogger(__name__) +T = TypeVar("T") + MESSAGE_FROM_PERSON_IN_ROOM = ( "You have a message on %(app)s from %(person)s in the %(room)s room..." @@ -638,10 +641,10 @@ def safe_text(raw_text): ) -def deduped_ordered_list(l): +def deduped_ordered_list(it: Iterable[T]) -> List[T]: seen = set() ret = [] - for item in l: + for item in it: if item not in seen: seen.add(item) ret.append(item) diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 2b635d6ca0..c3d0863429 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -214,9 +214,9 @@ class LoggingTransaction: def executemany(self, sql: str, *args: Any): self._do_execute(self.txn.executemany, sql, *args) - def _make_sql_one_line(self, sql): + def _make_sql_one_line(self, sql: str) -> str: "Strip newlines out of SQL so that the loggers in the DB are on one line" - return " ".join(l.strip() for l in sql.splitlines() if l.strip()) + return " ".join(line.strip() for line in sql.splitlines() if line.strip()) def _do_execute(self, func, sql, *args): sql = self._make_sql_one_line(sql) diff --git a/tests/config/test_load.py b/tests/config/test_load.py index b3e557bd6a..734a9983e8 100644 --- a/tests/config/test_load.py +++ b/tests/config/test_load.py @@ -122,7 +122,7 @@ class ConfigLoadingTestCase(unittest.TestCase): with open(self.file, "r") as f: contents = f.readlines() - contents = [l for l in contents if needle not in l] + contents = [line for line in contents if needle not in line] with open(self.file, "w") as f: f.write("".join(contents)) -- cgit 1.5.1 From 782e4e64df58dd9e114e7e1e467ae55219336b79 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 13 May 2020 13:38:22 +0100 Subject: Shuffle persist event data store functions. (#7440) The aim here is to get to a stage where we have a `PersistEventStore` that holds all the write methods used during event persistence, so that we can take that class out of the `DataStore` mixin and instansiate it separately. This will allow us to instansiate it on processes other than master, while also ensuring it is only available on processes that are configured to write to events stream. This is a bit of an architectural change, where we end up with multiple classes per data store (rather than one per data store we have now). We end up having: 1. Storage classes that provide high level APIs that can talk to multiple data stores. 2. Data store modules that consist of classes that must point at the same database instance. 3. Classes in a data store that can be instantiated on processes depending on config. --- changelog.d/7440.misc | 1 + synapse/storage/data_stores/__init__.py | 9 + synapse/storage/data_stores/main/__init__.py | 8 +- synapse/storage/data_stores/main/censor_events.py | 213 ++++ .../storage/data_stores/main/event_federation.py | 83 -- .../storage/data_stores/main/event_push_actions.py | 74 -- synapse/storage/data_stores/main/events.py | 1216 ++++++++------------ synapse/storage/data_stores/main/events_worker.py | 152 ++- synapse/storage/data_stores/main/metrics.py | 128 +++ synapse/storage/data_stores/main/purge_events.py | 399 +++++++ synapse/storage/data_stores/main/rejections.py | 11 - synapse/storage/data_stores/main/relations.py | 60 +- synapse/storage/data_stores/main/room.py | 49 - synapse/storage/data_stores/main/roommember.py | 90 -- synapse/storage/data_stores/main/search.py | 23 - synapse/storage/data_stores/main/signatures.py | 30 - synapse/storage/data_stores/main/state.py | 35 - synapse/storage/persist_events.py | 27 +- tests/storage/test_event_push_actions.py | 3 +- 19 files changed, 1376 insertions(+), 1235 deletions(-) create mode 100644 changelog.d/7440.misc create mode 100644 synapse/storage/data_stores/main/censor_events.py create mode 100644 synapse/storage/data_stores/main/metrics.py create mode 100644 synapse/storage/data_stores/main/purge_events.py (limited to 'synapse/storage') diff --git a/changelog.d/7440.misc b/changelog.d/7440.misc new file mode 100644 index 0000000000..2505cdc66f --- /dev/null +++ b/changelog.d/7440.misc @@ -0,0 +1 @@ +Refactor event persistence database functions in preparation for allowing them to be run on non-master processes. diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/data_stores/__init__.py index e1d03429ca..7c14ba91c2 100644 --- a/synapse/storage/data_stores/__init__.py +++ b/synapse/storage/data_stores/__init__.py @@ -15,6 +15,7 @@ import logging +from synapse.storage.data_stores.main.events import PersistEventsStore from synapse.storage.data_stores.state import StateGroupDataStore from synapse.storage.database import Database, make_conn from synapse.storage.engines import create_engine @@ -39,6 +40,7 @@ class DataStores(object): self.databases = [] self.main = None self.state = None + self.persist_events = None for database_config in hs.config.database.databases: db_name = database_config.name @@ -64,6 +66,13 @@ class DataStores(object): self.main = main_store_class(database, db_conn, hs) + # If we're on a process that can persist events (currently + # master), also instansiate a `PersistEventsStore` + if hs.config.worker.worker_app is None: + self.persist_events = PersistEventsStore( + hs, database, self.main + ) + if "state" in database_config.data_stores: logger.info("Starting 'state' data store") diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py index cd2a1f0461..5df9dce79d 100644 --- a/synapse/storage/data_stores/main/__init__.py +++ b/synapse/storage/data_stores/main/__init__.py @@ -34,6 +34,7 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache from .account_data import AccountDataStore from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore from .cache import CacheInvalidationWorkerStore +from .censor_events import CensorEventsStore from .client_ips import ClientIpStore from .deviceinbox import DeviceInboxStore from .devices import DeviceStore @@ -42,16 +43,17 @@ from .e2e_room_keys import EndToEndRoomKeyStore from .end_to_end_keys import EndToEndKeyStore from .event_federation import EventFederationStore from .event_push_actions import EventPushActionsStore -from .events import EventsStore from .events_bg_updates import EventsBackgroundUpdatesStore from .filtering import FilteringStore from .group_server import GroupServerStore from .keys import KeyStore from .media_repository import MediaRepositoryStore +from .metrics import ServerMetricsStore from .monthly_active_users import MonthlyActiveUsersStore from .openid import OpenIdStore from .presence import PresenceStore, UserPresenceState from .profile import ProfileStore +from .purge_events import PurgeEventsStore from .push_rule import PushRuleStore from .pusher import PusherStore from .receipts import ReceiptsStore @@ -88,7 +90,7 @@ class DataStore( StateStore, SignatureStore, ApplicationServiceStore, - EventsStore, + PurgeEventsStore, EventFederationStore, MediaRepositoryStore, RejectionsStore, @@ -113,8 +115,10 @@ class DataStore( MonthlyActiveUsersStore, StatsStore, RelationsStore, + CensorEventsStore, UIAuthStore, CacheInvalidationWorkerStore, + ServerMetricsStore, ): def __init__(self, database: Database, db_conn, hs): self.hs = hs diff --git a/synapse/storage/data_stores/main/censor_events.py b/synapse/storage/data_stores/main/censor_events.py new file mode 100644 index 0000000000..49683b4d5a --- /dev/null +++ b/synapse/storage/data_stores/main/censor_events.py @@ -0,0 +1,213 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import TYPE_CHECKING + +from twisted.internet import defer + +from synapse.events.utils import prune_event_dict +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage._base import SQLBaseStore +from synapse.storage.data_stores.main.cache import CacheInvalidationWorkerStore +from synapse.storage.data_stores.main.events import encode_json +from synapse.storage.data_stores.main.events_worker import EventsWorkerStore +from synapse.storage.database import Database + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +logger = logging.getLogger(__name__) + + +class CensorEventsStore(CacheInvalidationWorkerStore, EventsWorkerStore, SQLBaseStore): + def __init__(self, database: Database, db_conn, hs: "HomeServer"): + super().__init__(database, db_conn, hs) + + # This should only exist on master for now + assert ( + hs.config.worker.worker_app is None + ), "Can only instantiate CensorEventsStore on master" + + def _censor_redactions(): + return run_as_background_process( + "_censor_redactions", self._censor_redactions + ) + + if self.hs.config.redaction_retention_period is not None: + hs.get_clock().looping_call(_censor_redactions, 5 * 60 * 1000) + + async def _censor_redactions(self): + """Censors all redactions older than the configured period that haven't + been censored yet. + + By censor we mean update the event_json table with the redacted event. + """ + + if self.hs.config.redaction_retention_period is None: + return + + if not ( + await self.db.updates.has_completed_background_update( + "redactions_have_censored_ts_idx" + ) + ): + # We don't want to run this until the appropriate index has been + # created. + return + + before_ts = self._clock.time_msec() - self.hs.config.redaction_retention_period + + # We fetch all redactions that: + # 1. point to an event we have, + # 2. has a received_ts from before the cut off, and + # 3. we haven't yet censored. + # + # This is limited to 100 events to ensure that we don't try and do too + # much at once. We'll get called again so this should eventually catch + # up. + sql = """ + SELECT redactions.event_id, redacts FROM redactions + LEFT JOIN events AS original_event ON ( + redacts = original_event.event_id + ) + WHERE NOT have_censored + AND redactions.received_ts <= ? + ORDER BY redactions.received_ts ASC + LIMIT ? + """ + + rows = await self.db.execute( + "_censor_redactions_fetch", None, sql, before_ts, 100 + ) + + updates = [] + + for redaction_id, event_id in rows: + redaction_event = await self.get_event(redaction_id, allow_none=True) + original_event = await self.get_event( + event_id, allow_rejected=True, allow_none=True + ) + + # The SQL above ensures that we have both the redaction and + # original event, so if the `get_event` calls return None it + # means that the redaction wasn't allowed. Either way we know that + # the result won't change so we mark the fact that we've checked. + if ( + redaction_event + and original_event + and original_event.internal_metadata.is_redacted() + ): + # Redaction was allowed + pruned_json = encode_json( + prune_event_dict( + original_event.room_version, original_event.get_dict() + ) + ) + else: + # Redaction wasn't allowed + pruned_json = None + + updates.append((redaction_id, event_id, pruned_json)) + + def _update_censor_txn(txn): + for redaction_id, event_id, pruned_json in updates: + if pruned_json: + self._censor_event_txn(txn, event_id, pruned_json) + + self.db.simple_update_one_txn( + txn, + table="redactions", + keyvalues={"event_id": redaction_id}, + updatevalues={"have_censored": True}, + ) + + await self.db.runInteraction("_update_censor_txn", _update_censor_txn) + + def _censor_event_txn(self, txn, event_id, pruned_json): + """Censor an event by replacing its JSON in the event_json table with the + provided pruned JSON. + + Args: + txn (LoggingTransaction): The database transaction. + event_id (str): The ID of the event to censor. + pruned_json (str): The pruned JSON + """ + self.db.simple_update_one_txn( + txn, + table="event_json", + keyvalues={"event_id": event_id}, + updatevalues={"json": pruned_json}, + ) + + @defer.inlineCallbacks + def expire_event(self, event_id): + """Retrieve and expire an event that has expired, and delete its associated + expiry timestamp. If the event can't be retrieved, delete its associated + timestamp so we don't try to expire it again in the future. + + Args: + event_id (str): The ID of the event to delete. + """ + # Try to retrieve the event's content from the database or the event cache. + event = yield self.get_event(event_id) + + def delete_expired_event_txn(txn): + # Delete the expiry timestamp associated with this event from the database. + self._delete_event_expiry_txn(txn, event_id) + + if not event: + # If we can't find the event, log a warning and delete the expiry date + # from the database so that we don't try to expire it again in the + # future. + logger.warning( + "Can't expire event %s because we don't have it.", event_id + ) + return + + # Prune the event's dict then convert it to JSON. + pruned_json = encode_json( + prune_event_dict(event.room_version, event.get_dict()) + ) + + # Update the event_json table to replace the event's JSON with the pruned + # JSON. + self._censor_event_txn(txn, event.event_id, pruned_json) + + # We need to invalidate the event cache entry for this event because we + # changed its content in the database. We can't call + # self._invalidate_cache_and_stream because self.get_event_cache isn't of the + # right type. + txn.call_after(self._get_event_cache.invalidate, (event.event_id,)) + # Send that invalidation to replication so that other workers also invalidate + # the event cache. + self._send_invalidation_to_replication( + txn, "_get_event_cache", (event.event_id,) + ) + + yield self.db.runInteraction("delete_expired_event", delete_expired_event_txn) + + def _delete_event_expiry_txn(self, txn, event_id): + """Delete the expiry timestamp associated with an event ID without deleting the + actual event. + + Args: + txn (LoggingTransaction): The transaction to use to perform the deletion. + event_id (str): The event ID to delete the associated expiry timestamp of. + """ + return self.db.simple_delete_txn( + txn=txn, table="event_expiry", keyvalues={"event_id": event_id} + ) diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py index b99439cc37..24ce8c4330 100644 --- a/synapse/storage/data_stores/main/event_federation.py +++ b/synapse/storage/data_stores/main/event_federation.py @@ -640,89 +640,6 @@ class EventFederationStore(EventFederationWorkerStore): self._delete_old_forward_extrem_cache, 60 * 60 * 1000 ) - def _update_min_depth_for_room_txn(self, txn, room_id, depth): - min_depth = self._get_min_depth_interaction(txn, room_id) - - if min_depth is not None and depth >= min_depth: - return - - self.db.simple_upsert_txn( - txn, - table="room_depth", - keyvalues={"room_id": room_id}, - values={"min_depth": depth}, - ) - - def _handle_mult_prev_events(self, txn, events): - """ - For the given event, update the event edges table and forward and - backward extremities tables. - """ - self.db.simple_insert_many_txn( - txn, - table="event_edges", - values=[ - { - "event_id": ev.event_id, - "prev_event_id": e_id, - "room_id": ev.room_id, - "is_state": False, - } - for ev in events - for e_id in ev.prev_event_ids() - ], - ) - - self._update_backward_extremeties(txn, events) - - def _update_backward_extremeties(self, txn, events): - """Updates the event_backward_extremities tables based on the new/updated - events being persisted. - - This is called for new events *and* for events that were outliers, but - are now being persisted as non-outliers. - - Forward extremities are handled when we first start persisting the events. - """ - events_by_room = {} - for ev in events: - events_by_room.setdefault(ev.room_id, []).append(ev) - - query = ( - "INSERT INTO event_backward_extremities (event_id, room_id)" - " SELECT ?, ? WHERE NOT EXISTS (" - " SELECT 1 FROM event_backward_extremities" - " WHERE event_id = ? AND room_id = ?" - " )" - " AND NOT EXISTS (" - " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? " - " AND outlier = ?" - " )" - ) - - txn.executemany( - query, - [ - (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False) - for ev in events - for e_id in ev.prev_event_ids() - if not ev.internal_metadata.is_outlier() - ], - ) - - query = ( - "DELETE FROM event_backward_extremities" - " WHERE event_id = ? AND room_id = ?" - ) - txn.executemany( - query, - [ - (ev.event_id, ev.room_id) - for ev in events - if not ev.internal_metadata.is_outlier() - ], - ) - def _delete_old_forward_extrem_cache(self): def _delete_old_forward_extrem_cache_txn(txn): # Delete entries older than a month, while making sure we don't delete diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py index 8eed590929..0321274de2 100644 --- a/synapse/storage/data_stores/main/event_push_actions.py +++ b/synapse/storage/data_stores/main/event_push_actions.py @@ -652,69 +652,6 @@ class EventPushActionsStore(EventPushActionsWorkerStore): self._start_rotate_notifs, 30 * 60 * 1000 ) - def _set_push_actions_for_event_and_users_txn( - self, txn, events_and_contexts, all_events_and_contexts - ): - """Handles moving push actions from staging table to main - event_push_actions table for all events in `events_and_contexts`. - - Also ensures that all events in `all_events_and_contexts` are removed - from the push action staging area. - - Args: - events_and_contexts (list[(EventBase, EventContext)]): events - we are persisting - all_events_and_contexts (list[(EventBase, EventContext)]): all - events that we were going to persist. This includes events - we've already persisted, etc, that wouldn't appear in - events_and_context. - """ - - sql = """ - INSERT INTO event_push_actions ( - room_id, event_id, user_id, actions, stream_ordering, - topological_ordering, notif, highlight - ) - SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight - FROM event_push_actions_staging - WHERE event_id = ? - """ - - if events_and_contexts: - txn.executemany( - sql, - ( - ( - event.room_id, - event.internal_metadata.stream_ordering, - event.depth, - event.event_id, - ) - for event, _ in events_and_contexts - ), - ) - - for event, _ in events_and_contexts: - user_ids = self.db.simple_select_onecol_txn( - txn, - table="event_push_actions_staging", - keyvalues={"event_id": event.event_id}, - retcol="user_id", - ) - - for uid in user_ids: - txn.call_after( - self.get_unread_event_push_actions_by_room_for_user.invalidate_many, - (event.room_id, uid), - ) - - # Now we delete the staging area for *all* events that were being - # persisted. - txn.executemany( - "DELETE FROM event_push_actions_staging WHERE event_id = ?", - ((event.event_id,) for event, _ in all_events_and_contexts), - ) - @defer.inlineCallbacks def get_push_actions_for_user( self, user_id, before=None, limit=50, only_highlight=False @@ -763,17 +700,6 @@ class EventPushActionsStore(EventPushActionsWorkerStore): ) return result[0] or 0 - def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id): - # Sad that we have to blow away the cache for the whole room here - txn.call_after( - self.get_unread_event_push_actions_by_room_for_user.invalidate_many, - (room_id,), - ) - txn.execute( - "DELETE FROM event_push_actions WHERE room_id = ? AND event_id = ?", - (room_id, event_id), - ) - def _remove_old_push_actions_before_txn( self, txn, room_id, user_id, stream_ordering ): diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index e71c23541d..092739962b 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -17,39 +17,44 @@ import itertools import logging -from collections import Counter as c_counter, OrderedDict, namedtuple +from collections import OrderedDict, namedtuple from functools import wraps -from typing import Dict, List, Tuple +from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple -from six import iteritems, text_type +from six import integer_types, iteritems, text_type from six.moves import range +import attr from canonicaljson import json from prometheus_client import Counter from twisted.internet import defer import synapse.metrics -from synapse.api.constants import EventContentFields, EventTypes -from synapse.api.errors import SynapseError +from synapse.api.constants import ( + EventContentFields, + EventTypes, + Membership, + RelationTypes, +) from synapse.api.room_versions import RoomVersions +from synapse.crypto.event_signing import compute_event_reference_hash from synapse.events import EventBase # noqa: F401 from synapse.events.snapshot import EventContext # noqa: F401 -from synapse.events.utils import prune_event_dict from synapse.logging.utils import log_function -from synapse.metrics import BucketCollector -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import make_in_list_sql_clause -from synapse.storage.data_stores.main.event_federation import EventFederationStore -from synapse.storage.data_stores.main.events_worker import EventsWorkerStore -from synapse.storage.data_stores.main.state import StateGroupWorkerStore +from synapse.storage.data_stores.main.search import SearchEntry from synapse.storage.database import Database, LoggingTransaction -from synapse.storage.persist_events import DeltaState -from synapse.types import RoomStreamToken, StateMap, get_domain_from_id -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks +from synapse.storage.util.id_generators import StreamIdGenerator +from synapse.types import StateMap, get_domain_from_id from synapse.util.frozenutils import frozendict_json_encoder from synapse.util.iterutils import batch_iter +if TYPE_CHECKING: + from synapse.storage.data_stores.main import DataStore + from synapse.server import HomeServer + + logger = logging.getLogger(__name__) persist_event_counter = Counter("synapse_storage_events_persisted_events", "") @@ -94,58 +99,49 @@ def _retry_on_integrity_error(func): return f -# inherits from EventFederationStore so that we can call _update_backward_extremities -# and _handle_mult_prev_events (though arguably those could both be moved in here) -class EventsStore( - StateGroupWorkerStore, EventFederationStore, EventsWorkerStore, -): - def __init__(self, database: Database, db_conn, hs): - super(EventsStore, self).__init__(database, db_conn, hs) +@attr.s(slots=True) +class DeltaState: + """Deltas to use to update the `current_state_events` table. - # Collect metrics on the number of forward extremities that exist. - # Counter of number of extremities to count - self._current_forward_extremities_amount = c_counter() + Attributes: + to_delete: List of type/state_keys to delete from current state + to_insert: Map of state to upsert into current state + no_longer_in_room: The server is not longer in the room, so the room + should e.g. be removed from `current_state_events` table. + """ - BucketCollector( - "synapse_forward_extremities", - lambda: self._current_forward_extremities_amount, - buckets=[1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"], - ) + to_delete = attr.ib(type=List[Tuple[str, str]]) + to_insert = attr.ib(type=StateMap[str]) + no_longer_in_room = attr.ib(type=bool, default=False) - # Read the extrems every 60 minutes - def read_forward_extremities(): - # run as a background process to make sure that the database transactions - # have a logcontext to report to - return run_as_background_process( - "read_forward_extremities", self._read_forward_extremities - ) - hs.get_clock().looping_call(read_forward_extremities, 60 * 60 * 1000) +class PersistEventsStore: + """Contains all the functions for writing events to the database. - def _censor_redactions(): - return run_as_background_process( - "_censor_redactions", self._censor_redactions - ) + Should only be instansiated on one process (when using a worker mode setup). + + Note: This is not part of the `DataStore` mixin. + """ - if self.hs.config.redaction_retention_period is not None: - hs.get_clock().looping_call(_censor_redactions, 5 * 60 * 1000) + def __init__(self, hs: "HomeServer", db: Database, main_data_store: "DataStore"): + self.hs = hs + self.db = db + self.store = main_data_store + self.database_engine = db.engine + self._clock = hs.get_clock() self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages self.is_mine_id = hs.is_mine_id - @defer.inlineCallbacks - def _read_forward_extremities(self): - def fetch(txn): - txn.execute( - """ - select count(*) c from event_forward_extremities - group by room_id - """ - ) - return txn.fetchall() + # Ideally we'd move these ID gens here, unfortunately some other ID + # generators are chained off them so doing so is a bit of a PITA. + self._backfill_id_gen = self.store._backfill_id_gen # type: StreamIdGenerator + self._stream_id_gen = self.store._stream_id_gen # type: StreamIdGenerator - res = yield self.db.runInteraction("read_forward_extremities", fetch) - self._current_forward_extremities_amount = c_counter([x[0] for x in res]) + # This should only exist on master for now + assert ( + hs.config.worker.worker_app is None + ), "Can only instantiate PersistEventsStore on master" @_retry_on_integrity_error @defer.inlineCallbacks @@ -237,10 +233,10 @@ class EventsStore( event_counter.labels(event.type, origin_type, origin_entity).inc() for room_id, new_state in iteritems(current_state_for_room): - self.get_current_state_ids.prefill((room_id,), new_state) + self.store.get_current_state_ids.prefill((room_id,), new_state) for room_id, latest_event_ids in iteritems(new_forward_extremeties): - self.get_latest_event_ids_in_room.prefill( + self.store.get_latest_event_ids_in_room.prefill( (room_id,), list(latest_event_ids) ) @@ -586,7 +582,7 @@ class EventsStore( ) txn.call_after( - self._curr_state_delta_stream_cache.entity_has_changed, + self.store._curr_state_delta_stream_cache.entity_has_changed, room_id, stream_id, ) @@ -606,10 +602,13 @@ class EventsStore( for member in members_changed: txn.call_after( - self.get_rooms_for_user_with_stream_ordering.invalidate, (member,) + self.store.get_rooms_for_user_with_stream_ordering.invalidate, + (member,), ) - self._invalidate_state_caches_and_stream(txn, room_id, members_changed) + self.store._invalidate_state_caches_and_stream( + txn, room_id, members_changed + ) def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str): """Update the room version in the database based off current state @@ -647,7 +646,9 @@ class EventsStore( self.db.simple_delete_txn( txn, table="event_forward_extremities", keyvalues={"room_id": room_id} ) - txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,)) + txn.call_after( + self.store.get_latest_event_ids_in_room.invalidate, (room_id,) + ) self.db.simple_insert_many_txn( txn, @@ -713,10 +714,10 @@ class EventsStore( depth_updates = {} for event, context in events_and_contexts: # Remove the any existing cache entries for the event_ids - txn.call_after(self._invalidate_get_event_cache, event.event_id) + txn.call_after(self.store._invalidate_get_event_cache, event.event_id) if not backfilled: txn.call_after( - self._events_stream_cache.entity_has_changed, + self.store._events_stream_cache.entity_has_changed, event.room_id, event.internal_metadata.stream_ordering, ) @@ -1088,13 +1089,15 @@ class EventsStore( def prefill(): for cache_entry in to_prefill: - self._get_event_cache.prefill((cache_entry[0].event_id,), cache_entry) + self.store._get_event_cache.prefill( + (cache_entry[0].event_id,), cache_entry + ) txn.call_after(prefill) def _store_redaction(self, txn, event): # invalidate the cache for the redacted event - txn.call_after(self._invalidate_get_event_cache, event.redacts) + txn.call_after(self.store._invalidate_get_event_cache, event.redacts) self.db.simple_insert_txn( txn, @@ -1106,783 +1109,484 @@ class EventsStore( }, ) - async def _censor_redactions(self): - """Censors all redactions older than the configured period that haven't - been censored yet. + def insert_labels_for_event_txn( + self, txn, event_id, labels, room_id, topological_ordering + ): + """Store the mapping between an event's ID and its labels, with one row per + (event_id, label) tuple. - By censor we mean update the event_json table with the redacted event. + Args: + txn (LoggingTransaction): The transaction to execute. + event_id (str): The event's ID. + labels (list[str]): A list of text labels. + room_id (str): The ID of the room the event was sent to. + topological_ordering (int): The position of the event in the room's topology. """ + return self.db.simple_insert_many_txn( + txn=txn, + table="event_labels", + values=[ + { + "event_id": event_id, + "label": label, + "room_id": room_id, + "topological_ordering": topological_ordering, + } + for label in labels + ], + ) - if self.hs.config.redaction_retention_period is None: - return - - if not ( - await self.db.updates.has_completed_background_update( - "redactions_have_censored_ts_idx" - ) - ): - # We don't want to run this until the appropriate index has been - # created. - return - - before_ts = self._clock.time_msec() - self.hs.config.redaction_retention_period + def _insert_event_expiry_txn(self, txn, event_id, expiry_ts): + """Save the expiry timestamp associated with a given event ID. - # We fetch all redactions that: - # 1. point to an event we have, - # 2. has a received_ts from before the cut off, and - # 3. we haven't yet censored. - # - # This is limited to 100 events to ensure that we don't try and do too - # much at once. We'll get called again so this should eventually catch - # up. - sql = """ - SELECT redactions.event_id, redacts FROM redactions - LEFT JOIN events AS original_event ON ( - redacts = original_event.event_id - ) - WHERE NOT have_censored - AND redactions.received_ts <= ? - ORDER BY redactions.received_ts ASC - LIMIT ? + Args: + txn (LoggingTransaction): The database transaction to use. + event_id (str): The event ID the expiry timestamp is associated with. + expiry_ts (int): The timestamp at which to expire (delete) the event. """ - - rows = await self.db.execute( - "_censor_redactions_fetch", None, sql, before_ts, 100 + return self.db.simple_insert_txn( + txn=txn, + table="event_expiry", + values={"event_id": event_id, "expiry_ts": expiry_ts}, ) - updates = [] + def _store_event_reference_hashes_txn(self, txn, events): + """Store a hash for a PDU + Args: + txn (cursor): + events (list): list of Events. + """ - for redaction_id, event_id in rows: - redaction_event = await self.get_event(redaction_id, allow_none=True) - original_event = await self.get_event( - event_id, allow_rejected=True, allow_none=True + vals = [] + for event in events: + ref_alg, ref_hash_bytes = compute_event_reference_hash(event) + vals.append( + { + "event_id": event.event_id, + "algorithm": ref_alg, + "hash": memoryview(ref_hash_bytes), + } ) - # The SQL above ensures that we have both the redaction and - # original event, so if the `get_event` calls return None it - # means that the redaction wasn't allowed. Either way we know that - # the result won't change so we mark the fact that we've checked. - if ( - redaction_event - and original_event - and original_event.internal_metadata.is_redacted() - ): - # Redaction was allowed - pruned_json = encode_json( - prune_event_dict( - original_event.room_version, original_event.get_dict() - ) - ) - else: - # Redaction wasn't allowed - pruned_json = None - - updates.append((redaction_id, event_id, pruned_json)) - - def _update_censor_txn(txn): - for redaction_id, event_id, pruned_json in updates: - if pruned_json: - self._censor_event_txn(txn, event_id, pruned_json) - - self.db.simple_update_one_txn( - txn, - table="redactions", - keyvalues={"event_id": redaction_id}, - updatevalues={"have_censored": True}, - ) - - await self.db.runInteraction("_update_censor_txn", _update_censor_txn) + self.db.simple_insert_many_txn(txn, table="event_reference_hashes", values=vals) - def _censor_event_txn(self, txn, event_id, pruned_json): - """Censor an event by replacing its JSON in the event_json table with the - provided pruned JSON. - - Args: - txn (LoggingTransaction): The database transaction. - event_id (str): The ID of the event to censor. - pruned_json (str): The pruned JSON + def _store_room_members_txn(self, txn, events, backfilled): + """Store a room member in the database. """ - self.db.simple_update_one_txn( + self.db.simple_insert_many_txn( txn, - table="event_json", - keyvalues={"event_id": event_id}, - updatevalues={"json": pruned_json}, + table="room_memberships", + values=[ + { + "event_id": event.event_id, + "user_id": event.state_key, + "sender": event.user_id, + "room_id": event.room_id, + "membership": event.membership, + "display_name": event.content.get("displayname", None), + "avatar_url": event.content.get("avatar_url", None), + } + for event in events + ], ) - @defer.inlineCallbacks - def count_daily_messages(self): - """ - Returns an estimate of the number of messages sent in the last day. - - If it has been significantly less or more than one day since the last - call to this function, it will return None. - """ - - def _count_messages(txn): - sql = """ - SELECT COALESCE(COUNT(*), 0) FROM events - WHERE type = 'm.room.message' - AND stream_ordering > ? - """ - txn.execute(sql, (self.stream_ordering_day_ago,)) - (count,) = txn.fetchone() - return count - - ret = yield self.db.runInteraction("count_messages", _count_messages) - return ret - - @defer.inlineCallbacks - def count_daily_sent_messages(self): - def _count_messages(txn): - # This is good enough as if you have silly characters in your own - # hostname then thats your own fault. - like_clause = "%:" + self.hs.hostname - - sql = """ - SELECT COALESCE(COUNT(*), 0) FROM events - WHERE type = 'm.room.message' - AND sender LIKE ? - AND stream_ordering > ? - """ - - txn.execute(sql, (like_clause, self.stream_ordering_day_ago)) - (count,) = txn.fetchone() - return count - - ret = yield self.db.runInteraction("count_daily_sent_messages", _count_messages) - return ret - - @defer.inlineCallbacks - def count_daily_active_rooms(self): - def _count(txn): - sql = """ - SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events - WHERE type = 'm.room.message' - AND stream_ordering > ? - """ - txn.execute(sql, (self.stream_ordering_day_ago,)) - (count,) = txn.fetchone() - return count - - ret = yield self.db.runInteraction("count_daily_active_rooms", _count) - return ret - - @cached(num_args=5, max_entries=10) - def get_all_new_events( - self, - last_backfill_id, - last_forward_id, - current_backfill_id, - current_forward_id, - limit, - ): - """Get all the new events that have arrived at the server either as - new events or as backfilled events""" - have_backfill_events = last_backfill_id != current_backfill_id - have_forward_events = last_forward_id != current_forward_id - - if not have_backfill_events and not have_forward_events: - return defer.succeed(AllNewEventsResult([], [], [], [], [])) - - def get_all_new_events_txn(txn): - sql = ( - "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts" - " FROM events AS e" - " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" - " WHERE ? < stream_ordering AND stream_ordering <= ?" - " ORDER BY stream_ordering ASC" - " LIMIT ?" + for event in events: + txn.call_after( + self.store._membership_stream_cache.entity_has_changed, + event.state_key, + event.internal_metadata.stream_ordering, ) - if have_forward_events: - txn.execute(sql, (last_forward_id, current_forward_id, limit)) - new_forward_events = txn.fetchall() - - if len(new_forward_events) == limit: - upper_bound = new_forward_events[-1][0] - else: - upper_bound = current_forward_id - - sql = ( - "SELECT event_stream_ordering, event_id, state_group" - " FROM ex_outlier_stream" - " WHERE ? > event_stream_ordering" - " AND event_stream_ordering >= ?" - " ORDER BY event_stream_ordering DESC" - ) - txn.execute(sql, (last_forward_id, upper_bound)) - forward_ex_outliers = txn.fetchall() - else: - new_forward_events = [] - forward_ex_outliers = [] - - sql = ( - "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts" - " FROM events AS e" - " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" - " WHERE ? > stream_ordering AND stream_ordering >= ?" - " ORDER BY stream_ordering DESC" - " LIMIT ?" + txn.call_after( + self.store.get_invited_rooms_for_local_user.invalidate, + (event.state_key,), ) - if have_backfill_events: - txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit)) - new_backfill_events = txn.fetchall() - if len(new_backfill_events) == limit: - upper_bound = new_backfill_events[-1][0] - else: - upper_bound = current_backfill_id - - sql = ( - "SELECT -event_stream_ordering, event_id, state_group" - " FROM ex_outlier_stream" - " WHERE ? > event_stream_ordering" - " AND event_stream_ordering >= ?" - " ORDER BY event_stream_ordering DESC" - ) - txn.execute(sql, (-last_backfill_id, -upper_bound)) - backward_ex_outliers = txn.fetchall() - else: - new_backfill_events = [] - backward_ex_outliers = [] - - return AllNewEventsResult( - new_forward_events, - new_backfill_events, - forward_ex_outliers, - backward_ex_outliers, + # We update the local_invites table only if the event is "current", + # i.e., its something that has just happened. If the event is an + # outlier it is only current if its an "out of band membership", + # like a remote invite or a rejection of a remote invite. + is_new_state = not backfilled and ( + not event.internal_metadata.is_outlier() + or event.internal_metadata.is_out_of_band_membership() ) + is_mine = self.is_mine_id(event.state_key) + if is_new_state and is_mine: + if event.membership == Membership.INVITE: + self.db.simple_insert_txn( + txn, + table="local_invites", + values={ + "event_id": event.event_id, + "invitee": event.state_key, + "inviter": event.sender, + "room_id": event.room_id, + "stream_id": event.internal_metadata.stream_ordering, + }, + ) + else: + sql = ( + "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE" + " room_id = ? AND invitee = ? AND locally_rejected is NULL" + " AND replaced_by is NULL" + ) - return self.db.runInteraction("get_all_new_events", get_all_new_events_txn) - - def purge_history(self, room_id, token, delete_local_events): - """Deletes room history before a certain point - - Args: - room_id (str): + txn.execute( + sql, + ( + event.internal_metadata.stream_ordering, + event.event_id, + event.room_id, + event.state_key, + ), + ) - token (str): A topological token to delete events before + # We also update the `local_current_membership` table with + # latest invite info. This will usually get updated by the + # `current_state_events` handling, unless its an outlier. + if event.internal_metadata.is_outlier(): + # This should only happen for out of band memberships, so + # we add a paranoia check. + assert event.internal_metadata.is_out_of_band_membership() + + self.db.simple_upsert_txn( + txn, + table="local_current_membership", + keyvalues={ + "room_id": event.room_id, + "user_id": event.state_key, + }, + values={ + "event_id": event.event_id, + "membership": event.membership, + }, + ) - delete_local_events (bool): - if True, we will delete local events as well as remote ones - (instead of just marking them as outliers and deleting their - state groups). + def _handle_event_relations(self, txn, event): + """Handles inserting relation data during peristence of events - Returns: - Deferred[set[int]]: The set of state groups that are referenced by - deleted events. + Args: + txn + event (EventBase) """ + relation = event.content.get("m.relates_to") + if not relation: + # No relations + return - return self.db.runInteraction( - "purge_history", - self._purge_history_txn, - room_id, - token, - delete_local_events, - ) + rel_type = relation.get("rel_type") + if rel_type not in ( + RelationTypes.ANNOTATION, + RelationTypes.REFERENCE, + RelationTypes.REPLACE, + ): + # Unknown relation type + return - def _purge_history_txn(self, txn, room_id, token_str, delete_local_events): - token = RoomStreamToken.parse(token_str) - - # Tables that should be pruned: - # event_auth - # event_backward_extremities - # event_edges - # event_forward_extremities - # event_json - # event_push_actions - # event_reference_hashes - # event_search - # event_to_state_groups - # events - # rejections - # room_depth - # state_groups - # state_groups_state - - # we will build a temporary table listing the events so that we don't - # have to keep shovelling the list back and forth across the - # connection. Annoyingly the python sqlite driver commits the - # transaction on CREATE, so let's do this first. - # - # furthermore, we might already have the table from a previous (failed) - # purge attempt, so let's drop the table first. + parent_id = relation.get("event_id") + if not parent_id: + # Invalid relation + return - txn.execute("DROP TABLE IF EXISTS events_to_purge") + aggregation_key = relation.get("key") - txn.execute( - "CREATE TEMPORARY TABLE events_to_purge (" - " event_id TEXT NOT NULL," - " should_delete BOOLEAN NOT NULL" - ")" + self.db.simple_insert_txn( + txn, + table="event_relations", + values={ + "event_id": event.event_id, + "relates_to_id": parent_id, + "relation_type": rel_type, + "aggregation_key": aggregation_key, + }, ) - # First ensure that we're not about to delete all the forward extremeties - txn.execute( - "SELECT e.event_id, e.depth FROM events as e " - "INNER JOIN event_forward_extremities as f " - "ON e.event_id = f.event_id " - "AND e.room_id = f.room_id " - "WHERE f.room_id = ?", - (room_id,), + txn.call_after(self.store.get_relations_for_event.invalidate_many, (parent_id,)) + txn.call_after( + self.store.get_aggregation_groups_for_event.invalidate_many, (parent_id,) ) - rows = txn.fetchall() - max_depth = max(row[1] for row in rows) - - if max_depth < token.topological: - # We need to ensure we don't delete all the events from the database - # otherwise we wouldn't be able to send any events (due to not - # having any backwards extremeties) - raise SynapseError( - 400, "topological_ordering is greater than forward extremeties" - ) - - logger.info("[purge] looking for events to delete") - - should_delete_expr = "state_key IS NULL" - should_delete_params = () - if not delete_local_events: - should_delete_expr += " AND event_id NOT LIKE ?" - - # We include the parameter twice since we use the expression twice - should_delete_params += ("%:" + self.hs.hostname, "%:" + self.hs.hostname) - should_delete_params += (room_id, token.topological) + if rel_type == RelationTypes.REPLACE: + txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,)) - # Note that we insert events that are outliers and aren't going to be - # deleted, as nothing will happen to them. - txn.execute( - "INSERT INTO events_to_purge" - " SELECT event_id, %s" - " FROM events AS e LEFT JOIN state_events USING (event_id)" - " WHERE (NOT outlier OR (%s)) AND e.room_id = ? AND topological_ordering < ?" - % (should_delete_expr, should_delete_expr), - should_delete_params, - ) + def _handle_redaction(self, txn, redacted_event_id): + """Handles receiving a redaction and checking whether we need to remove + any redacted relations from the database. - # We create the indices *after* insertion as that's a lot faster. + Args: + txn + redacted_event_id (str): The event that was redacted. + """ - # create an index on should_delete because later we'll be looking for - # the should_delete / shouldn't_delete subsets - txn.execute( - "CREATE INDEX events_to_purge_should_delete" - " ON events_to_purge(should_delete)" + self.db.simple_delete_txn( + txn, table="event_relations", keyvalues={"event_id": redacted_event_id} ) - # We do joins against events_to_purge for e.g. calculating state - # groups to purge, etc., so lets make an index. - txn.execute("CREATE INDEX events_to_purge_id ON events_to_purge(event_id)") - - txn.execute("SELECT event_id, should_delete FROM events_to_purge") - event_rows = txn.fetchall() - logger.info( - "[purge] found %i events before cutoff, of which %i can be deleted", - len(event_rows), - sum(1 for e in event_rows if e[1]), - ) + def _store_room_topic_txn(self, txn, event): + if hasattr(event, "content") and "topic" in event.content: + self.store_event_search_txn( + txn, event, "content.topic", event.content["topic"] + ) - logger.info("[purge] Finding new backward extremities") + def _store_room_name_txn(self, txn, event): + if hasattr(event, "content") and "name" in event.content: + self.store_event_search_txn( + txn, event, "content.name", event.content["name"] + ) - # We calculate the new entries for the backward extremeties by finding - # events to be purged that are pointed to by events we're not going to - # purge. - txn.execute( - "SELECT DISTINCT e.event_id FROM events_to_purge AS e" - " INNER JOIN event_edges AS ed ON e.event_id = ed.prev_event_id" - " LEFT JOIN events_to_purge AS ep2 ON ed.event_id = ep2.event_id" - " WHERE ep2.event_id IS NULL" - ) - new_backwards_extrems = txn.fetchall() + def _store_room_message_txn(self, txn, event): + if hasattr(event, "content") and "body" in event.content: + self.store_event_search_txn( + txn, event, "content.body", event.content["body"] + ) - logger.info("[purge] replacing backward extremities: %r", new_backwards_extrems) + def _store_retention_policy_for_room_txn(self, txn, event): + if hasattr(event, "content") and ( + "min_lifetime" in event.content or "max_lifetime" in event.content + ): + if ( + "min_lifetime" in event.content + and not isinstance(event.content.get("min_lifetime"), integer_types) + ) or ( + "max_lifetime" in event.content + and not isinstance(event.content.get("max_lifetime"), integer_types) + ): + # Ignore the event if one of the value isn't an integer. + return - txn.execute( - "DELETE FROM event_backward_extremities WHERE room_id = ?", (room_id,) - ) + self.db.simple_insert_txn( + txn=txn, + table="room_retention", + values={ + "room_id": event.room_id, + "event_id": event.event_id, + "min_lifetime": event.content.get("min_lifetime"), + "max_lifetime": event.content.get("max_lifetime"), + }, + ) - # Update backward extremeties - txn.executemany( - "INSERT INTO event_backward_extremities (room_id, event_id)" - " VALUES (?, ?)", - [(room_id, event_id) for event_id, in new_backwards_extrems], - ) + self.store._invalidate_cache_and_stream( + txn, self.store.get_retention_policy_for_room, (event.room_id,) + ) - logger.info("[purge] finding state groups referenced by deleted events") + def store_event_search_txn(self, txn, event, key, value): + """Add event to the search table - # Get all state groups that are referenced by events that are to be - # deleted. - txn.execute( - """ - SELECT DISTINCT state_group FROM events_to_purge - INNER JOIN event_to_state_groups USING (event_id) + Args: + txn (cursor): + event (EventBase): + key (str): + value (str): """ + self.store.store_search_entries_txn( + txn, + ( + SearchEntry( + key=key, + value=value, + event_id=event.event_id, + room_id=event.room_id, + stream_ordering=event.internal_metadata.stream_ordering, + origin_server_ts=event.origin_server_ts, + ), + ), ) - referenced_state_groups = {sg for sg, in txn} - logger.info( - "[purge] found %i referenced state groups", len(referenced_state_groups) - ) + def _set_push_actions_for_event_and_users_txn( + self, txn, events_and_contexts, all_events_and_contexts + ): + """Handles moving push actions from staging table to main + event_push_actions table for all events in `events_and_contexts`. - logger.info("[purge] removing events from event_to_state_groups") - txn.execute( - "DELETE FROM event_to_state_groups " - "WHERE event_id IN (SELECT event_id from events_to_purge)" - ) - for event_id, _ in event_rows: - txn.call_after(self._get_state_group_for_event.invalidate, (event_id,)) + Also ensures that all events in `all_events_and_contexts` are removed + from the push action staging area. - # Delete all remote non-state events - for table in ( - "events", - "event_json", - "event_auth", - "event_edges", - "event_forward_extremities", - "event_reference_hashes", - "event_search", - "rejections", - ): - logger.info("[purge] removing events from %s", table) + Args: + events_and_contexts (list[(EventBase, EventContext)]): events + we are persisting + all_events_and_contexts (list[(EventBase, EventContext)]): all + events that we were going to persist. This includes events + we've already persisted, etc, that wouldn't appear in + events_and_context. + """ - txn.execute( - "DELETE FROM %s WHERE event_id IN (" - " SELECT event_id FROM events_to_purge WHERE should_delete" - ")" % (table,) + sql = """ + INSERT INTO event_push_actions ( + room_id, event_id, user_id, actions, stream_ordering, + topological_ordering, notif, highlight ) + SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight + FROM event_push_actions_staging + WHERE event_id = ? + """ - # event_push_actions lacks an index on event_id, and has one on - # (room_id, event_id) instead. - for table in ("event_push_actions",): - logger.info("[purge] removing events from %s", table) + if events_and_contexts: + txn.executemany( + sql, + ( + ( + event.room_id, + event.internal_metadata.stream_ordering, + event.depth, + event.event_id, + ) + for event, _ in events_and_contexts + ), + ) - txn.execute( - "DELETE FROM %s WHERE room_id = ? AND event_id IN (" - " SELECT event_id FROM events_to_purge WHERE should_delete" - ")" % (table,), - (room_id,), + for event, _ in events_and_contexts: + user_ids = self.db.simple_select_onecol_txn( + txn, + table="event_push_actions_staging", + keyvalues={"event_id": event.event_id}, + retcol="user_id", ) - # Mark all state and own events as outliers - logger.info("[purge] marking remaining events as outliers") - txn.execute( - "UPDATE events SET outlier = ?" - " WHERE event_id IN (" - " SELECT event_id FROM events_to_purge " - " WHERE NOT should_delete" - ")", - (True,), + for uid in user_ids: + txn.call_after( + self.store.get_unread_event_push_actions_by_room_for_user.invalidate_many, + (event.room_id, uid), + ) + + # Now we delete the staging area for *all* events that were being + # persisted. + txn.executemany( + "DELETE FROM event_push_actions_staging WHERE event_id = ?", + ((event.event_id,) for event, _ in all_events_and_contexts), ) - # synapse tries to take out an exclusive lock on room_depth whenever it - # persists events (because upsert), and once we run this update, we - # will block that for the rest of our transaction. - # - # So, let's stick it at the end so that we don't block event - # persistence. - # - # We do this by calculating the minimum depth of the backwards - # extremities. However, the events in event_backward_extremities - # are ones we don't have yet so we need to look at the events that - # point to it via event_edges table. - txn.execute( - """ - SELECT COALESCE(MIN(depth), 0) - FROM event_backward_extremities AS eb - INNER JOIN event_edges AS eg ON eg.prev_event_id = eb.event_id - INNER JOIN events AS e ON e.event_id = eg.event_id - WHERE eb.room_id = ? - """, + def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id): + # Sad that we have to blow away the cache for the whole room here + txn.call_after( + self.store.get_unread_event_push_actions_by_room_for_user.invalidate_many, (room_id,), ) - (min_depth,) = txn.fetchone() - - logger.info("[purge] updating room_depth to %d", min_depth) - txn.execute( - "UPDATE room_depth SET min_depth = ? WHERE room_id = ?", - (min_depth, room_id), + "DELETE FROM event_push_actions WHERE room_id = ? AND event_id = ?", + (room_id, event_id), ) - # finally, drop the temp table. this will commit the txn in sqlite, - # so make sure to keep this actually last. - txn.execute("DROP TABLE events_to_purge") - - logger.info("[purge] done") - - return referenced_state_groups - - def purge_room(self, room_id): - """Deletes all record of a room + def _store_rejections_txn(self, txn, event_id, reason): + self.db.simple_insert_txn( + txn, + table="rejections", + values={ + "event_id": event_id, + "reason": reason, + "last_check": self._clock.time_msec(), + }, + ) - Args: - room_id (str) + def _store_event_state_mappings_txn( + self, txn, events_and_contexts: Iterable[Tuple[EventBase, EventContext]] + ): + state_groups = {} + for event, context in events_and_contexts: + if event.internal_metadata.is_outlier(): + continue - Returns: - Deferred[List[int]]: The list of state groups to delete. - """ + # if the event was rejected, just give it the same state as its + # predecessor. + if context.rejected: + state_groups[event.event_id] = context.state_group_before_event + continue - return self.db.runInteraction("purge_room", self._purge_room_txn, room_id) + state_groups[event.event_id] = context.state_group - def _purge_room_txn(self, txn, room_id): - # First we fetch all the state groups that should be deleted, before - # we delete that information. - txn.execute( - """ - SELECT DISTINCT state_group FROM events - INNER JOIN event_to_state_groups USING(event_id) - WHERE events.room_id = ? - """, - (room_id,), + self.db.simple_insert_many_txn( + txn, + table="event_to_state_groups", + values=[ + {"state_group": state_group_id, "event_id": event_id} + for event_id, state_group_id in iteritems(state_groups) + ], ) - state_groups = [row[0] for row in txn] - - # Now we delete tables which lack an index on room_id but have one on event_id - for table in ( - "event_auth", - "event_edges", - "event_push_actions_staging", - "event_reference_hashes", - "event_relations", - "event_to_state_groups", - "redactions", - "rejections", - "state_events", - ): - logger.info("[purge] removing %s from %s", room_id, table) - - txn.execute( - """ - DELETE FROM %s WHERE event_id IN ( - SELECT event_id FROM events WHERE room_id=? - ) - """ - % (table,), - (room_id,), + for event_id, state_group_id in iteritems(state_groups): + txn.call_after( + self.store._get_state_group_for_event.prefill, + (event_id,), + state_group_id, ) - # and finally, the tables with an index on room_id (or no useful index) - for table in ( - "current_state_events", - "event_backward_extremities", - "event_forward_extremities", - "event_json", - "event_push_actions", - "event_search", - "events", - "group_rooms", - "public_room_list_stream", - "receipts_graph", - "receipts_linearized", - "room_aliases", - "room_depth", - "room_memberships", - "room_stats_state", - "room_stats_current", - "room_stats_historical", - "room_stats_earliest_token", - "rooms", - "stream_ordering_to_exterm", - "users_in_public_rooms", - "users_who_share_private_rooms", - # no useful index, but let's clear them anyway - "appservice_room_list", - "e2e_room_keys", - "event_push_summary", - "pusher_throttle", - "group_summary_rooms", - "local_invites", - "room_account_data", - "room_tags", - "local_current_membership", - ): - logger.info("[purge] removing %s from %s", room_id, table) - txn.execute("DELETE FROM %s WHERE room_id=?" % (table,), (room_id,)) - - # Other tables we do NOT need to clear out: - # - # - blocked_rooms - # This is important, to make sure that we don't accidentally rejoin a blocked - # room after it was purged - # - # - user_directory - # This has a room_id column, but it is unused - # - - # Other tables that we might want to consider clearing out include: - # - # - event_reports - # Given that these are intended for abuse management my initial - # inclination is to leave them in place. - # - # - current_state_delta_stream - # - ex_outlier_stream - # - room_tags_revisions - # The problem with these is that they are largeish and there is no room_id - # index on them. In any case we should be clearing out 'stream' tables - # periodically anyway (#5888) - - # TODO: we could probably usefully do a bunch of cache invalidation here + def _update_min_depth_for_room_txn(self, txn, room_id, depth): + min_depth = self.store._get_min_depth_interaction(txn, room_id) - logger.info("[purge] done") - - return state_groups - - async def is_event_after(self, event_id1, event_id2): - """Returns True if event_id1 is after event_id2 in the stream - """ - to_1, so_1 = await self._get_event_ordering(event_id1) - to_2, so_2 = await self._get_event_ordering(event_id2) - return (to_1, so_1) > (to_2, so_2) + if min_depth is not None and depth >= min_depth: + return - @cachedInlineCallbacks(max_entries=5000) - def _get_event_ordering(self, event_id): - res = yield self.db.simple_select_one( - table="events", - retcols=["topological_ordering", "stream_ordering"], - keyvalues={"event_id": event_id}, - allow_none=True, + self.db.simple_upsert_txn( + txn, + table="room_depth", + keyvalues={"room_id": room_id}, + values={"min_depth": depth}, ) - if not res: - raise SynapseError(404, "Could not find event %s" % (event_id,)) - - return (int(res["topological_ordering"]), int(res["stream_ordering"])) - - def insert_labels_for_event_txn( - self, txn, event_id, labels, room_id, topological_ordering - ): - """Store the mapping between an event's ID and its labels, with one row per - (event_id, label) tuple. - - Args: - txn (LoggingTransaction): The transaction to execute. - event_id (str): The event's ID. - labels (list[str]): A list of text labels. - room_id (str): The ID of the room the event was sent to. - topological_ordering (int): The position of the event in the room's topology. + def _handle_mult_prev_events(self, txn, events): """ - return self.db.simple_insert_many_txn( - txn=txn, - table="event_labels", + For the given event, update the event edges table and forward and + backward extremities tables. + """ + self.db.simple_insert_many_txn( + txn, + table="event_edges", values=[ { - "event_id": event_id, - "label": label, - "room_id": room_id, - "topological_ordering": topological_ordering, + "event_id": ev.event_id, + "prev_event_id": e_id, + "room_id": ev.room_id, + "is_state": False, } - for label in labels + for ev in events + for e_id in ev.prev_event_ids() ], ) - def _insert_event_expiry_txn(self, txn, event_id, expiry_ts): - """Save the expiry timestamp associated with a given event ID. - - Args: - txn (LoggingTransaction): The database transaction to use. - event_id (str): The event ID the expiry timestamp is associated with. - expiry_ts (int): The timestamp at which to expire (delete) the event. - """ - return self.db.simple_insert_txn( - txn=txn, - table="event_expiry", - values={"event_id": event_id, "expiry_ts": expiry_ts}, - ) - - @defer.inlineCallbacks - def expire_event(self, event_id): - """Retrieve and expire an event that has expired, and delete its associated - expiry timestamp. If the event can't be retrieved, delete its associated - timestamp so we don't try to expire it again in the future. - - Args: - event_id (str): The ID of the event to delete. - """ - # Try to retrieve the event's content from the database or the event cache. - event = yield self.get_event(event_id) - - def delete_expired_event_txn(txn): - # Delete the expiry timestamp associated with this event from the database. - self._delete_event_expiry_txn(txn, event_id) - - if not event: - # If we can't find the event, log a warning and delete the expiry date - # from the database so that we don't try to expire it again in the - # future. - logger.warning( - "Can't expire event %s because we don't have it.", event_id - ) - return - - # Prune the event's dict then convert it to JSON. - pruned_json = encode_json( - prune_event_dict(event.room_version, event.get_dict()) - ) - - # Update the event_json table to replace the event's JSON with the pruned - # JSON. - self._censor_event_txn(txn, event.event_id, pruned_json) - - # We need to invalidate the event cache entry for this event because we - # changed its content in the database. We can't call - # self._invalidate_cache_and_stream because self.get_event_cache isn't of the - # right type. - txn.call_after(self._get_event_cache.invalidate, (event.event_id,)) - # Send that invalidation to replication so that other workers also invalidate - # the event cache. - self._send_invalidation_to_replication( - txn, "_get_event_cache", (event.event_id,) - ) + self._update_backward_extremeties(txn, events) - yield self.db.runInteraction("delete_expired_event", delete_expired_event_txn) + def _update_backward_extremeties(self, txn, events): + """Updates the event_backward_extremities tables based on the new/updated + events being persisted. - def _delete_event_expiry_txn(self, txn, event_id): - """Delete the expiry timestamp associated with an event ID without deleting the - actual event. + This is called for new events *and* for events that were outliers, but + are now being persisted as non-outliers. - Args: - txn (LoggingTransaction): The transaction to use to perform the deletion. - event_id (str): The event ID to delete the associated expiry timestamp of. + Forward extremities are handled when we first start persisting the events. """ - return self.db.simple_delete_txn( - txn=txn, table="event_expiry", keyvalues={"event_id": event_id} + events_by_room = {} + for ev in events: + events_by_room.setdefault(ev.room_id, []).append(ev) + + query = ( + "INSERT INTO event_backward_extremities (event_id, room_id)" + " SELECT ?, ? WHERE NOT EXISTS (" + " SELECT 1 FROM event_backward_extremities" + " WHERE event_id = ? AND room_id = ?" + " )" + " AND NOT EXISTS (" + " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? " + " AND outlier = ?" + " )" ) - def get_next_event_to_expire(self): - """Retrieve the entry with the lowest expiry timestamp in the event_expiry - table, or None if there's no more event to expire. - - Returns: Deferred[Optional[Tuple[str, int]]] - A tuple containing the event ID as its first element and an expiry timestamp - as its second one, if there's at least one row in the event_expiry table. - None otherwise. - """ - - def get_next_event_to_expire_txn(txn): - txn.execute( - """ - SELECT event_id, expiry_ts FROM event_expiry - ORDER BY expiry_ts ASC LIMIT 1 - """ - ) - - return txn.fetchone() - - return self.db.runInteraction( - desc="get_next_event_to_expire", func=get_next_event_to_expire_txn + txn.executemany( + query, + [ + (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False) + for ev in events + for e_id in ev.prev_event_ids() + if not ev.internal_metadata.is_outlier() + ], ) - -AllNewEventsResult = namedtuple( - "AllNewEventsResult", - [ - "new_forward_events", - "new_backfill_events", - "forward_ex_outliers", - "backward_ex_outliers", - ], -) + query = ( + "DELETE FROM event_backward_extremities" + " WHERE event_id = ? AND room_id = ?" + ) + txn.executemany( + query, + [ + (ev.event_id, ev.room_id) + for ev in events + if not ev.internal_metadata.is_outlier() + ], + ) diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index b8c1bbdf99..970c31bd05 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -27,7 +27,7 @@ from constantly import NamedConstant, Names from twisted.internet import defer from synapse.api.constants import EventTypes -from synapse.api.errors import NotFoundError +from synapse.api.errors import NotFoundError, SynapseError from synapse.api.room_versions import ( KNOWN_ROOM_VERSIONS, EventFormatVersions, @@ -40,7 +40,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause from synapse.storage.database import Database from synapse.types import get_domain_from_id -from synapse.util.caches.descriptors import Cache +from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure @@ -1157,4 +1157,152 @@ class EventsWorkerStore(SQLBaseStore): rows = await self.db.runInteraction( "get_deltas_for_stream_id", get_deltas_for_stream_id_txn, to_token ) + return rows, to_token, True + + @cached(num_args=5, max_entries=10) + def get_all_new_events( + self, + last_backfill_id, + last_forward_id, + current_backfill_id, + current_forward_id, + limit, + ): + """Get all the new events that have arrived at the server either as + new events or as backfilled events""" + have_backfill_events = last_backfill_id != current_backfill_id + have_forward_events = last_forward_id != current_forward_id + + if not have_backfill_events and not have_forward_events: + return defer.succeed(AllNewEventsResult([], [], [], [], [])) + + def get_all_new_events_txn(txn): + sql = ( + "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts" + " FROM events AS e" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " WHERE ? < stream_ordering AND stream_ordering <= ?" + " ORDER BY stream_ordering ASC" + " LIMIT ?" + ) + if have_forward_events: + txn.execute(sql, (last_forward_id, current_forward_id, limit)) + new_forward_events = txn.fetchall() + + if len(new_forward_events) == limit: + upper_bound = new_forward_events[-1][0] + else: + upper_bound = current_forward_id + + sql = ( + "SELECT event_stream_ordering, event_id, state_group" + " FROM ex_outlier_stream" + " WHERE ? > event_stream_ordering" + " AND event_stream_ordering >= ?" + " ORDER BY event_stream_ordering DESC" + ) + txn.execute(sql, (last_forward_id, upper_bound)) + forward_ex_outliers = txn.fetchall() + else: + new_forward_events = [] + forward_ex_outliers = [] + + sql = ( + "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts" + " FROM events AS e" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " WHERE ? > stream_ordering AND stream_ordering >= ?" + " ORDER BY stream_ordering DESC" + " LIMIT ?" + ) + if have_backfill_events: + txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit)) + new_backfill_events = txn.fetchall() + + if len(new_backfill_events) == limit: + upper_bound = new_backfill_events[-1][0] + else: + upper_bound = current_backfill_id + + sql = ( + "SELECT -event_stream_ordering, event_id, state_group" + " FROM ex_outlier_stream" + " WHERE ? > event_stream_ordering" + " AND event_stream_ordering >= ?" + " ORDER BY event_stream_ordering DESC" + ) + txn.execute(sql, (-last_backfill_id, -upper_bound)) + backward_ex_outliers = txn.fetchall() + else: + new_backfill_events = [] + backward_ex_outliers = [] + + return AllNewEventsResult( + new_forward_events, + new_backfill_events, + forward_ex_outliers, + backward_ex_outliers, + ) + + return self.db.runInteraction("get_all_new_events", get_all_new_events_txn) + + async def is_event_after(self, event_id1, event_id2): + """Returns True if event_id1 is after event_id2 in the stream + """ + to_1, so_1 = await self._get_event_ordering(event_id1) + to_2, so_2 = await self._get_event_ordering(event_id2) + return (to_1, so_1) > (to_2, so_2) + + @cachedInlineCallbacks(max_entries=5000) + def _get_event_ordering(self, event_id): + res = yield self.db.simple_select_one( + table="events", + retcols=["topological_ordering", "stream_ordering"], + keyvalues={"event_id": event_id}, + allow_none=True, + ) + + if not res: + raise SynapseError(404, "Could not find event %s" % (event_id,)) + + return (int(res["topological_ordering"]), int(res["stream_ordering"])) + + def get_next_event_to_expire(self): + """Retrieve the entry with the lowest expiry timestamp in the event_expiry + table, or None if there's no more event to expire. + + Returns: Deferred[Optional[Tuple[str, int]]] + A tuple containing the event ID as its first element and an expiry timestamp + as its second one, if there's at least one row in the event_expiry table. + None otherwise. + """ + + def get_next_event_to_expire_txn(txn): + txn.execute( + """ + SELECT event_id, expiry_ts FROM event_expiry + ORDER BY expiry_ts ASC LIMIT 1 + """ + ) + + return txn.fetchone() + + return self.db.runInteraction( + desc="get_next_event_to_expire", func=get_next_event_to_expire_txn + ) + + +AllNewEventsResult = namedtuple( + "AllNewEventsResult", + [ + "new_forward_events", + "new_backfill_events", + "forward_ex_outliers", + "backward_ex_outliers", + ], +) diff --git a/synapse/storage/data_stores/main/metrics.py b/synapse/storage/data_stores/main/metrics.py new file mode 100644 index 0000000000..dad5bbc602 --- /dev/null +++ b/synapse/storage/data_stores/main/metrics.py @@ -0,0 +1,128 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import typing +from collections import Counter + +from twisted.internet import defer + +from synapse.metrics import BucketCollector +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage._base import SQLBaseStore +from synapse.storage.data_stores.main.event_push_actions import ( + EventPushActionsWorkerStore, +) +from synapse.storage.database import Database + + +class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): + """Functions to pull various metrics from the DB, for e.g. phone home + stats and prometheus metrics. + """ + + def __init__(self, database: Database, db_conn, hs): + super().__init__(database, db_conn, hs) + + # Collect metrics on the number of forward extremities that exist. + # Counter of number of extremities to count + self._current_forward_extremities_amount = ( + Counter() + ) # type: typing.Counter[int] + + BucketCollector( + "synapse_forward_extremities", + lambda: self._current_forward_extremities_amount, + buckets=[1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"], + ) + + # Read the extrems every 60 minutes + def read_forward_extremities(): + # run as a background process to make sure that the database transactions + # have a logcontext to report to + return run_as_background_process( + "read_forward_extremities", self._read_forward_extremities + ) + + hs.get_clock().looping_call(read_forward_extremities, 60 * 60 * 1000) + + async def _read_forward_extremities(self): + def fetch(txn): + txn.execute( + """ + select count(*) c from event_forward_extremities + group by room_id + """ + ) + return txn.fetchall() + + res = await self.db.runInteraction("read_forward_extremities", fetch) + self._current_forward_extremities_amount = Counter([x[0] for x in res]) + + @defer.inlineCallbacks + def count_daily_messages(self): + """ + Returns an estimate of the number of messages sent in the last day. + + If it has been significantly less or more than one day since the last + call to this function, it will return None. + """ + + def _count_messages(txn): + sql = """ + SELECT COALESCE(COUNT(*), 0) FROM events + WHERE type = 'm.room.message' + AND stream_ordering > ? + """ + txn.execute(sql, (self.stream_ordering_day_ago,)) + (count,) = txn.fetchone() + return count + + ret = yield self.db.runInteraction("count_messages", _count_messages) + return ret + + @defer.inlineCallbacks + def count_daily_sent_messages(self): + def _count_messages(txn): + # This is good enough as if you have silly characters in your own + # hostname then thats your own fault. + like_clause = "%:" + self.hs.hostname + + sql = """ + SELECT COALESCE(COUNT(*), 0) FROM events + WHERE type = 'm.room.message' + AND sender LIKE ? + AND stream_ordering > ? + """ + + txn.execute(sql, (like_clause, self.stream_ordering_day_ago)) + (count,) = txn.fetchone() + return count + + ret = yield self.db.runInteraction("count_daily_sent_messages", _count_messages) + return ret + + @defer.inlineCallbacks + def count_daily_active_rooms(self): + def _count(txn): + sql = """ + SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events + WHERE type = 'm.room.message' + AND stream_ordering > ? + """ + txn.execute(sql, (self.stream_ordering_day_ago,)) + (count,) = txn.fetchone() + return count + + ret = yield self.db.runInteraction("count_daily_active_rooms", _count) + return ret diff --git a/synapse/storage/data_stores/main/purge_events.py b/synapse/storage/data_stores/main/purge_events.py new file mode 100644 index 0000000000..a93e1ef198 --- /dev/null +++ b/synapse/storage/data_stores/main/purge_events.py @@ -0,0 +1,399 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Any, Tuple + +from synapse.api.errors import SynapseError +from synapse.storage._base import SQLBaseStore +from synapse.storage.data_stores.main.state import StateGroupWorkerStore +from synapse.types import RoomStreamToken + +logger = logging.getLogger(__name__) + + +class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): + def purge_history(self, room_id, token, delete_local_events): + """Deletes room history before a certain point + + Args: + room_id (str): + + token (str): A topological token to delete events before + + delete_local_events (bool): + if True, we will delete local events as well as remote ones + (instead of just marking them as outliers and deleting their + state groups). + + Returns: + Deferred[set[int]]: The set of state groups that are referenced by + deleted events. + """ + + return self.db.runInteraction( + "purge_history", + self._purge_history_txn, + room_id, + token, + delete_local_events, + ) + + def _purge_history_txn(self, txn, room_id, token_str, delete_local_events): + token = RoomStreamToken.parse(token_str) + + # Tables that should be pruned: + # event_auth + # event_backward_extremities + # event_edges + # event_forward_extremities + # event_json + # event_push_actions + # event_reference_hashes + # event_search + # event_to_state_groups + # events + # rejections + # room_depth + # state_groups + # state_groups_state + + # we will build a temporary table listing the events so that we don't + # have to keep shovelling the list back and forth across the + # connection. Annoyingly the python sqlite driver commits the + # transaction on CREATE, so let's do this first. + # + # furthermore, we might already have the table from a previous (failed) + # purge attempt, so let's drop the table first. + + txn.execute("DROP TABLE IF EXISTS events_to_purge") + + txn.execute( + "CREATE TEMPORARY TABLE events_to_purge (" + " event_id TEXT NOT NULL," + " should_delete BOOLEAN NOT NULL" + ")" + ) + + # First ensure that we're not about to delete all the forward extremeties + txn.execute( + "SELECT e.event_id, e.depth FROM events as e " + "INNER JOIN event_forward_extremities as f " + "ON e.event_id = f.event_id " + "AND e.room_id = f.room_id " + "WHERE f.room_id = ?", + (room_id,), + ) + rows = txn.fetchall() + max_depth = max(row[1] for row in rows) + + if max_depth < token.topological: + # We need to ensure we don't delete all the events from the database + # otherwise we wouldn't be able to send any events (due to not + # having any backwards extremeties) + raise SynapseError( + 400, "topological_ordering is greater than forward extremeties" + ) + + logger.info("[purge] looking for events to delete") + + should_delete_expr = "state_key IS NULL" + should_delete_params = () # type: Tuple[Any, ...] + if not delete_local_events: + should_delete_expr += " AND event_id NOT LIKE ?" + + # We include the parameter twice since we use the expression twice + should_delete_params += ("%:" + self.hs.hostname, "%:" + self.hs.hostname) + + should_delete_params += (room_id, token.topological) + + # Note that we insert events that are outliers and aren't going to be + # deleted, as nothing will happen to them. + txn.execute( + "INSERT INTO events_to_purge" + " SELECT event_id, %s" + " FROM events AS e LEFT JOIN state_events USING (event_id)" + " WHERE (NOT outlier OR (%s)) AND e.room_id = ? AND topological_ordering < ?" + % (should_delete_expr, should_delete_expr), + should_delete_params, + ) + + # We create the indices *after* insertion as that's a lot faster. + + # create an index on should_delete because later we'll be looking for + # the should_delete / shouldn't_delete subsets + txn.execute( + "CREATE INDEX events_to_purge_should_delete" + " ON events_to_purge(should_delete)" + ) + + # We do joins against events_to_purge for e.g. calculating state + # groups to purge, etc., so lets make an index. + txn.execute("CREATE INDEX events_to_purge_id ON events_to_purge(event_id)") + + txn.execute("SELECT event_id, should_delete FROM events_to_purge") + event_rows = txn.fetchall() + logger.info( + "[purge] found %i events before cutoff, of which %i can be deleted", + len(event_rows), + sum(1 for e in event_rows if e[1]), + ) + + logger.info("[purge] Finding new backward extremities") + + # We calculate the new entries for the backward extremeties by finding + # events to be purged that are pointed to by events we're not going to + # purge. + txn.execute( + "SELECT DISTINCT e.event_id FROM events_to_purge AS e" + " INNER JOIN event_edges AS ed ON e.event_id = ed.prev_event_id" + " LEFT JOIN events_to_purge AS ep2 ON ed.event_id = ep2.event_id" + " WHERE ep2.event_id IS NULL" + ) + new_backwards_extrems = txn.fetchall() + + logger.info("[purge] replacing backward extremities: %r", new_backwards_extrems) + + txn.execute( + "DELETE FROM event_backward_extremities WHERE room_id = ?", (room_id,) + ) + + # Update backward extremeties + txn.executemany( + "INSERT INTO event_backward_extremities (room_id, event_id)" + " VALUES (?, ?)", + [(room_id, event_id) for event_id, in new_backwards_extrems], + ) + + logger.info("[purge] finding state groups referenced by deleted events") + + # Get all state groups that are referenced by events that are to be + # deleted. + txn.execute( + """ + SELECT DISTINCT state_group FROM events_to_purge + INNER JOIN event_to_state_groups USING (event_id) + """ + ) + + referenced_state_groups = {sg for sg, in txn} + logger.info( + "[purge] found %i referenced state groups", len(referenced_state_groups) + ) + + logger.info("[purge] removing events from event_to_state_groups") + txn.execute( + "DELETE FROM event_to_state_groups " + "WHERE event_id IN (SELECT event_id from events_to_purge)" + ) + for event_id, _ in event_rows: + txn.call_after(self._get_state_group_for_event.invalidate, (event_id,)) + + # Delete all remote non-state events + for table in ( + "events", + "event_json", + "event_auth", + "event_edges", + "event_forward_extremities", + "event_reference_hashes", + "event_search", + "rejections", + ): + logger.info("[purge] removing events from %s", table) + + txn.execute( + "DELETE FROM %s WHERE event_id IN (" + " SELECT event_id FROM events_to_purge WHERE should_delete" + ")" % (table,) + ) + + # event_push_actions lacks an index on event_id, and has one on + # (room_id, event_id) instead. + for table in ("event_push_actions",): + logger.info("[purge] removing events from %s", table) + + txn.execute( + "DELETE FROM %s WHERE room_id = ? AND event_id IN (" + " SELECT event_id FROM events_to_purge WHERE should_delete" + ")" % (table,), + (room_id,), + ) + + # Mark all state and own events as outliers + logger.info("[purge] marking remaining events as outliers") + txn.execute( + "UPDATE events SET outlier = ?" + " WHERE event_id IN (" + " SELECT event_id FROM events_to_purge " + " WHERE NOT should_delete" + ")", + (True,), + ) + + # synapse tries to take out an exclusive lock on room_depth whenever it + # persists events (because upsert), and once we run this update, we + # will block that for the rest of our transaction. + # + # So, let's stick it at the end so that we don't block event + # persistence. + # + # We do this by calculating the minimum depth of the backwards + # extremities. However, the events in event_backward_extremities + # are ones we don't have yet so we need to look at the events that + # point to it via event_edges table. + txn.execute( + """ + SELECT COALESCE(MIN(depth), 0) + FROM event_backward_extremities AS eb + INNER JOIN event_edges AS eg ON eg.prev_event_id = eb.event_id + INNER JOIN events AS e ON e.event_id = eg.event_id + WHERE eb.room_id = ? + """, + (room_id,), + ) + (min_depth,) = txn.fetchone() + + logger.info("[purge] updating room_depth to %d", min_depth) + + txn.execute( + "UPDATE room_depth SET min_depth = ? WHERE room_id = ?", + (min_depth, room_id), + ) + + # finally, drop the temp table. this will commit the txn in sqlite, + # so make sure to keep this actually last. + txn.execute("DROP TABLE events_to_purge") + + logger.info("[purge] done") + + return referenced_state_groups + + def purge_room(self, room_id): + """Deletes all record of a room + + Args: + room_id (str) + + Returns: + Deferred[List[int]]: The list of state groups to delete. + """ + + return self.db.runInteraction("purge_room", self._purge_room_txn, room_id) + + def _purge_room_txn(self, txn, room_id): + # First we fetch all the state groups that should be deleted, before + # we delete that information. + txn.execute( + """ + SELECT DISTINCT state_group FROM events + INNER JOIN event_to_state_groups USING(event_id) + WHERE events.room_id = ? + """, + (room_id,), + ) + + state_groups = [row[0] for row in txn] + + # Now we delete tables which lack an index on room_id but have one on event_id + for table in ( + "event_auth", + "event_edges", + "event_push_actions_staging", + "event_reference_hashes", + "event_relations", + "event_to_state_groups", + "redactions", + "rejections", + "state_events", + ): + logger.info("[purge] removing %s from %s", room_id, table) + + txn.execute( + """ + DELETE FROM %s WHERE event_id IN ( + SELECT event_id FROM events WHERE room_id=? + ) + """ + % (table,), + (room_id,), + ) + + # and finally, the tables with an index on room_id (or no useful index) + for table in ( + "current_state_events", + "event_backward_extremities", + "event_forward_extremities", + "event_json", + "event_push_actions", + "event_search", + "events", + "group_rooms", + "public_room_list_stream", + "receipts_graph", + "receipts_linearized", + "room_aliases", + "room_depth", + "room_memberships", + "room_stats_state", + "room_stats_current", + "room_stats_historical", + "room_stats_earliest_token", + "rooms", + "stream_ordering_to_exterm", + "users_in_public_rooms", + "users_who_share_private_rooms", + # no useful index, but let's clear them anyway + "appservice_room_list", + "e2e_room_keys", + "event_push_summary", + "pusher_throttle", + "group_summary_rooms", + "local_invites", + "room_account_data", + "room_tags", + "local_current_membership", + ): + logger.info("[purge] removing %s from %s", room_id, table) + txn.execute("DELETE FROM %s WHERE room_id=?" % (table,), (room_id,)) + + # Other tables we do NOT need to clear out: + # + # - blocked_rooms + # This is important, to make sure that we don't accidentally rejoin a blocked + # room after it was purged + # + # - user_directory + # This has a room_id column, but it is unused + # + + # Other tables that we might want to consider clearing out include: + # + # - event_reports + # Given that these are intended for abuse management my initial + # inclination is to leave them in place. + # + # - current_state_delta_stream + # - ex_outlier_stream + # - room_tags_revisions + # The problem with these is that they are largeish and there is no room_id + # index on them. In any case we should be clearing out 'stream' tables + # periodically anyway (#5888) + + # TODO: we could probably usefully do a bunch of cache invalidation here + + logger.info("[purge] done") + + return state_groups diff --git a/synapse/storage/data_stores/main/rejections.py b/synapse/storage/data_stores/main/rejections.py index 1c07c7a425..27e5a2084a 100644 --- a/synapse/storage/data_stores/main/rejections.py +++ b/synapse/storage/data_stores/main/rejections.py @@ -21,17 +21,6 @@ logger = logging.getLogger(__name__) class RejectionsStore(SQLBaseStore): - def _store_rejections_txn(self, txn, event_id, reason): - self.db.simple_insert_txn( - txn, - table="rejections", - values={ - "event_id": event_id, - "reason": reason, - "last_check": self._clock.time_msec(), - }, - ) - def get_rejection_reason(self, event_id): return self.db.simple_select_one_onecol( table="rejections", diff --git a/synapse/storage/data_stores/main/relations.py b/synapse/storage/data_stores/main/relations.py index 046c2b4845..7d477f8d01 100644 --- a/synapse/storage/data_stores/main/relations.py +++ b/synapse/storage/data_stores/main/relations.py @@ -324,62 +324,4 @@ class RelationsWorkerStore(SQLBaseStore): class RelationsStore(RelationsWorkerStore): - def _handle_event_relations(self, txn, event): - """Handles inserting relation data during peristence of events - - Args: - txn - event (EventBase) - """ - relation = event.content.get("m.relates_to") - if not relation: - # No relations - return - - rel_type = relation.get("rel_type") - if rel_type not in ( - RelationTypes.ANNOTATION, - RelationTypes.REFERENCE, - RelationTypes.REPLACE, - ): - # Unknown relation type - return - - parent_id = relation.get("event_id") - if not parent_id: - # Invalid relation - return - - aggregation_key = relation.get("key") - - self.db.simple_insert_txn( - txn, - table="event_relations", - values={ - "event_id": event.event_id, - "relates_to_id": parent_id, - "relation_type": rel_type, - "aggregation_key": aggregation_key, - }, - ) - - txn.call_after(self.get_relations_for_event.invalidate_many, (parent_id,)) - txn.call_after( - self.get_aggregation_groups_for_event.invalidate_many, (parent_id,) - ) - - if rel_type == RelationTypes.REPLACE: - txn.call_after(self.get_applicable_edit.invalidate, (parent_id,)) - - def _handle_redaction(self, txn, redacted_event_id): - """Handles receiving a redaction and checking whether we need to remove - any redacted relations from the database. - - Args: - txn - redacted_event_id (str): The event that was redacted. - """ - - self.db.simple_delete_txn( - txn, table="event_relations", keyvalues={"event_id": redacted_event_id} - ) + pass diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py index cafa664c16..46f643c6b9 100644 --- a/synapse/storage/data_stores/main/room.py +++ b/synapse/storage/data_stores/main/room.py @@ -21,8 +21,6 @@ from abc import abstractmethod from enum import Enum from typing import Any, Dict, List, Optional, Tuple -from six import integer_types - from canonicaljson import json from twisted.internet import defer @@ -1302,53 +1300,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): return self.db.runInteraction("get_rooms", f) - def _store_room_topic_txn(self, txn, event): - if hasattr(event, "content") and "topic" in event.content: - self.store_event_search_txn( - txn, event, "content.topic", event.content["topic"] - ) - - def _store_room_name_txn(self, txn, event): - if hasattr(event, "content") and "name" in event.content: - self.store_event_search_txn( - txn, event, "content.name", event.content["name"] - ) - - def _store_room_message_txn(self, txn, event): - if hasattr(event, "content") and "body" in event.content: - self.store_event_search_txn( - txn, event, "content.body", event.content["body"] - ) - - def _store_retention_policy_for_room_txn(self, txn, event): - if hasattr(event, "content") and ( - "min_lifetime" in event.content or "max_lifetime" in event.content - ): - if ( - "min_lifetime" in event.content - and not isinstance(event.content.get("min_lifetime"), integer_types) - ) or ( - "max_lifetime" in event.content - and not isinstance(event.content.get("max_lifetime"), integer_types) - ): - # Ignore the event if one of the value isn't an integer. - return - - self.db.simple_insert_txn( - txn=txn, - table="room_retention", - values={ - "room_id": event.room_id, - "event_id": event.event_id, - "min_lifetime": event.content.get("min_lifetime"), - "max_lifetime": event.content.get("max_lifetime"), - }, - ) - - self._invalidate_cache_and_stream( - txn, self.get_retention_policy_for_room, (event.room_id,) - ) - def add_event_report( self, room_id, event_id, user_id, reason, content, received_ts ): diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py index ae9a76713c..04ffdcf934 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py @@ -1050,96 +1050,6 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): def __init__(self, database: Database, db_conn, hs): super(RoomMemberStore, self).__init__(database, db_conn, hs) - def _store_room_members_txn(self, txn, events, backfilled): - """Store a room member in the database. - """ - self.db.simple_insert_many_txn( - txn, - table="room_memberships", - values=[ - { - "event_id": event.event_id, - "user_id": event.state_key, - "sender": event.user_id, - "room_id": event.room_id, - "membership": event.membership, - "display_name": event.content.get("displayname", None), - "avatar_url": event.content.get("avatar_url", None), - } - for event in events - ], - ) - - for event in events: - txn.call_after( - self._membership_stream_cache.entity_has_changed, - event.state_key, - event.internal_metadata.stream_ordering, - ) - txn.call_after( - self.get_invited_rooms_for_local_user.invalidate, (event.state_key,) - ) - - # We update the local_invites table only if the event is "current", - # i.e., its something that has just happened. If the event is an - # outlier it is only current if its an "out of band membership", - # like a remote invite or a rejection of a remote invite. - is_new_state = not backfilled and ( - not event.internal_metadata.is_outlier() - or event.internal_metadata.is_out_of_band_membership() - ) - is_mine = self.hs.is_mine_id(event.state_key) - if is_new_state and is_mine: - if event.membership == Membership.INVITE: - self.db.simple_insert_txn( - txn, - table="local_invites", - values={ - "event_id": event.event_id, - "invitee": event.state_key, - "inviter": event.sender, - "room_id": event.room_id, - "stream_id": event.internal_metadata.stream_ordering, - }, - ) - else: - sql = ( - "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE" - " room_id = ? AND invitee = ? AND locally_rejected is NULL" - " AND replaced_by is NULL" - ) - - txn.execute( - sql, - ( - event.internal_metadata.stream_ordering, - event.event_id, - event.room_id, - event.state_key, - ), - ) - - # We also update the `local_current_membership` table with - # latest invite info. This will usually get updated by the - # `current_state_events` handling, unless its an outlier. - if event.internal_metadata.is_outlier(): - # This should only happen for out of band memberships, so - # we add a paranoia check. - assert event.internal_metadata.is_out_of_band_membership() - - self.db.simple_upsert_txn( - txn, - table="local_current_membership", - keyvalues={ - "room_id": event.room_id, - "user_id": event.state_key, - }, - values={ - "event_id": event.event_id, - "membership": event.membership, - }, - ) - @defer.inlineCallbacks def locally_reject_invite(self, user_id, room_id): sql = ( diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py index 47ebb8a214..ee75b92344 100644 --- a/synapse/storage/data_stores/main/search.py +++ b/synapse/storage/data_stores/main/search.py @@ -347,29 +347,6 @@ class SearchStore(SearchBackgroundUpdateStore): def __init__(self, database: Database, db_conn, hs): super(SearchStore, self).__init__(database, db_conn, hs) - def store_event_search_txn(self, txn, event, key, value): - """Add event to the search table - - Args: - txn (cursor): - event (EventBase): - key (str): - value (str): - """ - self.store_search_entries_txn( - txn, - ( - SearchEntry( - key=key, - value=value, - event_id=event.event_id, - room_id=event.room_id, - stream_ordering=event.internal_metadata.stream_ordering, - origin_server_ts=event.origin_server_ts, - ), - ), - ) - @defer.inlineCallbacks def search_msgs(self, room_ids, search_term, keys): """Performs a full text search over events with given keys. diff --git a/synapse/storage/data_stores/main/signatures.py b/synapse/storage/data_stores/main/signatures.py index 563216b63c..36244d9f5d 100644 --- a/synapse/storage/data_stores/main/signatures.py +++ b/synapse/storage/data_stores/main/signatures.py @@ -13,23 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import six - from unpaddedbase64 import encode_base64 from twisted.internet import defer -from synapse.crypto.event_signing import compute_event_reference_hash from synapse.storage._base import SQLBaseStore from synapse.util.caches.descriptors import cached, cachedList -# py2 sqlite has buffer hardcoded as only binary type, so we must use it, -# despite being deprecated and removed in favor of memoryview -if six.PY2: - db_binary_type = six.moves.builtins.buffer -else: - db_binary_type = memoryview - class SignatureWorkerStore(SQLBaseStore): @cached() @@ -79,23 +69,3 @@ class SignatureWorkerStore(SQLBaseStore): class SignatureStore(SignatureWorkerStore): """Persistence for event signatures and hashes""" - - def _store_event_reference_hashes_txn(self, txn, events): - """Store a hash for a PDU - Args: - txn (cursor): - events (list): list of Events. - """ - - vals = [] - for event in events: - ref_alg, ref_hash_bytes = compute_event_reference_hash(event) - vals.append( - { - "event_id": event.event_id, - "algorithm": ref_alg, - "hash": db_binary_type(ref_hash_bytes), - } - ) - - self.db.simple_insert_many_txn(txn, table="event_reference_hashes", values=vals) diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py index 3a3b9a8e72..21052fcc7a 100644 --- a/synapse/storage/data_stores/main/state.py +++ b/synapse/storage/data_stores/main/state.py @@ -16,17 +16,12 @@ import collections.abc import logging from collections import namedtuple -from typing import Iterable, Tuple - -from six import iteritems from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion -from synapse.events import EventBase -from synapse.events.snapshot import EventContext from synapse.storage._base import SQLBaseStore from synapse.storage.data_stores.main.events_worker import EventsWorkerStore from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore @@ -473,33 +468,3 @@ class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore): def __init__(self, database: Database, db_conn, hs): super(StateStore, self).__init__(database, db_conn, hs) - - def _store_event_state_mappings_txn( - self, txn, events_and_contexts: Iterable[Tuple[EventBase, EventContext]] - ): - state_groups = {} - for event, context in events_and_contexts: - if event.internal_metadata.is_outlier(): - continue - - # if the event was rejected, just give it the same state as its - # predecessor. - if context.rejected: - state_groups[event.event_id] = context.state_group_before_event - continue - - state_groups[event.event_id] = context.state_group - - self.db.simple_insert_many_txn( - txn, - table="event_to_state_groups", - values=[ - {"state_group": state_group_id, "event_id": event_id} - for event_id, state_group_id in iteritems(state_groups) - ], - ) - - for event_id, state_group_id in iteritems(state_groups): - txn.call_after( - self._get_state_group_for_event.prefill, (event_id,), state_group_id - ) diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index 0f9ac1cf09..41881ea20b 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -23,7 +23,6 @@ from typing import Iterable, List, Optional, Set, Tuple from six import iteritems from six.moves import range -import attr from prometheus_client import Counter, Histogram from twisted.internet import defer @@ -35,6 +34,7 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda from synapse.metrics.background_process_metrics import run_as_background_process from synapse.state import StateResolutionStore from synapse.storage.data_stores import DataStores +from synapse.storage.data_stores.main.events import DeltaState from synapse.types import StateMap from synapse.util.async_helpers import ObservableDeferred from synapse.util.metrics import Measure @@ -73,22 +73,6 @@ stale_forward_extremities_counter = Histogram( ) -@attr.s(slots=True) -class DeltaState: - """Deltas to use to update the `current_state_events` table. - - Attributes: - to_delete: List of type/state_keys to delete from current state - to_insert: Map of state to upsert into current state - no_longer_in_room: The server is not longer in the room, so the room - should e.g. be removed from `current_state_events` table. - """ - - to_delete = attr.ib(type=List[Tuple[str, str]]) - to_insert = attr.ib(type=StateMap[str]) - no_longer_in_room = attr.ib(type=bool, default=False) - - class _EventPeristenceQueue(object): """Queues up events so that they can be persisted in bulk with only one concurrent transaction per room. @@ -205,6 +189,7 @@ class EventsPersistenceStorage(object): # store for now. self.main_store = stores.main self.state_store = stores.state + self.persist_events_store = stores.persist_events self._clock = hs.get_clock() self.is_mine_id = hs.is_mine_id @@ -445,7 +430,7 @@ class EventsPersistenceStorage(object): if current_state is not None: current_state_for_room[room_id] = current_state - await self.main_store._persist_events_and_state_updates( + await self.persist_events_store._persist_events_and_state_updates( chunk, current_state_for_room=current_state_for_room, state_delta_for_room=state_delta_for_room, @@ -491,13 +476,15 @@ class EventsPersistenceStorage(object): ) # Remove any events which are prev_events of any existing events. - existing_prevs = await self.main_store._get_events_which_are_prevs(result) + existing_prevs = await self.persist_events_store._get_events_which_are_prevs( + result + ) result.difference_update(existing_prevs) # Finally handle the case where the new events have soft-failed prev # events. If they do we need to remove them and their prev events, # otherwise we end up with dangling extremities. - existing_prevs = await self.main_store._get_prevs_before_rejected( + existing_prevs = await self.persist_events_store._get_prevs_before_rejected( e_id for event in new_events for e_id in event.prev_event_ids() ) result.difference_update(existing_prevs) diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index d4bcf1821e..b45bc9c115 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -35,6 +35,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): def setUp(self): hs = yield tests.utils.setup_test_homeserver(self.addCleanup) self.store = hs.get_datastore() + self.persist_events_store = hs.get_datastores().persist_events @defer.inlineCallbacks def test_get_unread_push_actions_for_user_in_range_for_http(self): @@ -76,7 +77,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): ) yield self.store.db.runInteraction( "", - self.store._set_push_actions_for_event_and_users_txn, + self.persist_events_store._set_push_actions_for_event_and_users_txn, [(event, None)], [(event, None)], ) -- cgit 1.5.1 From 00ba9c48bf1da6a3173fd77ed72d7d8295f4d051 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 13 May 2020 10:37:11 +0100 Subject: Spelling --- synapse/storage/data_stores/__init__.py | 2 +- synapse/storage/data_stores/main/events.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/data_stores/__init__.py index 7c14ba91c2..791961b296 100644 --- a/synapse/storage/data_stores/__init__.py +++ b/synapse/storage/data_stores/__init__.py @@ -67,7 +67,7 @@ class DataStores(object): self.main = main_store_class(database, db_conn, hs) # If we're on a process that can persist events (currently - # master), also instansiate a `PersistEventsStore` + # master), also instantiate a `PersistEventsStore` if hs.config.worker.worker_app is None: self.persist_events = PersistEventsStore( hs, database, self.main diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index 092739962b..a97f8b3934 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -118,7 +118,7 @@ class DeltaState: class PersistEventsStore: """Contains all the functions for writing events to the database. - Should only be instansiated on one process (when using a worker mode setup). + Should only be instantiated on one process (when using a worker mode setup). Note: This is not part of the `DataStore` mixin. """ -- cgit 1.5.1 From 1124111a12c3ab35f8b68d9031695aec8b2c7c50 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 13 May 2020 17:15:40 +0100 Subject: Allow censoring of events to happen on workers. (#7492) This is safe as we can now write to cache invalidation stream on workers, and is required for when we move event persistence off master. --- changelog.d/7492.misc | 1 + synapse/app/generic_worker.py | 2 ++ synapse/handlers/message.py | 2 -- synapse/storage/data_stores/main/censor_events.py | 7 +------ 4 files changed, 4 insertions(+), 8 deletions(-) create mode 100644 changelog.d/7492.misc (limited to 'synapse/storage') diff --git a/changelog.d/7492.misc b/changelog.d/7492.misc new file mode 100644 index 0000000000..5ad31819d6 --- /dev/null +++ b/changelog.d/7492.misc @@ -0,0 +1 @@ +Allow censoring of events to happen on workers. diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 667ad20428..bccb1140b2 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -122,6 +122,7 @@ from synapse.rest.client.v2_alpha.register import RegisterRestServlet from synapse.rest.client.versions import VersionsRestServlet from synapse.rest.key.v2 import KeyApiV2Resource from synapse.server import HomeServer +from synapse.storage.data_stores.main.censor_events import CensorEventsStore from synapse.storage.data_stores.main.media_repository import MediaRepositoryStore from synapse.storage.data_stores.main.monthly_active_users import ( MonthlyActiveUsersWorkerStore, @@ -442,6 +443,7 @@ class GenericWorkerSlavedStore( SlavedGroupServerStore, SlavedAccountDataStore, SlavedPusherStore, + CensorEventsStore, SlavedEventStore, SlavedKeyStore, RoomStore, diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index a622a600b4..0242521cc6 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -72,7 +72,6 @@ class MessageHandler(object): self.state_store = self.storage.state self._event_serializer = hs.get_event_client_serializer() self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages - self._is_worker_app = bool(hs.config.worker_app) # The scheduled call to self._expire_event. None if no call is currently # scheduled. @@ -260,7 +259,6 @@ class MessageHandler(object): Args: event (EventBase): The event to schedule the expiry of. """ - assert not self._is_worker_app expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER) if not isinstance(expiry_ts, int) or event.is_state(): diff --git a/synapse/storage/data_stores/main/censor_events.py b/synapse/storage/data_stores/main/censor_events.py index 49683b4d5a..2d48261724 100644 --- a/synapse/storage/data_stores/main/censor_events.py +++ b/synapse/storage/data_stores/main/censor_events.py @@ -33,15 +33,10 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class CensorEventsStore(CacheInvalidationWorkerStore, EventsWorkerStore, SQLBaseStore): +class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBaseStore): def __init__(self, database: Database, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) - # This should only exist on master for now - assert ( - hs.config.worker.worker_app is None - ), "Can only instantiate CensorEventsStore on master" - def _censor_redactions(): return run_as_background_process( "_censor_redactions", self._censor_redactions -- cgit 1.5.1 From 86614e251f25ef3f1741b4cd9c9d60ee9e754862 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Fri, 15 May 2020 16:17:12 +0100 Subject: Fix a small typo in the arguments of simple_update in update_remote_profile_cache (#7511) --- changelog.d/7511.bugfix | 1 + synapse/storage/data_stores/main/profile.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelog.d/7511.bugfix (limited to 'synapse/storage') diff --git a/changelog.d/7511.bugfix b/changelog.d/7511.bugfix new file mode 100644 index 0000000000..cf8bc69c6f --- /dev/null +++ b/changelog.d/7511.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug that broke the update remote profile background process. diff --git a/synapse/storage/data_stores/main/profile.py b/synapse/storage/data_stores/main/profile.py index 2b52cf9c1a..bfc9369f0b 100644 --- a/synapse/storage/data_stores/main/profile.py +++ b/synapse/storage/data_stores/main/profile.py @@ -110,7 +110,7 @@ class ProfileStore(ProfileWorkerStore): return self.db.simple_update( table="remote_profile_cache", keyvalues={"user_id": user_id}, - values={ + updatevalues={ "displayname": displayname, "avatar_url": avatar_url, "last_check": self._clock.time_msec(), -- cgit 1.5.1 From 1f36ff69e8c7b2853e906598898878742d39c42b Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 15 May 2020 16:43:59 +0100 Subject: Move event stream handling out of slave store. (#7491) This allows us to have the logic on both master and workers, which is necessary to move event persistence off master. We also combine the instantiation of ID generators from DataStore and slave stores to the base worker stores. This allows us to select which process writes events independently of the master/worker splits. --- changelog.d/7491.misc | 1 + synapse/replication/slave/storage/events.py | 89 ------------------- synapse/replication/slave/storage/push_rule.py | 8 -- synapse/storage/data_stores/main/__init__.py | 17 ---- synapse/storage/data_stores/main/cache.py | 102 +++++++++++++++++++++- synapse/storage/data_stores/main/events_worker.py | 35 ++++++++ synapse/storage/data_stores/main/push_rule.py | 14 +++ synapse/storage/util/id_generators.py | 11 +++ 8 files changed, 161 insertions(+), 116 deletions(-) create mode 100644 changelog.d/7491.misc (limited to 'synapse/storage') diff --git a/changelog.d/7491.misc b/changelog.d/7491.misc new file mode 100644 index 0000000000..50eb226db7 --- /dev/null +++ b/changelog.d/7491.misc @@ -0,0 +1 @@ +Move event stream handling out of slave store. diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index b313720a4b..1a1a50a24f 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -15,11 +15,6 @@ # limitations under the License. import logging -from synapse.api.constants import EventTypes -from synapse.replication.tcp.streams.events import ( - EventsStreamCurrentStateRow, - EventsStreamEventRow, -) from synapse.storage.data_stores.main.event_federation import EventFederationWorkerStore from synapse.storage.data_stores.main.event_push_actions import ( EventPushActionsWorkerStore, @@ -35,7 +30,6 @@ from synapse.storage.database import Database from synapse.util.caches.stream_change_cache import StreamChangeCache from ._base import BaseSlavedStore -from ._slaved_id_tracker import SlavedIdTracker logger = logging.getLogger(__name__) @@ -62,11 +56,6 @@ class SlavedEventStore( BaseSlavedStore, ): def __init__(self, database: Database, db_conn, hs): - self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering") - self._backfill_id_gen = SlavedIdTracker( - db_conn, "events", "stream_ordering", step=-1 - ) - super(SlavedEventStore, self).__init__(database, db_conn, hs) events_max = self._stream_id_gen.get_current_token() @@ -92,81 +81,3 @@ class SlavedEventStore( def get_room_min_stream_ordering(self): return self._backfill_id_gen.get_current_token() - - def process_replication_rows(self, stream_name, instance_name, token, rows): - if stream_name == "events": - self._stream_id_gen.advance(token) - for row in rows: - self._process_event_stream_row(token, row) - elif stream_name == "backfill": - self._backfill_id_gen.advance(-token) - for row in rows: - self.invalidate_caches_for_event( - -token, - row.event_id, - row.room_id, - row.type, - row.state_key, - row.redacts, - row.relates_to, - backfilled=True, - ) - return super().process_replication_rows(stream_name, instance_name, token, rows) - - def _process_event_stream_row(self, token, row): - data = row.data - - if row.type == EventsStreamEventRow.TypeId: - self.invalidate_caches_for_event( - token, - data.event_id, - data.room_id, - data.type, - data.state_key, - data.redacts, - data.relates_to, - backfilled=False, - ) - elif row.type == EventsStreamCurrentStateRow.TypeId: - self._curr_state_delta_stream_cache.entity_has_changed( - row.data.room_id, token - ) - - if data.type == EventTypes.Member: - self.get_rooms_for_user_with_stream_ordering.invalidate( - (data.state_key,) - ) - else: - raise Exception("Unknown events stream row type %s" % (row.type,)) - - def invalidate_caches_for_event( - self, - stream_ordering, - event_id, - room_id, - etype, - state_key, - redacts, - relates_to, - backfilled, - ): - self._invalidate_get_event_cache(event_id) - - self.get_latest_event_ids_in_room.invalidate((room_id,)) - - self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,)) - - if not backfilled: - self._events_stream_cache.entity_has_changed(room_id, stream_ordering) - - if redacts: - self._invalidate_get_event_cache(redacts) - - if etype == EventTypes.Member: - self._membership_stream_cache.entity_has_changed(state_key, stream_ordering) - self.get_invited_rooms_for_local_user.invalidate((state_key,)) - - if relates_to: - self.get_relations_for_event.invalidate_many((relates_to,)) - self.get_aggregation_groups_for_event.invalidate_many((relates_to,)) - self.get_applicable_edit.invalidate((relates_to,)) diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py index 5d5816d7eb..6adb19463a 100644 --- a/synapse/replication/slave/storage/push_rule.py +++ b/synapse/replication/slave/storage/push_rule.py @@ -15,19 +15,11 @@ # limitations under the License. from synapse.storage.data_stores.main.push_rule import PushRulesWorkerStore -from synapse.storage.database import Database -from ._slaved_id_tracker import SlavedIdTracker from .events import SlavedEventStore class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore): - def __init__(self, database: Database, db_conn, hs): - self._push_rules_stream_id_gen = SlavedIdTracker( - db_conn, "push_rules_stream", "stream_id" - ) - super(SlavedPushRuleStore, self).__init__(database, db_conn, hs) - def get_push_rules_stream_token(self): return ( self._push_rules_stream_id_gen.get_current_token(), diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py index 5df9dce79d..4b4763c701 100644 --- a/synapse/storage/data_stores/main/__init__.py +++ b/synapse/storage/data_stores/main/__init__.py @@ -24,7 +24,6 @@ from synapse.config.homeserver import HomeServerConfig from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import ( - ChainedIdGenerator, IdGenerator, MultiWriterIdGenerator, StreamIdGenerator, @@ -125,19 +124,6 @@ class DataStore( self._clock = hs.get_clock() self.database_engine = database.engine - self._stream_id_gen = StreamIdGenerator( - db_conn, - "events", - "stream_ordering", - extra_tables=[("local_invites", "stream_id")], - ) - self._backfill_id_gen = StreamIdGenerator( - db_conn, - "events", - "stream_ordering", - step=-1, - extra_tables=[("ex_outlier_stream", "event_stream_ordering")], - ) self._presence_id_gen = StreamIdGenerator( db_conn, "presence_stream", "stream_id" ) @@ -164,9 +150,6 @@ class DataStore( self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") - self._push_rules_stream_id_gen = ChainedIdGenerator( - self._stream_id_gen, db_conn, "push_rules_stream", "stream_id" - ) self._pushers_id_gen = StreamIdGenerator( db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] ) diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py index 342a87a46b..eac5a4e55b 100644 --- a/synapse/storage/data_stores/main/cache.py +++ b/synapse/storage/data_stores/main/cache.py @@ -16,8 +16,13 @@ import itertools import logging -from typing import Any, Iterable, Optional +from typing import Any, Iterable, Optional, Tuple +from synapse.api.constants import EventTypes +from synapse.replication.tcp.streams.events import ( + EventsStreamCurrentStateRow, + EventsStreamEventRow, +) from synapse.storage._base import SQLBaseStore from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine @@ -66,7 +71,22 @@ class CacheInvalidationWorkerStore(SQLBaseStore): ) def process_replication_rows(self, stream_name, instance_name, token, rows): - if stream_name == "caches": + if stream_name == "events": + for row in rows: + self._process_event_stream_row(token, row) + elif stream_name == "backfill": + for row in rows: + self._invalidate_caches_for_event( + -token, + row.event_id, + row.room_id, + row.type, + row.state_key, + row.redacts, + row.relates_to, + backfilled=True, + ) + elif stream_name == "caches": if self._cache_id_gen: self._cache_id_gen.advance(instance_name, token) @@ -85,6 +105,84 @@ class CacheInvalidationWorkerStore(SQLBaseStore): super().process_replication_rows(stream_name, instance_name, token, rows) + def _process_event_stream_row(self, token, row): + data = row.data + + if row.type == EventsStreamEventRow.TypeId: + self._invalidate_caches_for_event( + token, + data.event_id, + data.room_id, + data.type, + data.state_key, + data.redacts, + data.relates_to, + backfilled=False, + ) + elif row.type == EventsStreamCurrentStateRow.TypeId: + self._curr_state_delta_stream_cache.entity_has_changed( + row.data.room_id, token + ) + + if data.type == EventTypes.Member: + self.get_rooms_for_user_with_stream_ordering.invalidate( + (data.state_key,) + ) + else: + raise Exception("Unknown events stream row type %s" % (row.type,)) + + def _invalidate_caches_for_event( + self, + stream_ordering, + event_id, + room_id, + etype, + state_key, + redacts, + relates_to, + backfilled, + ): + self._invalidate_get_event_cache(event_id) + + self.get_latest_event_ids_in_room.invalidate((room_id,)) + + self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,)) + + if not backfilled: + self._events_stream_cache.entity_has_changed(room_id, stream_ordering) + + if redacts: + self._invalidate_get_event_cache(redacts) + + if etype == EventTypes.Member: + self._membership_stream_cache.entity_has_changed(state_key, stream_ordering) + self.get_invited_rooms_for_local_user.invalidate((state_key,)) + + if relates_to: + self.get_relations_for_event.invalidate_many((relates_to,)) + self.get_aggregation_groups_for_event.invalidate_many((relates_to,)) + self.get_applicable_edit.invalidate((relates_to,)) + + async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]): + """Invalidates the cache and adds it to the cache stream so slaves + will know to invalidate their caches. + + This should only be used to invalidate caches where slaves won't + otherwise know from other replication streams that the cache should + be invalidated. + """ + cache_func = getattr(self, cache_name, None) + if not cache_func: + return + + cache_func.invalidate(keys) + await self.db.runInteraction( + "invalidate_cache_and_stream", + self._send_invalidation_to_replication, + cache_func.__name__, + keys, + ) + def _invalidate_cache_and_stream(self, txn, cache_func, keys): """Invalidates the cache and adds it to the cache stream so slaves will know to invalidate their caches. diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index 970c31bd05..9130b74eb5 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -37,8 +37,10 @@ from synapse.events import make_event_from_dict from synapse.events.utils import prune_event from synapse.logging.context import PreserveLoggingContext, current_context from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause from synapse.storage.database import Database +from synapse.storage.util.id_generators import StreamIdGenerator from synapse.types import get_domain_from_id from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks from synapse.util.iterutils import batch_iter @@ -74,6 +76,31 @@ class EventsWorkerStore(SQLBaseStore): def __init__(self, database: Database, db_conn, hs): super(EventsWorkerStore, self).__init__(database, db_conn, hs) + if hs.config.worker_app is None: + # We are the process in charge of generating stream ids for events, + # so instantiate ID generators based on the database + self._stream_id_gen = StreamIdGenerator( + db_conn, + "events", + "stream_ordering", + extra_tables=[("local_invites", "stream_id")], + ) + self._backfill_id_gen = StreamIdGenerator( + db_conn, + "events", + "stream_ordering", + step=-1, + extra_tables=[("ex_outlier_stream", "event_stream_ordering")], + ) + else: + # Another process is in charge of persisting events and generating + # stream IDs: rely on the replication streams to let us know which + # IDs we can process. + self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering") + self._backfill_id_gen = SlavedIdTracker( + db_conn, "events", "stream_ordering", step=-1 + ) + self._get_event_cache = Cache( "*getEvent*", keylen=3, @@ -85,6 +112,14 @@ class EventsWorkerStore(SQLBaseStore): self._event_fetch_list = [] self._event_fetch_ongoing = 0 + def process_replication_rows(self, stream_name, instance_name, token, rows): + if stream_name == "events": + self._stream_id_gen.advance(token) + elif stream_name == "backfill": + self._backfill_id_gen.advance(-token) + + super().process_replication_rows(stream_name, instance_name, token, rows) + def get_received_ts(self, event_id): """Get received_ts (when it was persisted) for the event. diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py index b3faafa0a4..ef8f40959f 100644 --- a/synapse/storage/data_stores/main/push_rule.py +++ b/synapse/storage/data_stores/main/push_rule.py @@ -16,19 +16,23 @@ import abc import logging +from typing import Union from canonicaljson import json from twisted.internet import defer from synapse.push.baserules import list_with_base_rules +from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.storage._base import SQLBaseStore from synapse.storage.data_stores.main.appservice import ApplicationServiceWorkerStore +from synapse.storage.data_stores.main.events_worker import EventsWorkerStore from synapse.storage.data_stores.main.pusher import PusherWorkerStore from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore from synapse.storage.database import Database from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException +from synapse.storage.util.id_generators import ChainedIdGenerator from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -64,6 +68,7 @@ class PushRulesWorkerStore( ReceiptsWorkerStore, PusherWorkerStore, RoomMemberWorkerStore, + EventsWorkerStore, SQLBaseStore, ): """This is an abstract base class where subclasses must implement @@ -77,6 +82,15 @@ class PushRulesWorkerStore( def __init__(self, database: Database, db_conn, hs): super(PushRulesWorkerStore, self).__init__(database, db_conn, hs) + if hs.config.worker.worker_app is None: + self._push_rules_stream_id_gen = ChainedIdGenerator( + self._stream_id_gen, db_conn, "push_rules_stream", "stream_id" + ) # type: Union[ChainedIdGenerator, SlavedIdTracker] + else: + self._push_rules_stream_id_gen = SlavedIdTracker( + db_conn, "push_rules_stream", "stream_id" + ) + push_rules_prefill, push_rules_id = self.db.get_cache_dict( db_conn, "push_rules_stream", diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 86d04ea9ac..f89ce0bed2 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -166,6 +166,7 @@ class ChainedIdGenerator(object): def __init__(self, chained_generator, db_conn, table, column): self.chained_generator = chained_generator + self._table = table self._lock = threading.Lock() self._current_max = _load_current_id(db_conn, table, column) self._unfinished_ids = deque() # type: Deque[Tuple[int, int]] @@ -204,6 +205,16 @@ class ChainedIdGenerator(object): return self._current_max, self.chained_generator.get_current_token() + def advance(self, token: int): + """Stub implementation for advancing the token when receiving updates + over replication; raises an exception as this instance should be the + only source of updates. + """ + + raise Exception( + "Attempted to advance token on source for table %r", self._table + ) + class MultiWriterIdGenerator: """An ID generator that tracks a stream that can have multiple writers. -- cgit 1.5.1 From 16090a077f9f387dcd42edabda58063e3df6b771 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Fri, 15 May 2020 17:17:42 +0100 Subject: Prevent 0-member/null room_version rooms from appearing in group room queries (#7465) --- changelog.d/7465.bugfix | 1 + synapse/storage/data_stores/main/group_server.py | 92 ++++++++++++++++++++---- 2 files changed, 79 insertions(+), 14 deletions(-) create mode 100644 changelog.d/7465.bugfix (limited to 'synapse/storage') diff --git a/changelog.d/7465.bugfix b/changelog.d/7465.bugfix new file mode 100644 index 0000000000..1cbe50caa5 --- /dev/null +++ b/changelog.d/7465.bugfix @@ -0,0 +1 @@ +Prevent rooms with 0 members or with invalid version strings from breaking group queries. \ No newline at end of file diff --git a/synapse/storage/data_stores/main/group_server.py b/synapse/storage/data_stores/main/group_server.py index 0963e6c250..fb1361f1c1 100644 --- a/synapse/storage/data_stores/main/group_server.py +++ b/synapse/storage/data_stores/main/group_server.py @@ -68,24 +68,78 @@ class GroupServerWorkerStore(SQLBaseStore): desc="get_invited_users_in_group", ) - def get_rooms_in_group(self, group_id, include_private=False): + def get_rooms_in_group(self, group_id: str, include_private: bool = False): + """Retrieve the rooms that belong to a given group. Does not return rooms that + lack members. + + Args: + group_id: The ID of the group to query for rooms + include_private: Whether to return private rooms in results + + Returns: + Deferred[List[Dict[str, str|bool]]]: A list of dictionaries, each in the + form of: + + { + "room_id": "!a_room_id:example.com", # The ID of the room + "is_public": False # Whether this is a public room or not + } + """ # TODO: Pagination - keyvalues = {"group_id": group_id} - if not include_private: - keyvalues["is_public"] = True + def _get_rooms_in_group_txn(txn): + sql = """ + SELECT room_id, is_public FROM group_rooms + WHERE group_id = ? + AND room_id IN ( + SELECT group_rooms.room_id FROM group_rooms + LEFT JOIN room_stats_current ON + group_rooms.room_id = room_stats_current.room_id + AND joined_members > 0 + AND local_users_in_room > 0 + LEFT JOIN rooms ON + group_rooms.room_id = rooms.room_id + AND (room_version <> '') = ? + ) + """ + args = [group_id, False] - return self.db.simple_select_list( - table="group_rooms", - keyvalues=keyvalues, - retcols=("room_id", "is_public"), - desc="get_rooms_in_group", - ) + if not include_private: + sql += " AND is_public = ?" + args += [True] + + txn.execute(sql, args) + + return [ + {"room_id": room_id, "is_public": is_public} + for room_id, is_public in txn + ] - def get_rooms_for_summary_by_category(self, group_id, include_private=False): + return self.db.runInteraction("get_rooms_in_group", _get_rooms_in_group_txn) + + def get_rooms_for_summary_by_category( + self, group_id: str, include_private: bool = False, + ): """Get the rooms and categories that should be included in a summary request - Returns ([rooms], [categories]) + Args: + group_id: The ID of the group to query the summary for + include_private: Whether to return private rooms in results + + Returns: + Deferred[Tuple[List, Dict]]: A tuple containing: + + * A list of dictionaries with the keys: + * "room_id": str, the room ID + * "is_public": bool, whether the room is public + * "category_id": str|None, the category ID if set, else None + * "order": int, the sort order of rooms + + * A dictionary with the key: + * category_id (str): a dictionary with the keys: + * "is_public": bool, whether the category is public + * "profile": str, the category profile + * "order": int, the sort order of rooms in this category """ def _get_rooms_for_summary_txn(txn): @@ -97,13 +151,23 @@ class GroupServerWorkerStore(SQLBaseStore): SELECT room_id, is_public, category_id, room_order FROM group_summary_rooms WHERE group_id = ? + AND room_id IN ( + SELECT group_rooms.room_id FROM group_rooms + LEFT JOIN room_stats_current ON + group_rooms.room_id = room_stats_current.room_id + AND joined_members > 0 + AND local_users_in_room > 0 + LEFT JOIN rooms ON + group_rooms.room_id = rooms.room_id + AND (room_version <> '') = ? + ) """ if not include_private: sql += " AND is_public = ?" - txn.execute(sql, (group_id, True)) + txn.execute(sql, (group_id, False, True)) else: - txn.execute(sql, (group_id,)) + txn.execute(sql, (group_id, False)) rooms = [ { -- cgit 1.5.1 From 03aff4c75ed3b0b106ed1395b3d03b1ab9b013a6 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 15 May 2020 17:22:47 +0100 Subject: Add a worker store for search insertion. (#7516) This is required as both event persistence and the background update needs access to this function. It should be perfectly safe for two workers to write to that table at the same time. --- changelog.d/7516.misc | 1 + synapse/app/generic_worker.py | 2 + synapse/storage/data_stores/main/search.py | 96 +++++++++++++++--------------- 3 files changed, 52 insertions(+), 47 deletions(-) create mode 100644 changelog.d/7516.misc (limited to 'synapse/storage') diff --git a/changelog.d/7516.misc b/changelog.d/7516.misc new file mode 100644 index 0000000000..94b0fd49b2 --- /dev/null +++ b/changelog.d/7516.misc @@ -0,0 +1 @@ +Add a worker store for search insertion, required for moving event persistence off master. diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 2e3add7ac5..ab801108ca 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -122,6 +122,7 @@ from synapse.storage.data_stores.main.monthly_active_users import ( MonthlyActiveUsersWorkerStore, ) from synapse.storage.data_stores.main.presence import UserPresenceState +from synapse.storage.data_stores.main.search import SearchWorkerStore from synapse.storage.data_stores.main.ui_auth import UIAuthWorkerStore from synapse.storage.data_stores.main.user_directory import UserDirectoryStore from synapse.types import ReadReceipt @@ -451,6 +452,7 @@ class GenericWorkerSlavedStore( SlavedFilteringStore, MonthlyActiveUsersWorkerStore, MediaRepositoryStore, + SearchWorkerStore, BaseSlavedStore, ): def __init__(self, database, db_conn, hs): diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py index ee75b92344..13f49d8060 100644 --- a/synapse/storage/data_stores/main/search.py +++ b/synapse/storage/data_stores/main/search.py @@ -37,7 +37,55 @@ SearchEntry = namedtuple( ) -class SearchBackgroundUpdateStore(SQLBaseStore): +class SearchWorkerStore(SQLBaseStore): + def store_search_entries_txn(self, txn, entries): + """Add entries to the search table + + Args: + txn (cursor): + entries (iterable[SearchEntry]): + entries to be added to the table + """ + if not self.hs.config.enable_search: + return + if isinstance(self.database_engine, PostgresEngine): + sql = ( + "INSERT INTO event_search" + " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)" + " VALUES (?,?,?,to_tsvector('english', ?),?,?)" + ) + + args = ( + ( + entry.event_id, + entry.room_id, + entry.key, + entry.value, + entry.stream_ordering, + entry.origin_server_ts, + ) + for entry in entries + ) + + txn.executemany(sql, args) + + elif isinstance(self.database_engine, Sqlite3Engine): + sql = ( + "INSERT INTO event_search (event_id, room_id, key, value)" + " VALUES (?,?,?,?)" + ) + args = ( + (entry.event_id, entry.room_id, entry.key, entry.value) + for entry in entries + ) + + txn.executemany(sql, args) + else: + # This should be unreachable. + raise Exception("Unrecognized database engine") + + +class SearchBackgroundUpdateStore(SearchWorkerStore): EVENT_SEARCH_UPDATE_NAME = "event_search" EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order" @@ -296,52 +344,6 @@ class SearchBackgroundUpdateStore(SQLBaseStore): return num_rows - def store_search_entries_txn(self, txn, entries): - """Add entries to the search table - - Args: - txn (cursor): - entries (iterable[SearchEntry]): - entries to be added to the table - """ - if not self.hs.config.enable_search: - return - if isinstance(self.database_engine, PostgresEngine): - sql = ( - "INSERT INTO event_search" - " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)" - " VALUES (?,?,?,to_tsvector('english', ?),?,?)" - ) - - args = ( - ( - entry.event_id, - entry.room_id, - entry.key, - entry.value, - entry.stream_ordering, - entry.origin_server_ts, - ) - for entry in entries - ) - - txn.executemany(sql, args) - - elif isinstance(self.database_engine, Sqlite3Engine): - sql = ( - "INSERT INTO event_search (event_id, room_id, key, value)" - " VALUES (?,?,?,?)" - ) - args = ( - (entry.event_id, entry.room_id, entry.key, entry.value) - for entry in entries - ) - - txn.executemany(sql, args) - else: - # This should be unreachable. - raise Exception("Unrecognized database engine") - class SearchStore(SearchBackgroundUpdateStore): def __init__(self, database: Database, db_conn, hs): -- cgit 1.5.1 From 34a43f0084ae2e16eb6e8f6eb2ecf4bb37bd4cf0 Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Fri, 15 May 2020 18:53:31 +0100 Subject: Fix a couple of small typos --- synapse/appservice/__init__.py | 2 +- synapse/storage/data_stores/main/appservice.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index aea3985a5f..1b13e84425 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -270,7 +270,7 @@ class ApplicationService(object): def is_exclusive_room(self, room_id): return self._is_exclusive(ApplicationService.NS_ROOMS, room_id) - def get_exlusive_user_regexes(self): + def get_exclusive_user_regexes(self): """Get the list of regexes used to determine if a user is exclusively registered by the AS """ diff --git a/synapse/storage/data_stores/main/appservice.py b/synapse/storage/data_stores/main/appservice.py index efbc06c796..7a1fe8cdd2 100644 --- a/synapse/storage/data_stores/main/appservice.py +++ b/synapse/storage/data_stores/main/appservice.py @@ -30,12 +30,12 @@ logger = logging.getLogger(__name__) def _make_exclusive_regex(services_cache): - # We precompie a regex constructed from all the regexes that the AS's + # We precompile a regex constructed from all the regexes that the AS's # have registered for exclusive users. exclusive_user_regexes = [ regex.pattern for service in services_cache - for regex in service.get_exlusive_user_regexes() + for regex in service.get_exclusive_user_regexes() ] if exclusive_user_regexes: exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes) -- cgit 1.5.1 From 6c1f7c722f0baade9aecf41f600fcced670c4fcb Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 15 May 2020 19:03:25 +0100 Subject: Fix limit logic for AccountDataStream (#7384) Make sure that the AccountDataStream presents complete updates, in the right order. This is much the same fix as #7337 and #7358, but applied to a different stream. --- changelog.d/7384.bugfix | 1 + synapse/replication/tcp/streams/_base.py | 68 +++++++++--- synapse/storage/data_stores/main/account_data.py | 62 +++++++---- tests/replication/tcp/streams/test_account_data.py | 117 +++++++++++++++++++++ 4 files changed, 217 insertions(+), 31 deletions(-) create mode 100644 changelog.d/7384.bugfix create mode 100644 tests/replication/tcp/streams/test_account_data.py (limited to 'synapse/storage') diff --git a/changelog.d/7384.bugfix b/changelog.d/7384.bugfix new file mode 100644 index 0000000000..f49c600173 --- /dev/null +++ b/changelog.d/7384.bugfix @@ -0,0 +1 @@ +Fix a bug where event updates might not be sent over replication to worker processes after the stream falls behind. diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index b48a6a3e91..d42aaff055 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -14,14 +14,27 @@ # See the License for the specific language governing permissions and # limitations under the License. +import heapq import logging from collections import namedtuple -from typing import Any, Awaitable, Callable, List, Optional, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + List, + Optional, + Tuple, + TypeVar, +) import attr from synapse.replication.http.streams import ReplicationGetStreamUpdates +if TYPE_CHECKING: + import synapse.server + logger = logging.getLogger(__name__) # the number of rows to request from an update_function. @@ -37,7 +50,7 @@ Token = int # parsing with Stream.parse_row (which turns it into a `ROW_TYPE`). Normally it's # just a row from a database query, though this is dependent on the stream in question. # -StreamRow = Tuple +StreamRow = TypeVar("StreamRow", bound=Tuple) # The type returned by the update_function of a stream, as well as get_updates(), # get_updates_since, etc. @@ -533,32 +546,63 @@ class AccountDataStream(Stream): """ AccountDataStreamRow = namedtuple( - "AccountDataStream", ("user_id", "room_id", "data_type") # str # str # str + "AccountDataStream", + ("user_id", "room_id", "data_type"), # str # Optional[str] # str ) NAME = "account_data" ROW_TYPE = AccountDataStreamRow - def __init__(self, hs): + def __init__(self, hs: "synapse.server.HomeServer"): self.store = hs.get_datastore() super().__init__( hs.get_instance_name(), current_token_without_instance(self.store.get_max_account_data_stream_id), - db_query_to_update_function(self._update_function), + self._update_function, + ) + + async def _update_function( + self, instance_name: str, from_token: int, to_token: int, limit: int + ) -> StreamUpdateResult: + limited = False + global_results = await self.store.get_updated_global_account_data( + from_token, to_token, limit ) - async def _update_function(self, from_token, to_token, limit): - global_results, room_results = await self.store.get_all_updated_account_data( - from_token, from_token, to_token, limit + # if the global results hit the limit, we'll need to limit the room results to + # the same stream token. + if len(global_results) >= limit: + to_token = global_results[-1][0] + limited = True + + room_results = await self.store.get_updated_room_account_data( + from_token, to_token, limit ) - results = list(room_results) - results.extend( - (stream_id, user_id, None, account_data_type) + # likewise, if the room results hit the limit, limit the global results to + # the same stream token. + if len(room_results) >= limit: + to_token = room_results[-1][0] + limited = True + + # convert the global results to the right format, and limit them to the to_token + # at the same time + global_rows = ( + (stream_id, (user_id, None, account_data_type)) for stream_id, user_id, account_data_type in global_results + if stream_id <= to_token + ) + + # we know that the room_results are already limited to `to_token` so no need + # for a check on `stream_id` here. + room_rows = ( + (stream_id, (user_id, room_id, account_data_type)) + for stream_id, user_id, room_id, account_data_type in room_results ) - return results + # we need to return a sorted list, so merge them together. + updates = list(heapq.merge(room_rows, global_rows)) + return updates, to_token, limited class GroupServerStream(Stream): diff --git a/synapse/storage/data_stores/main/account_data.py b/synapse/storage/data_stores/main/account_data.py index 46b494b334..f9eef1b78e 100644 --- a/synapse/storage/data_stores/main/account_data.py +++ b/synapse/storage/data_stores/main/account_data.py @@ -16,6 +16,7 @@ import abc import logging +from typing import List, Tuple from canonicaljson import json @@ -175,41 +176,64 @@ class AccountDataWorkerStore(SQLBaseStore): "get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn ) - def get_all_updated_account_data( - self, last_global_id, last_room_id, current_id, limit - ): - """Get all the client account_data that has changed on the server + async def get_updated_global_account_data( + self, last_id: int, current_id: int, limit: int + ) -> List[Tuple[int, str, str]]: + """Get the global account_data that has changed, for the account_data stream + Args: - last_global_id(int): The position to fetch from for top level data - last_room_id(int): The position to fetch from for per room data - current_id(int): The position to fetch up to. + last_id: the last stream_id from the previous batch. + current_id: the maximum stream_id to return up to + limit: the maximum number of rows to return + Returns: - A deferred pair of lists of tuples of stream_id int, user_id string, - room_id string, and type string. + A list of tuples of stream_id int, user_id string, + and type string. """ - if last_room_id == current_id and last_global_id == current_id: - return defer.succeed(([], [])) + if last_id == current_id: + return [] - def get_updated_account_data_txn(txn): + def get_updated_global_account_data_txn(txn): sql = ( "SELECT stream_id, user_id, account_data_type" " FROM account_data WHERE ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC LIMIT ?" ) - txn.execute(sql, (last_global_id, current_id, limit)) - global_results = txn.fetchall() + txn.execute(sql, (last_id, current_id, limit)) + return txn.fetchall() + + return await self.db.runInteraction( + "get_updated_global_account_data", get_updated_global_account_data_txn + ) + + async def get_updated_room_account_data( + self, last_id: int, current_id: int, limit: int + ) -> List[Tuple[int, str, str, str]]: + """Get the global account_data that has changed, for the account_data stream + Args: + last_id: the last stream_id from the previous batch. + current_id: the maximum stream_id to return up to + limit: the maximum number of rows to return + + Returns: + A list of tuples of stream_id int, user_id string, + room_id string and type string. + """ + if last_id == current_id: + return [] + + def get_updated_room_account_data_txn(txn): sql = ( "SELECT stream_id, user_id, room_id, account_data_type" " FROM room_account_data WHERE ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC LIMIT ?" ) - txn.execute(sql, (last_room_id, current_id, limit)) - room_results = txn.fetchall() - return global_results, room_results + txn.execute(sql, (last_id, current_id, limit)) + return txn.fetchall() - return self.db.runInteraction( - "get_all_updated_account_data_txn", get_updated_account_data_txn + return await self.db.runInteraction( + "get_updated_room_account_data", get_updated_room_account_data_txn ) def get_updated_account_data_for_user(self, user_id, stream_id): diff --git a/tests/replication/tcp/streams/test_account_data.py b/tests/replication/tcp/streams/test_account_data.py new file mode 100644 index 0000000000..6a5116dd2a --- /dev/null +++ b/tests/replication/tcp/streams/test_account_data.py @@ -0,0 +1,117 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.replication.tcp.streams._base import ( + _STREAM_UPDATE_TARGET_ROW_COUNT, + AccountDataStream, +) + +from tests.replication._base import BaseStreamTestCase + + +class AccountDataStreamTestCase(BaseStreamTestCase): + def test_update_function_room_account_data_limit(self): + """Test replication with many room account data updates + """ + store = self.hs.get_datastore() + + # generate lots of account data updates + updates = [] + for i in range(_STREAM_UPDATE_TARGET_ROW_COUNT + 5): + update = "m.test_type.%i" % (i,) + self.get_success( + store.add_account_data_to_room("test_user", "test_room", update, {}) + ) + updates.append(update) + + # also one global update + self.get_success(store.add_account_data_for_user("test_user", "m.global", {})) + + # tell the notifier to catch up to avoid duplicate rows. + # workaround for https://github.com/matrix-org/synapse/issues/7360 + # FIXME remove this when the above is fixed + self.replicate() + + # check we're testing what we think we are: no rows should yet have been + # received + self.assertEqual([], self.test_handler.received_rdata_rows) + + # now reconnect to pull the updates + self.reconnect() + self.replicate() + + # we should have received all the expected rows in the right order + received_rows = self.test_handler.received_rdata_rows + + for t in updates: + (stream_name, token, row) = received_rows.pop(0) + self.assertEqual(stream_name, AccountDataStream.NAME) + self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow) + self.assertEqual(row.data_type, t) + self.assertEqual(row.room_id, "test_room") + + (stream_name, token, row) = received_rows.pop(0) + self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow) + self.assertEqual(row.data_type, "m.global") + self.assertIsNone(row.room_id) + + self.assertEqual([], received_rows) + + def test_update_function_global_account_data_limit(self): + """Test replication with many global account data updates + """ + store = self.hs.get_datastore() + + # generate lots of account data updates + updates = [] + for i in range(_STREAM_UPDATE_TARGET_ROW_COUNT + 5): + update = "m.test_type.%i" % (i,) + self.get_success(store.add_account_data_for_user("test_user", update, {})) + updates.append(update) + + # also one per-room update + self.get_success( + store.add_account_data_to_room("test_user", "test_room", "m.per_room", {}) + ) + + # tell the notifier to catch up to avoid duplicate rows. + # workaround for https://github.com/matrix-org/synapse/issues/7360 + # FIXME remove this when the above is fixed + self.replicate() + + # check we're testing what we think we are: no rows should yet have been + # received + self.assertEqual([], self.test_handler.received_rdata_rows) + + # now reconnect to pull the updates + self.reconnect() + self.replicate() + + # we should have received all the expected rows in the right order + received_rows = self.test_handler.received_rdata_rows + + for t in updates: + (stream_name, token, row) = received_rows.pop(0) + self.assertEqual(stream_name, AccountDataStream.NAME) + self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow) + self.assertEqual(row.data_type, t) + self.assertIsNone(row.room_id) + + (stream_name, token, row) = received_rows.pop(0) + self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow) + self.assertEqual(row.data_type, "m.per_room") + self.assertEqual(row.room_id, "test_room") + + self.assertEqual([], received_rows) -- cgit 1.5.1 From 08fa96f03037178620f5f0dd609fac52fbf7f2d1 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 15 May 2020 19:07:24 +0100 Subject: Remove `exception_to_unicode` this is a no-op on python 3. --- synapse/storage/database.py | 15 +++------------ synapse/util/stringutils.py | 36 ------------------------------------ 2 files changed, 3 insertions(+), 48 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/database.py b/synapse/storage/database.py index c3d0863429..9947dbce77 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -50,7 +50,6 @@ from synapse.storage.background_updates import BackgroundUpdater from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.storage.types import Connection, Cursor from synapse.types import Collection -from synapse.util.stringutils import exception_to_unicode logger = logging.getLogger(__name__) @@ -424,20 +423,14 @@ class Database(object): # This can happen if the database disappears mid # transaction. logger.warning( - "[TXN OPERROR] {%s} %s %d/%d", - name, - exception_to_unicode(e), - i, - N, + "[TXN OPERROR] {%s} %s %d/%d", name, e, i, N, ) if i < N: i += 1 try: conn.rollback() except self.engine.module.Error as e1: - logger.warning( - "[TXN EROLL] {%s} %s", name, exception_to_unicode(e1) - ) + logger.warning("[TXN EROLL] {%s} %s", name, e1) continue raise except self.engine.module.DatabaseError as e: @@ -449,9 +442,7 @@ class Database(object): conn.rollback() except self.engine.module.Error as e1: logger.warning( - "[TXN EROLL] {%s} %s", - name, - exception_to_unicode(e1), + "[TXN EROLL] {%s} %s", name, e1, ) continue raise diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index 6899bcb788..2cfa5cf721 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -85,42 +85,6 @@ def to_ascii(s): return s -def exception_to_unicode(e): - """Helper function to extract the text of an exception as a unicode string - - Args: - e (Exception): exception to be stringified - - Returns: - unicode - """ - # urgh, this is a mess. The basic problem here is that psycopg2 constructs its - # exceptions with PyErr_SetString, with a (possibly non-ascii) argument. str() will - # then produce the raw byte sequence. Under Python 2, this will then cause another - # error if it gets mixed with a `unicode` object, as per - # https://github.com/matrix-org/synapse/issues/4252 - - # First of all, if we're under python3, everything is fine because it will sort this - # nonsense out for us. - if not PY2: - return str(e) - - # otherwise let's have a stab at decoding the exception message. We'll circumvent - # Exception.__str__(), which would explode if someone raised Exception(u'non-ascii') - # and instead look at what is in the args member. - - if len(e.args) == 0: - return "" - elif len(e.args) > 1: - return six.text_type(repr(e.args)) - - msg = e.args[0] - if isinstance(msg, bytes): - return msg.decode("utf-8", errors="replace") - else: - return msg - - def assert_valid_client_secret(client_secret): """Validate that a given string matches the client_secret regex defined by the spec""" if client_secret_regex.match(client_secret) is None: -- cgit 1.5.1 From 65902e08c3f4449de9baa4e6466f126585f688b3 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 15 May 2020 19:12:03 +0100 Subject: remove to_ascii this is a no-op on python 3. --- synapse/storage/data_stores/main/roommember.py | 25 ++++++++++--------------- synapse/storage/data_stores/main/state.py | 5 +---- synapse/util/stringutils.py | 20 +------------------- 3 files changed, 12 insertions(+), 38 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py index 48810a3e91..1e9c850152 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py @@ -45,7 +45,6 @@ from synapse.util.async_helpers import Linearizer from synapse.util.caches import intern_string from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList from synapse.util.metrics import Measure -from synapse.util.stringutils import to_ascii logger = logging.getLogger(__name__) @@ -179,7 +178,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): """ txn.execute(sql, (room_id, Membership.JOIN)) - return [to_ascii(r[0]) for r in txn] + return [r[0] for r in txn] @cached(max_entries=100000) def get_room_summary(self, room_id): @@ -223,7 +222,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): txn.execute(sql, (room_id,)) res = {} for count, membership in txn: - summary = res.setdefault(to_ascii(membership), MemberSummary([], count)) + summary = res.setdefault(membership, MemberSummary([], count)) # we order by membership and then fairly arbitrarily by event_id so # heroes are consistent @@ -255,11 +254,11 @@ class RoomMemberWorkerStore(EventsWorkerStore): # 6 is 5 (number of heroes) plus 1, in case one of them is the calling user. txn.execute(sql, (room_id, Membership.JOIN, Membership.INVITE, 6)) for user_id, membership, event_id in txn: - summary = res[to_ascii(membership)] + summary = res[membership] # we will always have a summary for this membership type at this # point given the summary currently contains the counts. members = summary.members - members.append((to_ascii(user_id), to_ascii(event_id))) + members.append((user_id, event_id)) return res @@ -584,13 +583,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): ev_entry = event_map.get(event_id) if ev_entry: if ev_entry.event.membership == Membership.JOIN: - users_in_room[to_ascii(ev_entry.event.state_key)] = ProfileInfo( - display_name=to_ascii( - ev_entry.event.content.get("displayname", None) - ), - avatar_url=to_ascii( - ev_entry.event.content.get("avatar_url", None) - ), + users_in_room[ev_entry.event.state_key] = ProfileInfo( + display_name=ev_entry.event.content.get("displayname", None), + avatar_url=ev_entry.event.content.get("avatar_url", None), ) else: missing_member_event_ids.append(event_id) @@ -604,9 +599,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): if event is not None and event.type == EventTypes.Member: if event.membership == Membership.JOIN: if event.event_id in member_event_ids: - users_in_room[to_ascii(event.state_key)] = ProfileInfo( - display_name=to_ascii(event.content.get("displayname", None)), - avatar_url=to_ascii(event.content.get("avatar_url", None)), + users_in_room[event.state_key] = ProfileInfo( + display_name=event.content.get("displayname", None), + avatar_url=event.content.get("avatar_url", None), ) return users_in_room diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py index 21052fcc7a..347cc50778 100644 --- a/synapse/storage/data_stores/main/state.py +++ b/synapse/storage/data_stores/main/state.py @@ -29,7 +29,6 @@ from synapse.storage.database import Database from synapse.storage.state import StateFilter from synapse.util.caches import intern_string from synapse.util.caches.descriptors import cached, cachedList -from synapse.util.stringutils import to_ascii logger = logging.getLogger(__name__) @@ -185,9 +184,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): (room_id,), ) - return { - (intern_string(r[0]), intern_string(r[1])): to_ascii(r[2]) for r in txn - } + return {(intern_string(r[0]), intern_string(r[1])): r[2] for r in txn} return self.db.runInteraction( "get_current_state_ids", _get_current_state_ids_txn diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index 2cfa5cf721..81a44184ca 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -19,8 +19,7 @@ import re import string from collections import Iterable -import six -from six import PY2, PY3 +from six import PY3 from six.moves import range from synapse.api.errors import Codes, SynapseError @@ -68,23 +67,6 @@ def is_ascii(s): return True -def to_ascii(s): - """Converts a string to ascii if it is ascii, otherwise leave it alone. - - If given None then will return None. - """ - if PY3: - return s - - if s is None: - return None - - try: - return s.encode("ascii") - except UnicodeEncodeError: - return s - - def assert_valid_client_secret(client_secret): """Validate that a given string matches the client_secret regex defined by the spec""" if client_secret_regex.match(client_secret) is None: -- cgit 1.5.1 From e6027562e2a1964bcaa0163f1615ab72bfc6630b Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 15 May 2020 19:26:54 +0100 Subject: remove `builtins.buffer` code from storage code this is no longer needed on python 3 --- scripts-dev/convert_server_keys.py | 9 ++------- synapse/storage/_base.py | 8 -------- synapse/storage/data_stores/main/keys.py | 10 ++-------- synapse/storage/data_stores/main/transactions.py | 9 +-------- 4 files changed, 5 insertions(+), 31 deletions(-) (limited to 'synapse/storage') diff --git a/scripts-dev/convert_server_keys.py b/scripts-dev/convert_server_keys.py index 06b4c1e2ff..961dc59f11 100644 --- a/scripts-dev/convert_server_keys.py +++ b/scripts-dev/convert_server_keys.py @@ -3,8 +3,6 @@ import json import sys import time -import six - import psycopg2 import yaml from canonicaljson import encode_canonical_json @@ -12,10 +10,7 @@ from signedjson.key import read_signing_keys from signedjson.sign import sign_json from unpaddedbase64 import encode_base64 -if six.PY2: - db_type = six.moves.builtins.buffer -else: - db_type = memoryview +db_binary_type = memoryview def select_v1_keys(connection): @@ -72,7 +67,7 @@ def rows_v2(server, json): valid_until = json["valid_until_ts"] key_json = encode_canonical_json(json) for key_id in json["verify_keys"]: - yield (server, key_id, "-", valid_until, valid_until, db_type(key_json)) + yield (server, key_id, "-", valid_until, valid_until, db_binary_type(key_json)) def main(): diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 59073c0a42..bfce541ca7 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -19,9 +19,6 @@ import random from abc import ABCMeta from typing import Any, Optional -from six import PY2 -from six.moves import builtins - from canonicaljson import json from synapse.storage.database import LoggingTransaction # noqa: F401 @@ -103,11 +100,6 @@ def db_to_json(db_content): if isinstance(db_content, memoryview): db_content = db_content.tobytes() - # psycopg2 on Python 2 returns buffer objects, which we need to cast to - # bytes to decode - if PY2 and isinstance(db_content, builtins.buffer): - db_content = bytes(db_content) - # Decode it to a Unicode string before feeding it to json.loads, so we # consistenty get a Unicode-containing object out. if isinstance(db_content, (bytes, bytearray)): diff --git a/synapse/storage/data_stores/main/keys.py b/synapse/storage/data_stores/main/keys.py index ba89c68c9f..4e1642a27a 100644 --- a/synapse/storage/data_stores/main/keys.py +++ b/synapse/storage/data_stores/main/keys.py @@ -17,8 +17,6 @@ import itertools import logging -import six - from signedjson.key import decode_verify_key_bytes from synapse.storage._base import SQLBaseStore @@ -28,12 +26,8 @@ from synapse.util.iterutils import batch_iter logger = logging.getLogger(__name__) -# py2 sqlite has buffer hardcoded as only binary type, so we must use it, -# despite being deprecated and removed in favor of memoryview -if six.PY2: - db_binary_type = six.moves.builtins.buffer -else: - db_binary_type = memoryview + +db_binary_type = memoryview class KeyStore(SQLBaseStore): diff --git a/synapse/storage/data_stores/main/transactions.py b/synapse/storage/data_stores/main/transactions.py index 5b07c2fbc0..a9bf457939 100644 --- a/synapse/storage/data_stores/main/transactions.py +++ b/synapse/storage/data_stores/main/transactions.py @@ -16,8 +16,6 @@ import logging from collections import namedtuple -import six - from canonicaljson import encode_canonical_json from twisted.internet import defer @@ -27,12 +25,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import Database from synapse.util.caches.expiringcache import ExpiringCache -# py2 sqlite has buffer hardcoded as only binary type, so we must use it, -# despite being deprecated and removed in favor of memoryview -if six.PY2: - db_binary_type = six.moves.builtins.buffer -else: - db_binary_type = memoryview +db_binary_type = memoryview logger = logging.getLogger(__name__) -- cgit 1.5.1 From f6f92845f84301619f4c48eed9364b25db0b1445 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 21 May 2020 13:20:10 +0100 Subject: Fix bug in persist events when dealing with non member types. (#7548) `_is_server_still_joined` will throw if it is given state updates with non-user ID state keys with local user leaves. This is actually rarely a problem since local leaves almost always get persisted by themselves. (I discovered this on a branch that was otherwise broken, so I haven't seen this in the wild) --- changelog.d/7548.bugfix | 1 + synapse/storage/persist_events.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) create mode 100644 changelog.d/7548.bugfix (limited to 'synapse/storage') diff --git a/changelog.d/7548.bugfix b/changelog.d/7548.bugfix new file mode 100644 index 0000000000..1233b3b31a --- /dev/null +++ b/changelog.d/7548.bugfix @@ -0,0 +1 @@ +Fix bug where a local user leaving a room could fail under rare circumstances. diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index 41881ea20b..12e1ffb9a2 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -740,8 +740,8 @@ class EventsPersistenceStorage(object): # whose state has changed as we've already their new state above. users_to_ignore = [ state_key - for _, state_key in itertools.chain(delta.to_insert, delta.to_delete) - if self.is_mine_id(state_key) + for typ, state_key in itertools.chain(delta.to_insert, delta.to_delete) + if typ == EventTypes.Member and self.is_mine_id(state_key) ] if await self.main_store.is_local_host_in_room_ignoring_users( -- cgit 1.5.1