summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/storage/databases/__init__.py10
-rw-r--r--synapse/storage/databases/main/__init__.py6
-rw-r--r--synapse/storage/databases/main/account_data.py88
-rw-r--r--synapse/storage/databases/main/cache.py72
-rw-r--r--synapse/storage/databases/main/client_ips.py33
-rw-r--r--synapse/storage/databases/main/delayed_events.py549
-rw-r--r--synapse/storage/databases/main/deviceinbox.py8
-rw-r--r--synapse/storage/databases/main/devices.py27
-rw-r--r--synapse/storage/databases/main/e2e_room_keys.py42
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py165
-rw-r--r--synapse/storage/databases/main/event_federation.py16
-rw-r--r--synapse/storage/databases/main/event_push_actions.py95
-rw-r--r--synapse/storage/databases/main/events.py1055
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py1425
-rw-r--r--synapse/storage/databases/main/events_worker.py248
-rw-r--r--synapse/storage/databases/main/media_repository.py110
-rw-r--r--synapse/storage/databases/main/monthly_active_users.py93
-rw-r--r--synapse/storage/databases/main/profile.py346
-rw-r--r--synapse/storage/databases/main/purge_events.py50
-rw-r--r--synapse/storage/databases/main/push_rule.py1
-rw-r--r--synapse/storage/databases/main/receipts.py309
-rw-r--r--synapse/storage/databases/main/registration.py742
-rw-r--r--synapse/storage/databases/main/room.py345
-rw-r--r--synapse/storage/databases/main/roommember.py534
-rw-r--r--synapse/storage/databases/main/search.py8
-rw-r--r--synapse/storage/databases/main/sliding_sync.py603
-rw-r--r--synapse/storage/databases/main/state.py42
-rw-r--r--synapse/storage/databases/main/state_deltas.py176
-rw-r--r--synapse/storage/databases/main/stats.py8
-rw-r--r--synapse/storage/databases/main/stream.py603
-rw-r--r--synapse/storage/databases/main/tags.py70
-rw-r--r--synapse/storage/databases/main/transactions.py4
-rw-r--r--synapse/storage/databases/main/user_directory.py40
-rw-r--r--synapse/storage/databases/state/bg_updates.py17
-rw-r--r--synapse/storage/databases/state/deletion.py561
-rw-r--r--synapse/storage/databases/state/store.py172
36 files changed, 7297 insertions, 1376 deletions
diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py

index dd9fc01fb0..81886ff765 100644 --- a/synapse/storage/databases/__init__.py +++ b/synapse/storage/databases/__init__.py
@@ -26,6 +26,7 @@ from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool, make_conn from synapse.storage.databases.main.events import PersistEventsStore from synapse.storage.databases.state import StateGroupDataStore +from synapse.storage.databases.state.deletion import StateDeletionDataStore from synapse.storage.engines import create_engine from synapse.storage.prepare_database import prepare_database @@ -49,12 +50,14 @@ class Databases(Generic[DataStoreT]): main state persist_events + state_deletion """ databases: List[DatabasePool] main: "DataStore" # FIXME: https://github.com/matrix-org/synapse/issues/11165: actually an instance of `main_store_class` state: StateGroupDataStore persist_events: Optional[PersistEventsStore] + state_deletion: StateDeletionDataStore def __init__(self, main_store_class: Type[DataStoreT], hs: "HomeServer"): # Note we pass in the main store class here as workers use a different main @@ -63,6 +66,7 @@ class Databases(Generic[DataStoreT]): self.databases = [] main: Optional[DataStoreT] = None state: Optional[StateGroupDataStore] = None + state_deletion: Optional[StateDeletionDataStore] = None persist_events: Optional[PersistEventsStore] = None for database_config in hs.config.database.databases: @@ -114,7 +118,8 @@ class Databases(Generic[DataStoreT]): if state: raise Exception("'state' data store already configured") - state = StateGroupDataStore(database, db_conn, hs) + state_deletion = StateDeletionDataStore(database, db_conn, hs) + state = StateGroupDataStore(database, db_conn, hs, state_deletion) db_conn.commit() @@ -135,7 +140,7 @@ class Databases(Generic[DataStoreT]): if not main: raise Exception("No 'main' database configured") - if not state: + if not state or not state_deletion: raise Exception("No 'state' database configured") # We use local variables here to ensure that the databases do not have @@ -143,3 +148,4 @@ class Databases(Generic[DataStoreT]): self.main = main # type: ignore[assignment] self.state = state self.persist_events = persist_events + self.state_deletion = state_deletion diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 586e84f2a4..86431f6e40 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py
@@ -3,7 +3,7 @@ # # Copyright 2019-2021 The Matrix.org Foundation C.I.C. # Copyright 2014-2016 OpenMarket Ltd -# Copyright (C) 2023 New Vector, Ltd +# Copyright (C) 2023-2024 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as @@ -33,6 +33,7 @@ from synapse.storage.database import ( LoggingDatabaseConnection, LoggingTransaction, ) +from synapse.storage.databases.main.sliding_sync import SlidingSyncStore from synapse.storage.databases.main.stats import UserSortOrder from synapse.storage.engines import BaseDatabaseEngine from synapse.storage.types import Cursor @@ -43,6 +44,7 @@ from .appservice import ApplicationServiceStore, ApplicationServiceTransactionSt from .cache import CacheInvalidationWorkerStore from .censor_events import CensorEventsStore from .client_ips import ClientIpWorkerStore +from .delayed_events import DelayedEventsStore from .deviceinbox import DeviceInboxStore from .devices import DeviceStore from .directory import DirectoryStore @@ -156,6 +158,8 @@ class DataStore( LockStore, SessionStore, TaskSchedulerWorkerStore, + SlidingSyncStore, + DelayedEventsStore, ): def __init__( self, diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 966393869b..715815cc09 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py
@@ -34,6 +34,7 @@ from typing import ( ) from synapse.api.constants import AccountDataTypes +from synapse.api.errors import Codes, SynapseError from synapse.replication.tcp.streams import AccountDataStream from synapse.storage._base import db_to_json from synapse.storage.database import ( @@ -43,6 +44,7 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.databases.main.push_rule import PushRulesWorkerStore +from synapse.storage.invite_rule import InviteRulesConfig from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.types import JsonDict, JsonMapping from synapse.util import json_encoder @@ -102,6 +104,8 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) self._delete_account_data_for_deactivated_users, ) + self._msc4155_enabled = hs.config.experimental.msc4155_enabled + def get_max_account_data_stream_id(self) -> int: """Get the current max stream ID for account data stream @@ -177,7 +181,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) def get_room_account_data_for_user_txn( txn: LoggingTransaction, - ) -> Dict[str, Dict[str, JsonDict]]: + ) -> Dict[str, Dict[str, JsonMapping]]: # The 'content != '{}' condition below prevents us from using # `simple_select_list_txn` here, as it doesn't support conditions # other than 'equals'. @@ -194,7 +198,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) txn.execute(sql, (user_id,)) - by_room: Dict[str, Dict[str, JsonDict]] = {} + by_room: Dict[str, Dict[str, JsonMapping]] = {} for room_id, account_data_type, content in txn: room_data = by_room.setdefault(room_id, {}) @@ -394,7 +398,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) async def get_updated_global_account_data_for_user( self, user_id: str, stream_id: int - ) -> Mapping[str, JsonMapping]: + ) -> Dict[str, JsonMapping]: """Get all the global account_data that's changed for a user. Args: @@ -407,7 +411,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) def get_updated_global_account_data_for_user( txn: LoggingTransaction, - ) -> Dict[str, JsonDict]: + ) -> Dict[str, JsonMapping]: sql = """ SELECT account_data_type, content FROM account_data WHERE user_id = ? AND stream_id > ? @@ -429,7 +433,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) async def get_updated_room_account_data_for_user( self, user_id: str, stream_id: int - ) -> Dict[str, Dict[str, JsonDict]]: + ) -> Dict[str, Dict[str, JsonMapping]]: """Get all the room account_data that's changed for a user. Args: @@ -442,14 +446,14 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) def get_updated_room_account_data_for_user_txn( txn: LoggingTransaction, - ) -> Dict[str, Dict[str, JsonDict]]: + ) -> Dict[str, Dict[str, JsonMapping]]: sql = """ SELECT room_id, account_data_type, content FROM room_account_data WHERE user_id = ? AND stream_id > ? """ txn.execute(sql, (user_id, stream_id)) - account_data_by_room: Dict[str, Dict[str, JsonDict]] = {} + account_data_by_room: Dict[str, Dict[str, JsonMapping]] = {} for row in txn: room_account_data = account_data_by_room.setdefault(row[0], {}) room_account_data[row[1]] = db_to_json(row[2]) @@ -467,6 +471,56 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) get_updated_room_account_data_for_user_txn, ) + async def get_updated_room_account_data_for_user_for_room( + self, + # Since there are multiple arguments with the same type, force keyword arguments + # so people don't accidentally swap the order + *, + user_id: str, + room_id: str, + from_stream_id: int, + to_stream_id: int, + ) -> Dict[str, JsonMapping]: + """Get the room account_data that's changed for a user in a room. + + (> `from_stream_id` and <= `to_stream_id`) + + Args: + user_id: The user to get the account_data for. + room_id: The room to check + from_stream_id: The point in the stream to fetch from + to_stream_id: The point in the stream to fetch to + + Returns: + A dict of the room account data. + """ + + def get_updated_room_account_data_for_user_for_room_txn( + txn: LoggingTransaction, + ) -> Dict[str, JsonMapping]: + sql = """ + SELECT account_data_type, content FROM room_account_data + WHERE user_id = ? AND room_id = ? AND stream_id > ? AND stream_id <= ? + """ + txn.execute(sql, (user_id, room_id, from_stream_id, to_stream_id)) + + room_account_data: Dict[str, JsonMapping] = {} + for row in txn: + room_account_data[row[0]] = db_to_json(row[1]) + + return room_account_data + + changed = self._account_data_stream_cache.has_entity_changed( + user_id, int(from_stream_id) + ) + if not changed: + return {} + + return await self.db_pool.runInteraction( + "get_updated_room_account_data_for_user_for_room", + get_updated_room_account_data_for_user_for_room_txn, + ) + @cached(max_entries=5000, iterable=True) async def ignored_by(self, user_id: str) -> FrozenSet[str]: """ @@ -507,6 +561,23 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) ) ) + async def get_invite_config_for_user(self, user_id: str) -> InviteRulesConfig: + """ + Get the invite configuration for the current user. + + Args: + user_id: + """ + + if not self._msc4155_enabled: + # This equates to allowing all invites, as if the setting was off. + return InviteRulesConfig(None) + + data = await self.get_global_account_data_by_type_for_user( + user_id, AccountDataTypes.MSC4155_INVITE_PERMISSION_CONFIG + ) + return InviteRulesConfig(data) + def process_replication_rows( self, stream_name: str, @@ -710,6 +781,9 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) else: currently_ignored_users = set() + if user_id in currently_ignored_users: + raise SynapseError(400, "You cannot ignore yourself", Codes.INVALID_PARAM) + # If the data has not changed, nothing to do. if previously_ignored_users == currently_ignored_users: return diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 63624f3e8f..9418fb6dd7 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py
@@ -41,6 +41,7 @@ from synapse.storage.database import ( LoggingDatabaseConnection, LoggingTransaction, ) +from synapse.storage.databases.main.events import SLIDING_SYNC_RELEVANT_STATE_SET from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.util.caches.descriptors import CachedFunction @@ -218,6 +219,11 @@ class CacheInvalidationWorkerStore(SQLBaseStore): room_id = row.keys[0] members_changed = set(row.keys[1:]) self._invalidate_state_caches(room_id, members_changed) + self._curr_state_delta_stream_cache.entity_has_changed( # type: ignore[attr-defined] + room_id, token + ) + for user_id in members_changed: + self._membership_stream_cache.entity_has_changed(user_id, token) # type: ignore[attr-defined] elif row.cache_func == PURGE_HISTORY_CACHE_NAME: if row.keys is None: raise Exception( @@ -235,6 +241,35 @@ class CacheInvalidationWorkerStore(SQLBaseStore): room_id = row.keys[0] self._invalidate_caches_for_room_events(room_id) self._invalidate_caches_for_room(room_id) + self._curr_state_delta_stream_cache.entity_has_changed( # type: ignore[attr-defined] + room_id, token + ) + # Note: This code is commented out to improve cache performance. + # While uncommenting would provide complete correctness, our + # automatic forgotten room purge logic (see + # `forgotten_room_retention_period`) means this would frequently + # clear the entire cache (effectively) and probably have a noticable + # impact on the cache hit ratio. + # + # Not updating the cache here is safe because: + # + # 1. `_membership_stream_cache` is only used to indicate the + # *absence* of changes, i.e. "nothing has changed between tokens + # X and Y and so return early and don't query the database". + # 2. `_membership_stream_cache` is used when we query data from + # `current_state_delta_stream` and `room_memberships` but since + # nothing new is written to the database for those tables when + # purging/deleting a room (only deleting rows), there is nothing + # changed to care about. + # + # At worst, the cache might indicate a change at token X, at which + # point, we will query the database and discover nothing is there. + # + # Ideally, we would make it so that we could clear the cache on a + # more granular level but that's a bit complex and fiddly to do with + # room membership. + # + # self._membership_stream_cache.all_entities_changed(token) # type: ignore[attr-defined] else: self._attempt_to_invalidate_cache(row.cache_func, row.keys) @@ -271,20 +306,33 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self._attempt_to_invalidate_cache( "get_rooms_for_user", (data.state_key,) ) + self._attempt_to_invalidate_cache( + "get_sliding_sync_rooms_for_user_from_membership_snapshots", None + ) + self._membership_stream_cache.entity_has_changed(data.state_key, token) # type: ignore[attr-defined] elif data.type == EventTypes.RoomEncryption: self._attempt_to_invalidate_cache( "get_room_encryption", (data.room_id,) ) elif data.type == EventTypes.Create: self._attempt_to_invalidate_cache("get_room_type", (data.room_id,)) + + if (data.type, data.state_key) in SLIDING_SYNC_RELEVANT_STATE_SET: + self._attempt_to_invalidate_cache( + "get_sliding_sync_rooms_for_user_from_membership_snapshots", None + ) elif row.type == EventsStreamAllStateRow.TypeId: assert isinstance(data, EventsStreamAllStateRow) # Similar to the above, but the entire caches are invalidated. This is # unfortunate for the membership caches, but should recover quickly. self._curr_state_delta_stream_cache.entity_has_changed(data.room_id, token) # type: ignore[attr-defined] + self._membership_stream_cache.all_entities_changed(token) # type: ignore[attr-defined] self._attempt_to_invalidate_cache("get_rooms_for_user", None) self._attempt_to_invalidate_cache("get_room_type", (data.room_id,)) self._attempt_to_invalidate_cache("get_room_encryption", (data.room_id,)) + self._attempt_to_invalidate_cache( + "get_sliding_sync_rooms_for_user_from_membership_snapshots", None + ) else: raise Exception("Unknown events stream row type %s" % (row.type,)) @@ -312,6 +360,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self._attempt_to_invalidate_cache( "get_unread_event_push_actions_by_room_for_user", (room_id,) ) + self._attempt_to_invalidate_cache("get_metadata_for_event", (room_id, event_id)) + + self._attempt_to_invalidate_cache("_get_max_event_pos", (room_id,)) # The `_get_membership_from_event_id` is immutable, except for the # case where we look up an event *before* persisting it. @@ -344,6 +395,10 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self._attempt_to_invalidate_cache( "_get_rooms_for_local_user_where_membership_is_inner", (state_key,) ) + self._attempt_to_invalidate_cache( + "get_sliding_sync_rooms_for_user_from_membership_snapshots", + (state_key,), + ) self._attempt_to_invalidate_cache( "did_forget", @@ -360,6 +415,11 @@ class CacheInvalidationWorkerStore(SQLBaseStore): elif etype == EventTypes.RoomEncryption: self._attempt_to_invalidate_cache("get_room_encryption", (room_id,)) + if (etype, state_key) in SLIDING_SYNC_RELEVANT_STATE_SET: + self._attempt_to_invalidate_cache( + "get_sliding_sync_rooms_for_user_from_membership_snapshots", None + ) + if relates_to: self._attempt_to_invalidate_cache( "get_relations_for_event", @@ -404,6 +464,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore): ) self._attempt_to_invalidate_cache("get_relations_for_event", (room_id,)) + self._attempt_to_invalidate_cache("_get_max_event_pos", (room_id,)) + self._attempt_to_invalidate_cache("_get_membership_from_event_id", None) self._attempt_to_invalidate_cache("get_applicable_edit", None) self._attempt_to_invalidate_cache("get_thread_id", None) @@ -413,6 +475,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self._attempt_to_invalidate_cache( "_get_rooms_for_local_user_where_membership_is_inner", None ) + self._attempt_to_invalidate_cache( + "get_sliding_sync_rooms_for_user_from_membership_snapshots", None + ) self._attempt_to_invalidate_cache("did_forget", None) self._attempt_to_invalidate_cache("get_forgotten_rooms_for_user", None) self._attempt_to_invalidate_cache("get_references_for_event", None) @@ -425,6 +490,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self._attempt_to_invalidate_cache("_get_state_group_for_event", None) self._attempt_to_invalidate_cache("get_event_ordering", None) + self._attempt_to_invalidate_cache("get_metadata_for_event", (room_id,)) self._attempt_to_invalidate_cache("is_partial_state_event", None) self._attempt_to_invalidate_cache("_get_joined_profile_from_event_id", None) @@ -450,6 +516,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self._attempt_to_invalidate_cache("get_account_data_for_room", None) self._attempt_to_invalidate_cache("get_account_data_for_room_and_type", None) + self._attempt_to_invalidate_cache("get_tags_for_room", None) self._attempt_to_invalidate_cache("get_aliases_for_room", (room_id,)) self._attempt_to_invalidate_cache("get_latest_event_ids_in_room", (room_id,)) self._attempt_to_invalidate_cache("_get_forward_extremeties_for_room", None) @@ -469,6 +536,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self._attempt_to_invalidate_cache( "get_current_hosts_in_room_ordered", (room_id,) ) + self._attempt_to_invalidate_cache( + "get_sliding_sync_rooms_for_user_from_membership_snapshots", None + ) self._attempt_to_invalidate_cache("did_forget", None) self._attempt_to_invalidate_cache("get_forgotten_rooms_for_user", None) self._attempt_to_invalidate_cache("_get_membership_from_event_id", None) @@ -476,6 +546,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self._attempt_to_invalidate_cache("get_room_type", (room_id,)) self._attempt_to_invalidate_cache("get_room_encryption", (room_id,)) + self._attempt_to_invalidate_cache("_get_max_event_pos", (room_id,)) + # And delete state caches. self._invalidate_state_caches_all(room_id) diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index 4b66247640..69008804bd 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py
@@ -20,10 +20,19 @@ # import logging -from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union, cast +from typing import ( + TYPE_CHECKING, + Dict, + List, + Mapping, + Optional, + Tuple, + TypedDict, + Union, + cast, +) import attr -from typing_extensions import TypedDict from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore @@ -238,9 +247,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): INNER JOIN user_ips USING (user_id, access_token, ip) GROUP BY user_id, access_token, ip HAVING count(*) > 1 - """.format( - clause - ), + """.format(clause), args, ) res = cast( @@ -373,9 +380,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): LIMIT ? ) c INNER JOIN user_ips AS u USING (user_id, device_id, last_seen) - """ % { - "where_clause": where_clause - } + """ % {"where_clause": where_clause} txn.execute(sql, where_args + [batch_size]) rows = cast(List[Tuple[int, str, str, str, str]], txn.fetchall()) @@ -645,9 +650,9 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke @wrap_as_background_process("update_client_ips") async def _update_client_ips_batch(self) -> None: - assert ( - self._update_on_this_worker - ), "This worker is not designated to update client IPs" + assert self._update_on_this_worker, ( + "This worker is not designated to update client IPs" + ) # If the DB pool has already terminated, don't try updating if not self.db_pool.is_running(): @@ -666,9 +671,9 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke txn: LoggingTransaction, to_update: Mapping[Tuple[str, str, str], Tuple[str, Optional[str], int]], ) -> None: - assert ( - self._update_on_this_worker - ), "This worker is not designated to update client IPs" + assert self._update_on_this_worker, ( + "This worker is not designated to update client IPs" + ) # Keys and values for the `user_ips` upsert. user_ips_keys = [] diff --git a/synapse/storage/databases/main/delayed_events.py b/synapse/storage/databases/main/delayed_events.py new file mode 100644
index 0000000000..c88682d55c --- /dev/null +++ b/synapse/storage/databases/main/delayed_events.py
@@ -0,0 +1,549 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2024 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# <https://www.gnu.org/licenses/agpl-3.0.html>. +# + +import logging +from typing import List, NewType, Optional, Tuple + +import attr + +from synapse.api.errors import NotFoundError +from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage.database import LoggingTransaction, StoreError +from synapse.storage.engines import PostgresEngine +from synapse.types import JsonDict, RoomID +from synapse.util import json_encoder, stringutils as stringutils + +logger = logging.getLogger(__name__) + + +DelayID = NewType("DelayID", str) +UserLocalpart = NewType("UserLocalpart", str) +DeviceID = NewType("DeviceID", str) +EventType = NewType("EventType", str) +StateKey = NewType("StateKey", str) + +Delay = NewType("Delay", int) +Timestamp = NewType("Timestamp", int) + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class EventDetails: + room_id: RoomID + type: EventType + state_key: Optional[StateKey] + origin_server_ts: Optional[Timestamp] + content: JsonDict + device_id: Optional[DeviceID] + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class DelayedEventDetails(EventDetails): + delay_id: DelayID + user_localpart: UserLocalpart + + +class DelayedEventsStore(SQLBaseStore): + async def get_delayed_events_stream_pos(self) -> int: + """ + Gets the stream position of the background process to watch for state events + that target the same piece of state as any pending delayed events. + """ + return await self.db_pool.simple_select_one_onecol( + table="delayed_events_stream_pos", + keyvalues={}, + retcol="stream_id", + desc="get_delayed_events_stream_pos", + ) + + async def update_delayed_events_stream_pos(self, stream_id: Optional[int]) -> None: + """ + Updates the stream position of the background process to watch for state events + that target the same piece of state as any pending delayed events. + + Must only be used by the worker running the background process. + """ + await self.db_pool.simple_update_one( + table="delayed_events_stream_pos", + keyvalues={}, + updatevalues={"stream_id": stream_id}, + desc="update_delayed_events_stream_pos", + ) + + async def add_delayed_event( + self, + *, + user_localpart: str, + device_id: Optional[str], + creation_ts: Timestamp, + room_id: str, + event_type: str, + state_key: Optional[str], + origin_server_ts: Optional[int], + content: JsonDict, + delay: int, + ) -> Tuple[DelayID, Timestamp]: + """ + Inserts a new delayed event in the DB. + + Returns: The generated ID assigned to the added delayed event, + and the send time of the next delayed event to be sent, + which is either the event just added or one added earlier. + """ + delay_id = _generate_delay_id() + send_ts = Timestamp(creation_ts + delay) + + def add_delayed_event_txn(txn: LoggingTransaction) -> Timestamp: + self.db_pool.simple_insert_txn( + txn, + table="delayed_events", + values={ + "delay_id": delay_id, + "user_localpart": user_localpart, + "device_id": device_id, + "delay": delay, + "send_ts": send_ts, + "room_id": room_id, + "event_type": event_type, + "state_key": state_key, + "origin_server_ts": origin_server_ts, + "content": json_encoder.encode(content), + }, + ) + + next_send_ts = self._get_next_delayed_event_send_ts_txn(txn) + assert next_send_ts is not None + return next_send_ts + + next_send_ts = await self.db_pool.runInteraction( + "add_delayed_event", add_delayed_event_txn + ) + + return delay_id, next_send_ts + + async def restart_delayed_event( + self, + *, + delay_id: str, + user_localpart: str, + current_ts: Timestamp, + ) -> Timestamp: + """ + Restarts the send time of the matching delayed event, + as long as it hasn't already been marked for processing. + + Args: + delay_id: The ID of the delayed event to restart. + user_localpart: The localpart of the delayed event's owner. + current_ts: The current time, which will be used to calculate the new send time. + + Returns: The send time of the next delayed event to be sent, + which is either the event just restarted, or another one + with an earlier send time than the restarted one's new send time. + + Raises: + NotFoundError: if there is no matching delayed event. + """ + + def restart_delayed_event_txn( + txn: LoggingTransaction, + ) -> Timestamp: + txn.execute( + """ + UPDATE delayed_events + SET send_ts = ? + delay + WHERE delay_id = ? AND user_localpart = ? + AND NOT is_processed + """, + ( + current_ts, + delay_id, + user_localpart, + ), + ) + if txn.rowcount == 0: + raise NotFoundError("Delayed event not found") + + next_send_ts = self._get_next_delayed_event_send_ts_txn(txn) + assert next_send_ts is not None + return next_send_ts + + return await self.db_pool.runInteraction( + "restart_delayed_event", restart_delayed_event_txn + ) + + async def get_all_delayed_events_for_user( + self, + user_localpart: str, + ) -> List[JsonDict]: + """Returns all pending delayed events owned by the given user.""" + # TODO: Support Pagination stream API ("next_batch" field) + rows = await self.db_pool.execute( + "get_all_delayed_events_for_user", + """ + SELECT + delay_id, + room_id, + event_type, + state_key, + delay, + send_ts, + content + FROM delayed_events + WHERE user_localpart = ? AND NOT is_processed + ORDER BY send_ts + """, + user_localpart, + ) + return [ + { + "delay_id": DelayID(row[0]), + "room_id": str(RoomID.from_string(row[1])), + "type": EventType(row[2]), + **({"state_key": StateKey(row[3])} if row[3] is not None else {}), + "delay": Delay(row[4]), + "running_since": Timestamp(row[5] - row[4]), + "content": db_to_json(row[6]), + } + for row in rows + ] + + async def process_timeout_delayed_events( + self, current_ts: Timestamp + ) -> Tuple[ + List[DelayedEventDetails], + Optional[Timestamp], + ]: + """ + Marks for processing all delayed events that should have been sent prior to the provided time + that haven't already been marked as such. + + Returns: The details of all newly-processed delayed events, + and the send time of the next delayed event to be sent, if any. + """ + + def process_timeout_delayed_events_txn( + txn: LoggingTransaction, + ) -> Tuple[ + List[DelayedEventDetails], + Optional[Timestamp], + ]: + sql_cols = ", ".join( + ( + "delay_id", + "user_localpart", + "room_id", + "event_type", + "state_key", + "origin_server_ts", + "send_ts", + "content", + "device_id", + ) + ) + sql_update = "UPDATE delayed_events SET is_processed = TRUE" + sql_where = "WHERE send_ts <= ? AND NOT is_processed" + sql_args = (current_ts,) + sql_order = "ORDER BY send_ts" + if isinstance(self.database_engine, PostgresEngine): + # Do this only in Postgres because: + # - SQLite's RETURNING emits rows in an arbitrary order + # - https://www.sqlite.org/lang_returning.html#limitations_and_caveats + # - SQLite does not support data-modifying statements in a WITH clause + # - https://www.sqlite.org/lang_with.html + # - https://www.postgresql.org/docs/current/queries-with.html#QUERIES-WITH-MODIFYING + txn.execute( + f""" + WITH events_to_send AS ( + {sql_update} {sql_where} RETURNING * + ) SELECT {sql_cols} FROM events_to_send {sql_order} + """, + sql_args, + ) + rows = txn.fetchall() + else: + txn.execute( + f"SELECT {sql_cols} FROM delayed_events {sql_where} {sql_order}", + sql_args, + ) + rows = txn.fetchall() + txn.execute(f"{sql_update} {sql_where}", sql_args) + assert txn.rowcount == len(rows) + + events = [ + DelayedEventDetails( + RoomID.from_string(row[2]), + EventType(row[3]), + StateKey(row[4]) if row[4] is not None else None, + # If no custom_origin_ts is set, use send_ts as the event's timestamp + Timestamp(row[5] if row[5] is not None else row[6]), + db_to_json(row[7]), + DeviceID(row[8]) if row[8] is not None else None, + DelayID(row[0]), + UserLocalpart(row[1]), + ) + for row in rows + ] + next_send_ts = self._get_next_delayed_event_send_ts_txn(txn) + return events, next_send_ts + + return await self.db_pool.runInteraction( + "process_timeout_delayed_events", process_timeout_delayed_events_txn + ) + + async def process_target_delayed_event( + self, + *, + delay_id: str, + user_localpart: str, + ) -> Tuple[ + EventDetails, + Optional[Timestamp], + ]: + """ + Marks for processing the matching delayed event, regardless of its timeout time, + as long as it has not already been marked as such. + + Args: + delay_id: The ID of the delayed event to restart. + user_localpart: The localpart of the delayed event's owner. + + Returns: The details of the matching delayed event, + and the send time of the next delayed event to be sent, if any. + + Raises: + NotFoundError: if there is no matching delayed event. + """ + + def process_target_delayed_event_txn( + txn: LoggingTransaction, + ) -> Tuple[ + EventDetails, + Optional[Timestamp], + ]: + sql_cols = ", ".join( + ( + "room_id", + "event_type", + "state_key", + "origin_server_ts", + "content", + "device_id", + ) + ) + sql_update = "UPDATE delayed_events SET is_processed = TRUE" + sql_where = "WHERE delay_id = ? AND user_localpart = ? AND NOT is_processed" + sql_args = (delay_id, user_localpart) + txn.execute( + ( + f"{sql_update} {sql_where} RETURNING {sql_cols}" + if self.database_engine.supports_returning + else f"SELECT {sql_cols} FROM delayed_events {sql_where}" + ), + sql_args, + ) + row = txn.fetchone() + if row is None: + raise NotFoundError("Delayed event not found") + elif not self.database_engine.supports_returning: + txn.execute(f"{sql_update} {sql_where}", sql_args) + assert txn.rowcount == 1 + + event = EventDetails( + RoomID.from_string(row[0]), + EventType(row[1]), + StateKey(row[2]) if row[2] is not None else None, + Timestamp(row[3]) if row[3] is not None else None, + db_to_json(row[4]), + DeviceID(row[5]) if row[5] is not None else None, + ) + + return event, self._get_next_delayed_event_send_ts_txn(txn) + + return await self.db_pool.runInteraction( + "process_target_delayed_event", process_target_delayed_event_txn + ) + + async def cancel_delayed_event( + self, + *, + delay_id: str, + user_localpart: str, + ) -> Optional[Timestamp]: + """ + Cancels the matching delayed event, i.e. remove it as long as it hasn't been processed. + + Args: + delay_id: The ID of the delayed event to restart. + user_localpart: The localpart of the delayed event's owner. + + Returns: The send time of the next delayed event to be sent, if any. + + Raises: + NotFoundError: if there is no matching delayed event. + """ + + def cancel_delayed_event_txn( + txn: LoggingTransaction, + ) -> Optional[Timestamp]: + try: + self.db_pool.simple_delete_one_txn( + txn, + table="delayed_events", + keyvalues={ + "delay_id": delay_id, + "user_localpart": user_localpart, + "is_processed": False, + }, + ) + except StoreError: + if txn.rowcount == 0: + raise NotFoundError("Delayed event not found") + else: + raise + + return self._get_next_delayed_event_send_ts_txn(txn) + + return await self.db_pool.runInteraction( + "cancel_delayed_event", cancel_delayed_event_txn + ) + + async def cancel_delayed_state_events( + self, + *, + room_id: str, + event_type: str, + state_key: str, + not_from_localpart: str, + ) -> Optional[Timestamp]: + """ + Cancels all matching delayed state events, i.e. remove them as long as they haven't been processed. + + Args: + room_id: The room ID to match against. + event_type: The event type to match against. + state_key: The state key to match against. + not_from_localpart: The localpart of a user whose delayed events to not cancel. + If set to the empty string, any users' delayed events may be cancelled. + + Returns: The send time of the next delayed event to be sent, if any. + """ + + def cancel_delayed_state_events_txn( + txn: LoggingTransaction, + ) -> Optional[Timestamp]: + txn.execute( + """ + DELETE FROM delayed_events + WHERE room_id = ? AND event_type = ? AND state_key = ? + AND user_localpart <> ? + AND NOT is_processed + """, + ( + room_id, + event_type, + state_key, + not_from_localpart, + ), + ) + return self._get_next_delayed_event_send_ts_txn(txn) + + return await self.db_pool.runInteraction( + "cancel_delayed_state_events", cancel_delayed_state_events_txn + ) + + async def delete_processed_delayed_event( + self, + delay_id: DelayID, + user_localpart: UserLocalpart, + ) -> None: + """ + Delete the matching delayed event, as long as it has been marked as processed. + + Throws: + StoreError: if there is no matching delayed event, or if it has not yet been processed. + """ + return await self.db_pool.simple_delete_one( + table="delayed_events", + keyvalues={ + "delay_id": delay_id, + "user_localpart": user_localpart, + "is_processed": True, + }, + desc="delete_processed_delayed_event", + ) + + async def delete_processed_delayed_state_events( + self, + *, + room_id: str, + event_type: str, + state_key: str, + ) -> None: + """ + Delete the matching delayed state events that have been marked as processed. + """ + await self.db_pool.simple_delete( + table="delayed_events", + keyvalues={ + "room_id": room_id, + "event_type": event_type, + "state_key": state_key, + "is_processed": True, + }, + desc="delete_processed_delayed_state_events", + ) + + async def unprocess_delayed_events(self) -> None: + """ + Unmark all delayed events for processing. + """ + await self.db_pool.simple_update( + table="delayed_events", + keyvalues={"is_processed": True}, + updatevalues={"is_processed": False}, + desc="unprocess_delayed_events", + ) + + async def get_next_delayed_event_send_ts(self) -> Optional[Timestamp]: + """ + Returns the send time of the next delayed event to be sent, if any. + """ + return await self.db_pool.runInteraction( + "get_next_delayed_event_send_ts", + self._get_next_delayed_event_send_ts_txn, + db_autocommit=True, + ) + + def _get_next_delayed_event_send_ts_txn( + self, txn: LoggingTransaction + ) -> Optional[Timestamp]: + result = self.db_pool.simple_select_one_onecol_txn( + txn, + table="delayed_events", + keyvalues={"is_processed": False}, + retcol="MIN(send_ts)", + allow_none=True, + ) + return Timestamp(result) if result is not None else None + + +def _generate_delay_id() -> DelayID: + """Generates an opaque string, for use as a delay ID""" + + # We use the following format for delay IDs: + # syd_<random string> + # They are scoped to user localparts, so it is possible for + # the same ID to exist for multiple users. + + return DelayID(f"syd_{stringutils.random_string(20)}") diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 042d595ea0..d47833655d 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py
@@ -200,9 +200,9 @@ class DeviceInboxWorkerStore(SQLBaseStore): to_stream_id=to_stream_id, ) - assert ( - last_processed_stream_id == to_stream_id - ), "Expected _get_device_messages to process all to-device messages up to `to_stream_id`" + assert last_processed_stream_id == to_stream_id, ( + "Expected _get_device_messages to process all to-device messages up to `to_stream_id`" + ) return user_id_device_id_to_messages @@ -1116,7 +1116,7 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore): txn.execute(sql, (start, stop)) - destinations = {d for d, in txn} + destinations = {d for (d,) in txn} to_remove = set() for d in destinations: try: diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 53024bddc3..6191f22cd6 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py
@@ -27,6 +27,7 @@ from typing import ( Dict, Iterable, List, + Literal, Mapping, Optional, Set, @@ -35,7 +36,6 @@ from typing import ( ) from canonicaljson import encode_canonical_json -from typing_extensions import Literal from synapse.api.constants import EduTypes from synapse.api.errors import Codes, StoreError @@ -282,9 +282,10 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): "count_devices_by_users", count_devices_by_users_txn, user_ids ) + @cached() async def get_device( self, user_id: str, device_id: str - ) -> Optional[Dict[str, Any]]: + ) -> Optional[Mapping[str, Any]]: """Retrieve a device. Only returns devices that are not marked as hidden. @@ -670,9 +671,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): result["keys"] = keys device_display_name = None - if ( - self.hs.config.federation.allow_device_name_lookup_over_federation - ): + if self.hs.config.federation.allow_device_name_lookup_over_federation: device_display_name = device.display_name if device_display_name: result["device_display_name"] = device_display_name @@ -917,7 +916,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): from_key, to_key, ) - return {u for u, in rows} + return {u for (u,) in rows} @cancellable async def get_users_whose_devices_changed( @@ -968,7 +967,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): txn.database_engine, "user_id", chunk ) txn.execute(sql % (clause,), [from_key, to_key] + args) - changes.update(user_id for user_id, in txn) + changes.update(user_id for (user_id,) in txn) return changes @@ -1093,7 +1092,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): ), ) - results: Dict[str, Optional[str]] = {user_id: None for user_id in user_ids} + results: Dict[str, Optional[str]] = dict.fromkeys(user_ids) results.update(rows) return results @@ -1424,7 +1423,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): DELETE FROM device_lists_outbound_last_success WHERE destination = ? AND user_id = ? """ - txn.execute_batch(sql, ((row[0], row[1]) for row in rows)) + txn.execute_batch(sql, [(row[0], row[1]) for row in rows]) logger.info("Pruned %d device list outbound pokes", count) @@ -1520,7 +1519,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): args: List[Any], ) -> Set[str]: txn.execute(sql.format(clause=clause), args) - return {user_id for user_id, in txn} + return {user_id for (user_id,) in txn} changes = set() for chunk in batch_iter(changed_room_ids, 1000): @@ -1560,7 +1559,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): txn: LoggingTransaction, ) -> Set[str]: txn.execute(sql, (from_id, to_id)) - return {room_id for room_id, in txn} + return {room_id for (room_id,) in txn} return await self.db_pool.runInteraction( "get_all_device_list_changes", @@ -1819,6 +1818,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): }, desc="store_device", ) + await self.invalidate_cache_and_stream("get_device", (user_id, device_id)) + if not inserted: # if the device already exists, check if it's a real device, or # if the device ID is reserved by something else @@ -1884,6 +1885,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): values=device_ids, keyvalues={"user_id": user_id}, ) + self._invalidate_cache_and_stream_bulk( + txn, self.get_device, [(user_id, device_id) for device_id in device_ids] + ) for batch in batch_iter(device_ids, 100): await self.db_pool.runInteraction( @@ -1917,6 +1921,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): updatevalues=updates, desc="update_device", ) + await self.invalidate_cache_and_stream("get_device", (user_id, device_id)) async def update_remote_device_list_cache_entry( self, user_id: str, device_id: str, content: JsonDict, stream_id: str diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py
index 4d6a921ab2..904ae5cb58 100644 --- a/synapse/storage/databases/main/e2e_room_keys.py +++ b/synapse/storage/databases/main/e2e_room_keys.py
@@ -19,9 +19,18 @@ # # -from typing import TYPE_CHECKING, Dict, Iterable, List, Mapping, Optional, Tuple, cast - -from typing_extensions import Literal, TypedDict +from typing import ( + TYPE_CHECKING, + Dict, + Iterable, + List, + Literal, + Mapping, + Optional, + Tuple, + TypedDict, + cast, +) from synapse.api.errors import StoreError from synapse.logging.opentracing import log_kv, trace @@ -387,9 +396,7 @@ class EndToEndRoomKeyStore(EndToEndRoomKeyBackgroundStore): is_verified, session_data FROM e2e_room_keys WHERE user_id = ? AND version = ? AND (%s) - """ % ( - " OR ".join(where_clauses) - ) + """ % (" OR ".join(where_clauses)) txn.execute(sql, params) @@ -512,19 +519,16 @@ class EndToEndRoomKeyStore(EndToEndRoomKeyBackgroundStore): # it isn't there. raise StoreError(404, "No backup with that version exists") - row = cast( - Tuple[int, str, str, Optional[int]], - self.db_pool.simple_select_one_txn( - txn, - table="e2e_room_keys_versions", - keyvalues={ - "user_id": user_id, - "version": this_version, - "deleted": 0, - }, - retcols=("version", "algorithm", "auth_data", "etag"), - allow_none=False, - ), + row = self.db_pool.simple_select_one_txn( + txn, + table="e2e_room_keys_versions", + keyvalues={ + "user_id": user_id, + "version": this_version, + "deleted": 0, + }, + retcols=("version", "algorithm", "auth_data", "etag"), + allow_none=False, ) return { "auth_data": db_to_json(row[2]), diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 9e6c9561ae..341e7014d6 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -27,6 +27,7 @@ from typing import ( Dict, Iterable, List, + Literal, Mapping, Optional, Sequence, @@ -39,7 +40,6 @@ from typing import ( import attr from canonicaljson import encode_canonical_json -from typing_extensions import Literal from synapse.api.constants import DeviceKeyAlgorithms from synapse.appservice import ( @@ -99,6 +99,13 @@ class EndToEndKeyBackgroundStore(SQLBaseStore): unique=True, ) + self.db_pool.updates.register_background_index_update( + update_name="add_otk_ts_added_index", + index_name="e2e_one_time_keys_json_user_id_device_id_algorithm_ts_added_idx", + table="e2e_one_time_keys_json", + columns=("user_id", "device_id", "algorithm", "ts_added_ms"), + ) + class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorkerStore): def __init__( @@ -472,9 +479,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker signature_sql = """ SELECT user_id, key_id, target_device_id, signature FROM e2e_cross_signing_signatures WHERE %s - """ % ( - " OR ".join("(" + q + ")" for q in signature_query_clauses) - ) + """ % (" OR ".join("(" + q + ")" for q in signature_query_clauses)) txn.execute(signature_sql, signature_query_params) return cast( @@ -917,9 +922,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker FROM e2e_cross_signing_keys WHERE %(clause)s ORDER BY user_id, keytype, stream_id DESC - """ % { - "clause": clause - } + """ % {"clause": clause} else: # SQLite has special handling for bare columns when using # MIN/MAX with a `GROUP BY` clause where it picks the value from @@ -929,9 +932,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker FROM e2e_cross_signing_keys WHERE %(clause)s GROUP BY user_id, keytype - """ % { - "clause": clause - } + """ % {"clause": clause} txn.execute(sql, params) @@ -1128,7 +1129,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker """Take a list of one time keys out of the database. Args: - query_list: An iterable of tuples of (user ID, device ID, algorithm). + query_list: An iterable of tuples of (user ID, device ID, algorithm, number of keys). Returns: A tuple (results, missing) of: @@ -1316,9 +1317,14 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker OTK was found. """ + # Return the oldest keys from this device (based on `ts_added_ms`). + # Doing so means that keys are issued in the same order they were uploaded, + # which reduces the chances of a client expiring its copy of a (private) + # key while the public key is still on the server, waiting to be issued. sql = """ SELECT key_id, key_json FROM e2e_one_time_keys_json WHERE user_id = ? AND device_id = ? AND algorithm = ? + ORDER BY ts_added_ms LIMIT ? """ @@ -1360,13 +1366,22 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker A list of tuples (user_id, device_id, algorithm, key_id, key_json) for each OTK claimed. """ + # Find, delete, and return the oldest keys from each device (based on + # `ts_added_ms`). + # + # Doing so means that keys are issued in the same order they were uploaded, + # which reduces the chances of a client expiring its copy of a (private) + # key while the public key is still on the server, waiting to be issued. sql = """ WITH claims(user_id, device_id, algorithm, claim_count) AS ( VALUES ? ), ranked_keys AS ( SELECT user_id, device_id, algorithm, key_id, claim_count, - ROW_NUMBER() OVER (PARTITION BY (user_id, device_id, algorithm)) AS r + ROW_NUMBER() OVER ( + PARTITION BY (user_id, device_id, algorithm) + ORDER BY ts_added_ms + ) AS r FROM e2e_one_time_keys_json JOIN claims USING (user_id, device_id, algorithm) ) @@ -1438,6 +1453,93 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker impl, ) + async def delete_old_otks_for_next_user_batch( + self, after_user_id: str, number_of_users: int + ) -> Tuple[List[str], int]: + """Deletes old OTKs belonging to the next batch of users + + Returns: + `(users, rows)`, where: + * `users` is the user IDs of the updated users. An empty list if we are done. + * `rows` is the number of deleted rows + """ + + def impl(txn: LoggingTransaction) -> Tuple[List[str], int]: + # Find a batch of users + txn.execute( + """ + SELECT DISTINCT(user_id) FROM e2e_one_time_keys_json + WHERE user_id > ? + ORDER BY user_id + LIMIT ? + """, + (after_user_id, number_of_users), + ) + users = [row[0] for row in txn.fetchall()] + if len(users) == 0: + return users, 0 + + # Delete any old OTKs belonging to those users. + # + # We only actually consider OTKs whose key ID is 6 characters long. These + # keys were likely made by libolm rather than Vodozemac; libolm only kept + # 100 private OTKs, so was far more vulnerable than Vodozemac to throwing + # away keys prematurely. + clause, args = make_in_list_sql_clause( + txn.database_engine, "user_id", users + ) + sql = f""" + DELETE FROM e2e_one_time_keys_json + WHERE {clause} AND ts_added_ms < ? AND length(key_id) = 6 + """ + args.append(self._clock.time_msec() - (7 * 24 * 3600 * 1000)) + txn.execute(sql, args) + + return users, txn.rowcount + + return await self.db_pool.runInteraction( + "delete_old_otks_for_next_user_batch", impl + ) + + async def allow_master_cross_signing_key_replacement_without_uia( + self, user_id: str, duration_ms: int + ) -> Optional[int]: + """Mark this user's latest master key as being replaceable without UIA. + + Said replacement will only be permitted for a short time after calling this + function. That time period is controlled by the duration argument. + + Returns: + None, if there is no such key. + Otherwise, the timestamp before which replacement is allowed without UIA. + """ + timestamp = self._clock.time_msec() + duration_ms + + def impl(txn: LoggingTransaction) -> Optional[int]: + txn.execute( + """ + UPDATE e2e_cross_signing_keys + SET updatable_without_uia_before_ms = ? + WHERE stream_id = ( + SELECT stream_id + FROM e2e_cross_signing_keys + WHERE user_id = ? AND keytype = 'master' + ORDER BY stream_id DESC + LIMIT 1 + ) + """, + (timestamp, user_id), + ) + if txn.rowcount == 0: + return None + + return timestamp + + return await self.db_pool.runInteraction( + "allow_master_cross_signing_key_replacement_without_uia", + impl, + ) + class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): def __init__( @@ -1692,42 +1794,3 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): ], desc="add_e2e_signing_key", ) - - async def allow_master_cross_signing_key_replacement_without_uia( - self, user_id: str, duration_ms: int - ) -> Optional[int]: - """Mark this user's latest master key as being replaceable without UIA. - - Said replacement will only be permitted for a short time after calling this - function. That time period is controlled by the duration argument. - - Returns: - None, if there is no such key. - Otherwise, the timestamp before which replacement is allowed without UIA. - """ - timestamp = self._clock.time_msec() + duration_ms - - def impl(txn: LoggingTransaction) -> Optional[int]: - txn.execute( - """ - UPDATE e2e_cross_signing_keys - SET updatable_without_uia_before_ms = ? - WHERE stream_id = ( - SELECT stream_id - FROM e2e_cross_signing_keys - WHERE user_id = ? AND keytype = 'master' - ORDER BY stream_id DESC - LIMIT 1 - ) - """, - (timestamp, user_id), - ) - if txn.rowcount == 0: - return None - - return timestamp - - return await self.db_pool.runInteraction( - "allow_master_cross_signing_key_replacement_without_uia", - impl, - ) diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 715846865b..46aa5902d8 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py
@@ -326,7 +326,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas """ rows = txn.execute_values(sql, chains.items()) - results.update(r for r, in rows) + results.update(r for (r,) in rows) else: # For SQLite we just fall back to doing a noddy for loop. sql = """ @@ -335,7 +335,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas """ for chain_id, max_no in chains.items(): txn.execute(sql, (chain_id, max_no)) - results.update(r for r, in txn) + results.update(r for (r,) in txn) return results @@ -645,7 +645,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas ] rows = txn.execute_values(sql, args) - result.update(r for r, in rows) + result.update(r for (r,) in rows) else: # For SQLite we just fall back to doing a noddy for loop. sql = """ @@ -654,7 +654,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas """ for chain_id, (min_no, max_no) in chain_to_gap.items(): txn.execute(sql, (chain_id, min_no, max_no)) - result.update(r for r, in txn) + result.update(r for (r,) in txn) return result @@ -1220,13 +1220,11 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas HAVING count(*) > ? ORDER BY count(*) DESC LIMIT ? - """ % ( - where_clause, - ) + """ % (where_clause,) query_args = list(itertools.chain(room_id_filter, [min_count, limit])) txn.execute(sql, query_args) - return [room_id for room_id, in txn] + return [room_id for (room_id,) in txn] return await self.db_pool.runInteraction( "get_rooms_with_many_extremities", _get_rooms_with_many_extremities_txn @@ -1358,7 +1356,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas def get_forward_extremeties_for_room_txn(txn: LoggingTransaction) -> List[str]: txn.execute(sql, (stream_ordering, room_id)) - return [event_id for event_id, in txn] + return [event_id for (event_id,) in txn] event_ids = await self.db_pool.runInteraction( "get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 0ebf5b53d5..6fb4a6df8c 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py
@@ -1034,97 +1034,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas # one of the subqueries may have hit the limit. return notifs[:limit] - async def get_unread_push_actions_for_user_in_range_for_email( - self, - user_id: str, - min_stream_ordering: int, - max_stream_ordering: int, - limit: int = 20, - ) -> List[EmailPushAction]: - """Get a list of the most recent unread push actions for a given user, - within the given stream ordering range. Called by the emailpusher - - Args: - user_id: The user to fetch push actions for. - min_stream_ordering: The exclusive lower bound on the - stream ordering of event push actions to fetch. - max_stream_ordering: The inclusive upper bound on the - stream ordering of event push actions to fetch. - limit: The maximum number of rows to return. - Returns: - A list of dicts with the keys "event_id", "room_id", "stream_ordering", "actions", "received_ts". - The list will be ordered by descending received_ts. - The list will have between 0~limit entries. - """ - - def get_push_actions_txn( - txn: LoggingTransaction, - ) -> List[Tuple[str, str, str, int, str, bool, int]]: - sql = """ - SELECT ep.event_id, ep.room_id, ep.thread_id, ep.stream_ordering, - ep.actions, ep.highlight, e.received_ts - FROM event_push_actions AS ep - INNER JOIN events AS e USING (room_id, event_id) - WHERE - ep.user_id = ? - AND ep.stream_ordering > ? - AND ep.stream_ordering <= ? - AND ep.notif = 1 - ORDER BY ep.stream_ordering DESC LIMIT ? - """ - txn.execute(sql, (user_id, min_stream_ordering, max_stream_ordering, limit)) - return cast(List[Tuple[str, str, str, int, str, bool, int]], txn.fetchall()) - - push_actions = await self.db_pool.runInteraction( - "get_unread_push_actions_for_user_in_range_email", get_push_actions_txn - ) - - room_ids = set() - thread_ids = [] - for ( - _, - room_id, - thread_id, - _, - _, - _, - _, - ) in push_actions: - room_ids.add(room_id) - thread_ids.append(thread_id) - - receipts_by_room = await self.db_pool.runInteraction( - "get_unread_push_actions_for_user_in_range_email_receipts", - self._get_receipts_for_room_and_threads_txn, - user_id=user_id, - room_ids=room_ids, - thread_ids=thread_ids, - ) - - # Make a list of dicts from the two sets of results. - notifs = [ - EmailPushAction( - event_id=event_id, - room_id=room_id, - stream_ordering=stream_ordering, - actions=_deserialize_action(actions, highlight), - received_ts=received_ts, - ) - for event_id, room_id, thread_id, stream_ordering, actions, highlight, received_ts in push_actions - if receipts_by_room.get(room_id, MISSING_ROOM_RECEIPT).is_unread( - thread_id, stream_ordering - ) - ] - - # Now sort it so it's ordered correctly, since currently it will - # contain results from the first query, correctly ordered, followed - # by results from the second query, but we want them all ordered - # by received_ts (most recent first) - notifs.sort(key=lambda r: -(r.received_ts or 0)) - - # Now return the first `limit` - return notifs[:limit] - async def get_if_maybe_push_in_range_for_user( self, user_id: str, min_stream_ordering: int ) -> bool: @@ -1860,9 +1769,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas AND epa.notif = 1 ORDER BY epa.stream_ordering DESC LIMIT ? - """ % ( - before_clause, - ) + """ % (before_clause,) txn.execute(sql, args) return cast( List[Tuple[str, str, int, int, str, bool, str, int]], txn.fetchall() diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 1f7acdb859..b7cc0433e7 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py
@@ -32,8 +32,10 @@ from typing import ( Iterable, List, Optional, + Sequence, Set, Tuple, + TypedDict, cast, ) @@ -41,17 +43,24 @@ import attr from prometheus_client import Counter import synapse.metrics -from synapse.api.constants import EventContentFields, EventTypes, RelationTypes +from synapse.api.constants import ( + EventContentFields, + EventTypes, + Membership, + RelationTypes, +) from synapse.api.errors import PartialStateConflictError from synapse.api.room_versions import RoomVersions -from synapse.events import EventBase, relation_from_event +from synapse.events import EventBase, StrippedStateEvent, relation_from_event from synapse.events.snapshot import EventContext +from synapse.events.utils import parse_stripped_state_event from synapse.logging.opentracing import trace from synapse.storage._base import db_to_json, make_in_list_sql_clause from synapse.storage.database import ( DatabasePool, LoggingDatabaseConnection, LoggingTransaction, + make_tuple_in_list_sql_clause, ) from synapse.storage.databases.main.event_federation import EventFederationStore from synapse.storage.databases.main.events_worker import EventCacheEntry @@ -59,7 +68,15 @@ from synapse.storage.databases.main.search import SearchEntry from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import AbstractStreamIdGenerator from synapse.storage.util.sequence import SequenceGenerator -from synapse.types import JsonDict, StateMap, StrCollection, get_domain_from_id +from synapse.types import ( + JsonDict, + MutableStateMap, + StateMap, + StrCollection, + get_domain_from_id, +) +from synapse.types.handlers import SLIDING_SYNC_DEFAULT_BUMP_EVENT_TYPES +from synapse.types.state import StateFilter from synapse.util import json_encoder from synapse.util.iterutils import batch_iter, sorted_topologically from synapse.util.stringutils import non_null_str_or_none @@ -78,6 +95,19 @@ event_counter = Counter( ["type", "origin_type", "origin_entity"], ) +# State event type/key pairs that we need to gather to fill in the +# `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` tables. +SLIDING_SYNC_RELEVANT_STATE_SET = ( + # So we can fill in the `room_type` column + (EventTypes.Create, ""), + # So we can fill in the `is_encrypted` column + (EventTypes.RoomEncryption, ""), + # So we can fill in the `room_name` column + (EventTypes.Name, ""), + # So we can fill in the `tombstone_successor_room_id` column + (EventTypes.Tombstone, ""), +) + @attr.s(slots=True, auto_attribs=True) class DeltaState: @@ -99,6 +129,80 @@ class DeltaState: return not self.to_delete and not self.to_insert and not self.no_longer_in_room +# We want `total=False` because we want to allow values to be unset. +class SlidingSyncStateInsertValues(TypedDict, total=False): + """ + Insert values relevant for the `sliding_sync_joined_rooms` and + `sliding_sync_membership_snapshots` database tables. + """ + + room_type: Optional[str] + is_encrypted: Optional[bool] + room_name: Optional[str] + tombstone_successor_room_id: Optional[str] + + +class SlidingSyncMembershipSnapshotSharedInsertValues( + SlidingSyncStateInsertValues, total=False +): + """ + Insert values for `sliding_sync_membership_snapshots` that we can share across + multiple memberships + """ + + has_known_state: Optional[bool] + + +@attr.s(slots=True, auto_attribs=True) +class SlidingSyncMembershipInfo: + """ + Values unique to each membership + """ + + user_id: str + sender: str + membership_event_id: str + membership: str + + +@attr.s(slots=True, auto_attribs=True) +class SlidingSyncMembershipInfoWithEventPos(SlidingSyncMembershipInfo): + """ + SlidingSyncMembershipInfo + `stream_ordering`/`instance_name` of the membership + event + """ + + membership_event_stream_ordering: int + membership_event_instance_name: str + + +@attr.s(slots=True, auto_attribs=True) +class SlidingSyncTableChanges: + room_id: str + # If the row doesn't exist in the `sliding_sync_joined_rooms` table, we need to + # fully-insert it which means we also need to include a `bump_stamp` value to use + # for the row. This should only be populated when we're trying to fully-insert a + # row. + # + # FIXME: This can be removed once we bump `SCHEMA_COMPAT_VERSION` and run the + # foreground update for + # `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` (tracked by + # https://github.com/element-hq/synapse/issues/17623) + joined_room_bump_stamp_to_fully_insert: Optional[int] + # Values to upsert into `sliding_sync_joined_rooms` + joined_room_updates: SlidingSyncStateInsertValues + + # Shared values to upsert into `sliding_sync_membership_snapshots` for each + # `to_insert_membership_snapshots` + membership_snapshot_shared_insert_values: ( + SlidingSyncMembershipSnapshotSharedInsertValues + ) + # List of membership to insert into `sliding_sync_membership_snapshots` + to_insert_membership_snapshots: List[SlidingSyncMembershipInfo] + # List of user_id to delete from `sliding_sync_membership_snapshots` + to_delete_membership_snapshots: List[str] + + @attr.s(slots=True, auto_attribs=True) class NewEventChainLinks: """Information about new auth chain links that need to be added to the DB. @@ -142,9 +246,9 @@ class PersistEventsStore: self.is_mine_id = hs.is_mine_id # This should only exist on instances that are configured to write - assert ( - hs.get_instance_name() in hs.config.worker.writers.events - ), "Can only instantiate EventsStore on master" + assert hs.get_instance_name() in hs.config.worker.writers.events, ( + "Can only instantiate EventsStore on master" + ) # Since we have been configured to write, we ought to have id generators, # rather than id trackers. @@ -223,9 +327,24 @@ class PersistEventsStore: async with stream_ordering_manager as stream_orderings: for (event, _), stream in zip(events_and_contexts, stream_orderings): + # XXX: We can't rely on `stream_ordering`/`instance_name` being correct + # at this point. We could be working with events that were previously + # persisted as an `outlier` with one `stream_ordering` but are now being + # persisted again and de-outliered and are being assigned a different + # `stream_ordering` here that won't end up being used. + # `_update_outliers_txn()` will fix this discrepancy (always use the + # `stream_ordering` from the first time it was persisted). event.internal_metadata.stream_ordering = stream event.internal_metadata.instance_name = self._instance_name + sliding_sync_table_changes = None + if state_delta_for_room is not None: + sliding_sync_table_changes = ( + await self._calculate_sliding_sync_table_changes( + room_id, events_and_contexts, state_delta_for_room + ) + ) + await self.db_pool.runInteraction( "persist_events", self._persist_events_txn, @@ -235,6 +354,7 @@ class PersistEventsStore: state_delta_for_room=state_delta_for_room, new_forward_extremities=new_forward_extremities, new_event_links=new_event_links, + sliding_sync_table_changes=sliding_sync_table_changes, ) persist_event_counter.inc(len(events_and_contexts)) @@ -261,6 +381,301 @@ class PersistEventsStore: (room_id,), frozenset(new_forward_extremities) ) + async def _calculate_sliding_sync_table_changes( + self, + room_id: str, + events_and_contexts: Sequence[Tuple[EventBase, EventContext]], + delta_state: DeltaState, + ) -> SlidingSyncTableChanges: + """ + Calculate the changes to the `sliding_sync_membership_snapshots` and + `sliding_sync_joined_rooms` tables given the deltas that are going to be used to + update the `current_state_events` table. + + Just a bunch of pre-processing so we so we don't need to spend time in the + transaction itself gathering all of this info. It's also easier to deal with + redactions outside of a transaction. + + Args: + room_id: The room ID currently being processed. + events_and_contexts: List of tuples of (event, context) being persisted. + This is completely optional (you can pass an empty list) and will just + save us from fetching the events from the database if we already have + them. We assume the list is sorted ascending by `stream_ordering`. We + don't care about the sort when the events are backfilled (with negative + `stream_ordering`). + delta_state: Deltas that are going to be used to update the + `current_state_events` table. Changes to the current state of the room. + + Returns: + SlidingSyncTableChanges + """ + to_insert = delta_state.to_insert + to_delete = delta_state.to_delete + + # If no state is changing, we don't need to do anything. This can happen when a + # partial-stated room is re-syncing the current state. + if not to_insert and not to_delete: + return SlidingSyncTableChanges( + room_id=room_id, + joined_room_bump_stamp_to_fully_insert=None, + joined_room_updates={}, + membership_snapshot_shared_insert_values={}, + to_insert_membership_snapshots=[], + to_delete_membership_snapshots=[], + ) + + event_map = {event.event_id: event for event, _ in events_and_contexts} + + # Handle gathering info for the `sliding_sync_membership_snapshots` table + # + # This would only happen if someone was state reset out of the room + user_ids_to_delete_membership_snapshots = [ + state_key + for event_type, state_key in to_delete + if event_type == EventTypes.Member and self.is_mine_id(state_key) + ] + + membership_snapshot_shared_insert_values: SlidingSyncMembershipSnapshotSharedInsertValues = {} + membership_infos_to_insert_membership_snapshots: List[ + SlidingSyncMembershipInfo + ] = [] + if to_insert: + membership_event_id_to_user_id_map: Dict[str, str] = {} + for state_key, event_id in to_insert.items(): + if state_key[0] == EventTypes.Member and self.is_mine_id(state_key[1]): + membership_event_id_to_user_id_map[event_id] = state_key[1] + + membership_event_map: Dict[str, EventBase] = {} + # In normal event persist scenarios, we should be able to find the + # membership events in the `events_and_contexts` given to us but it's + # possible a state reset happened which added us to the room without a + # corresponding new membership event (reset back to a previous membership). + missing_membership_event_ids: Set[str] = set() + for membership_event_id in membership_event_id_to_user_id_map.keys(): + membership_event = event_map.get(membership_event_id) + if membership_event: + membership_event_map[membership_event_id] = membership_event + else: + missing_membership_event_ids.add(membership_event_id) + + # Otherwise, we need to find a couple events that we were reset to. + if missing_membership_event_ids: + remaining_events = await self.store.get_events( + missing_membership_event_ids + ) + # There shouldn't be any missing events + assert remaining_events.keys() == missing_membership_event_ids, ( + missing_membership_event_ids.difference(remaining_events.keys()) + ) + membership_event_map.update(remaining_events) + + for ( + membership_event_id, + user_id, + ) in membership_event_id_to_user_id_map.items(): + membership_infos_to_insert_membership_snapshots.append( + # XXX: We don't use `SlidingSyncMembershipInfoWithEventPos` here + # because we're sourcing the event from `events_and_contexts`, we + # can't rely on `stream_ordering`/`instance_name` being correct at + # this point. We could be working with events that were previously + # persisted as an `outlier` with one `stream_ordering` but are now + # being persisted again and de-outliered and assigned a different + # `stream_ordering` that won't end up being used. Since we call + # `_calculate_sliding_sync_table_changes()` before + # `_update_outliers_txn()` which fixes this discrepancy (always use + # the `stream_ordering` from the first time it was persisted), we're + # working with an unreliable `stream_ordering` value that will + # possibly be unused and not make it into the `events` table. + SlidingSyncMembershipInfo( + user_id=user_id, + sender=membership_event_map[membership_event_id].sender, + membership_event_id=membership_event_id, + membership=membership_event_map[membership_event_id].membership, + ) + ) + + if membership_infos_to_insert_membership_snapshots: + current_state_ids_map: MutableStateMap[str] = dict( + await self.store.get_partial_filtered_current_state_ids( + room_id, + state_filter=StateFilter.from_types( + SLIDING_SYNC_RELEVANT_STATE_SET + ), + ) + ) + # Since we fetched the current state before we took `to_insert`/`to_delete` + # into account, we need to do a couple fixups. + # + # Update the current_state_map with what we have `to_delete` + for state_key in to_delete: + current_state_ids_map.pop(state_key, None) + # Update the current_state_map with what we have `to_insert` + for state_key, event_id in to_insert.items(): + if state_key in SLIDING_SYNC_RELEVANT_STATE_SET: + current_state_ids_map[state_key] = event_id + + current_state_map: MutableStateMap[EventBase] = {} + # In normal event persist scenarios, we probably won't be able to find + # these state events in `events_and_contexts` since we don't generally + # batch up local membership changes with other events, but it can + # happen. + missing_state_event_ids: Set[str] = set() + for state_key, event_id in current_state_ids_map.items(): + event = event_map.get(event_id) + if event: + current_state_map[state_key] = event + else: + missing_state_event_ids.add(event_id) + + # Otherwise, we need to find a couple events + if missing_state_event_ids: + remaining_events = await self.store.get_events( + missing_state_event_ids + ) + # There shouldn't be any missing events + assert remaining_events.keys() == missing_state_event_ids, ( + missing_state_event_ids.difference(remaining_events.keys()) + ) + for event in remaining_events.values(): + current_state_map[(event.type, event.state_key)] = event + + if current_state_map: + state_insert_values = PersistEventsStore._get_sliding_sync_insert_values_from_state_map( + current_state_map + ) + membership_snapshot_shared_insert_values.update(state_insert_values) + # We have current state to work from + membership_snapshot_shared_insert_values["has_known_state"] = True + else: + # We don't have any `current_state_events` anymore (previously + # cleared out because of `no_longer_in_room`). This can happen if + # one user is joined and another is invited (some non-join + # membership). If the joined user leaves, we are `no_longer_in_room` + # and `current_state_events` is cleared out. When the invited user + # rejects the invite (leaves the room), we will end up here. + # + # In these cases, we should inherit the meta data from the previous + # snapshot so we shouldn't update any of the state values. When + # using sliding sync filters, this will prevent the room from + # disappearing/appearing just because you left the room. + # + # Ideally, we could additionally assert that we're only here for + # valid non-join membership transitions. + assert delta_state.no_longer_in_room + + # Handle gathering info for the `sliding_sync_joined_rooms` table + # + # We only deal with + # updating the state related columns. The + # `event_stream_ordering`/`bump_stamp` are updated elsewhere in the event + # persisting stack (see + # `_update_sliding_sync_tables_with_new_persisted_events_txn()`) + # + joined_room_updates: SlidingSyncStateInsertValues = {} + bump_stamp_to_fully_insert: Optional[int] = None + if not delta_state.no_longer_in_room: + current_state_ids_map = {} + + # Always fully-insert rows if they don't already exist in the + # `sliding_sync_joined_rooms` table. This way we can rely on a row if it + # exists in the table. + # + # FIXME: This can be removed once we bump `SCHEMA_COMPAT_VERSION` and run the + # foreground update for + # `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` (tracked by + # https://github.com/element-hq/synapse/issues/17623) + existing_row_in_table = await self.store.db_pool.simple_select_one_onecol( + table="sliding_sync_joined_rooms", + keyvalues={"room_id": room_id}, + retcol="room_id", + allow_none=True, + ) + if not existing_row_in_table: + most_recent_bump_event_pos_results = ( + await self.store.get_last_event_pos_in_room( + room_id, + event_types=SLIDING_SYNC_DEFAULT_BUMP_EVENT_TYPES, + ) + ) + if most_recent_bump_event_pos_results is not None: + _, new_bump_event_pos = most_recent_bump_event_pos_results + + # If we've just joined a remote room, then the last bump event may + # have been backfilled (and so have a negative stream ordering). + # These negative stream orderings can't sensibly be compared, so + # instead just leave it as `None` in the table and we will use their + # membership event position as the bump event position in the + # Sliding Sync API. + if new_bump_event_pos.stream > 0: + bump_stamp_to_fully_insert = new_bump_event_pos.stream + + current_state_ids_map = dict( + await self.store.get_partial_filtered_current_state_ids( + room_id, + state_filter=StateFilter.from_types( + SLIDING_SYNC_RELEVANT_STATE_SET + ), + ) + ) + + # Look through the items we're going to insert into the current state to see + # if there is anything that we care about and should also update in the + # `sliding_sync_joined_rooms` table. + for state_key, event_id in to_insert.items(): + if state_key in SLIDING_SYNC_RELEVANT_STATE_SET: + current_state_ids_map[state_key] = event_id + + # Get the full event objects for the current state events + # + # In normal event persist scenarios, we should be able to find the state + # events in the `events_and_contexts` given to us but it's possible a state + # reset happened which that reset back to a previous state. + current_state_map = {} + missing_event_ids: Set[str] = set() + for state_key, event_id in current_state_ids_map.items(): + event = event_map.get(event_id) + if event: + current_state_map[state_key] = event + else: + missing_event_ids.add(event_id) + + # Otherwise, we need to find a couple events that we were reset to. + if missing_event_ids: + remaining_events = await self.store.get_events(missing_event_ids) + # There shouldn't be any missing events + assert remaining_events.keys() == missing_event_ids, ( + missing_event_ids.difference(remaining_events.keys()) + ) + for event in remaining_events.values(): + current_state_map[(event.type, event.state_key)] = event + + joined_room_updates = ( + PersistEventsStore._get_sliding_sync_insert_values_from_state_map( + current_state_map + ) + ) + + # If something is being deleted from the state, we need to clear it out + for state_key in to_delete: + if state_key == (EventTypes.Create, ""): + joined_room_updates["room_type"] = None + elif state_key == (EventTypes.RoomEncryption, ""): + joined_room_updates["is_encrypted"] = False + elif state_key == (EventTypes.Name, ""): + joined_room_updates["room_name"] = None + + return SlidingSyncTableChanges( + room_id=room_id, + # For `sliding_sync_joined_rooms` + joined_room_bump_stamp_to_fully_insert=bump_stamp_to_fully_insert, + joined_room_updates=joined_room_updates, + # For `sliding_sync_membership_snapshots` + membership_snapshot_shared_insert_values=membership_snapshot_shared_insert_values, + to_insert_membership_snapshots=membership_infos_to_insert_membership_snapshots, + to_delete_membership_snapshots=user_ids_to_delete_membership_snapshots, + ) + async def calculate_chain_cover_index_for_events( self, room_id: str, events: Collection[EventBase] ) -> Dict[str, NewEventChainLinks]: @@ -315,7 +730,7 @@ class PersistEventsStore: keyvalues={}, retcols=("event_id",), ) - already_persisted_events = {event_id for event_id, in rows} + already_persisted_events = {event_id for (event_id,) in rows} state_events = [ event for event in state_events @@ -458,6 +873,7 @@ class PersistEventsStore: state_delta_for_room: Optional[DeltaState], new_forward_extremities: Optional[Set[str]], new_event_links: Dict[str, NewEventChainLinks], + sliding_sync_table_changes: Optional[SlidingSyncTableChanges], ) -> None: """Insert some number of room events into the necessary database tables. @@ -478,9 +894,14 @@ class PersistEventsStore: delete_existing True to purge existing table rows for the events from the database. This is useful when retrying due to IntegrityError. - state_delta_for_room: The current-state delta for the room. + state_delta_for_room: Deltas that are going to be used to update the + `current_state_events` table. Changes to the current state of the room. new_forward_extremities: The new forward extremities for the room: a set of the event ids which are the forward extremities. + sliding_sync_table_changes: Changes to the + `sliding_sync_membership_snapshots` and `sliding_sync_joined_rooms` tables + derived from the given `delta_state` (see + `_calculate_sliding_sync_table_changes(...)`) Raises: PartialStateConflictError: if attempting to persist a partial state event in @@ -590,10 +1011,22 @@ class PersistEventsStore: # room_memberships, where applicable. # NB: This function invalidates all state related caches if state_delta_for_room: + # If the state delta exists, the sliding sync table changes should also exist + assert sliding_sync_table_changes is not None + self._update_current_state_txn( - txn, room_id, state_delta_for_room, min_stream_order + txn, + room_id, + state_delta_for_room, + min_stream_order, + sliding_sync_table_changes, ) + # We only update the sliding sync tables for non-backfilled events. + self._update_sliding_sync_tables_with_new_persisted_events_txn( + txn, room_id, events_and_contexts + ) + def _persist_event_auth_chain_txn( self, txn: LoggingTransaction, @@ -1128,8 +1561,20 @@ class PersistEventsStore: self, room_id: str, state_delta: DeltaState, + sliding_sync_table_changes: SlidingSyncTableChanges, ) -> None: - """Update the current state stored in the datatabase for the given room""" + """ + Update the current state stored in the datatabase for the given room + + Args: + room_id + state_delta: Deltas that are going to be used to update the + `current_state_events` table. Changes to the current state of the room. + sliding_sync_table_changes: Changes to the + `sliding_sync_membership_snapshots` and `sliding_sync_joined_rooms` tables + derived from the given `delta_state` (see + `_calculate_sliding_sync_table_changes(...)`) + """ if state_delta.is_noop(): return @@ -1141,6 +1586,7 @@ class PersistEventsStore: room_id, delta_state=state_delta, stream_id=stream_ordering, + sliding_sync_table_changes=sliding_sync_table_changes, ) def _update_current_state_txn( @@ -1149,16 +1595,40 @@ class PersistEventsStore: room_id: str, delta_state: DeltaState, stream_id: int, + sliding_sync_table_changes: SlidingSyncTableChanges, ) -> None: + """ + Handles updating tables that track the current state of a room. + + Args: + txn + room_id + delta_state: Deltas that are going to be used to update the + `current_state_events` table. Changes to the current state of the room. + stream_id: This is expected to be the minimum `stream_ordering` for the + batch of events that we are persisting; which means we do not end up in a + situation where workers see events before the `current_state_delta` updates. + FIXME: However, this function also gets called with next upcoming + `stream_ordering` when we re-sync the state of a partial stated room (see + `update_current_state(...)`) which may be "correct" but it would be good to + nail down what exactly is the expected value here. + sliding_sync_table_changes: Changes to the + `sliding_sync_membership_snapshots` and `sliding_sync_joined_rooms` tables + derived from the given `delta_state` (see + `_calculate_sliding_sync_table_changes(...)`) + """ to_delete = delta_state.to_delete to_insert = delta_state.to_insert + # Sanity check we're processing the same thing + assert room_id == sliding_sync_table_changes.room_id + # Figure out the changes of membership to invalidate the # `get_rooms_for_user` cache. # We find out which membership events we may have deleted # and which we have added, then we invalidate the caches for all # those users. - members_changed = { + members_to_cache_bust = { state_key for ev_type, state_key in itertools.chain(to_delete, to_insert) if ev_type == EventTypes.Member @@ -1182,16 +1652,22 @@ class PersistEventsStore: """ txn.execute(sql, (stream_id, self._instance_name, room_id)) + # Grab the list of users before we clear out the current state + users_in_room = self.store.get_users_in_room_txn(txn, room_id) # We also want to invalidate the membership caches for users # that were in the room. - users_in_room = self.store.get_users_in_room_txn(txn, room_id) - members_changed.update(users_in_room) + members_to_cache_bust.update(users_in_room) self.db_pool.simple_delete_txn( txn, table="current_state_events", keyvalues={"room_id": room_id}, ) + self.db_pool.simple_delete_txn( + txn, + table="sliding_sync_joined_rooms", + keyvalues={"room_id": room_id}, + ) else: # We're still in the room, so we update the current state as normal. @@ -1216,7 +1692,7 @@ class PersistEventsStore: """ txn.execute_batch( sql, - ( + [ ( stream_id, self._instance_name, @@ -1229,17 +1705,17 @@ class PersistEventsStore: state_key, ) for etype, state_key in itertools.chain(to_delete, to_insert) - ), + ], ) # Now we actually update the current_state_events table txn.execute_batch( "DELETE FROM current_state_events" " WHERE room_id = ? AND type = ? AND state_key = ?", - ( + [ (room_id, etype, state_key) for etype, state_key in itertools.chain(to_delete, to_insert) - ), + ], ) # We include the membership in the current state table, hence we do @@ -1260,6 +1736,63 @@ class PersistEventsStore: ], ) + # Handle updating the `sliding_sync_joined_rooms` table. We only deal with + # updating the state related columns. The + # `event_stream_ordering`/`bump_stamp` are updated elsewhere in the event + # persisting stack (see + # `_update_sliding_sync_tables_with_new_persisted_events_txn()`) + # + # We only need to update when one of the relevant state values has changed + if sliding_sync_table_changes.joined_room_updates: + sliding_sync_updates_keys = ( + sliding_sync_table_changes.joined_room_updates.keys() + ) + sliding_sync_updates_values = ( + sliding_sync_table_changes.joined_room_updates.values() + ) + + args: List[Any] = [ + room_id, + room_id, + sliding_sync_table_changes.joined_room_bump_stamp_to_fully_insert, + ] + args.extend(iter(sliding_sync_updates_values)) + + # XXX: We use a sub-query for `stream_ordering` because it's unreliable to + # pre-calculate from `events_and_contexts` at the time when + # `_calculate_sliding_sync_table_changes()` is ran. We could be working + # with events that were previously persisted as an `outlier` with one + # `stream_ordering` but are now being persisted again and de-outliered + # and assigned a different `stream_ordering`. Since we call + # `_calculate_sliding_sync_table_changes()` before + # `_update_outliers_txn()` which fixes this discrepancy (always use the + # `stream_ordering` from the first time it was persisted), we're working + # with an unreliable `stream_ordering` value that will possibly be + # unused and not make it into the `events` table. + # + # We don't update `event_stream_ordering` `ON CONFLICT` because it's + # simpler and we can just rely on + # `_update_sliding_sync_tables_with_new_persisted_events_txn()` to do + # the right thing (same for `bump_stamp`). The only reason we're + # inserting `event_stream_ordering` here is because the column has a + # `NON NULL` constraint and we need some answer. + txn.execute( + f""" + INSERT INTO sliding_sync_joined_rooms + (room_id, event_stream_ordering, bump_stamp, {", ".join(sliding_sync_updates_keys)}) + VALUES ( + ?, + (SELECT stream_ordering FROM events WHERE room_id = ? ORDER BY stream_ordering DESC LIMIT 1), + ?, + {", ".join("?" for _ in sliding_sync_updates_values)} + ) + ON CONFLICT (room_id) + DO UPDATE SET + {", ".join(f"{key} = EXCLUDED.{key}" for key in sliding_sync_updates_keys)} + """, + args, + ) + # We now update `local_current_membership`. We do this regardless # of whether we're still in the room or not to handle the case where # e.g. we just got banned (where we need to record that fact here). @@ -1272,11 +1805,11 @@ class PersistEventsStore: txn.execute_batch( "DELETE FROM local_current_membership" " WHERE room_id = ? AND user_id = ?", - ( + [ (room_id, state_key) for etype, state_key in itertools.chain(to_delete, to_insert) if etype == EventTypes.Member and self.is_mine_id(state_key) - ), + ], ) if to_insert: @@ -1296,20 +1829,422 @@ class PersistEventsStore: ], ) + # Handle updating the `sliding_sync_membership_snapshots` table + # + # This would only happen if someone was state reset out of the room + if sliding_sync_table_changes.to_delete_membership_snapshots: + self.db_pool.simple_delete_many_txn( + txn, + table="sliding_sync_membership_snapshots", + column="user_id", + values=sliding_sync_table_changes.to_delete_membership_snapshots, + keyvalues={"room_id": room_id}, + ) + + # We do this regardless of whether the server is `no_longer_in_room` or not + # because we still want a row if a local user was just left/kicked or got banned + # from the room. + if sliding_sync_table_changes.to_insert_membership_snapshots: + # Update the `sliding_sync_membership_snapshots` table + # + sliding_sync_snapshot_keys = sliding_sync_table_changes.membership_snapshot_shared_insert_values.keys() + sliding_sync_snapshot_values = sliding_sync_table_changes.membership_snapshot_shared_insert_values.values() + # We need to insert/update regardless of whether we have + # `sliding_sync_snapshot_keys` because there are other fields in the `ON + # CONFLICT` upsert to run (see inherit case (explained in + # `_calculate_sliding_sync_table_changes()`) for more context when this + # happens). + # + # XXX: We use a sub-query for `stream_ordering` because it's unreliable to + # pre-calculate from `events_and_contexts` at the time when + # `_calculate_sliding_sync_table_changes()` is ran. We could be working with + # events that were previously persisted as an `outlier` with one + # `stream_ordering` but are now being persisted again and de-outliered and + # assigned a different `stream_ordering` that won't end up being used. Since + # we call `_calculate_sliding_sync_table_changes()` before + # `_update_outliers_txn()` which fixes this discrepancy (always use the + # `stream_ordering` from the first time it was persisted), we're working + # with an unreliable `stream_ordering` value that will possibly be unused + # and not make it into the `events` table. + txn.execute_batch( + f""" + INSERT INTO sliding_sync_membership_snapshots + (room_id, user_id, sender, membership_event_id, membership, forgotten, event_stream_ordering, event_instance_name + {("," + ", ".join(sliding_sync_snapshot_keys)) if sliding_sync_snapshot_keys else ""}) + VALUES ( + ?, ?, ?, ?, ?, ?, + (SELECT stream_ordering FROM events WHERE event_id = ?), + (SELECT COALESCE(instance_name, 'master') FROM events WHERE event_id = ?) + {("," + ", ".join("?" for _ in sliding_sync_snapshot_values)) if sliding_sync_snapshot_values else ""} + ) + ON CONFLICT (room_id, user_id) + DO UPDATE SET + sender = EXCLUDED.sender, + membership_event_id = EXCLUDED.membership_event_id, + membership = EXCLUDED.membership, + forgotten = EXCLUDED.forgotten, + event_stream_ordering = EXCLUDED.event_stream_ordering + {("," + ", ".join(f"{key} = EXCLUDED.{key}" for key in sliding_sync_snapshot_keys)) if sliding_sync_snapshot_keys else ""} + """, + [ + [ + room_id, + membership_info.user_id, + membership_info.sender, + membership_info.membership_event_id, + membership_info.membership, + # Since this is a new membership, it isn't forgotten anymore (which + # matches how Synapse currently thinks about the forgotten status) + 0, + # XXX: We do not use `membership_info.membership_event_stream_ordering` here + # because it is an unreliable value. See XXX note above. + membership_info.membership_event_id, + # XXX: We do not use `membership_info.membership_event_instance_name` here + # because it is an unreliable value. See XXX note above. + membership_info.membership_event_id, + ] + + list(sliding_sync_snapshot_values) + for membership_info in sliding_sync_table_changes.to_insert_membership_snapshots + ], + ) + txn.call_after( self.store._curr_state_delta_stream_cache.entity_has_changed, room_id, stream_id, ) + for user_id in members_to_cache_bust: + txn.call_after( + self.store._membership_stream_cache.entity_has_changed, + user_id, + stream_id, + ) + # Invalidate the various caches - self.store._invalidate_state_caches_and_stream(txn, room_id, members_changed) + self.store._invalidate_state_caches_and_stream( + txn, room_id, members_to_cache_bust + ) # Check if any of the remote membership changes requires us to # unsubscribe from their device lists. self.store.handle_potentially_left_users_txn( - txn, {m for m in members_changed if not self.hs.is_mine_id(m)} + txn, {m for m in members_to_cache_bust if not self.hs.is_mine_id(m)} + ) + + @classmethod + def _get_relevant_sliding_sync_current_state_event_ids_txn( + cls, txn: LoggingTransaction, room_id: str + ) -> MutableStateMap[str]: + """ + Fetch the current state event IDs for the relevant (to the + `sliding_sync_joined_rooms` table) state types for the given room. + + Returns: + A tuple of: + 1. StateMap of event IDs necessary to to fetch the relevant state values + needed to insert into the + `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots`. + 2. The corresponding latest `stream_id` in the + `current_state_delta_stream` table. This is useful to compare against + the `current_state_delta_stream` table later so you can check whether + the current state has changed since you last fetched the current + state. + """ + # Fetch the current state event IDs from the database + ( + event_type_and_state_key_in_list_clause, + event_type_and_state_key_args, + ) = make_tuple_in_list_sql_clause( + txn.database_engine, + ("type", "state_key"), + SLIDING_SYNC_RELEVANT_STATE_SET, + ) + txn.execute( + f""" + SELECT c.event_id, c.type, c.state_key + FROM current_state_events AS c + WHERE + c.room_id = ? + AND {event_type_and_state_key_in_list_clause} + """, + [room_id] + event_type_and_state_key_args, ) + current_state_map: MutableStateMap[str] = { + (event_type, state_key): event_id for event_id, event_type, state_key in txn + } + + return current_state_map + + @classmethod + def _get_sliding_sync_insert_values_from_state_map( + cls, state_map: StateMap[EventBase] + ) -> SlidingSyncStateInsertValues: + """ + Extract the relevant state values from the `state_map` needed to insert into the + `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` tables. + + Returns: + Map from column names (`room_type`, `is_encrypted`, `room_name`) to relevant + state values needed to insert into + the `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` tables. + """ + # Map of values to insert/update in the `sliding_sync_membership_snapshots` table + sliding_sync_insert_map: SlidingSyncStateInsertValues = {} + + # Parse the raw event JSON + for state_key, event in state_map.items(): + if state_key == (EventTypes.Create, ""): + room_type = event.content.get(EventContentFields.ROOM_TYPE) + # Scrutinize JSON values + if room_type is None or ( + isinstance(room_type, str) + # We ignore values with null bytes as Postgres doesn't allow them in + # text columns. + and "\0" not in room_type + ): + sliding_sync_insert_map["room_type"] = room_type + elif state_key == (EventTypes.RoomEncryption, ""): + encryption_algorithm = event.content.get( + EventContentFields.ENCRYPTION_ALGORITHM + ) + is_encrypted = encryption_algorithm is not None + sliding_sync_insert_map["is_encrypted"] = is_encrypted + elif state_key == (EventTypes.Name, ""): + room_name = event.content.get(EventContentFields.ROOM_NAME) + # Scrutinize JSON values. We ignore values with nulls as + # postgres doesn't allow null bytes in text columns. + if room_name is None or ( + isinstance(room_name, str) + # We ignore values with null bytes as Postgres doesn't allow them in + # text columns. + and "\0" not in room_name + ): + sliding_sync_insert_map["room_name"] = room_name + elif state_key == (EventTypes.Tombstone, ""): + successor_room_id = event.content.get( + EventContentFields.TOMBSTONE_SUCCESSOR_ROOM + ) + # Scrutinize JSON values + if successor_room_id is None or ( + isinstance(successor_room_id, str) + # We ignore values with null bytes as Postgres doesn't allow them in + # text columns. + and "\0" not in successor_room_id + ): + sliding_sync_insert_map["tombstone_successor_room_id"] = ( + successor_room_id + ) + else: + # We only expect to see events according to the + # `SLIDING_SYNC_RELEVANT_STATE_SET`. + raise AssertionError( + "Unexpected event (we should not be fetching extra events or this " + + "piece of code needs to be updated to handle a new event type added " + + "to `SLIDING_SYNC_RELEVANT_STATE_SET`): {state_key} {event.event_id}" + ) + + return sliding_sync_insert_map + + @classmethod + def _get_sliding_sync_insert_values_from_stripped_state( + cls, unsigned_stripped_state_events: Any + ) -> SlidingSyncMembershipSnapshotSharedInsertValues: + """ + Pull out the relevant state values from the stripped state on an invite or knock + membership event needed to insert into the `sliding_sync_membership_snapshots` + tables. + + Returns: + Map from column names (`room_type`, `is_encrypted`, `room_name`) to relevant + state values needed to insert into the `sliding_sync_membership_snapshots` tables. + """ + # Map of values to insert/update in the `sliding_sync_membership_snapshots` table + sliding_sync_insert_map: SlidingSyncMembershipSnapshotSharedInsertValues = {} + + if unsigned_stripped_state_events is not None: + stripped_state_map: MutableStateMap[StrippedStateEvent] = {} + if isinstance(unsigned_stripped_state_events, list): + for raw_stripped_event in unsigned_stripped_state_events: + stripped_state_event = parse_stripped_state_event( + raw_stripped_event + ) + if stripped_state_event is not None: + stripped_state_map[ + ( + stripped_state_event.type, + stripped_state_event.state_key, + ) + ] = stripped_state_event + + # If there is some stripped state, we assume the remote server passed *all* + # of the potential stripped state events for the room. + create_stripped_event = stripped_state_map.get((EventTypes.Create, "")) + # Sanity check that we at-least have the create event + if create_stripped_event is not None: + sliding_sync_insert_map["has_known_state"] = True + + # XXX: Keep this up-to-date with `SLIDING_SYNC_RELEVANT_STATE_SET` + + # Find the room_type + sliding_sync_insert_map["room_type"] = ( + create_stripped_event.content.get(EventContentFields.ROOM_TYPE) + if create_stripped_event is not None + else None + ) + + # Find whether the room is_encrypted + encryption_stripped_event = stripped_state_map.get( + (EventTypes.RoomEncryption, "") + ) + encryption = ( + encryption_stripped_event.content.get( + EventContentFields.ENCRYPTION_ALGORITHM + ) + if encryption_stripped_event is not None + else None + ) + sliding_sync_insert_map["is_encrypted"] = encryption is not None + + # Find the room_name + room_name_stripped_event = stripped_state_map.get((EventTypes.Name, "")) + sliding_sync_insert_map["room_name"] = ( + room_name_stripped_event.content.get(EventContentFields.ROOM_NAME) + if room_name_stripped_event is not None + else None + ) + + # Check for null bytes in the room name and type. We have to + # ignore values with null bytes as Postgres doesn't allow them + # in text columns. + if ( + sliding_sync_insert_map["room_name"] is not None + and "\0" in sliding_sync_insert_map["room_name"] + ): + sliding_sync_insert_map.pop("room_name") + + if ( + sliding_sync_insert_map["room_type"] is not None + and "\0" in sliding_sync_insert_map["room_type"] + ): + sliding_sync_insert_map.pop("room_type") + + # Find the tombstone_successor_room_id + # Note: This isn't one of the stripped state events according to the spec + # but seems like there is no reason not to support this kind of thing. + tombstone_stripped_event = stripped_state_map.get( + (EventTypes.Tombstone, "") + ) + sliding_sync_insert_map["tombstone_successor_room_id"] = ( + tombstone_stripped_event.content.get( + EventContentFields.TOMBSTONE_SUCCESSOR_ROOM + ) + if tombstone_stripped_event is not None + else None + ) + + if ( + sliding_sync_insert_map["tombstone_successor_room_id"] is not None + and "\0" in sliding_sync_insert_map["tombstone_successor_room_id"] + ): + sliding_sync_insert_map.pop("tombstone_successor_room_id") + + else: + # No stripped state provided + sliding_sync_insert_map["has_known_state"] = False + sliding_sync_insert_map["room_type"] = None + sliding_sync_insert_map["room_name"] = None + sliding_sync_insert_map["is_encrypted"] = False + else: + # No stripped state provided + sliding_sync_insert_map["has_known_state"] = False + sliding_sync_insert_map["room_type"] = None + sliding_sync_insert_map["room_name"] = None + sliding_sync_insert_map["is_encrypted"] = False + + return sliding_sync_insert_map + + def _update_sliding_sync_tables_with_new_persisted_events_txn( + self, + txn: LoggingTransaction, + room_id: str, + events_and_contexts: List[Tuple[EventBase, EventContext]], + ) -> None: + """ + Update the latest `event_stream_ordering`/`bump_stamp` columns in the + `sliding_sync_joined_rooms` table for the room with new events. + + This function assumes that `_store_event_txn()` (to persist the event) and + `_update_current_state_txn(...)` (so that `sliding_sync_joined_rooms` table has + been updated with rooms that were joined) have already been run. + + Args: + txn + room_id: The room that all of the events belong to + events_and_contexts: The events being persisted. We assume the list is + sorted ascending by `stream_ordering`. We don't care about the sort when the + events are backfilled (with negative `stream_ordering`). + """ + + # Nothing to do if there are no events + if len(events_and_contexts) == 0: + return + + # Since the list is sorted ascending by `stream_ordering`, the last event should + # have the highest `stream_ordering`. + max_stream_ordering = events_and_contexts[-1][ + 0 + ].internal_metadata.stream_ordering + # `stream_ordering` should be assigned for persisted events + assert max_stream_ordering is not None + # Check if the event is a backfilled event (with a negative `stream_ordering`). + # If one event is backfilled, we assume this whole batch was backfilled. + if max_stream_ordering < 0: + # We only update the sliding sync tables for non-backfilled events. + return + + max_bump_stamp = None + for event, _ in reversed(events_and_contexts): + # Sanity check that all events belong to the same room + assert event.room_id == room_id + + if event.type in SLIDING_SYNC_DEFAULT_BUMP_EVENT_TYPES: + # `stream_ordering` should be assigned for persisted events + assert event.internal_metadata.stream_ordering is not None + + max_bump_stamp = event.internal_metadata.stream_ordering + + # Since we're iterating in reverse, we can break as soon as we find a + # matching bump event which should have the highest `stream_ordering`. + break + + # Handle updating the `sliding_sync_joined_rooms` table. + # + txn.execute( + """ + UPDATE sliding_sync_joined_rooms + SET + event_stream_ordering = CASE + WHEN event_stream_ordering IS NULL OR event_stream_ordering < ? + THEN ? + ELSE event_stream_ordering + END, + bump_stamp = CASE + WHEN bump_stamp IS NULL OR bump_stamp < ? + THEN ? + ELSE bump_stamp + END + WHERE room_id = ? + """, + ( + max_stream_ordering, + max_stream_ordering, + max_bump_stamp, + max_bump_stamp, + room_id, + ), + ) + # This may or may not update any rows depending if we are `no_longer_in_room` def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str) -> None: """Update the room version in the database based off current state @@ -1931,7 +2866,9 @@ class PersistEventsStore: ) for event in events: + # Sanity check that we're working with persisted events assert event.internal_metadata.stream_ordering is not None + assert event.internal_metadata.instance_name is not None # We update the local_current_membership table only if the event is # "current", i.e., its something that has just happened. @@ -1945,6 +2882,16 @@ class PersistEventsStore: and event.internal_metadata.is_outlier() and event.internal_metadata.is_out_of_band_membership() ): + # The only sort of out-of-band-membership events we expect to see here + # are remote invites/knocks and LEAVE events corresponding to + # rejected/retracted invites and rescinded knocks. + assert event.type == EventTypes.Member + assert event.membership in ( + Membership.INVITE, + Membership.KNOCK, + Membership.LEAVE, + ) + self.db_pool.simple_upsert_txn( txn, table="local_current_membership", @@ -1956,6 +2903,59 @@ class PersistEventsStore: }, ) + # Handle updating the `sliding_sync_membership_snapshots` table + # (out-of-band membership events only) + # + raw_stripped_state_events = None + if event.membership == Membership.INVITE: + invite_room_state = event.unsigned.get("invite_room_state") + raw_stripped_state_events = invite_room_state + elif event.membership == Membership.KNOCK: + knock_room_state = event.unsigned.get("knock_room_state") + raw_stripped_state_events = knock_room_state + + insert_values = { + "sender": event.sender, + "membership_event_id": event.event_id, + "membership": event.membership, + # Since this is a new membership, it isn't forgotten anymore (which + # matches how Synapse currently thinks about the forgotten status) + "forgotten": 0, + "event_stream_ordering": event.internal_metadata.stream_ordering, + "event_instance_name": event.internal_metadata.instance_name, + } + if event.membership == Membership.LEAVE: + # Inherit the meta data from the remote invite/knock. When using + # sliding sync filters, this will prevent the room from + # disappearing/appearing just because you left the room. + pass + elif event.membership in (Membership.INVITE, Membership.KNOCK): + extra_insert_values = ( + self._get_sliding_sync_insert_values_from_stripped_state( + raw_stripped_state_events + ) + ) + insert_values.update(extra_insert_values) + else: + # We don't know how to handle this type of membership yet + # + # FIXME: We should use `assert_never` here but for some reason + # the exhaustive matching doesn't recognize the `Never` here. + # assert_never(event.membership) + raise AssertionError( + f"Unexpected out-of-band membership {event.membership} ({event.event_id}) that we don't know how to handle yet" + ) + + self.db_pool.simple_upsert_txn( + txn, + table="sliding_sync_membership_snapshots", + keyvalues={ + "room_id": event.room_id, + "user_id": event.state_key, + }, + values=insert_values, + ) + def _handle_event_relations( self, txn: LoggingTransaction, event: EventBase ) -> None: @@ -2221,7 +3221,7 @@ class PersistEventsStore: if notifiable_events: txn.execute_batch( sql, - ( + [ ( event.room_id, event.internal_metadata.stream_ordering, @@ -2229,18 +3229,18 @@ class PersistEventsStore: event.event_id, ) for event in notifiable_events - ), + ], ) # Now we delete the staging area for *all* events that were being # persisted. txn.execute_batch( "DELETE FROM event_push_actions_staging WHERE event_id = ?", - ( + [ (event.event_id,) for event, _ in all_events_and_contexts if event.internal_metadata.is_notifiable() - ), + ], ) def _remove_push_actions_for_event_id_txn( @@ -2415,7 +3415,7 @@ class PersistEventsStore: ) potential_backwards_extremities.difference_update( - e for e, in existing_events_outliers + e for (e,) in existing_events_outliers ) if potential_backwards_extremities: @@ -2448,8 +3448,7 @@ class PersistEventsStore: # Delete all these events that we've already fetched and now know that their # prev events are the new backwards extremeties. query = ( - "DELETE FROM event_backward_extremities" - " WHERE event_id = ? AND room_id = ?" + "DELETE FROM event_backward_extremities WHERE event_id = ? AND room_id = ?" ) backward_extremity_tuples_to_remove = [ (ev.event_id, ev.room_id) diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 64d303e330..5c83a9f779 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -24,9 +24,14 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, cast import attr -from synapse.api.constants import EventContentFields, RelationTypes +from synapse.api.constants import ( + MAX_DEPTH, + EventContentFields, + Membership, + RelationTypes, +) from synapse.api.room_versions import KNOWN_ROOM_VERSIONS -from synapse.events import make_event_from_dict +from synapse.events import EventBase, make_event_from_dict from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( DatabasePool, @@ -34,9 +39,27 @@ from synapse.storage.database import ( LoggingTransaction, make_tuple_comparison_clause, ) -from synapse.storage.databases.main.events import PersistEventsStore +from synapse.storage.databases.main.events import ( + SLIDING_SYNC_RELEVANT_STATE_SET, + PersistEventsStore, + SlidingSyncMembershipInfoWithEventPos, + SlidingSyncMembershipSnapshotSharedInsertValues, + SlidingSyncStateInsertValues, +) +from synapse.storage.databases.main.events_worker import ( + DatabaseCorruptionError, + InvalidEventError, +) +from synapse.storage.databases.main.state_deltas import StateDeltasStore +from synapse.storage.databases.main.stream import StreamWorkerStore +from synapse.storage.engines import PostgresEngine from synapse.storage.types import Cursor -from synapse.types import JsonDict, StrCollection +from synapse.types import JsonDict, RoomStreamToken, StateMap, StrCollection +from synapse.types.handlers import SLIDING_SYNC_DEFAULT_BUMP_EVENT_TYPES +from synapse.types.state import StateFilter +from synapse.types.storage import _BackgroundUpdates +from synapse.util import json_encoder +from synapse.util.iterutils import batch_iter if TYPE_CHECKING: from synapse.server import HomeServer @@ -59,26 +82,6 @@ _REPLACE_STREAM_ORDERING_SQL_COMMANDS = ( ) -class _BackgroundUpdates: - EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" - EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url" - DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities" - POPULATE_STREAM_ORDERING2 = "populate_stream_ordering2" - INDEX_STREAM_ORDERING2 = "index_stream_ordering2" - INDEX_STREAM_ORDERING2_CONTAINS_URL = "index_stream_ordering2_contains_url" - INDEX_STREAM_ORDERING2_ROOM_ORDER = "index_stream_ordering2_room_order" - INDEX_STREAM_ORDERING2_ROOM_STREAM = "index_stream_ordering2_room_stream" - INDEX_STREAM_ORDERING2_TS = "index_stream_ordering2_ts" - REPLACE_STREAM_ORDERING_COLUMN = "replace_stream_ordering_column" - - EVENT_EDGES_DROP_INVALID_ROWS = "event_edges_drop_invalid_rows" - EVENT_EDGES_REPLACE_INDEX = "event_edges_replace_index" - - EVENTS_POPULATE_STATE_KEY_REJECTIONS = "events_populate_state_key_rejections" - - EVENTS_JUMP_TO_DATE_INDEX = "events_jump_to_date_index" - - @attr.s(slots=True, frozen=True, auto_attribs=True) class _CalculateChainCover: """Return value for _calculate_chain_cover_txn.""" @@ -97,7 +100,19 @@ class _CalculateChainCover: finished_room_map: Dict[str, Tuple[int, int]] -class EventsBackgroundUpdatesStore(SQLBaseStore): +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _JoinedRoomStreamOrderingUpdate: + """ + Intermediate container class used in `SLIDING_SYNC_JOINED_ROOMS_BG_UPDATE` + """ + + # The most recent event stream_ordering for the room + most_recent_event_stream_ordering: int + # The most recent event `bump_stamp` for the room + most_recent_bump_stamp: Optional[int] + + +class EventsBackgroundUpdatesStore(StreamWorkerStore, StateDeltasStore, SQLBaseStore): def __init__( self, database: DatabasePool, @@ -279,6 +294,44 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): where_clause="NOT outlier", ) + # Handle background updates for Sliding Sync tables + # + self.db_pool.updates.register_background_update_handler( + _BackgroundUpdates.SLIDING_SYNC_PREFILL_JOINED_ROOMS_TO_RECALCULATE_TABLE_BG_UPDATE, + self._sliding_sync_prefill_joined_rooms_to_recalculate_table_bg_update, + ) + # Add some background updates to populate the sliding sync tables + self.db_pool.updates.register_background_update_handler( + _BackgroundUpdates.SLIDING_SYNC_JOINED_ROOMS_BG_UPDATE, + self._sliding_sync_joined_rooms_bg_update, + ) + self.db_pool.updates.register_background_update_handler( + _BackgroundUpdates.SLIDING_SYNC_MEMBERSHIP_SNAPSHOTS_BG_UPDATE, + self._sliding_sync_membership_snapshots_bg_update, + ) + # Add a background update to fix data integrity issue in the + # `sliding_sync_membership_snapshots` -> `forgotten` column + self.db_pool.updates.register_background_update_handler( + _BackgroundUpdates.SLIDING_SYNC_MEMBERSHIP_SNAPSHOTS_FIX_FORGOTTEN_COLUMN_BG_UPDATE, + self._sliding_sync_membership_snapshots_fix_forgotten_column_bg_update, + ) + + self.db_pool.updates.register_background_update_handler( + _BackgroundUpdates.FIXUP_MAX_DEPTH_CAP, self.fixup_max_depth_cap_bg_update + ) + + # We want this to run on the main database at startup before we start processing + # events. + # + # FIXME: This can be removed once we bump `SCHEMA_COMPAT_VERSION` and run the + # foreground update for + # `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` (tracked by + # https://github.com/element-hq/synapse/issues/17623) + with db_conn.cursor(txn_name="resolve_sliding_sync") as txn: + _resolve_stale_data_in_sliding_sync_tables( + txn=txn, + ) + async def _background_reindex_fields_sender( self, progress: JsonDict, batch_size: int ) -> int: @@ -586,7 +639,8 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): room_ids = {row[0] for row in rows} for room_id in room_ids: txn.call_after( - self.get_latest_event_ids_in_room.invalidate, (room_id,) # type: ignore[attr-defined] + self.get_latest_event_ids_in_room.invalidate, # type: ignore[attr-defined] + (room_id,), ) self.db_pool.simple_delete_many_txn( @@ -1073,7 +1127,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): PersistEventsStore._add_chain_cover_index( txn, self.db_pool, - self.event_chain_id_gen, # type: ignore[attr-defined] + self.event_chain_id_gen, event_to_room_id, event_to_types, cast(Dict[str, StrCollection], event_to_auth_chain), @@ -1516,3 +1570,1320 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): ) return batch_size + + async def _sliding_sync_prefill_joined_rooms_to_recalculate_table_bg_update( + self, progress: JsonDict, _batch_size: int + ) -> int: + """ + Prefill `sliding_sync_joined_rooms_to_recalculate` table with all rooms we know about already. + """ + + def _txn(txn: LoggingTransaction) -> None: + # We do this as one big bulk insert. This has been tested on a bigger + # homeserver with ~10M rooms and took 60s. There is potential for this to + # starve disk usage while this goes on. + # + # We upsert in case we have to run this multiple times. + txn.execute( + """ + INSERT INTO sliding_sync_joined_rooms_to_recalculate + (room_id) + SELECT DISTINCT room_id FROM local_current_membership + WHERE membership = 'join' + ON CONFLICT (room_id) + DO NOTHING; + """, + ) + + await self.db_pool.runInteraction( + "_sliding_sync_prefill_joined_rooms_to_recalculate_table_bg_update", + _txn, + ) + + # Background update is done. + await self.db_pool.updates._end_background_update( + _BackgroundUpdates.SLIDING_SYNC_PREFILL_JOINED_ROOMS_TO_RECALCULATE_TABLE_BG_UPDATE + ) + return 0 + + async def _sliding_sync_joined_rooms_bg_update( + self, progress: JsonDict, batch_size: int + ) -> int: + """ + Background update to populate the `sliding_sync_joined_rooms` table. + """ + # We don't need to fetch any progress state because we just grab the next N + # events in `sliding_sync_joined_rooms_to_recalculate` + + def _get_rooms_to_update_txn(txn: LoggingTransaction) -> List[Tuple[str]]: + """ + Returns: + A list of room ID's to update along with the progress value + (event_stream_ordering) indicating the continuation point in the + `current_state_events` table for the next batch. + """ + # Fetch the set of room IDs that we want to update + # + # We use `current_state_events` table as the barometer for whether the + # server is still participating in the room because if we're + # `no_longer_in_room`, this table would be cleared out for the given + # `room_id`. + txn.execute( + """ + SELECT room_id + FROM sliding_sync_joined_rooms_to_recalculate + LIMIT ? + """, + (batch_size,), + ) + + rooms_to_update_rows = cast(List[Tuple[str]], txn.fetchall()) + + return rooms_to_update_rows + + rooms_to_update = await self.db_pool.runInteraction( + "_sliding_sync_joined_rooms_bg_update._get_rooms_to_update_txn", + _get_rooms_to_update_txn, + ) + + if not rooms_to_update: + await self.db_pool.updates._end_background_update( + _BackgroundUpdates.SLIDING_SYNC_JOINED_ROOMS_BG_UPDATE + ) + return 0 + + # Map from room_id to insert/update state values in the `sliding_sync_joined_rooms` table. + joined_room_updates: Dict[str, SlidingSyncStateInsertValues] = {} + # Map from room_id to stream_ordering/bump_stamp, etc values + joined_room_stream_ordering_updates: Dict[ + str, _JoinedRoomStreamOrderingUpdate + ] = {} + # As long as we get this value before we fetch the current state, we can use it + # to check if something has changed since that point. + most_recent_current_state_delta_stream_id = ( + await self.get_max_stream_id_in_current_state_deltas() + ) + for (room_id,) in rooms_to_update: + current_state_ids_map = await self.db_pool.runInteraction( + "_sliding_sync_joined_rooms_bg_update._get_relevant_sliding_sync_current_state_event_ids_txn", + PersistEventsStore._get_relevant_sliding_sync_current_state_event_ids_txn, + room_id, + ) + + # If we're not joined to the room a) it doesn't belong in the + # `sliding_sync_joined_rooms` table so we should skip and b) we won't have + # any `current_state_events` for the room. + if not current_state_ids_map: + continue + + try: + fetched_events = await self.get_events(current_state_ids_map.values()) + except (DatabaseCorruptionError, InvalidEventError) as e: + logger.warning( + "Failed to fetch state for room '%s' due to corrupted events. Ignoring. Error: %s", + room_id, + e, + ) + continue + + current_state_map: StateMap[EventBase] = { + state_key: fetched_events[event_id] + for state_key, event_id in current_state_ids_map.items() + # `get_events(...)` will filter out events for unknown room versions + if event_id in fetched_events + } + + # Even if we are joined to the room, this can happen for unknown room + # versions (old room versions that aren't known anymore) since + # `get_events(...)` will filter out events for unknown room versions + if not current_state_map: + continue + + state_insert_values = ( + PersistEventsStore._get_sliding_sync_insert_values_from_state_map( + current_state_map + ) + ) + # We should have some insert values for each room, even if they are `None` + assert state_insert_values + joined_room_updates[room_id] = state_insert_values + + # Figure out the stream_ordering of the latest event in the room + most_recent_event_pos_results = await self.get_last_event_pos_in_room( + room_id, event_types=None + ) + assert most_recent_event_pos_results is not None, ( + f"We should not be seeing `None` here because the room ({room_id}) should at-least have a create event " + + "given we pulled the room out of `current_state_events`" + ) + most_recent_event_stream_ordering = most_recent_event_pos_results[1].stream + + # The `most_recent_event_stream_ordering` should be positive, + # however there are (very rare) rooms where that is not the case in + # the matrix.org database. It's not clear how they got into that + # state, but does mean that we cannot assert that the stream + # ordering is indeed positive. + + # Figure out the latest `bump_stamp` in the room. This could be `None` for a + # federated room you just joined where all of events are still `outliers` or + # backfilled history. In the Sliding Sync API, we default to the user's + # membership event `stream_ordering` if we don't have a `bump_stamp` so + # having it as `None` in this table is fine. + bump_stamp_event_pos_results = await self.get_last_event_pos_in_room( + room_id, event_types=SLIDING_SYNC_DEFAULT_BUMP_EVENT_TYPES + ) + most_recent_bump_stamp = None + if ( + bump_stamp_event_pos_results is not None + and bump_stamp_event_pos_results[1].stream > 0 + ): + most_recent_bump_stamp = bump_stamp_event_pos_results[1].stream + + joined_room_stream_ordering_updates[room_id] = ( + _JoinedRoomStreamOrderingUpdate( + most_recent_event_stream_ordering=most_recent_event_stream_ordering, + most_recent_bump_stamp=most_recent_bump_stamp, + ) + ) + + def _fill_table_txn(txn: LoggingTransaction) -> None: + # Handle updating the `sliding_sync_joined_rooms` table + # + for ( + room_id, + update_map, + ) in joined_room_updates.items(): + joined_room_stream_ordering_update = ( + joined_room_stream_ordering_updates[room_id] + ) + event_stream_ordering = ( + joined_room_stream_ordering_update.most_recent_event_stream_ordering + ) + bump_stamp = joined_room_stream_ordering_update.most_recent_bump_stamp + + # Check if the current state has been updated since we gathered it. + # We're being careful not to insert/overwrite with stale data. + state_deltas_since_we_gathered_current_state = ( + self.get_current_state_deltas_for_room_txn( + txn, + room_id, + from_token=RoomStreamToken( + stream=most_recent_current_state_delta_stream_id + ), + to_token=None, + ) + ) + for state_delta in state_deltas_since_we_gathered_current_state: + # We only need to check for the state is relevant to the + # `sliding_sync_joined_rooms` table. + if ( + state_delta.event_type, + state_delta.state_key, + ) in SLIDING_SYNC_RELEVANT_STATE_SET: + # Raising exception so we can just exit and try again. It would + # be hard to resolve this within the transaction because we need + # to get full events out that take redactions into account. We + # could add some retry logic here, but it's easier to just let + # the background update try again. + raise Exception( + "Current state was updated after we gathered it to update " + + "`sliding_sync_joined_rooms` in the background update. " + + "Raising exception so we can just try again." + ) + + # Since we fully insert rows into `sliding_sync_joined_rooms`, we can + # just do everything on insert and `ON CONFLICT DO NOTHING`. + # + self.db_pool.simple_upsert_txn( + txn, + table="sliding_sync_joined_rooms", + keyvalues={"room_id": room_id}, + values={}, + insertion_values={ + **update_map, + # The reason we're only *inserting* (not *updating*) `event_stream_ordering` + # and `bump_stamp` is because if they are present, that means they are already + # up-to-date. + "event_stream_ordering": event_stream_ordering, + "bump_stamp": bump_stamp, + }, + ) + + # Now that we've processed all the room, we can remove them from the + # queue. + # + # Note: we need to remove all the rooms from the queue we pulled out + # from the DB, not just the ones we've processed above. Otherwise + # we'll simply keep pulling out the same rooms over and over again. + self.db_pool.simple_delete_many_batch_txn( + txn, + table="sliding_sync_joined_rooms_to_recalculate", + keys=("room_id",), + values=rooms_to_update, + ) + + await self.db_pool.runInteraction( + "sliding_sync_joined_rooms_bg_update", _fill_table_txn + ) + + return len(rooms_to_update) + + async def _sliding_sync_membership_snapshots_bg_update( + self, progress: JsonDict, batch_size: int + ) -> int: + """ + Background update to populate the `sliding_sync_membership_snapshots` table. + """ + # We do this in two phases: a) the initial phase where we go through all + # room memberships, and then b) a second phase where we look at new + # memberships (this is to handle the case where we downgrade and then + # upgrade again). + # + # We have to do this as two phases (rather than just the second phase + # where we iterate on event_stream_ordering), as the + # `event_stream_ordering` column may have null values for old rows. + # Therefore we first do the set of historic rooms and *then* look at any + # new rows (which will have a non-null `event_stream_ordering`). + initial_phase = progress.get("initial_phase") + if initial_phase is None: + # If this is the first run, store the current max stream position. + # We know we will go through all memberships less than the current + # max in the initial phase. + progress = { + "initial_phase": True, + "last_event_stream_ordering": self.get_room_max_stream_ordering(), + } + await self.db_pool.updates._background_update_progress( + _BackgroundUpdates.SLIDING_SYNC_MEMBERSHIP_SNAPSHOTS_BG_UPDATE, + progress, + ) + initial_phase = True + + last_room_id = progress.get("last_room_id", "") + last_user_id = progress.get("last_user_id", "") + last_event_stream_ordering = progress["last_event_stream_ordering"] + + def _find_memberships_to_update_txn( + txn: LoggingTransaction, + ) -> List[ + Tuple[ + str, + Optional[str], + Optional[str], + str, + str, + str, + str, + int, + Optional[str], + bool, + ] + ]: + # Fetch the set of event IDs that we want to update + # + # We skip over rows which we've already handled, i.e. have a + # matching row in `sliding_sync_membership_snapshots` with the same + # room, user and event ID. + # + # We also ignore rooms that the user has left themselves (i.e. not + # kicked). This is to avoid having to port lots of old rooms that we + # will never send down sliding sync (as we exclude such rooms from + # initial syncs). + + if initial_phase: + # There are some old out-of-band memberships (before + # https://github.com/matrix-org/synapse/issues/6983) where we don't have + # the corresponding room stored in the `rooms` table`. We use `LEFT JOIN + # rooms AS r USING (room_id)` to find the rooms missing from `rooms` and + # insert a row for them below. + txn.execute( + """ + SELECT + c.room_id, + r.room_id, + r.room_version, + c.user_id, + e.sender, + c.event_id, + c.membership, + e.stream_ordering, + e.instance_name, + e.outlier + FROM local_current_membership AS c + LEFT JOIN sliding_sync_membership_snapshots AS m USING (room_id, user_id) + INNER JOIN events AS e USING (event_id) + LEFT JOIN rooms AS r ON (c.room_id = r.room_id) + WHERE (c.room_id, c.user_id) > (?, ?) + AND (m.user_id IS NULL OR c.event_id != m.membership_event_id) + ORDER BY c.room_id ASC, c.user_id ASC + LIMIT ? + """, + (last_room_id, last_user_id, batch_size), + ) + elif last_event_stream_ordering is not None: + # It's important to sort by `event_stream_ordering` *ascending* (oldest to + # newest) so that if we see that this background update in progress and want + # to start the catch-up process, we can safely assume that it will + # eventually get to the rooms we want to catch-up on anyway (see + # `_resolve_stale_data_in_sliding_sync_tables()`). + # + # `c.room_id` is duplicated to make it match what we're doing in the + # `initial_phase`. But we can avoid doing the extra `rooms` table join + # because we can assume all of these new events won't have this problem. + txn.execute( + """ + SELECT + c.room_id, + r.room_id, + r.room_version, + c.user_id, + e.sender, + c.event_id, + c.membership, + c.event_stream_ordering, + e.instance_name, + e.outlier + FROM local_current_membership AS c + LEFT JOIN sliding_sync_membership_snapshots AS m USING (room_id, user_id) + INNER JOIN events AS e USING (event_id) + LEFT JOIN rooms AS r ON (c.room_id = r.room_id) + WHERE c.event_stream_ordering > ? + AND (m.user_id IS NULL OR c.event_id != m.membership_event_id) + ORDER BY c.event_stream_ordering ASC + LIMIT ? + """, + (last_event_stream_ordering, batch_size), + ) + else: + raise Exception("last_event_stream_ordering should not be None") + + memberships_to_update_rows = cast( + List[ + Tuple[ + str, + Optional[str], + Optional[str], + str, + str, + str, + str, + int, + Optional[str], + bool, + ] + ], + txn.fetchall(), + ) + + return memberships_to_update_rows + + memberships_to_update_rows = await self.db_pool.runInteraction( + "sliding_sync_membership_snapshots_bg_update._find_memberships_to_update_txn", + _find_memberships_to_update_txn, + ) + + if not memberships_to_update_rows: + if initial_phase: + # Move onto the next phase. + await self.db_pool.updates._background_update_progress( + _BackgroundUpdates.SLIDING_SYNC_MEMBERSHIP_SNAPSHOTS_BG_UPDATE, + { + "initial_phase": False, + "last_event_stream_ordering": last_event_stream_ordering, + }, + ) + return 0 + else: + # We've finished both phases, we're done. + await self.db_pool.updates._end_background_update( + _BackgroundUpdates.SLIDING_SYNC_MEMBERSHIP_SNAPSHOTS_BG_UPDATE + ) + return 0 + + def _find_previous_invite_or_knock_membership_txn( + txn: LoggingTransaction, room_id: str, user_id: str, event_id: str + ) -> Optional[Tuple[str, str]]: + # Find the previous invite/knock event before the leave event + # + # Here are some notes on how we landed on this query: + # + # We're using `topological_ordering` instead of `stream_ordering` because + # somehow it's possible to have your `leave` event backfilled with a + # negative `stream_ordering` and your previous `invite` event with a + # positive `stream_ordering` so we wouldn't have a chance of finding the + # previous membership with a naive `event_stream_ordering < ?` comparison. + # + # Also be careful because `room_memberships.event_stream_ordering` is + # nullable and not always filled in. You would need to join on `events` to + # rely on `events.stream_ordering` instead. Even though the + # `events.stream_ordering` also doesn't have a `NOT NULL` constraint, it + # doesn't have any rows where this is the case (checked on `matrix.org`). + # The fact the `events.stream_ordering` is a nullable column is a holdover + # from a rename of the column. + # + # You might also consider using the `event_auth` table to find the previous + # membership, but there are cases where somehow a membership event doesn't + # point back to the previous membership event in the auth events (unknown + # cause). + txn.execute( + """ + SELECT event_id, membership + FROM room_memberships AS m + INNER JOIN events AS e USING (room_id, event_id) + WHERE + room_id = ? + AND m.user_id = ? + AND (m.membership = ? OR m.membership = ?) + AND e.event_id != ? + ORDER BY e.topological_ordering DESC + LIMIT 1 + """, + ( + room_id, + user_id, + # We look explicitly for `invite` and `knock` events instead of + # just their previous membership as someone could have been `invite` + # -> `ban` -> unbanned (`leave`) and we want to find the `invite` + # event where the stripped state is. + Membership.INVITE, + Membership.KNOCK, + event_id, + ), + ) + row = txn.fetchone() + + if row is None: + # Generally we should have an invite or knock event for leaves + # that are outliers, however this may not always be the case + # (e.g. a local user got kicked but the kick event got pulled in + # as an outlier). + return None + + event_id, membership = row + + return event_id, membership + + # Map from (room_id, user_id) to ... + to_insert_membership_snapshots: Dict[ + Tuple[str, str], SlidingSyncMembershipSnapshotSharedInsertValues + ] = {} + to_insert_membership_infos: Dict[ + Tuple[str, str], SlidingSyncMembershipInfoWithEventPos + ] = {} + for ( + room_id, + room_id_from_rooms_table, + room_version_id, + user_id, + sender, + membership_event_id, + membership, + membership_event_stream_ordering, + membership_event_instance_name, + is_outlier, + ) in memberships_to_update_rows: + # We don't know how to handle `membership` values other than these. The + # code below would need to be updated. + assert membership in ( + Membership.JOIN, + Membership.INVITE, + Membership.KNOCK, + Membership.LEAVE, + Membership.BAN, + ) + + if ( + room_version_id is not None + and room_version_id not in KNOWN_ROOM_VERSIONS + ): + # Ignore rooms with unknown room versions (these were + # experimental rooms, that we no longer support). + continue + + # There are some old out-of-band memberships (before + # https://github.com/matrix-org/synapse/issues/6983) where we don't have the + # corresponding room stored in the `rooms` table`. We have a `FOREIGN KEY` + # constraint on the `sliding_sync_membership_snapshots` table so we have to + # fix-up these memberships by adding the room to the `rooms` table. + if room_id_from_rooms_table is None: + await self.db_pool.simple_insert( + table="rooms", + values={ + "room_id": room_id, + # Only out-of-band memberships are missing from the `rooms` + # table so that is the only type of membership we're dealing + # with here. Since we don't calculate the "chain cover" for + # out-of-band memberships, we can just set this to `True` as if + # the user ever joins the room, we will end up calculating the + # "chain cover" anyway. + "has_auth_chain_index": True, + }, + ) + + # Map of values to insert/update in the `sliding_sync_membership_snapshots` table + sliding_sync_membership_snapshots_insert_map: SlidingSyncMembershipSnapshotSharedInsertValues = {} + if membership == Membership.JOIN: + # If we're still joined, we can pull from current state. + current_state_ids_map: StateMap[ + str + ] = await self.hs.get_storage_controllers().state.get_current_state_ids( + room_id, + state_filter=StateFilter.from_types( + SLIDING_SYNC_RELEVANT_STATE_SET + ), + # Partially-stated rooms should have all state events except for + # remote membership events so we don't need to wait at all because + # we only want some non-membership state + await_full_state=False, + ) + # We're iterating over rooms that we are joined to so they should + # have `current_state_events` and we should have some current state + # for each room + if current_state_ids_map: + try: + fetched_events = await self.get_events( + current_state_ids_map.values() + ) + except (DatabaseCorruptionError, InvalidEventError) as e: + logger.warning( + "Failed to fetch state for room '%s' due to corrupted events. Ignoring. Error: %s", + room_id, + e, + ) + continue + + current_state_map: StateMap[EventBase] = { + state_key: fetched_events[event_id] + for state_key, event_id in current_state_ids_map.items() + # `get_events(...)` will filter out events for unknown room versions + if event_id in fetched_events + } + + # Can happen for unknown room versions (old room versions that aren't known + # anymore) since `get_events(...)` will filter out events for unknown room + # versions + if not current_state_map: + continue + + state_insert_values = PersistEventsStore._get_sliding_sync_insert_values_from_state_map( + current_state_map + ) + sliding_sync_membership_snapshots_insert_map.update( + state_insert_values + ) + # We should have some insert values for each room, even if they are `None` + assert sliding_sync_membership_snapshots_insert_map + + # We have current state to work from + sliding_sync_membership_snapshots_insert_map["has_known_state"] = ( + True + ) + else: + # Although we expect every room to have a create event (even + # past unknown room versions since we haven't supported one + # without it), there seem to be some corrupted rooms in + # practice that don't have the create event in the + # `current_state_events` table. The create event does exist + # in the events table though. We'll just say that we don't + # know the state for these rooms and continue on with our + # day. + sliding_sync_membership_snapshots_insert_map = { + "has_known_state": False, + "room_type": None, + "room_name": None, + "is_encrypted": False, + } + elif membership in (Membership.INVITE, Membership.KNOCK) or ( + membership in (Membership.LEAVE, Membership.BAN) and is_outlier + ): + invite_or_knock_event_id = None + invite_or_knock_membership = None + + # If the event is an `out_of_band_membership` (special case of + # `outlier`), we never had historical state so we have to pull from + # the stripped state on the previous invite/knock event. This gives + # us a consistent view of the room state regardless of your + # membership (i.e. the room shouldn't disappear if your using the + # `is_encrypted` filter and you leave). + if membership in (Membership.LEAVE, Membership.BAN) and is_outlier: + previous_membership = await self.db_pool.runInteraction( + "sliding_sync_membership_snapshots_bg_update._find_previous_invite_or_knock_membership_txn", + _find_previous_invite_or_knock_membership_txn, + room_id, + user_id, + membership_event_id, + ) + if previous_membership is not None: + ( + invite_or_knock_event_id, + invite_or_knock_membership, + ) = previous_membership + else: + invite_or_knock_event_id = membership_event_id + invite_or_knock_membership = membership + + if ( + invite_or_knock_event_id is not None + and invite_or_knock_membership is not None + ): + # Pull from the stripped state on the invite/knock event + invite_or_knock_event = await self.get_event( + invite_or_knock_event_id + ) + + raw_stripped_state_events = None + if invite_or_knock_membership == Membership.INVITE: + invite_room_state = invite_or_knock_event.unsigned.get( + "invite_room_state" + ) + raw_stripped_state_events = invite_room_state + elif invite_or_knock_membership == Membership.KNOCK: + knock_room_state = invite_or_knock_event.unsigned.get( + "knock_room_state" + ) + raw_stripped_state_events = knock_room_state + + sliding_sync_membership_snapshots_insert_map = PersistEventsStore._get_sliding_sync_insert_values_from_stripped_state( + raw_stripped_state_events + ) + else: + # We couldn't find any state for the membership, so we just have to + # leave it as empty. + sliding_sync_membership_snapshots_insert_map = { + "has_known_state": False, + "room_type": None, + "room_name": None, + "is_encrypted": False, + } + + # We should have some insert values for each room, even if no + # stripped state is on the event because we still want to record + # that we have no known state + assert sliding_sync_membership_snapshots_insert_map + elif membership in (Membership.LEAVE, Membership.BAN): + # Pull from historical state + state_ids_map = await self.hs.get_storage_controllers().state.get_state_ids_for_event( + membership_event_id, + state_filter=StateFilter.from_types( + SLIDING_SYNC_RELEVANT_STATE_SET + ), + # Partially-stated rooms should have all state events except for + # remote membership events so we don't need to wait at all because + # we only want some non-membership state + await_full_state=False, + ) + + try: + fetched_events = await self.get_events(state_ids_map.values()) + except (DatabaseCorruptionError, InvalidEventError) as e: + logger.warning( + "Failed to fetch state for room '%s' due to corrupted events. Ignoring. Error: %s", + room_id, + e, + ) + continue + + state_map: StateMap[EventBase] = { + state_key: fetched_events[event_id] + for state_key, event_id in state_ids_map.items() + # `get_events(...)` will filter out events for unknown room versions + if event_id in fetched_events + } + + # Can happen for unknown room versions (old room versions that aren't known + # anymore) since `get_events(...)` will filter out events for unknown room + # versions + if not state_map: + continue + + state_insert_values = ( + PersistEventsStore._get_sliding_sync_insert_values_from_state_map( + state_map + ) + ) + sliding_sync_membership_snapshots_insert_map.update(state_insert_values) + # We should have some insert values for each room, even if they are `None` + assert sliding_sync_membership_snapshots_insert_map + + # We have historical state to work from + sliding_sync_membership_snapshots_insert_map["has_known_state"] = True + else: + # We don't know how to handle this type of membership yet + # + # FIXME: We should use `assert_never` here but for some reason + # the exhaustive matching doesn't recognize the `Never` here. + # assert_never(membership) + raise AssertionError( + f"Unexpected membership {membership} ({membership_event_id}) that we don't know how to handle yet" + ) + + to_insert_membership_snapshots[(room_id, user_id)] = ( + sliding_sync_membership_snapshots_insert_map + ) + to_insert_membership_infos[(room_id, user_id)] = ( + SlidingSyncMembershipInfoWithEventPos( + user_id=user_id, + sender=sender, + membership_event_id=membership_event_id, + membership=membership, + membership_event_stream_ordering=membership_event_stream_ordering, + # If instance_name is null we default to "master" + membership_event_instance_name=membership_event_instance_name + or "master", + ) + ) + + def _fill_table_txn(txn: LoggingTransaction) -> None: + # Handle updating the `sliding_sync_membership_snapshots` table + # + for key, insert_map in to_insert_membership_snapshots.items(): + room_id, user_id = key + membership_info = to_insert_membership_infos[key] + sender = membership_info.sender + membership_event_id = membership_info.membership_event_id + membership = membership_info.membership + membership_event_stream_ordering = ( + membership_info.membership_event_stream_ordering + ) + membership_event_instance_name = ( + membership_info.membership_event_instance_name + ) + + # We don't need to upsert the state because we never partially + # insert/update the snapshots and anything already there is up-to-date + # EXCEPT for the `forgotten` field since that is updated out-of-band + # from the membership changes. + # + # Even though we're only doing insertions, we're using + # `simple_upsert_txn()` here to avoid unique violation errors that would + # happen from `simple_insert_txn()` + self.db_pool.simple_upsert_txn( + txn, + table="sliding_sync_membership_snapshots", + keyvalues={"room_id": room_id, "user_id": user_id}, + values={}, + insertion_values={ + **insert_map, + "sender": sender, + "membership_event_id": membership_event_id, + "membership": membership, + "event_stream_ordering": membership_event_stream_ordering, + "event_instance_name": membership_event_instance_name, + }, + ) + # We need to find the `forgotten` value during the transaction because + # we can't risk inserting stale data. + if isinstance(txn.database_engine, PostgresEngine): + txn.execute( + """ + UPDATE sliding_sync_membership_snapshots + SET + forgotten = m.forgotten + FROM room_memberships AS m + WHERE sliding_sync_membership_snapshots.room_id = ? + AND sliding_sync_membership_snapshots.user_id = ? + AND membership_event_id = ? + AND membership_event_id = m.event_id + AND m.event_id IS NOT NULL + """, + ( + room_id, + user_id, + membership_event_id, + ), + ) + else: + # SQLite doesn't support UPDATE FROM before 3.33.0, so we do + # this via sub-selects. + txn.execute( + """ + UPDATE sliding_sync_membership_snapshots + SET + forgotten = (SELECT forgotten FROM room_memberships WHERE event_id = ?) + WHERE room_id = ? and user_id = ? AND membership_event_id = ? + """, + ( + membership_event_id, + room_id, + user_id, + membership_event_id, + ), + ) + + await self.db_pool.runInteraction( + "sliding_sync_membership_snapshots_bg_update", _fill_table_txn + ) + + # Update the progress + ( + room_id, + _room_id_from_rooms_table, + _room_version_id, + user_id, + _sender, + _membership_event_id, + _membership, + membership_event_stream_ordering, + _membership_event_instance_name, + _is_outlier, + ) = memberships_to_update_rows[-1] + + progress = { + "initial_phase": initial_phase, + "last_room_id": room_id, + "last_user_id": user_id, + "last_event_stream_ordering": last_event_stream_ordering, + } + if not initial_phase: + progress["last_event_stream_ordering"] = membership_event_stream_ordering + + await self.db_pool.updates._background_update_progress( + _BackgroundUpdates.SLIDING_SYNC_MEMBERSHIP_SNAPSHOTS_BG_UPDATE, + progress, + ) + + return len(memberships_to_update_rows) + + async def _sliding_sync_membership_snapshots_fix_forgotten_column_bg_update( + self, progress: JsonDict, batch_size: int + ) -> int: + """ + Background update to update the `sliding_sync_membership_snapshots` -> + `forgotten` column to be in sync with the `room_memberships` table. + + Because of previously flawed code (now fixed); any room that someone has + forgotten and subsequently re-joined or had any new membership on, we need to go + and update the column to match the `room_memberships` table as it has fallen out + of sync. + """ + last_event_stream_ordering = progress.get( + "last_event_stream_ordering", -(1 << 31) + ) + + def _txn( + txn: LoggingTransaction, + ) -> int: + """ + Returns: + The number of rows updated. + """ + + # To simplify things, we can just recheck any row in + # `sliding_sync_membership_snapshots` with `forgotten=1` + txn.execute( + """ + SELECT + s.room_id, + s.user_id, + s.membership_event_id, + s.event_stream_ordering, + m.forgotten + FROM sliding_sync_membership_snapshots AS s + INNER JOIN room_memberships AS m ON (s.membership_event_id = m.event_id) + WHERE s.event_stream_ordering > ? + AND s.forgotten = 1 + ORDER BY s.event_stream_ordering ASC + LIMIT ? + """, + (last_event_stream_ordering, batch_size), + ) + + memberships_to_update_rows = cast( + List[Tuple[str, str, str, int, int]], + txn.fetchall(), + ) + if not memberships_to_update_rows: + return 0 + + # Assemble the values to update + # + # (room_id, user_id) + key_values: List[Tuple[str, str]] = [] + # (forgotten,) + value_values: List[Tuple[int]] = [] + for ( + room_id, + user_id, + _membership_event_id, + _event_stream_ordering, + forgotten, + ) in memberships_to_update_rows: + key_values.append( + ( + room_id, + user_id, + ) + ) + value_values.append((forgotten,)) + + # Update all of the rows in one go + self.db_pool.simple_update_many_txn( + txn, + table="sliding_sync_membership_snapshots", + key_names=("room_id", "user_id"), + key_values=key_values, + value_names=("forgotten",), + value_values=value_values, + ) + + # Update the progress + ( + _room_id, + _user_id, + _membership_event_id, + event_stream_ordering, + _forgotten, + ) = memberships_to_update_rows[-1] + self.db_pool.updates._background_update_progress_txn( + txn, + _BackgroundUpdates.SLIDING_SYNC_MEMBERSHIP_SNAPSHOTS_FIX_FORGOTTEN_COLUMN_BG_UPDATE, + { + "last_event_stream_ordering": event_stream_ordering, + }, + ) + + return len(memberships_to_update_rows) + + num_rows = await self.db_pool.runInteraction( + "_sliding_sync_membership_snapshots_fix_forgotten_column_bg_update", + _txn, + ) + + if not num_rows: + await self.db_pool.updates._end_background_update( + _BackgroundUpdates.SLIDING_SYNC_MEMBERSHIP_SNAPSHOTS_FIX_FORGOTTEN_COLUMN_BG_UPDATE + ) + + return num_rows + + async def fixup_max_depth_cap_bg_update( + self, progress: JsonDict, batch_size: int + ) -> int: + """Fixes the topological ordering for events that have a depth greater + than MAX_DEPTH. This should fix /messages ordering oddities.""" + + room_id_bound = progress.get("room_id", "") + + def redo_max_depth_bg_update_txn(txn: LoggingTransaction) -> Tuple[bool, int]: + txn.execute( + """ + SELECT room_id, room_version FROM rooms + WHERE room_id > ? + ORDER BY room_id + LIMIT ? + """, + (room_id_bound, batch_size), + ) + + # Find the next room ID to process, with a relevant room version. + room_ids: List[str] = [] + max_room_id: Optional[str] = None + for room_id, room_version_str in txn: + max_room_id = room_id + + # We only want to process rooms with a known room version that + # has strict canonical json validation enabled. + room_version = KNOWN_ROOM_VERSIONS.get(room_version_str) + if room_version and room_version.strict_canonicaljson: + room_ids.append(room_id) + + if max_room_id is None: + # The query did not return any rooms, so we are done. + return True, 0 + + # Update the progress to the last room ID we pulled from the DB, + # this ensures we always make progress. + self.db_pool.updates._background_update_progress_txn( + txn, + _BackgroundUpdates.FIXUP_MAX_DEPTH_CAP, + progress={"room_id": max_room_id}, + ) + + if not room_ids: + # There were no rooms in this batch that required the fix. + return False, 0 + + clause, list_args = make_in_list_sql_clause( + self.database_engine, "room_id", room_ids + ) + sql = f""" + UPDATE events SET topological_ordering = ? + WHERE topological_ordering > ? AND {clause} + """ + args = [MAX_DEPTH, MAX_DEPTH] + args.extend(list_args) + txn.execute(sql, args) + + return False, len(room_ids) + + done, num_rooms = await self.db_pool.runInteraction( + "redo_max_depth_bg_update", redo_max_depth_bg_update_txn + ) + + if done: + await self.db_pool.updates._end_background_update( + _BackgroundUpdates.FIXUP_MAX_DEPTH_CAP + ) + + return num_rooms + + +def _resolve_stale_data_in_sliding_sync_tables( + txn: LoggingTransaction, +) -> None: + """ + Clears stale/out-of-date entries from the + `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` tables. + + This accounts for when someone downgrades their Synapse version and then upgrades it + again. This will ensure that we don't have any stale/out-of-date data in the + `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` tables since any new + events sent in rooms would have also needed to be written to the sliding sync + tables. For example a new event needs to bump `event_stream_ordering` in + `sliding_sync_joined_rooms` table or some state in the room changing (like the room + name). Or another example of someone's membership changing in a room affecting + `sliding_sync_membership_snapshots`. + + This way, if a row exists in the sliding sync tables, we are able to rely on it + (accurate data). And if a row doesn't exist, we use a fallback to get the same info + until the background updates fill in the rows or a new event comes in triggering it + to be fully inserted. + + FIXME: This can be removed once we bump `SCHEMA_COMPAT_VERSION` and run the + foreground update for + `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` (tracked by + https://github.com/element-hq/synapse/issues/17623) + """ + + _resolve_stale_data_in_sliding_sync_joined_rooms_table(txn) + _resolve_stale_data_in_sliding_sync_membership_snapshots_table(txn) + + +def _resolve_stale_data_in_sliding_sync_joined_rooms_table( + txn: LoggingTransaction, +) -> None: + """ + Clears stale/out-of-date entries from the `sliding_sync_joined_rooms` table and + kicks-off the background update to catch-up with what we missed while Synapse was + downgraded. + + See `_resolve_stale_data_in_sliding_sync_tables()` description above for more + context. + """ + + # Find the point when we stopped writing to the `sliding_sync_joined_rooms` table + txn.execute( + """ + SELECT event_stream_ordering + FROM sliding_sync_joined_rooms + ORDER BY event_stream_ordering DESC + LIMIT 1 + """, + ) + + # If we have nothing written to the `sliding_sync_joined_rooms` table, there is + # nothing to clean up + row = cast(Optional[Tuple[int]], txn.fetchone()) + max_stream_ordering_sliding_sync_joined_rooms_table = None + depends_on = None + if row is not None: + (max_stream_ordering_sliding_sync_joined_rooms_table,) = row + + txn.execute( + """ + SELECT room_id + FROM events + WHERE stream_ordering > ? + GROUP BY room_id + ORDER BY MAX(stream_ordering) ASC + """, + (max_stream_ordering_sliding_sync_joined_rooms_table,), + ) + + room_rows = txn.fetchall() + # No new events have been written to the `events` table since the last time we wrote + # to the `sliding_sync_joined_rooms` table so there is nothing to clean up. This is + # the expected normal scenario for people who have not downgraded their Synapse + # version. + if not room_rows: + return + + # 1000 is an arbitrary batch size with no testing + for chunk in batch_iter(room_rows, 1000): + # Handle updating the `sliding_sync_joined_rooms` table + # + # Clear out the stale data + DatabasePool.simple_delete_many_batch_txn( + txn, + table="sliding_sync_joined_rooms", + keys=("room_id",), + values=chunk, + ) + + # Update the `sliding_sync_joined_rooms_to_recalculate` table with the rooms + # that went stale and now need to be recalculated. + DatabasePool.simple_upsert_many_txn_native_upsert( + txn, + table="sliding_sync_joined_rooms_to_recalculate", + key_names=("room_id",), + key_values=chunk, + value_names=(), + # No value columns, therefore make a blank list so that the following + # zip() works correctly. + value_values=[() for x in range(len(chunk))], + ) + else: + # Avoid adding the background updates when there is no data to run them on (if + # the homeserver has no rooms). The portdb script refuses to run with pending + # background updates and since we potentially add them every time the server + # starts, we add this check for to allow the script to breath. + txn.execute("SELECT 1 FROM local_current_membership LIMIT 1") + row = txn.fetchone() + if row is None: + # There are no rooms, so don't schedule the bg update. + return + + # Re-run the `sliding_sync_joined_rooms_to_recalculate` prefill if there is + # nothing in the `sliding_sync_joined_rooms` table + DatabasePool.simple_upsert_txn_native_upsert( + txn, + table="background_updates", + keyvalues={ + "update_name": _BackgroundUpdates.SLIDING_SYNC_PREFILL_JOINED_ROOMS_TO_RECALCULATE_TABLE_BG_UPDATE + }, + values={}, + # Only insert the row if it doesn't already exist. If it already exists, + # we're already working on it + insertion_values={ + "progress_json": "{}", + }, + ) + depends_on = _BackgroundUpdates.SLIDING_SYNC_PREFILL_JOINED_ROOMS_TO_RECALCULATE_TABLE_BG_UPDATE + + # Now kick-off the background update to catch-up with what we missed while Synapse + # was downgraded. + # + # We may need to catch-up on everything if we have nothing written to the + # `sliding_sync_joined_rooms` table yet. This could happen if someone had zero rooms + # on their server (so the normal background update completes), downgrade Synapse + # versions, join and create some new rooms, and upgrade again. + DatabasePool.simple_upsert_txn_native_upsert( + txn, + table="background_updates", + keyvalues={ + "update_name": _BackgroundUpdates.SLIDING_SYNC_JOINED_ROOMS_BG_UPDATE + }, + values={}, + # Only insert the row if it doesn't already exist. If it already exists, we will + # eventually fill in the rows we're trying to populate. + insertion_values={ + # Empty progress is expected since it's not used for this background update. + "progress_json": "{}", + # Wait for the prefill to finish + "depends_on": depends_on, + }, + ) + + +def _resolve_stale_data_in_sliding_sync_membership_snapshots_table( + txn: LoggingTransaction, +) -> None: + """ + Clears stale/out-of-date entries from the `sliding_sync_membership_snapshots` table + and kicks-off the background update to catch-up with what we missed while Synapse + was downgraded. + + See `_resolve_stale_data_in_sliding_sync_tables()` description above for more + context. + """ + + # Find the point when we stopped writing to the `sliding_sync_membership_snapshots` table + txn.execute( + """ + SELECT event_stream_ordering + FROM sliding_sync_membership_snapshots + ORDER BY event_stream_ordering DESC + LIMIT 1 + """, + ) + + # If we have nothing written to the `sliding_sync_membership_snapshots` table, + # there is nothing to clean up + row = cast(Optional[Tuple[int]], txn.fetchone()) + max_stream_ordering_sliding_sync_membership_snapshots_table = None + if row is not None: + (max_stream_ordering_sliding_sync_membership_snapshots_table,) = row + + # XXX: Since `forgotten` is simply a flag on the `room_memberships` table that is + # set out-of-band, there is no way to tell whether it was set while Synapse was + # downgraded. The only thing the user can do is `/forget` again if they run into + # this. + # + # This only picks up changes to memberships. + txn.execute( + """ + SELECT user_id, room_id + FROM local_current_membership + WHERE event_stream_ordering > ? + ORDER BY event_stream_ordering ASC + """, + (max_stream_ordering_sliding_sync_membership_snapshots_table,), + ) + + membership_rows = txn.fetchall() + # No new events have been written to the `events` table since the last time we wrote + # to the `sliding_sync_membership_snapshots` table so there is nothing to clean up. + # This is the expected normal scenario for people who have not downgraded their + # Synapse version. + if not membership_rows: + return + + # 1000 is an arbitrary batch size with no testing + for chunk in batch_iter(membership_rows, 1000): + # Handle updating the `sliding_sync_membership_snapshots` table + # + DatabasePool.simple_delete_many_batch_txn( + txn, + table="sliding_sync_membership_snapshots", + keys=("user_id", "room_id"), + values=chunk, + ) + else: + # Avoid adding the background updates when there is no data to run them on (if + # the homeserver has no rooms). The portdb script refuses to run with pending + # background updates and since we potentially add them every time the server + # starts, we add this check for to allow the script to breath. + txn.execute("SELECT 1 FROM local_current_membership LIMIT 1") + row = txn.fetchone() + if row is None: + # There are no rooms, so don't schedule the bg update. + return + + # Now kick-off the background update to catch-up with what we missed while Synapse + # was downgraded. + # + # We may need to catch-up on everything if we have nothing written to the + # `sliding_sync_membership_snapshots` table yet. This could happen if someone had + # zero rooms on their server (so the normal background update completes), downgrade + # Synapse versions, join and create some new rooms, and upgrade again. + # + progress_json: JsonDict = {} + if max_stream_ordering_sliding_sync_membership_snapshots_table is not None: + progress_json["initial_phase"] = False + progress_json["last_event_stream_ordering"] = ( + max_stream_ordering_sliding_sync_membership_snapshots_table + ) + + DatabasePool.simple_upsert_txn_native_upsert( + txn, + table="background_updates", + keyvalues={ + "update_name": _BackgroundUpdates.SLIDING_SYNC_MEMBERSHIP_SNAPSHOTS_BG_UPDATE + }, + values={}, + # Only insert the row if it doesn't already exist. If it already exists, we will + # eventually fill in the rows we're trying to populate. + insertion_values={ + "progress_json": json_encoder.encode(progress_json), + }, + ) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 4d4877c4c3..3db4460f57 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py
@@ -30,6 +30,7 @@ from typing import ( Dict, Iterable, List, + Literal, Mapping, MutableMapping, Optional, @@ -41,7 +42,6 @@ from typing import ( import attr from prometheus_client import Gauge -from typing_extensions import Literal from twisted.internet import defer @@ -61,7 +61,13 @@ from synapse.logging.context import ( current_context, make_deferred_yieldable, ) -from synapse.logging.opentracing import start_active_span, tag_args, trace +from synapse.logging.opentracing import ( + SynapseTags, + set_tag, + start_active_span, + tag_args, + trace, +) from synapse.metrics.background_process_metrics import ( run_as_background_process, wrap_as_background_process, @@ -83,6 +89,7 @@ from synapse.storage.util.id_generators import ( from synapse.storage.util.sequence import build_sequence_generator from synapse.types import JsonDict, get_domain_from_id from synapse.types.state import StateFilter +from synapse.types.storage import _BackgroundUpdates from synapse.util import unwrapFirstError from synapse.util.async_helpers import ObservableDeferred, delay_cancellation from synapse.util.caches.descriptors import cached, cachedList @@ -98,6 +105,26 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +class DatabaseCorruptionError(RuntimeError): + """We found an event in the DB that has a persisted event ID that doesn't + match its computed event ID.""" + + def __init__( + self, room_id: str, persisted_event_id: str, computed_event_id: str + ) -> None: + self.room_id = room_id + self.persisted_event_id = persisted_event_id + self.computed_event_id = computed_event_id + + message = ( + f"Database corruption: Event {persisted_event_id} in room {room_id} " + f"from the database appears to have been modified (calculated " + f"event id {computed_event_id})" + ) + + super().__init__(message) + + # These values are used in the `enqueue_event` and `_fetch_loop` methods to # control how we batch/bulk fetch events from the database. # The values are plucked out of thing air to make initial sync run faster @@ -166,6 +193,14 @@ class _EventRow: outlier: bool +@attr.s(slots=True, frozen=True, auto_attribs=True) +class EventMetadata: + """Event metadata returned by `get_metadata_for_event(..)`""" + + sender: str + received_ts: int + + class EventRedactBehaviour(Enum): """ What to do when retrieving a redacted event from the database. @@ -304,6 +339,16 @@ class EventsWorkerStore(SQLBaseStore): writers=["master"], ) + # Added to accommodate some queries for the admin API in order to fetch/filter + # membership events by when it was received + self.db_pool.updates.register_background_index_update( + update_name="events_received_ts_index", + index_name="received_ts_idx", + table="events", + columns=("received_ts",), + where_clause="type = 'm.room.member'", + ) + def get_un_partial_stated_events_token(self, instance_name: str) -> int: return ( self._un_partial_stated_events_stream_id_gen.get_current_token_for_writer( @@ -457,6 +502,8 @@ class EventsWorkerStore(SQLBaseStore): ) -> Optional[EventBase]: """Get an event from the database by event_id. + Events for unknown room versions will also be filtered out. + Args: event_id: The event_id of the event to fetch @@ -502,6 +549,7 @@ class EventsWorkerStore(SQLBaseStore): return event + @trace async def get_events( self, event_ids: Collection[str], @@ -511,6 +559,10 @@ class EventsWorkerStore(SQLBaseStore): ) -> Dict[str, EventBase]: """Get events from the database + Unknown events will be omitted from the response. + + Events for unknown room versions will also be filtered out. + Args: event_ids: The event_ids of the events to fetch @@ -529,6 +581,11 @@ class EventsWorkerStore(SQLBaseStore): Returns: A mapping from event_id to event. """ + set_tag( + SynapseTags.FUNC_ARG_PREFIX + "event_ids.length", + str(len(event_ids)), + ) + events = await self.get_events_as_list( event_ids, redact_behaviour=redact_behaviour, @@ -553,6 +610,8 @@ class EventsWorkerStore(SQLBaseStore): Unknown events will be omitted from the response. + Events for unknown room versions will also be filtered out. + Args: event_ids: The event_ids of the events to fetch @@ -574,6 +633,10 @@ class EventsWorkerStore(SQLBaseStore): Note that the returned list may be smaller than the list of event IDs if not all events could be fetched. """ + set_tag( + SynapseTags.FUNC_ARG_PREFIX + "event_ids.length", + str(len(event_ids)), + ) if not event_ids: return [] @@ -694,10 +757,11 @@ class EventsWorkerStore(SQLBaseStore): return events + @trace @cancellable async def get_unredacted_events_from_cache_or_db( self, - event_ids: Iterable[str], + event_ids: Collection[str], allow_rejected: bool = False, ) -> Dict[str, EventCacheEntry]: """Fetch a bunch of events from the cache or the database. @@ -719,6 +783,11 @@ class EventsWorkerStore(SQLBaseStore): Returns: map from event id to result """ + set_tag( + SynapseTags.FUNC_ARG_PREFIX + "event_ids.length", + str(len(event_ids)), + ) + # Shortcut: check if we have any events in the *in memory* cache - this function # may be called repeatedly for the same event so at this point we cannot reach # out to any external cache for performance reasons. The external cache is @@ -755,9 +824,9 @@ class EventsWorkerStore(SQLBaseStore): if missing_events_ids: - async def get_missing_events_from_cache_or_db() -> ( - Dict[str, EventCacheEntry] - ): + async def get_missing_events_from_cache_or_db() -> Dict[ + str, EventCacheEntry + ]: """Fetches the events in `missing_event_ids` from the database. Also creates entries in `self._current_event_fetches` to allow @@ -907,7 +976,7 @@ class EventsWorkerStore(SQLBaseStore): events, update_metrics=update_metrics ) - missing_event_ids = (e for e in events if e not in event_map) + missing_event_ids = [e for e in events if e not in event_map] event_map.update( await self._get_events_from_external_cache( events=missing_event_ids, @@ -917,8 +986,9 @@ class EventsWorkerStore(SQLBaseStore): return event_map + @trace async def _get_events_from_external_cache( - self, events: Iterable[str], update_metrics: bool = True + self, events: Collection[str], update_metrics: bool = True ) -> Dict[str, EventCacheEntry]: """Fetch events from any configured external cache. @@ -928,6 +998,10 @@ class EventsWorkerStore(SQLBaseStore): events: list of event_ids to fetch update_metrics: Whether to update the cache hit ratio metrics """ + set_tag( + SynapseTags.FUNC_ARG_PREFIX + "events.length", + str(len(events)), + ) event_map = {} for event_id in events: @@ -1193,6 +1267,7 @@ class EventsWorkerStore(SQLBaseStore): with PreserveLoggingContext(): self.hs.get_reactor().callFromThread(fire_errback, e) + @trace async def _get_events_from_db( self, event_ids: Collection[str] ) -> Dict[str, EventCacheEntry]: @@ -1211,6 +1286,11 @@ class EventsWorkerStore(SQLBaseStore): map from event id to result. May return extra events which weren't asked for. """ + set_tag( + SynapseTags.FUNC_ARG_PREFIX + "event_ids.length", + str(len(event_ids)), + ) + fetched_event_ids: Set[str] = set() fetched_events: Dict[str, _EventRow] = {} @@ -1356,10 +1436,8 @@ class EventsWorkerStore(SQLBaseStore): if original_ev.event_id != event_id: # it's difficult to see what to do here. Pretty much all bets are off # if Synapse cannot rely on the consistency of its database. - raise RuntimeError( - f"Database corruption: Event {event_id} in room {d['room_id']} " - f"from the database appears to have been modified (calculated " - f"event id {original_ev.event_id})" + raise DatabaseCorruptionError( + d["room_id"], event_id, original_ev.event_id ) event_map[event_id] = original_ev @@ -1639,7 +1717,7 @@ class EventsWorkerStore(SQLBaseStore): txn.database_engine, "e.event_id", event_ids ) txn.execute(sql + clause, args) - found_events = {eid for eid, in txn} + found_events = {eid for (eid,) in txn} # ... and then we can update the results for each key return {eid: (eid in found_events) for eid in event_ids} @@ -1838,9 +1916,9 @@ class EventsWorkerStore(SQLBaseStore): " LIMIT ?" ) txn.execute(sql, (-last_id, -current_id, instance_name, limit)) - new_event_updates: List[Tuple[int, Tuple[str, str, str, str, str, str]]] = ( - [] - ) + new_event_updates: List[ + Tuple[int, Tuple[str, str, str, str, str, str]] + ] = [] row: Tuple[int, str, str, str, str, str, str] # Type safety: iterating over `txn` yields `Tuple`, i.e. # `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a @@ -2439,3 +2517,141 @@ class EventsWorkerStore(SQLBaseStore): ) self.invalidate_get_event_cache_after_txn(txn, event_id) + + async def get_events_sent_by_user_in_room( + self, user_id: str, room_id: str, limit: int, filter: Optional[List[str]] = None + ) -> Optional[List[str]]: + """ + Get a list of event ids of events sent by the user in the specified room + + Args: + user_id: user ID to search against + room_id: room ID of the room to search for events in + filter: type of events to filter for + limit: maximum number of event ids to return + """ + + def _get_events_by_user_in_room_txn( + txn: LoggingTransaction, + user_id: str, + room_id: str, + filter: Optional[List[str]], + batch_size: int, + offset: int, + ) -> Tuple[Optional[List[str]], int]: + if filter: + base_clause, args = make_in_list_sql_clause( + txn.database_engine, "type", filter + ) + clause = f"AND {base_clause}" + parameters = (user_id, room_id, *args, batch_size, offset) + else: + clause = "" + parameters = (user_id, room_id, batch_size, offset) + + sql = f""" + SELECT event_id FROM events + WHERE sender = ? AND room_id = ? + {clause} + ORDER BY received_ts DESC + LIMIT ? + OFFSET ? + """ + txn.execute(sql, parameters) + res = txn.fetchall() + if res: + events = [row[0] for row in res] + else: + events = None + + return events, offset + batch_size + + offset = 0 + batch_size = 100 + if batch_size > limit: + batch_size = limit + + selected_ids: List[str] = [] + while offset < limit: + res, offset = await self.db_pool.runInteraction( + "get_events_by_user", + _get_events_by_user_in_room_txn, + user_id, + room_id, + filter, + batch_size, + offset, + ) + if res: + selected_ids = selected_ids + res + else: + break + return selected_ids + + async def have_finished_sliding_sync_background_jobs(self) -> bool: + """Return if it's safe to use the sliding sync membership tables.""" + + return await self.db_pool.updates.have_completed_background_updates( + ( + _BackgroundUpdates.SLIDING_SYNC_PREFILL_JOINED_ROOMS_TO_RECALCULATE_TABLE_BG_UPDATE, + _BackgroundUpdates.SLIDING_SYNC_JOINED_ROOMS_BG_UPDATE, + _BackgroundUpdates.SLIDING_SYNC_MEMBERSHIP_SNAPSHOTS_BG_UPDATE, + ) + ) + + async def get_sent_invite_count_by_user(self, user_id: str, from_ts: int) -> int: + """ + Get the number of invites sent by the given user at or after the provided timestamp. + + Args: + user_id: user ID to search against + from_ts: a timestamp in milliseconds from the unix epoch. Filters against + `events.received_ts` + + """ + + def _get_sent_invite_count_by_user_txn( + txn: LoggingTransaction, user_id: str, from_ts: int + ) -> int: + sql = """ + SELECT COUNT(rm.event_id) + FROM room_memberships AS rm + INNER JOIN events AS e USING(event_id) + WHERE rm.sender = ? + AND rm.membership = 'invite' + AND e.type = 'm.room.member' + AND e.received_ts >= ? + """ + + txn.execute(sql, (user_id, from_ts)) + res = txn.fetchone() + + if res is None: + return 0 + return int(res[0]) + + return await self.db_pool.runInteraction( + "_get_sent_invite_count_by_user_txn", + _get_sent_invite_count_by_user_txn, + user_id, + from_ts, + ) + + @cached(tree=True) + async def get_metadata_for_event( + self, room_id: str, event_id: str + ) -> Optional[EventMetadata]: + row = await self.db_pool.simple_select_one( + table="events", + keyvalues={"room_id": room_id, "event_id": event_id}, + retcols=("sender", "received_ts"), + allow_none=True, + desc="get_metadata_for_event", + ) + if row is None: + return None + + return EventMetadata( + sender=row[0], + received_ts=row[1], + ) diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 7617fd3ad4..04866524e3 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py
@@ -19,6 +19,7 @@ # [This file includes modifications made by New Vector Limited] # # +import logging from enum import Enum from typing import ( TYPE_CHECKING, @@ -51,6 +52,8 @@ BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2 = ( "media_repository_drop_index_wo_method_2" ) +logger = logging.getLogger(__name__) + @attr.s(slots=True, frozen=True, auto_attribs=True) class LocalMedia: @@ -65,6 +68,7 @@ class LocalMedia: safe_from_quarantine: bool user_id: Optional[str] authenticated: Optional[bool] + sha256: Optional[str] @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -79,6 +83,7 @@ class RemoteMedia: last_access_ts: int quarantined_by: Optional[str] authenticated: Optional[bool] + sha256: Optional[str] @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -154,6 +159,26 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): unique=True, ) + self.db_pool.updates.register_background_index_update( + update_name="local_media_repository_sha256_idx", + index_name="local_media_repository_sha256", + table="local_media_repository", + where_clause="sha256 IS NOT NULL", + columns=[ + "sha256", + ], + ) + + self.db_pool.updates.register_background_index_update( + update_name="remote_media_cache_sha256_idx", + index_name="remote_media_cache_sha256", + table="remote_media_cache", + where_clause="sha256 IS NOT NULL", + columns=[ + "sha256", + ], + ) + self.db_pool.updates.register_background_update_handler( BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2, self._drop_media_index_without_method, @@ -221,6 +246,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): "safe_from_quarantine", "user_id", "authenticated", + "sha256", ), allow_none=True, desc="get_local_media", @@ -239,6 +265,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): safe_from_quarantine=row[7], user_id=row[8], authenticated=row[9], + sha256=row[10], ) async def get_local_media_by_user_paginate( @@ -295,7 +322,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): quarantined_by, safe_from_quarantine, user_id, - authenticated + authenticated, + sha256 FROM local_media_repository WHERE user_id = ? ORDER BY {order_by_column} {order}, media_id ASC @@ -320,6 +348,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): safe_from_quarantine=bool(row[8]), user_id=row[9], authenticated=row[10], + sha256=row[11], ) for row in txn ] @@ -449,6 +478,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): media_length: int, user_id: UserID, url_cache: Optional[str] = None, + sha256: Optional[str] = None, + quarantined_by: Optional[str] = None, ) -> None: if self.hs.config.media.enable_authenticated_media: authenticated = True @@ -466,6 +497,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): "user_id": user_id.to_string(), "url_cache": url_cache, "authenticated": authenticated, + "sha256": sha256, + "quarantined_by": quarantined_by, }, desc="store_local_media", ) @@ -477,20 +510,28 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): upload_name: Optional[str], media_length: int, user_id: UserID, + sha256: str, url_cache: Optional[str] = None, + quarantined_by: Optional[str] = None, ) -> None: + updatevalues = { + "media_type": media_type, + "upload_name": upload_name, + "media_length": media_length, + "url_cache": url_cache, + "sha256": sha256, + } + + # This should never be un-set by this function. + if quarantined_by is not None: + updatevalues["quarantined_by"] = quarantined_by + await self.db_pool.simple_update_one( "local_media_repository", keyvalues={ - "user_id": user_id.to_string(), "media_id": media_id, }, - updatevalues={ - "media_type": media_type, - "upload_name": upload_name, - "media_length": media_length, - "url_cache": url_cache, - }, + updatevalues=updatevalues, desc="update_local_media", ) @@ -657,6 +698,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): "last_access_ts", "quarantined_by", "authenticated", + "sha256", ), allow_none=True, desc="get_cached_remote_media", @@ -674,6 +716,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): last_access_ts=row[5], quarantined_by=row[6], authenticated=row[7], + sha256=row[8], ) async def store_cached_remote_media( @@ -685,6 +728,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): time_now_ms: int, upload_name: Optional[str], filesystem_id: str, + sha256: Optional[str], ) -> None: if self.hs.config.media.enable_authenticated_media: authenticated = True @@ -703,6 +747,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): "filesystem_id": filesystem_id, "last_access_ts": time_now_ms, "authenticated": authenticated, + "sha256": sha256, }, desc="store_cached_remote_media", ) @@ -729,10 +774,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): txn.execute_batch( sql, - ( + [ (time_ms, media_origin, media_id) for media_origin, media_id in remote_media - ), + ], ) sql = ( @@ -740,7 +785,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): " WHERE media_id = ?" ) - txn.execute_batch(sql, ((time_ms, media_id) for media_id in local_media)) + txn.execute_batch(sql, [(time_ms, media_id) for media_id in local_media]) await self.db_pool.runInteraction( "update_cached_last_access_time", update_cache_txn @@ -946,3 +991,46 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): await self.db_pool.runInteraction( "delete_url_cache_media", _delete_url_cache_media_txn ) + + async def get_is_hash_quarantined(self, sha256: str) -> bool: + """Get whether a specific sha256 hash digest matches any quarantined media. + + Returns: + None if the media_id doesn't exist. + """ + + # If we don't have the index yet, performance tanks, so we return False. + # In the background updates, remote_media_cache_sha256_idx is created + # after local_media_repository_sha256_idx, which is why we only need to + # check for the completion of the former. + if not await self.db_pool.updates.has_completed_background_update( + "remote_media_cache_sha256_idx" + ): + return False + + def get_matching_media_txn( + txn: LoggingTransaction, table: str, sha256: str + ) -> bool: + # Return on first match + sql = """ + SELECT 1 + FROM local_media_repository + WHERE sha256 = ? AND quarantined_by IS NOT NULL + + UNION ALL + + SELECT 1 + FROM remote_media_cache + WHERE sha256 = ? AND quarantined_by IS NOT NULL + LIMIT 1 + """ + txn.execute(sql, (sha256, sha256)) + row = txn.fetchone() + return row is not None + + return await self.db_pool.runInteraction( + "get_matching_media_txn", + get_matching_media_txn, + "local_media_repository", + sha256, + ) diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index 8e948c5e8d..c384675839 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -29,7 +29,6 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.registration import RegistrationWorkerStore from synapse.util.caches.descriptors import cached -from synapse.util.threepids import canonicalise_email if TYPE_CHECKING: from synapse.server import HomeServer @@ -65,18 +64,6 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore): self._mau_stats_only = hs.config.server.mau_stats_only - if self._update_on_this_worker: - # Do not add more reserved users than the total allowable number - self.db_pool.new_transaction( - db_conn, - "initialise_mau_threepids", - [], - [], - [], - self._initialise_reserved_users, - hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value], - ) - @cached(num_args=0) async def get_monthly_active_count(self) -> int: """Generates current count of monthly active users @@ -174,26 +161,6 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore): return await self.db_pool.runInteraction("list_users", _list_users) - async def get_registered_reserved_users(self) -> List[str]: - """Of the reserved threepids defined in config, retrieve those that are associated - with registered users - - Returns: - User IDs of actual users that are reserved - """ - users = [] - - for tp in self.hs.config.server.mau_limits_reserved_threepids[ - : self.hs.config.server.max_mau_value - ]: - user_id = await self.hs.get_datastores().main.get_user_id_by_threepid( - tp["medium"], canonicalise_email(tp["address"]) - ) - if user_id: - users.append(user_id) - - return users - @cached(num_args=1) async def user_last_seen_monthly_active(self, user_id: str) -> Optional[int]: """ @@ -289,50 +256,10 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore): ) self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ()) - reserved_users = await self.get_registered_reserved_users() await self.db_pool.runInteraction( - "reap_monthly_active_users", _reap_users, reserved_users + "reap_monthly_active_users", _reap_users, [] ) - def _initialise_reserved_users( - self, txn: LoggingTransaction, threepids: List[dict] - ) -> None: - """Ensures that reserved threepids are accounted for in the MAU table, should - be called on start up. - - Args: - txn: - threepids: List of threepid dicts to reserve - """ - assert ( - self._update_on_this_worker - ), "This worker is not designated to update MAUs" - - # XXX what is this function trying to achieve? It upserts into - # monthly_active_users for each *registered* reserved mau user, but why? - # - # - shouldn't there already be an entry for each reserved user (at least - # if they have been active recently)? - # - # - if it's important that the timestamp is kept up to date, why do we only - # run this at startup? - - for tp in threepids: - user_id = self.get_user_id_by_threepid_txn(txn, tp["medium"], tp["address"]) - - if user_id: - is_support = self.is_support_user_txn(txn, user_id) - if not is_support: - # We do this manually here to avoid hitting https://github.com/matrix-org/synapse/issues/6791 - self.db_pool.simple_upsert_txn( - txn, - table="monthly_active_users", - keyvalues={"user_id": user_id}, - values={"timestamp": int(self._clock.time_msec())}, - ) - else: - logger.warning("mau limit reserved threepid %s not found in db" % tp) - async def upsert_monthly_active_user(self, user_id: str) -> None: """Updates or inserts the user into the monthly active user table, which is used to track the current MAU usage of the server @@ -340,9 +267,9 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore): Args: user_id: user to add/update """ - assert ( - self._update_on_this_worker - ), "This worker is not designated to update MAUs" + assert self._update_on_this_worker, ( + "This worker is not designated to update MAUs" + ) # Support user never to be included in MAU stats. Note I can't easily call this # from upsert_monthly_active_user_txn because then I need a _txn form of @@ -379,9 +306,9 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore): txn: user_id: user to add/update """ - assert ( - self._update_on_this_worker - ), "This worker is not designated to update MAUs" + assert self._update_on_this_worker, ( + "This worker is not designated to update MAUs" + ) # Am consciously deciding to lock the table on the basis that is ought # never be a big table and alternative approaches (batching multiple @@ -409,9 +336,9 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore): Args: user_id: the user_id to query """ - assert ( - self._update_on_this_worker - ), "This worker is not designated to update MAUs" + assert self._update_on_this_worker, ( + "This worker is not designated to update MAUs" + ) if self._limit_usage_by_mau or self._mau_stats_only: # Trial users and guests should not be included as part of MAU group diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index 996aea808d..30d8a58d96 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py
@@ -18,8 +18,13 @@ # [This file includes modifications made by New Vector Limited] # # -from typing import TYPE_CHECKING, Optional +import json +from typing import TYPE_CHECKING, Dict, Optional, Tuple, cast +from canonicaljson import encode_canonical_json + +from synapse.api.constants import ProfileFields +from synapse.api.errors import Codes, StoreError from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( DatabasePool, @@ -27,13 +32,17 @@ from synapse.storage.database import ( LoggingTransaction, ) from synapse.storage.databases.main.roommember import ProfileInfo -from synapse.storage.engines import PostgresEngine -from synapse.types import JsonDict, UserID +from synapse.storage.engines import PostgresEngine, Sqlite3Engine +from synapse.types import JsonDict, JsonValue, UserID if TYPE_CHECKING: from synapse.server import HomeServer +# The number of bytes that the serialized profile can have. +MAX_PROFILE_SIZE = 65536 + + class ProfileWorkerStore(SQLBaseStore): def __init__( self, @@ -144,6 +153,16 @@ class ProfileWorkerStore(SQLBaseStore): return 50 async def get_profileinfo(self, user_id: UserID) -> ProfileInfo: + """ + Fetch the display name and avatar URL of a user. + + Args: + user_id: The user ID to fetch the profile for. + + Returns: + The user's display name and avatar URL. Values may be null if unset + or if the user doesn't exist. + """ profile = await self.db_pool.simple_select_one( table="profiles", keyvalues={"full_user_id": user_id.to_string()}, @@ -158,6 +177,15 @@ class ProfileWorkerStore(SQLBaseStore): return ProfileInfo(avatar_url=profile[1], display_name=profile[0]) async def get_profile_displayname(self, user_id: UserID) -> Optional[str]: + """ + Fetch the display name of a user. + + Args: + user_id: The user to get the display name for. + + Raises: + 404 if the user does not exist. + """ return await self.db_pool.simple_select_one_onecol( table="profiles", keyvalues={"full_user_id": user_id.to_string()}, @@ -166,6 +194,15 @@ class ProfileWorkerStore(SQLBaseStore): ) async def get_profile_avatar_url(self, user_id: UserID) -> Optional[str]: + """ + Fetch the avatar URL of a user. + + Args: + user_id: The user to get the avatar URL for. + + Raises: + 404 if the user does not exist. + """ return await self.db_pool.simple_select_one_onecol( table="profiles", keyvalues={"full_user_id": user_id.to_string()}, @@ -173,7 +210,96 @@ class ProfileWorkerStore(SQLBaseStore): desc="get_profile_avatar_url", ) + async def get_profile_field(self, user_id: UserID, field_name: str) -> JsonValue: + """ + Get a custom profile field for a user. + + Args: + user_id: The user's ID. + field_name: The custom profile field name. + + Returns: + The string value if the field exists, otherwise raises 404. + """ + + def get_profile_field(txn: LoggingTransaction) -> JsonValue: + # This will error if field_name has double quotes in it, but that's not + # possible due to the grammar. + field_path = f'$."{field_name}"' + + if isinstance(self.database_engine, PostgresEngine): + sql = """ + SELECT JSONB_PATH_EXISTS(fields, ?), JSONB_EXTRACT_PATH(fields, ?) + FROM profiles + WHERE user_id = ? + """ + txn.execute( + sql, + (field_path, field_name, user_id.localpart), + ) + + # Test exists first since value being None is used for both + # missing and a null JSON value. + exists, value = cast(Tuple[bool, JsonValue], txn.fetchone()) + if not exists: + raise StoreError(404, "No row found") + return value + + else: + sql = """ + SELECT JSON_TYPE(fields, ?), JSON_EXTRACT(fields, ?) + FROM profiles + WHERE user_id = ? + """ + txn.execute( + sql, + (field_path, field_path, user_id.localpart), + ) + + # If value_type is None, then the value did not exist. + value_type, value = cast( + Tuple[Optional[str], JsonValue], txn.fetchone() + ) + if not value_type: + raise StoreError(404, "No row found") + # If value_type is object or array, then need to deserialize the JSON. + # Scalar values are properly returned directly. + if value_type in ("object", "array"): + assert isinstance(value, str) + return json.loads(value) + return value + + return await self.db_pool.runInteraction("get_profile_field", get_profile_field) + + async def get_profile_fields(self, user_id: UserID) -> Dict[str, str]: + """ + Get all custom profile fields for a user. + + Args: + user_id: The user's ID. + + Returns: + A dictionary of custom profile fields. + """ + result = await self.db_pool.simple_select_one_onecol( + table="profiles", + keyvalues={"full_user_id": user_id.to_string()}, + retcol="fields", + desc="get_profile_fields", + ) + # The SQLite driver doesn't automatically convert JSON to + # Python objects + if isinstance(self.database_engine, Sqlite3Engine) and result: + result = json.loads(result) + return result or {} + async def create_profile(self, user_id: UserID) -> None: + """ + Create a blank profile for a user. + + Args: + user_id: The user to create the profile for. + """ user_localpart = user_id.localpart await self.db_pool.simple_insert( table="profiles", @@ -181,6 +307,71 @@ class ProfileWorkerStore(SQLBaseStore): desc="create_profile", ) + def _check_profile_size( + self, + txn: LoggingTransaction, + user_id: UserID, + new_field_name: str, + new_value: JsonValue, + ) -> None: + # For each entry there are 4 quotes (2 each for key and value), 1 colon, + # and 1 comma. + PER_VALUE_EXTRA = 6 + + # Add the size of the current custom profile fields, ignoring the entry + # which will be overwritten. + if isinstance(txn.database_engine, PostgresEngine): + size_sql = """ + SELECT + OCTET_LENGTH((fields - ?)::text), OCTET_LENGTH(displayname), OCTET_LENGTH(avatar_url) + FROM profiles + WHERE + user_id = ? + """ + txn.execute( + size_sql, + (new_field_name, user_id.localpart), + ) + else: + size_sql = """ + SELECT + LENGTH(json_remove(fields, ?)), LENGTH(displayname), LENGTH(avatar_url) + FROM profiles + WHERE + user_id = ? + """ + txn.execute( + size_sql, + # This will error if field_name has double quotes in it, but that's not + # possible due to the grammar. + (f'$."{new_field_name}"', user_id.localpart), + ) + row = cast(Tuple[Optional[int], Optional[int], Optional[int]], txn.fetchone()) + + # The values return null if the column is null. + total_bytes = ( + # Discount the opening and closing braces to avoid double counting, + # but add one for a comma. + # -2 + 1 = -1 + (row[0] - 1 if row[0] else 0) + + ( + row[1] + len("displayname") + PER_VALUE_EXTRA + if new_field_name != ProfileFields.DISPLAYNAME and row[1] + else 0 + ) + + ( + row[2] + len("avatar_url") + PER_VALUE_EXTRA + if new_field_name != ProfileFields.AVATAR_URL and row[2] + else 0 + ) + ) + + # Add the length of the field being added + the braces. + total_bytes += len(encode_canonical_json({new_field_name: new_value})) + + if total_bytes > MAX_PROFILE_SIZE: + raise StoreError(400, "Profile too large", Codes.PROFILE_TOO_LARGE) + async def set_profile_displayname( self, user_id: UserID, new_displayname: Optional[str] ) -> None: @@ -193,14 +384,25 @@ class ProfileWorkerStore(SQLBaseStore): name is removed. """ user_localpart = user_id.localpart - await self.db_pool.simple_upsert( - table="profiles", - keyvalues={"user_id": user_localpart}, - values={ - "displayname": new_displayname, - "full_user_id": user_id.to_string(), - }, - desc="set_profile_displayname", + + def set_profile_displayname(txn: LoggingTransaction) -> None: + if new_displayname is not None: + self._check_profile_size( + txn, user_id, ProfileFields.DISPLAYNAME, new_displayname + ) + + self.db_pool.simple_upsert_txn( + txn, + table="profiles", + keyvalues={"user_id": user_localpart}, + values={ + "displayname": new_displayname, + "full_user_id": user_id.to_string(), + }, + ) + + await self.db_pool.runInteraction( + "set_profile_displayname", set_profile_displayname ) async def set_profile_avatar_url( @@ -215,13 +417,125 @@ class ProfileWorkerStore(SQLBaseStore): removed. """ user_localpart = user_id.localpart - await self.db_pool.simple_upsert( - table="profiles", - keyvalues={"user_id": user_localpart}, - values={"avatar_url": new_avatar_url, "full_user_id": user_id.to_string()}, - desc="set_profile_avatar_url", + + def set_profile_avatar_url(txn: LoggingTransaction) -> None: + if new_avatar_url is not None: + self._check_profile_size( + txn, user_id, ProfileFields.AVATAR_URL, new_avatar_url + ) + + self.db_pool.simple_upsert_txn( + txn, + table="profiles", + keyvalues={"user_id": user_localpart}, + values={ + "avatar_url": new_avatar_url, + "full_user_id": user_id.to_string(), + }, + ) + + await self.db_pool.runInteraction( + "set_profile_avatar_url", set_profile_avatar_url ) + async def set_profile_field( + self, user_id: UserID, field_name: str, new_value: JsonValue + ) -> None: + """ + Set a custom profile field for a user. + + Args: + user_id: The user's ID. + field_name: The name of the custom profile field. + new_value: The value of the custom profile field. + """ + + # Encode to canonical JSON. + canonical_value = encode_canonical_json(new_value) + + def set_profile_field(txn: LoggingTransaction) -> None: + self._check_profile_size(txn, user_id, field_name, new_value) + + if isinstance(self.database_engine, PostgresEngine): + from psycopg2.extras import Json + + # Note that the || jsonb operator is not recursive, any duplicate + # keys will be taken from the second value. + sql = """ + INSERT INTO profiles (user_id, full_user_id, fields) VALUES (?, ?, JSON_BUILD_OBJECT(?, ?::jsonb)) + ON CONFLICT (user_id) + DO UPDATE SET full_user_id = EXCLUDED.full_user_id, fields = COALESCE(profiles.fields, '{}'::jsonb) || EXCLUDED.fields + """ + + txn.execute( + sql, + ( + user_id.localpart, + user_id.to_string(), + field_name, + # Pass as a JSON object since we have passing bytes disabled + # at the database driver. + Json(json.loads(canonical_value)), + ), + ) + else: + # You may be tempted to use json_patch instead of providing the parameters + # twice, but that recursively merges objects instead of replacing. + sql = """ + INSERT INTO profiles (user_id, full_user_id, fields) VALUES (?, ?, JSON_OBJECT(?, JSON(?))) + ON CONFLICT (user_id) + DO UPDATE SET full_user_id = EXCLUDED.full_user_id, fields = JSON_SET(COALESCE(profiles.fields, '{}'), ?, JSON(?)) + """ + # This will error if field_name has double quotes in it, but that's not + # possible due to the grammar. + json_field_name = f'$."{field_name}"' + + txn.execute( + sql, + ( + user_id.localpart, + user_id.to_string(), + json_field_name, + canonical_value, + json_field_name, + canonical_value, + ), + ) + + await self.db_pool.runInteraction("set_profile_field", set_profile_field) + + async def delete_profile_field(self, user_id: UserID, field_name: str) -> None: + """ + Remove a custom profile field for a user. + + Args: + user_id: The user's ID. + field_name: The name of the custom profile field. + """ + + def delete_profile_field(txn: LoggingTransaction) -> None: + if isinstance(self.database_engine, PostgresEngine): + sql = """ + UPDATE profiles SET fields = fields - ? + WHERE user_id = ? + """ + txn.execute( + sql, + (field_name, user_id.localpart), + ) + else: + sql = """ + UPDATE profiles SET fields = json_remove(fields, ?) + WHERE user_id = ? + """ + txn.execute( + sql, + # This will error if field_name has double quotes in it. + (f'$."{field_name}"', user_id.localpart), + ) + + await self.db_pool.runInteraction("delete_profile_field", delete_profile_field) + class ProfileStore(ProfileWorkerStore): pass diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 3b81ed943c..a11f522f03 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py
@@ -20,7 +20,7 @@ # import logging -from typing import Any, List, Set, Tuple, cast +from typing import Any, Set, Tuple, cast from synapse.api.errors import SynapseError from synapse.storage.database import LoggingTransaction @@ -199,9 +199,8 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): # Update backward extremeties txn.execute_batch( - "INSERT INTO event_backward_extremities (room_id, event_id)" - " VALUES (?, ?)", - [(room_id, event_id) for event_id, in new_backwards_extrems], + "INSERT INTO event_backward_extremities (room_id, event_id) VALUES (?, ?)", + [(room_id, event_id) for (event_id,) in new_backwards_extrems], ) logger.info("[purge] finding state groups referenced by deleted events") @@ -215,7 +214,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): """ ) - referenced_state_groups = {sg for sg, in txn} + referenced_state_groups = {sg for (sg,) in txn} logger.info( "[purge] found %i referenced state groups", len(referenced_state_groups) ) @@ -332,7 +331,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): return referenced_state_groups - async def purge_room(self, room_id: str) -> List[int]: + async def purge_room(self, room_id: str) -> None: """Deletes all record of a room Args: @@ -348,7 +347,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): # purge any of those rows which were added during the first. logger.info("[purge] Starting initial main purge of [1/2]") - state_groups_to_delete = await self.db_pool.runInteraction( + await self.db_pool.runInteraction( "purge_room", self._purge_room_txn, room_id=room_id, @@ -356,18 +355,15 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): ) logger.info("[purge] Starting secondary main purge of [2/2]") - state_groups_to_delete.extend( - await self.db_pool.runInteraction( - "purge_room", - self._purge_room_txn, - room_id=room_id, - ), + await self.db_pool.runInteraction( + "purge_room", + self._purge_room_txn, + room_id=room_id, ) - logger.info("[purge] Done with main purge") - return state_groups_to_delete + logger.info("[purge] Done with main purge") - def _purge_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[int]: + def _purge_room_txn(self, txn: LoggingTransaction, room_id: str) -> None: # This collides with event persistence so we cannot write new events and metadata into # a room while deleting it or this transaction will fail. if isinstance(self.database_engine, PostgresEngine): @@ -376,18 +372,10 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): (room_id,), ) - # First, fetch all the state groups that should be deleted, before - # we delete that information. - txn.execute( - """ - SELECT DISTINCT state_group FROM events - INNER JOIN event_to_state_groups USING(event_id) - WHERE events.room_id = ? - """, - (room_id,), - ) - - state_groups = [row[0] for row in txn] + if isinstance(self.database_engine, PostgresEngine): + # Disable statement timeouts for this transaction; purging rooms can + # take a while! + txn.execute("SET LOCAL statement_timeout = 0") # Get all the auth chains that are referenced by events that are to be # deleted. @@ -454,6 +442,10 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): # so must be deleted first. "local_current_membership", "room_memberships", + # Note: the sliding_sync_ tables have foreign keys to the `events` table + # so must be deleted first. + "sliding_sync_joined_rooms", + "sliding_sync_membership_snapshots", "events", "federation_inbound_events_staging", "receipts_graph", @@ -504,5 +496,3 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): # periodically anyway (https://github.com/matrix-org/synapse/issues/5888) self._invalidate_caches_for_room_and_stream(txn, room_id) - - return state_groups diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index bbdde17711..86c87f78bf 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py
@@ -109,6 +109,7 @@ def _load_rules( msc3664_enabled=experimental_config.msc3664_enabled, msc3381_polls_enabled=experimental_config.msc3381_polls_enabled, msc4028_push_encrypted_events=experimental_config.msc4028_push_encrypted_events, + msc4210_enabled=experimental_config.msc4210_enabled, ) return filtered_rules diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 3bde0ae0d4..9964331510 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py
@@ -30,10 +30,12 @@ from typing import ( Mapping, Optional, Sequence, + Set, Tuple, cast, ) +import attr from immutabledict import immutabledict from synapse.api.constants import EduTypes @@ -43,6 +45,7 @@ from synapse.storage.database import ( DatabasePool, LoggingDatabaseConnection, LoggingTransaction, + make_tuple_in_list_sql_clause, ) from synapse.storage.engines._base import IsolationLevel from synapse.storage.util.id_generators import MultiWriterIdGenerator @@ -51,10 +54,12 @@ from synapse.types import ( JsonMapping, MultiWriterStreamToken, PersistedPosition, + StrCollection, ) from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache +from synapse.util.iterutils import batch_iter if TYPE_CHECKING: from synapse.server import HomeServer @@ -62,6 +67,57 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +@attr.s(auto_attribs=True, slots=True, frozen=True) +class ReceiptInRoom: + receipt_type: str + user_id: str + event_id: str + thread_id: Optional[str] + data: JsonMapping + + @staticmethod + def merge_to_content(receipts: Collection["ReceiptInRoom"]) -> JsonMapping: + """Merge the given set of receipts (in a room) into the receipt + content format. + + Returns: + A mapping of the combined receipts: event ID -> receipt type -> user + ID -> receipt data. + """ + # MSC4102: always replace threaded receipts with unthreaded ones if + # there is a clash. This means we will drop some receipts, but MSC4102 + # is designed to drop semantically meaningless receipts, so this is + # okay. Previously, we would drop meaningful data! + # + # We do this by finding the unthreaded receipts, and then filtering out + # matching threaded receipts. + + # Set of (user_id, event_id) + unthreaded_receipts: Set[Tuple[str, str]] = { + (receipt.user_id, receipt.event_id) + for receipt in receipts + if receipt.thread_id is None + } + + # event_id -> receipt_type -> user_id -> receipt data + content: Dict[str, Dict[str, Dict[str, JsonMapping]]] = {} + for receipt in receipts: + data = receipt.data + if receipt.thread_id is not None: + if (receipt.user_id, receipt.event_id) in unthreaded_receipts: + # Ignore threaded receipts if we have an unthreaded one. + continue + + data = dict(data) + data["thread_id"] = receipt.thread_id + + content.setdefault(receipt.event_id, {}).setdefault( + receipt.receipt_type, {} + )[receipt.user_id] = data + + return content + + class ReceiptsWorkerStore(SQLBaseStore): def __init__( self, @@ -398,7 +454,7 @@ class ReceiptsWorkerStore(SQLBaseStore): def f( txn: LoggingTransaction, - ) -> List[Tuple[str, str, str, str, Optional[str], str]]: + ) -> Mapping[str, Sequence[ReceiptInRoom]]: if from_key: sql = """ SELECT stream_id, instance_name, room_id, receipt_type, @@ -428,50 +484,46 @@ class ReceiptsWorkerStore(SQLBaseStore): txn.execute(sql + clause, [to_key.get_max_stream_pos()] + list(args)) - return [ - (room_id, receipt_type, user_id, event_id, thread_id, data) - for stream_id, instance_name, room_id, receipt_type, user_id, event_id, thread_id, data in txn - if MultiWriterStreamToken.is_stream_position_in_range( + results: Dict[str, List[ReceiptInRoom]] = {} + for ( + stream_id, + instance_name, + room_id, + receipt_type, + user_id, + event_id, + thread_id, + data, + ) in txn: + if not MultiWriterStreamToken.is_stream_position_in_range( from_key, to_key, instance_name, stream_id + ): + continue + + results.setdefault(room_id, []).append( + ReceiptInRoom( + receipt_type=receipt_type, + user_id=user_id, + event_id=event_id, + thread_id=thread_id, + data=db_to_json(data), + ) ) - ] + + return results txn_results = await self.db_pool.runInteraction( "_get_linearized_receipts_for_rooms", f ) - results: JsonDict = {} - for room_id, receipt_type, user_id, event_id, thread_id, data in txn_results: - # We want a single event per room, since we want to batch the - # receipts by room, event and type. - room_event = results.setdefault( - room_id, - {"type": EduTypes.RECEIPT, "room_id": room_id, "content": {}}, - ) - - # The content is of the form: - # {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. } - event_entry = room_event["content"].setdefault(event_id, {}) - receipt_type_dict = event_entry.setdefault(receipt_type, {}) - - # MSC4102: always replace threaded receipts with unthreaded ones if there is a clash. - # Specifically: - # - if there is no existing receipt, great, set the data. - # - if there is an existing receipt, is it threaded (thread_id present)? - # YES: replace if this receipt has no thread id. NO: do not replace. - # This means we will drop some receipts, but MSC4102 is designed to drop semantically - # meaningless receipts, so this is okay. Previously, we would drop meaningful data! - receipt_data = db_to_json(data) - if user_id in receipt_type_dict: # existing receipt - # is the existing receipt threaded and we are currently processing an unthreaded one? - if "thread_id" in receipt_type_dict[user_id] and not thread_id: - receipt_type_dict[user_id] = ( - receipt_data # replace with unthreaded one - ) - else: # receipt does not exist, just set it - receipt_type_dict[user_id] = receipt_data - if thread_id: - receipt_type_dict[user_id]["thread_id"] = thread_id + results: JsonDict = { + room_id: { + "room_id": room_id, + "type": EduTypes.RECEIPT, + "content": ReceiptInRoom.merge_to_content(receipts), + } + for room_id, receipts in txn_results.items() + } results = { room_id: [results[room_id]] if room_id in results else [] @@ -479,6 +531,69 @@ class ReceiptsWorkerStore(SQLBaseStore): } return results + async def get_linearized_receipts_for_events( + self, + room_and_event_ids: Collection[Tuple[str, str]], + ) -> Mapping[str, Sequence[ReceiptInRoom]]: + """Get all receipts for the given set of events. + + Arguments: + room_and_event_ids: A collection of 2-tuples of room ID and + event IDs to fetch receipts for + + Returns: + A list of receipts, one per room. + """ + if not room_and_event_ids: + return {} + + def get_linearized_receipts_for_events_txn( + txn: LoggingTransaction, + room_id_event_id_tuples: Collection[Tuple[str, str]], + ) -> List[Tuple[str, str, str, str, Optional[str], str]]: + clause, args = make_tuple_in_list_sql_clause( + self.database_engine, ("room_id", "event_id"), room_id_event_id_tuples + ) + + sql = f""" + SELECT room_id, receipt_type, user_id, event_id, thread_id, data + FROM receipts_linearized + WHERE {clause} + """ + + txn.execute(sql, args) + + return txn.fetchall() + + # room_id -> receipts + room_to_receipts: Dict[str, List[ReceiptInRoom]] = {} + for batch in batch_iter(room_and_event_ids, 1000): + batch_results = await self.db_pool.runInteraction( + "get_linearized_receipts_for_events", + get_linearized_receipts_for_events_txn, + batch, + ) + + for ( + room_id, + receipt_type, + user_id, + event_id, + thread_id, + data, + ) in batch_results: + room_to_receipts.setdefault(room_id, []).append( + ReceiptInRoom( + receipt_type=receipt_type, + user_id=user_id, + event_id=event_id, + thread_id=thread_id, + data=db_to_json(data), + ) + ) + + return room_to_receipts + @cached( num_args=2, ) @@ -550,6 +665,114 @@ class ReceiptsWorkerStore(SQLBaseStore): return results + async def get_linearized_receipts_for_user_in_rooms( + self, user_id: str, room_ids: StrCollection, to_key: MultiWriterStreamToken + ) -> Mapping[str, Sequence[ReceiptInRoom]]: + """Fetch all receipts for the user in the given room. + + Returns: + A dict from room ID to receipts in the room. + """ + + def get_linearized_receipts_for_user_in_rooms_txn( + txn: LoggingTransaction, + batch_room_ids: StrCollection, + ) -> List[Tuple[str, str, str, str, Optional[str], str]]: + clause, args = make_in_list_sql_clause( + self.database_engine, "room_id", batch_room_ids + ) + + sql = f""" + SELECT instance_name, stream_id, room_id, receipt_type, user_id, event_id, thread_id, data + FROM receipts_linearized + WHERE {clause} AND user_id = ? AND stream_id <= ? + """ + + args.append(user_id) + args.append(to_key.get_max_stream_pos()) + + txn.execute(sql, args) + + return [ + (room_id, receipt_type, user_id, event_id, thread_id, data) + for instance_name, stream_id, room_id, receipt_type, user_id, event_id, thread_id, data in txn + if MultiWriterStreamToken.is_stream_position_in_range( + low=None, + high=to_key, + instance_name=instance_name, + pos=stream_id, + ) + ] + + # room_id -> receipts + room_to_receipts: Dict[str, List[ReceiptInRoom]] = {} + for batch in batch_iter(room_ids, 1000): + batch_results = await self.db_pool.runInteraction( + "get_linearized_receipts_for_events", + get_linearized_receipts_for_user_in_rooms_txn, + batch, + ) + + for ( + room_id, + receipt_type, + user_id, + event_id, + thread_id, + data, + ) in batch_results: + room_to_receipts.setdefault(room_id, []).append( + ReceiptInRoom( + receipt_type=receipt_type, + user_id=user_id, + event_id=event_id, + thread_id=thread_id, + data=db_to_json(data), + ) + ) + + return room_to_receipts + + async def get_rooms_with_receipts_between( + self, + room_ids: StrCollection, + from_key: MultiWriterStreamToken, + to_key: MultiWriterStreamToken, + ) -> StrCollection: + """Given a set of room_ids, find out which ones (may) have receipts + between the two tokens (> `from_token` and <= `to_token`).""" + + room_ids = self._receipts_stream_cache.get_entities_changed( + room_ids, from_key.stream + ) + if not room_ids: + return [] + + def f(txn: LoggingTransaction, room_ids: StrCollection) -> StrCollection: + clause, args = make_in_list_sql_clause( + self.database_engine, "room_id", room_ids + ) + + sql = f""" + SELECT DISTINCT room_id FROM receipts_linearized + WHERE {clause} AND ? < stream_id AND stream_id <= ? + """ + args.append(from_key.stream) + args.append(to_key.get_max_stream_pos()) + + txn.execute(sql, args) + + return [room_id for (room_id,) in txn] + + results: List[str] = [] + for batch in batch_iter(room_ids, 1000): + batch_result = await self.db_pool.runInteraction( + "get_rooms_with_receipts_between", f, batch + ) + results.extend(batch_result) + + return results + async def get_users_sent_receipts_between( self, last_id: int, current_id: int ) -> List[str]: @@ -807,9 +1030,7 @@ class ReceiptsWorkerStore(SQLBaseStore): SELECT event_id WHERE room_id = ? AND stream_ordering IN ( SELECT max(stream_ordering) WHERE %s ) - """ % ( - clause, - ) + """ % (clause,) txn.execute(sql, [room_id] + list(args)) rows = txn.fetchall() @@ -954,6 +1175,12 @@ class ReceiptsBackgroundUpdateStore(SQLBaseStore): self.RECEIPTS_GRAPH_UNIQUE_INDEX_UPDATE_NAME, self._background_receipts_graph_unique_index, ) + self.db_pool.updates.register_background_index_update( + update_name="receipts_room_id_event_id_index", + index_name="receipts_linearized_event_id", + table="receipts_linearized", + columns=("room_id", "event_id"), + ) async def _populate_receipt_event_stream_ordering( self, progress: JsonDict, batch_size: int diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index df7f8a43b7..868803e169 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py
@@ -32,7 +32,6 @@ from synapse.api.errors import ( NotFoundError, StoreError, SynapseError, - ThreepidValidationError, ) from synapse.config.homeserver import HomeServerConfig from synapse.metrics.background_process_metrics import wrap_as_background_process @@ -149,30 +148,6 @@ class LoginTokenLookupResult: """The session ID advertised by the SSO Identity Provider.""" -@attr.s(frozen=True, slots=True, auto_attribs=True) -class ThreepidResult: - medium: str - address: str - validated_at: int - added_at: int - - -@attr.s(frozen=True, slots=True, auto_attribs=True) -class ThreepidValidationSession: - address: str - """address of the 3pid""" - medium: str - """medium of the 3pid""" - client_secret: str - """a secret provided by the client for this validation session""" - session_id: str - """ID of the validation session""" - last_send_attempt: int - """a number serving to dedupe send attempts for this session""" - validated_at: Optional[int] - """timestamp of when this session was validated if so""" - - class RegistrationWorkerStore(CacheInvalidationWorkerStore): def __init__( self, @@ -215,12 +190,6 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): self._set_expiration_date_when_missing, ) - # Create a background job for culling expired 3PID validity tokens - if hs.config.worker.run_background_tasks: - self._clock.looping_call( - self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS - ) - @cached() async def get_user_by_id(self, user_id: str) -> Optional[UserInfo]: """Returns info about the user account, if it exists.""" @@ -583,7 +552,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): await self.db_pool.runInteraction("set_shadow_banned", set_shadow_banned_txn) - async def set_user_type(self, user: UserID, user_type: Optional[UserTypes]) -> None: + async def set_user_type( + self, user: UserID, user_type: Optional[Union[UserTypes, str]] + ) -> None: """Sets the user type. Args: @@ -683,7 +654,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): retcol="user_type", allow_none=True, ) - return res is None + return res is None or res not in [UserTypes.BOT, UserTypes.SUPPORT] def is_support_user_txn(self, txn: LoggingTransaction, user_id: str) -> bool: res = self.db_pool.simple_select_one_onecol_txn( @@ -759,17 +730,37 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): external_id: id on that system user_id: complete mxid that it is mapped to """ + self._invalidate_cache_and_stream( + txn, self.get_user_by_external_id, (auth_provider, external_id) + ) - self.db_pool.simple_insert_txn( + # This INSERT ... ON CONFLICT DO NOTHING statement will cause a + # 'could not serialize access due to concurrent update' + # if the row is added concurrently by another transaction. + # This is exactly what we want, as it makes the transaction get retried + # in a new snapshot where we can check for a genuine conflict. + was_inserted = self.db_pool.simple_upsert_txn( txn, table="user_external_ids", - values={ - "auth_provider": auth_provider, - "external_id": external_id, - "user_id": user_id, - }, + keyvalues={"auth_provider": auth_provider, "external_id": external_id}, + values={}, + insertion_values={"user_id": user_id}, ) + if not was_inserted: + existing_id = self.db_pool.simple_select_one_onecol_txn( + txn, + table="user_external_ids", + keyvalues={"auth_provider": auth_provider, "user_id": user_id}, + retcol="external_id", + allow_none=True, + ) + + if existing_id != external_id: + raise ExternalIDReuseException( + f"{user_id!r} has external id {existing_id!r} for {auth_provider} but trying to add {external_id!r}" + ) + async def remove_user_external_id( self, auth_provider: str, external_id: str, user_id: str ) -> None: @@ -789,6 +780,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): }, desc="remove_user_external_id", ) + await self.invalidate_cache_and_stream( + "get_user_by_external_id", (auth_provider, external_id) + ) async def replace_user_external_id( self, @@ -809,29 +803,20 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): ExternalIDReuseException if the new external_id could not be mapped. """ - def _remove_user_external_ids_txn( + def _replace_user_external_id_txn( txn: LoggingTransaction, - user_id: str, ) -> None: - """Remove all mappings from external user ids to a mxid - If these mappings are not found, this method does nothing. - - Args: - user_id: complete mxid that it is mapped to - """ - self.db_pool.simple_delete_txn( txn, table="user_external_ids", keyvalues={"user_id": user_id}, ) - def _replace_user_external_id_txn( - txn: LoggingTransaction, - ) -> None: - _remove_user_external_ids_txn(txn, user_id) - for auth_provider, external_id in record_external_ids: + self._invalidate_cache_and_stream( + txn, self.get_user_by_external_id, (auth_provider, external_id) + ) + self._record_user_external_id_txn( txn, auth_provider, @@ -847,6 +832,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): except self.database_engine.module.IntegrityError: raise ExternalIDReuseException() + @cached() async def get_user_by_external_id( self, auth_provider: str, external_id: str ) -> Optional[str]: @@ -944,10 +930,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): return await self.db_pool.runInteraction("count_users", _count_users) async def count_real_users(self) -> int: - """Counts all users without a special user_type registered on the homeserver.""" + """Counts all users without the bot or support user_types registered on the homeserver.""" def _count_users(txn: LoggingTransaction) -> int: - txn.execute("SELECT COUNT(*) FROM users where user_type is null") + txn.execute( + f"SELECT COUNT(*) FROM users WHERE user_type IS NULL OR user_type NOT IN ('{UserTypes.BOT}', '{UserTypes.SUPPORT}')" + ) row = txn.fetchone() assert row is not None return row[0] @@ -965,161 +953,6 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): return str(next_id) - async def get_user_id_by_threepid(self, medium: str, address: str) -> Optional[str]: - """Returns user id from threepid - - Args: - medium: threepid medium e.g. email - address: threepid address e.g. me@example.com. This must already be - in canonical form. - - Returns: - The user ID or None if no user id/threepid mapping exists - """ - user_id = await self.db_pool.runInteraction( - "get_user_id_by_threepid", self.get_user_id_by_threepid_txn, medium, address - ) - return user_id - - def get_user_id_by_threepid_txn( - self, txn: LoggingTransaction, medium: str, address: str - ) -> Optional[str]: - """Returns user id from threepid - - Args: - txn: - medium: threepid medium e.g. email - address: threepid address e.g. me@example.com - - Returns: - user id, or None if no user id/threepid mapping exists - """ - return self.db_pool.simple_select_one_onecol_txn( - txn, - "user_threepids", - {"medium": medium, "address": address}, - "user_id", - True, - ) - - async def user_add_threepid( - self, - user_id: str, - medium: str, - address: str, - validated_at: int, - added_at: int, - ) -> None: - await self.db_pool.simple_upsert( - "user_threepids", - {"medium": medium, "address": address}, - {"user_id": user_id, "validated_at": validated_at, "added_at": added_at}, - ) - - async def user_get_threepids(self, user_id: str) -> List[ThreepidResult]: - results = cast( - List[Tuple[str, str, int, int]], - await self.db_pool.simple_select_list( - "user_threepids", - keyvalues={"user_id": user_id}, - retcols=["medium", "address", "validated_at", "added_at"], - desc="user_get_threepids", - ), - ) - return [ - ThreepidResult( - medium=r[0], - address=r[1], - validated_at=r[2], - added_at=r[3], - ) - for r in results - ] - - async def user_delete_threepid( - self, user_id: str, medium: str, address: str - ) -> None: - await self.db_pool.simple_delete( - "user_threepids", - keyvalues={"user_id": user_id, "medium": medium, "address": address}, - desc="user_delete_threepid", - ) - - async def add_user_bound_threepid( - self, user_id: str, medium: str, address: str, id_server: str - ) -> None: - """The server proxied a bind request to the given identity server on - behalf of the given user. We need to remember this in case the user - asks us to unbind the threepid. - - Args: - user_id - medium - address - id_server - """ - # We need to use an upsert, in case they user had already bound the - # threepid - await self.db_pool.simple_upsert( - table="user_threepid_id_server", - keyvalues={ - "user_id": user_id, - "medium": medium, - "address": address, - "id_server": id_server, - }, - values={}, - insertion_values={}, - desc="add_user_bound_threepid", - ) - - async def user_get_bound_threepids(self, user_id: str) -> List[Tuple[str, str]]: - """Get the threepids that a user has bound to an identity server through the homeserver - The homeserver remembers where binds to an identity server occurred. Using this - method can retrieve those threepids. - - Args: - user_id: The ID of the user to retrieve threepids for - - Returns: - List of tuples of two strings: - medium: The medium of the threepid (e.g "email") - address: The address of the threepid (e.g "bob@example.com") - """ - return cast( - List[Tuple[str, str]], - await self.db_pool.simple_select_list( - table="user_threepid_id_server", - keyvalues={"user_id": user_id}, - retcols=["medium", "address"], - desc="user_get_bound_threepids", - ), - ) - - async def remove_user_bound_threepid( - self, user_id: str, medium: str, address: str, id_server: str - ) -> None: - """The server proxied an unbind request to the given identity server on - behalf of the given user, so we remove the mapping of threepid to - identity server. - - Args: - user_id - medium - address - id_server - """ - await self.db_pool.simple_delete( - table="user_threepid_id_server", - keyvalues={ - "user_id": user_id, - "medium": medium, - "address": address, - "id_server": id_server, - }, - desc="remove_user_bound_threepid", - ) - async def get_id_servers_user_bound( self, user_id: str, medium: str, address: str ) -> List[str]: @@ -1204,123 +1037,6 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): return bool(res) - async def get_threepid_validation_session( - self, - medium: Optional[str], - client_secret: str, - address: Optional[str] = None, - sid: Optional[str] = None, - validated: Optional[bool] = True, - ) -> Optional[ThreepidValidationSession]: - """Gets a session_id and last_send_attempt (if available) for a - combination of validation metadata - - Args: - medium: The medium of the 3PID - client_secret: A unique string provided by the client to help identify this - validation attempt - address: The address of the 3PID - sid: The ID of the validation session - validated: Whether sessions should be filtered by - whether they have been validated already or not. None to - perform no filtering - - Returns: - A ThreepidValidationSession or None if a validation session is not found - """ - if not client_secret: - raise SynapseError( - 400, "Missing parameter: client_secret", errcode=Codes.MISSING_PARAM - ) - - keyvalues = {"client_secret": client_secret} - if medium: - keyvalues["medium"] = medium - if address: - keyvalues["address"] = address - if sid: - keyvalues["session_id"] = sid - - assert address or sid - - def get_threepid_validation_session_txn( - txn: LoggingTransaction, - ) -> Optional[ThreepidValidationSession]: - sql = """ - SELECT address, session_id, medium, client_secret, - last_send_attempt, validated_at - FROM threepid_validation_session WHERE %s - """ % ( - " AND ".join("%s = ?" % k for k in keyvalues.keys()), - ) - - if validated is not None: - sql += " AND validated_at IS " + ("NOT NULL" if validated else "NULL") - - sql += " LIMIT 1" - - txn.execute(sql, list(keyvalues.values())) - row = txn.fetchone() - if not row: - return None - - return ThreepidValidationSession( - address=row[0], - session_id=row[1], - medium=row[2], - client_secret=row[3], - last_send_attempt=row[4], - validated_at=row[5], - ) - - return await self.db_pool.runInteraction( - "get_threepid_validation_session", get_threepid_validation_session_txn - ) - - async def delete_threepid_session(self, session_id: str) -> None: - """Removes a threepid validation session from the database. This can - be done after validation has been performed and whatever action was - waiting on it has been carried out - - Args: - session_id: The ID of the session to delete - """ - - def delete_threepid_session_txn(txn: LoggingTransaction) -> None: - self.db_pool.simple_delete_txn( - txn, - table="threepid_validation_token", - keyvalues={"session_id": session_id}, - ) - self.db_pool.simple_delete_txn( - txn, - table="threepid_validation_session", - keyvalues={"session_id": session_id}, - ) - - await self.db_pool.runInteraction( - "delete_threepid_session", delete_threepid_session_txn - ) - - @wrap_as_background_process("cull_expired_threepid_validation_tokens") - async def cull_expired_threepid_validation_tokens(self) -> None: - """Remove threepid validation tokens with expiry dates that have passed""" - - def cull_expired_threepid_validation_tokens_txn( - txn: LoggingTransaction, ts: int - ) -> None: - sql = """ - DELETE FROM threepid_validation_token WHERE - expires < ? - """ - txn.execute(sql, (ts,)) - - await self.db_pool.runInteraction( - "cull_expired_threepid_validation_tokens", - cull_expired_threepid_validation_tokens_txn, - self._clock.time_msec(), - ) - @wrap_as_background_process("account_validity_set_expiration_dates") async def _set_expiration_date_when_missing(self) -> None: """ @@ -1512,15 +1228,14 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): # Override type because the return type is only optional if # allow_none is True, and we don't want mypy throwing errors # about None not being indexable. - pending, completed = cast( - Tuple[int, int], - self.db_pool.simple_select_one_txn( - txn, - "registration_tokens", - keyvalues={"token": token}, - retcols=["pending", "completed"], - ), + row = self.db_pool.simple_select_one_txn( + txn, + "registration_tokens", + keyvalues={"token": token}, + retcols=("pending", "completed"), ) + pending = int(row[0]) + completed = int(row[1]) # Decrement pending and increment completed self.db_pool.simple_update_one_txn( @@ -2093,6 +1808,136 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): func=is_user_approved_txn, ) + async def set_user_deactivated_status( + self, user_id: str, deactivated: bool + ) -> None: + """Set the `deactivated` property for the provided user to the provided value. + + Args: + user_id: The ID of the user to set the status for. + deactivated: The value to set for `deactivated`. + """ + + await self.db_pool.runInteraction( + "set_user_deactivated_status", + self.set_user_deactivated_status_txn, + user_id, + deactivated, + ) + + def set_user_deactivated_status_txn( + self, txn: LoggingTransaction, user_id: str, deactivated: bool + ) -> None: + self.db_pool.simple_update_one_txn( + txn=txn, + table="users", + keyvalues={"name": user_id}, + updatevalues={"deactivated": 1 if deactivated else 0}, + ) + self._invalidate_cache_and_stream( + txn, self.get_user_deactivated_status, (user_id,) + ) + self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) + self._invalidate_cache_and_stream(txn, self.is_guest, (user_id,)) + + async def set_user_suspended_status(self, user_id: str, suspended: bool) -> None: + """ + Set whether the user's account is suspended in the `users` table. + + Args: + user_id: The user ID of the user in question + suspended: True if the user is suspended, false if not + """ + await self.db_pool.runInteraction( + "set_user_suspended_status", + self.set_user_suspended_status_txn, + user_id, + suspended, + ) + + def set_user_suspended_status_txn( + self, txn: LoggingTransaction, user_id: str, suspended: bool + ) -> None: + self.db_pool.simple_update_one_txn( + txn=txn, + table="users", + keyvalues={"name": user_id}, + updatevalues={"suspended": suspended}, + ) + self._invalidate_cache_and_stream( + txn, self.get_user_suspended_status, (user_id,) + ) + self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) + + async def set_user_locked_status(self, user_id: str, locked: bool) -> None: + """Set the `locked` property for the provided user to the provided value. + + Args: + user_id: The ID of the user to set the status for. + locked: The value to set for `locked`. + """ + + await self.db_pool.runInteraction( + "set_user_locked_status", + self.set_user_locked_status_txn, + user_id, + locked, + ) + + def set_user_locked_status_txn( + self, txn: LoggingTransaction, user_id: str, locked: bool + ) -> None: + self.db_pool.simple_update_one_txn( + txn=txn, + table="users", + keyvalues={"name": user_id}, + updatevalues={"locked": locked}, + ) + self._invalidate_cache_and_stream(txn, self.get_user_locked_status, (user_id,)) + self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) + + async def update_user_approval_status( + self, user_id: UserID, approved: bool + ) -> None: + """Set the user's 'approved' flag to the given value. + + The boolean will be turned into an int (in update_user_approval_status_txn) + because the column is a smallint. + + Args: + user_id: the user to update the flag for. + approved: the value to set the flag to. + """ + await self.db_pool.runInteraction( + "update_user_approval_status", + self.update_user_approval_status_txn, + user_id.to_string(), + approved, + ) + + def update_user_approval_status_txn( + self, txn: LoggingTransaction, user_id: str, approved: bool + ) -> None: + """Set the user's 'approved' flag to the given value. + + The boolean is turned into an int because the column is a smallint. + + Args: + txn: the current database transaction. + user_id: the user to update the flag for. + approved: the value to set the flag to. + """ + self.db_pool.simple_update_one_txn( + txn=txn, + table="users", + keyvalues={"name": user_id}, + updatevalues={"approved": approved}, + ) + + # Invalidate the caches of methods that read the value of the 'approved' flag. + self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) + self._invalidate_cache_and_stream(txn, self.is_user_approved, (user_id,)) + class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): def __init__( @@ -2205,117 +2050,6 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): return nb_processed - async def set_user_deactivated_status( - self, user_id: str, deactivated: bool - ) -> None: - """Set the `deactivated` property for the provided user to the provided value. - - Args: - user_id: The ID of the user to set the status for. - deactivated: The value to set for `deactivated`. - """ - - await self.db_pool.runInteraction( - "set_user_deactivated_status", - self.set_user_deactivated_status_txn, - user_id, - deactivated, - ) - - def set_user_deactivated_status_txn( - self, txn: LoggingTransaction, user_id: str, deactivated: bool - ) -> None: - self.db_pool.simple_update_one_txn( - txn=txn, - table="users", - keyvalues={"name": user_id}, - updatevalues={"deactivated": 1 if deactivated else 0}, - ) - self._invalidate_cache_and_stream( - txn, self.get_user_deactivated_status, (user_id,) - ) - self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - txn.call_after(self.is_guest.invalidate, (user_id,)) - - async def set_user_suspended_status(self, user_id: str, suspended: bool) -> None: - """ - Set whether the user's account is suspended in the `users` table. - - Args: - user_id: The user ID of the user in question - suspended: True if the user is suspended, false if not - """ - await self.db_pool.runInteraction( - "set_user_suspended_status", - self.set_user_suspended_status_txn, - user_id, - suspended, - ) - - def set_user_suspended_status_txn( - self, txn: LoggingTransaction, user_id: str, suspended: bool - ) -> None: - self.db_pool.simple_update_one_txn( - txn=txn, - table="users", - keyvalues={"name": user_id}, - updatevalues={"suspended": suspended}, - ) - self._invalidate_cache_and_stream( - txn, self.get_user_suspended_status, (user_id,) - ) - self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - - async def set_user_locked_status(self, user_id: str, locked: bool) -> None: - """Set the `locked` property for the provided user to the provided value. - - Args: - user_id: The ID of the user to set the status for. - locked: The value to set for `locked`. - """ - - await self.db_pool.runInteraction( - "set_user_locked_status", - self.set_user_locked_status_txn, - user_id, - locked, - ) - - def set_user_locked_status_txn( - self, txn: LoggingTransaction, user_id: str, locked: bool - ) -> None: - self.db_pool.simple_update_one_txn( - txn=txn, - table="users", - keyvalues={"name": user_id}, - updatevalues={"locked": locked}, - ) - self._invalidate_cache_and_stream(txn, self.get_user_locked_status, (user_id,)) - self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - - def update_user_approval_status_txn( - self, txn: LoggingTransaction, user_id: str, approved: bool - ) -> None: - """Set the user's 'approved' flag to the given value. - - The boolean is turned into an int because the column is a smallint. - - Args: - txn: the current database transaction. - user_id: the user to update the flag for. - approved: the value to set the flag to. - """ - self.db_pool.simple_update_one_txn( - txn=txn, - table="users", - keyvalues={"name": user_id}, - updatevalues={"approved": approved}, - ) - - # Invalidate the caches of methods that read the value of the 'approved' flag. - self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - self._invalidate_cache_and_stream(txn, self.is_user_approved, (user_id,)) - class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): def __init__( @@ -2326,9 +2060,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): ): super().__init__(database, db_conn, hs) - self._ignore_unknown_session_error = ( - hs.config.server.request_token_inhibit_3pid_errors - ) + self._ignore_unknown_session_error = False # Used to use whether 3pid errors were suppressed or not... Problem? self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id") @@ -2514,7 +2246,8 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): the user, setting their displayname to the given value admin: is an admin user? user_type: type of user. One of the values from api.constants.UserTypes, - or None for a normal user. + a custom value set in the configuration file, or None for a normal + user. shadow_banned: Whether the user is shadow-banned, i.e. they may be told their requests succeeded but we ignore them. approved: Whether to consider the user has already been approved by an @@ -2796,96 +2529,6 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): desc="add_user_pending_deactivation", ) - async def validate_threepid_session( - self, session_id: str, client_secret: str, token: str, current_ts: int - ) -> Optional[str]: - """Attempt to validate a threepid session using a token - - Args: - session_id: The id of a validation session - client_secret: A unique string provided by the client to help identify - this validation attempt - token: A validation token - current_ts: The current unix time in milliseconds. Used for checking - token expiry status - - Raises: - ThreepidValidationError: if a matching validation token was not found or has - expired - - Returns: - A str representing a link to redirect the user to if there is one. - """ - - # Insert everything into a transaction in order to run atomically - def validate_threepid_session_txn(txn: LoggingTransaction) -> Optional[str]: - row = self.db_pool.simple_select_one_txn( - txn, - table="threepid_validation_session", - keyvalues={"session_id": session_id}, - retcols=["client_secret", "validated_at"], - allow_none=True, - ) - - if not row: - if self._ignore_unknown_session_error: - # If we need to inhibit the error caused by an incorrect session ID, - # use None as placeholder values for the client secret and the - # validation timestamp. - # It shouldn't be an issue because they're both only checked after - # the token check, which should fail. And if it doesn't for some - # reason, the next check is on the client secret, which is NOT NULL, - # so we don't have to worry about the client secret matching by - # accident. - row = None, None - else: - raise ThreepidValidationError("Unknown session_id") - - retrieved_client_secret, validated_at = row - - row = self.db_pool.simple_select_one_txn( - txn, - table="threepid_validation_token", - keyvalues={"session_id": session_id, "token": token}, - retcols=["expires", "next_link"], - allow_none=True, - ) - - if not row: - raise ThreepidValidationError( - "Validation token not found or has expired" - ) - expires, next_link = row - - if retrieved_client_secret != client_secret: - raise ThreepidValidationError( - "This client_secret does not match the provided session_id" - ) - - # If the session is already validated, no need to revalidate - if validated_at: - return next_link - - if expires <= current_ts: - raise ThreepidValidationError( - "This token has expired. Please request a new one" - ) - - # Looks good. Validate the session - self.db_pool.simple_update_txn( - txn, - table="threepid_validation_session", - keyvalues={"session_id": session_id}, - updatevalues={"validated_at": self._clock.time_msec()}, - ) - - return next_link - - # Return next_link if it exists - return await self.db_pool.runInteraction( - "validate_threepid_session_txn", validate_threepid_session_txn - ) - async def start_or_continue_validation_session( self, medium: str, @@ -2944,25 +2587,6 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): start_or_continue_validation_session_txn, ) - async def update_user_approval_status( - self, user_id: UserID, approved: bool - ) -> None: - """Set the user's 'approved' flag to the given value. - - The boolean will be turned into an int (in update_user_approval_status_txn) - because the column is a smallint. - - Args: - user_id: the user to update the flag for. - approved: the value to set the flag to. - """ - await self.db_pool.runInteraction( - "update_user_approval_status", - self.update_user_approval_status_txn, - user_id.to_string(), - approved, - ) - @wrap_as_background_process("delete_expired_login_tokens") async def _delete_expired_login_tokens(self) -> None: """Remove login tokens with expiry dates that have passed.""" diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 80a4bf95f2..347dbbba6b 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py
@@ -51,11 +51,15 @@ from synapse.api.room_versions import RoomVersion, RoomVersions from synapse.config.homeserver import HomeServerConfig from synapse.events import EventBase from synapse.replication.tcp.streams.partial_state import UnPartialStatedRoomStream -from synapse.storage._base import db_to_json, make_in_list_sql_clause +from synapse.storage._base import ( + db_to_json, + make_in_list_sql_clause, +) from synapse.storage.database import ( DatabasePool, LoggingDatabaseConnection, LoggingTransaction, + make_tuple_in_list_sql_clause, ) from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.types import Cursor @@ -73,6 +77,8 @@ logger = logging.getLogger(__name__) @attr.s(slots=True, frozen=True, auto_attribs=True) class RatelimitOverride: + # n.b. elsewhere in Synapse messages_per_second is represented as a float, but it is + # an integer in the database messages_per_second: int burst_count: int @@ -604,6 +610,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): search_term: Optional[str], public_rooms: Optional[bool], empty_rooms: Optional[bool], + emma_include_tombstone: bool = False, ) -> Tuple[List[Dict[str, Any]], int]: """Function to retrieve a paginated list of rooms as json. @@ -623,6 +630,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): If true, empty rooms are queried. if false, empty rooms are excluded from the query. When it is none (the default), both empty rooms and none-empty rooms are queried. + emma_include_tombstone: If true, include tombstone events in the results. Returns: A list of room dicts and an integer representing the total number of rooms that exist given this query @@ -791,11 +799,43 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): room_count = cast(Tuple[int], txn.fetchone()) return rooms, room_count[0] - return await self.db_pool.runInteraction( + result = await self.db_pool.runInteraction( "get_rooms_paginate", _get_rooms_paginate_txn, ) + if emma_include_tombstone: + room_id_sql, room_id_args = make_in_list_sql_clause( + self.database_engine, "cse.room_id", [r["room_id"] for r in result[0]] + ) + + tombstone_sql = """ + SELECT cse.room_id, cse.event_id, ej.json + FROM current_state_events cse + JOIN event_json ej USING (event_id) + WHERE cse.type = 'm.room.tombstone' + AND {room_id_sql} + """.format( + room_id_sql=room_id_sql + ) + + def _get_tombstones_txn( + txn: LoggingTransaction, + ) -> Dict[str, JsonDict]: + txn.execute(tombstone_sql, room_id_args) + for room_id, event_id, json in txn: + for result_room in result[0]: + if result_room["room_id"] == room_id: + result_room["gay.rory.synapse_admin_extensions.tombstone"] = db_to_json(json) + break + return result[0], result[1] + + result = await self.db_pool.runInteraction( + "get_rooms_tombstones", _get_tombstones_txn, + ) + + return result + @cached(max_entries=10000) async def get_ratelimit_for_user(self, user_id: str) -> Optional[RatelimitOverride]: """Check if there are any overrides for ratelimiting for the given user @@ -1127,6 +1167,109 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): return local_media_ids + def _quarantine_local_media_txn( + self, + txn: LoggingTransaction, + hashes: Set[str], + media_ids: Set[str], + quarantined_by: Optional[str], + ) -> int: + """Quarantine and unquarantine local media items. + + Args: + txn (cursor) + hashes: A set of sha256 hashes for any media that should be quarantined + media_ids: A set of media IDs for any media that should be quarantined + quarantined_by: The ID of the user who initiated the quarantine request + If it is `None` media will be removed from quarantine + Returns: + The total number of media items quarantined + """ + total_media_quarantined = 0 + + # Effectively a legacy path, update any media that was explicitly named. + if media_ids: + sql_many_clause_sql, sql_many_clause_args = make_in_list_sql_clause( + txn.database_engine, "media_id", media_ids + ) + sql = f""" + UPDATE local_media_repository + SET quarantined_by = ? + WHERE {sql_many_clause_sql}""" + + if quarantined_by is not None: + sql += " AND safe_from_quarantine = FALSE" + + txn.execute(sql, [quarantined_by] + sql_many_clause_args) + # Note that a rowcount of -1 can be used to indicate no rows were affected. + total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0 + + # Update any media that was identified via hash. + if hashes: + sql_many_clause_sql, sql_many_clause_args = make_in_list_sql_clause( + txn.database_engine, "sha256", hashes + ) + sql = f""" + UPDATE local_media_repository + SET quarantined_by = ? + WHERE {sql_many_clause_sql}""" + + if quarantined_by is not None: + sql += " AND safe_from_quarantine = FALSE" + + txn.execute(sql, [quarantined_by] + sql_many_clause_args) + total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0 + + return total_media_quarantined + + def _quarantine_remote_media_txn( + self, + txn: LoggingTransaction, + hashes: Set[str], + media: Set[Tuple[str, str]], + quarantined_by: Optional[str], + ) -> int: + """Quarantine and unquarantine remote items + + Args: + txn (cursor) + hashes: A set of sha256 hashes for any media that should be quarantined + media_ids: A set of tuples (media_origin, media_id) for any media that should be quarantined + quarantined_by: The ID of the user who initiated the quarantine request + If it is `None` media will be removed from quarantine + Returns: + The total number of media items quarantined + """ + total_media_quarantined = 0 + + if media: + sql_in_list_clause, sql_args = make_tuple_in_list_sql_clause( + txn.database_engine, + ("media_origin", "media_id"), + media, + ) + sql = f""" + UPDATE remote_media_cache + SET quarantined_by = ? + WHERE {sql_in_list_clause}""" + + txn.execute(sql, [quarantined_by] + sql_args) + total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0 + + total_media_quarantined = 0 + if hashes: + sql_many_clause_sql, sql_many_clause_args = make_in_list_sql_clause( + txn.database_engine, "sha256", hashes + ) + sql = f""" + UPDATE remote_media_cache + SET quarantined_by = ? + WHERE {sql_many_clause_sql}""" + txn.execute(sql, [quarantined_by] + sql_many_clause_args) + total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0 + + return total_media_quarantined + def _quarantine_media_txn( self, txn: LoggingTransaction, @@ -1146,40 +1289,93 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): Returns: The total number of media items quarantined """ - - # Update all the tables to set the quarantined_by flag - sql = """ - UPDATE local_media_repository - SET quarantined_by = ? - WHERE media_id = ? - """ - - # set quarantine - if quarantined_by is not None: - sql += "AND safe_from_quarantine = FALSE" - txn.executemany( - sql, [(quarantined_by, media_id) for media_id in local_mxcs] + hashes = set() + media_ids = set() + remote_media = set() + + # First, determine the hashes of the media we want to delete. + # We also want the media_ids for any media that lacks a hash. + if local_mxcs: + hash_sql_many_clause_sql, hash_sql_many_clause_args = ( + make_in_list_sql_clause(txn.database_engine, "media_id", local_mxcs) ) - # remove from quarantine - else: - txn.executemany( - sql, [(quarantined_by, media_id) for media_id in local_mxcs] + hash_sql = f"SELECT sha256, media_id FROM local_media_repository WHERE {hash_sql_many_clause_sql}" + if quarantined_by is not None: + hash_sql += " AND safe_from_quarantine = FALSE" + + txn.execute(hash_sql, hash_sql_many_clause_args) + for sha256, media_id in txn: + if sha256: + hashes.add(sha256) + else: + media_ids.add(media_id) + + # Do the same for remote media + if remote_mxcs: + hash_sql_in_list_clause, hash_sql_args = make_tuple_in_list_sql_clause( + txn.database_engine, + ("media_origin", "media_id"), + remote_mxcs, ) - # Note that a rowcount of -1 can be used to indicate no rows were affected. - total_media_quarantined = txn.rowcount if txn.rowcount > 0 else 0 + hash_sql = f"SELECT sha256, media_origin, media_id FROM remote_media_cache WHERE {hash_sql_in_list_clause}" + txn.execute(hash_sql, hash_sql_args) + for sha256, media_origin, media_id in txn: + if sha256: + hashes.add(sha256) + else: + remote_media.add((media_origin, media_id)) - txn.executemany( - """ - UPDATE remote_media_cache - SET quarantined_by = ? - WHERE media_origin = ? AND media_id = ? - """, - ((quarantined_by, origin, media_id) for origin, media_id in remote_mxcs), + count = self._quarantine_local_media_txn(txn, hashes, media_ids, quarantined_by) + count += self._quarantine_remote_media_txn( + txn, hashes, remote_media, quarantined_by ) - total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0 - return total_media_quarantined + return count + + async def block_room(self, room_id: str, user_id: str) -> None: + """Marks the room as blocked. + + Can be called multiple times (though we'll only track the last user to + block this room). + + Can be called on a room unknown to this homeserver. + + Args: + room_id: Room to block + user_id: Who blocked it + """ + await self.db_pool.simple_upsert( + table="blocked_rooms", + keyvalues={"room_id": room_id}, + values={}, + insertion_values={"user_id": user_id}, + desc="block_room", + ) + await self.db_pool.runInteraction( + "block_room_invalidation", + self._invalidate_cache_and_stream, + self.is_room_blocked, + (room_id,), + ) + + async def unblock_room(self, room_id: str) -> None: + """Remove the room from blocking list. + + Args: + room_id: Room to unblock + """ + await self.db_pool.simple_delete( + table="blocked_rooms", + keyvalues={"room_id": room_id}, + desc="unblock_room", + ) + await self.db_pool.runInteraction( + "block_room_invalidation", + self._invalidate_cache_and_stream, + self.is_room_blocked, + (room_id,), + ) async def get_rooms_for_retention_period_in_range( self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False @@ -1382,6 +1578,30 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): partial_state_rooms = {row[0] for row in rows} return {room_id: room_id in partial_state_rooms for room_id in room_ids} + @cached(max_entries=10000, iterable=True) + async def get_partial_rooms(self) -> AbstractSet[str]: + """Get any "partial-state" rooms which the user is in. + + This is fast as the set of partially stated rooms at any point across + the whole server is small, and so such a query is fast. This is also + faster than looking up whether a set of room ID's are partially stated + via `is_partial_state_room_batched(...)` because of the sheer amount of + CPU time looking all the rooms up in the cache. + """ + + def _get_partial_rooms_for_user_txn( + txn: LoggingTransaction, + ) -> AbstractSet[str]: + sql = """ + SELECT room_id FROM partial_state_rooms + """ + txn.execute(sql) + return {room_id for (room_id,) in txn} + + return await self.db_pool.runInteraction( + "get_partial_rooms_for_user", _get_partial_rooms_for_user_txn + ) + async def get_join_event_id_and_device_lists_stream_id_for_partial_state( self, room_id: str ) -> Tuple[str, int]: @@ -1562,6 +1782,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): direction: Direction = Direction.BACKWARDS, user_id: Optional[str] = None, room_id: Optional[str] = None, + event_sender_user_id: Optional[str] = None, ) -> Tuple[List[Dict[str, Any]], int]: """Retrieve a paginated list of event reports @@ -1572,6 +1793,8 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): oldest first (forwards) user_id: search for user_id. Ignored if user_id is None room_id: search for room_id. Ignored if room_id is None + event_sender_user_id: search for the sender of the reported event. Ignored if + event_sender_user_id is None Returns: Tuple of: json list of event reports @@ -1591,6 +1814,10 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): filters.append("er.room_id LIKE ?") args.extend(["%" + room_id + "%"]) + if event_sender_user_id: + filters.append("events.sender = ?") + args.extend([event_sender_user_id]) + if direction == Direction.BACKWARDS: order = "DESC" else: @@ -1606,11 +1833,10 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): sql = """ SELECT COUNT(*) as total_event_reports FROM event_reports AS er + LEFT JOIN events USING(event_id) JOIN room_stats_state ON room_stats_state.room_id = er.room_id {} - """.format( - where_clause - ) + """.format(where_clause) txn.execute(sql, args) count = cast(Tuple[int], txn.fetchone())[0] @@ -1626,8 +1852,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): room_stats_state.canonical_alias, room_stats_state.name FROM event_reports AS er - LEFT JOIN events - ON events.event_id = er.event_id + LEFT JOIN events USING(event_id) JOIN room_stats_state ON room_stats_state.room_id = er.room_id {where_clause} @@ -2343,6 +2568,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): self._invalidate_cache_and_stream( txn, self._get_partial_state_servers_at_join, (room_id,) ) + self._invalidate_all_cache_and_stream(txn, self.get_partial_rooms) async def write_partial_state_rooms_join_event_id( self, @@ -2470,50 +2696,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): ) return next_id - async def block_room(self, room_id: str, user_id: str) -> None: - """Marks the room as blocked. - - Can be called multiple times (though we'll only track the last user to - block this room). - - Can be called on a room unknown to this homeserver. - - Args: - room_id: Room to block - user_id: Who blocked it - """ - await self.db_pool.simple_upsert( - table="blocked_rooms", - keyvalues={"room_id": room_id}, - values={}, - insertion_values={"user_id": user_id}, - desc="block_room", - ) - await self.db_pool.runInteraction( - "block_room_invalidation", - self._invalidate_cache_and_stream, - self.is_room_blocked, - (room_id,), - ) - - async def unblock_room(self, room_id: str) -> None: - """Remove the room from blocking list. - - Args: - room_id: Room to unblock - """ - await self.db_pool.simple_delete( - table="blocked_rooms", - keyvalues={"room_id": room_id}, - desc="unblock_room", - ) - await self.db_pool.runInteraction( - "block_room_invalidation", - self._invalidate_cache_and_stream, - self.is_room_blocked, - (room_id,), - ) - async def clear_partial_state_room(self, room_id: str) -> Optional[int]: """Clears the partial state flag for a room. @@ -2527,7 +2709,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): still contains events with partial state. """ try: - async with self._un_partial_stated_rooms_stream_id_gen.get_next() as un_partial_state_room_stream_id: + async with ( + self._un_partial_stated_rooms_stream_id_gen.get_next() as un_partial_state_room_stream_id + ): await self.db_pool.runInteraction( "clear_partial_state_room", self._clear_partial_state_room_txn, @@ -2564,6 +2748,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): self._invalidate_cache_and_stream( txn, self._get_partial_state_servers_at_join, (room_id,) ) + self._invalidate_all_cache_and_stream(txn, self.get_partial_rooms) DatabasePool.simple_insert_txn( txn, diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 1d9f0f52e1..7ca73abb83 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py
@@ -19,6 +19,7 @@ # # import logging +from http import HTTPStatus from typing import ( TYPE_CHECKING, AbstractSet, @@ -39,6 +40,8 @@ from typing import ( import attr from synapse.api.constants import EventTypes, Membership +from synapse.api.errors import Codes, SynapseError +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.logging.opentracing import trace from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import wrap_as_background_process @@ -50,13 +53,20 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.storage.databases.main.stream import _filter_results_by_stream from synapse.storage.engines import Sqlite3Engine -from synapse.storage.roommember import MemberSummary, ProfileInfo, RoomsForUser +from synapse.storage.roommember import ( + MemberSummary, + ProfileInfo, + RoomsForUser, + RoomsForUserSlidingSync, +) from synapse.types import ( JsonDict, PersistedEventPosition, StateMap, StrCollection, + StreamToken, get_domain_from_id, ) from synapse.util.caches.descriptors import _CacheContext, cached, cachedList @@ -71,6 +81,7 @@ logger = logging.getLogger(__name__) _MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update" _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership" +_POPULATE_PARTICIPANT_BG_UPDATE_BATCH_SIZE = 1000 @attr.s(frozen=True, slots=True, auto_attribs=True) @@ -225,9 +236,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): AND m.room_id = c.room_id AND m.user_id = c.state_key WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ? AND %s - """ % ( - clause, - ) + """ % (clause,) txn.execute(sql, (room_id, Membership.JOIN, *ids)) return {r[0]: ProfileInfo(display_name=r[1], avatar_url=r[2]) for r in txn} @@ -306,18 +315,10 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): # We do this all in one transaction to keep the cache small. # FIXME: get rid of this when we have room_stats - # Note, rejected events will have a null membership field, so - # we we manually filter them out. - sql = """ - SELECT count(*), membership FROM current_state_events - WHERE type = 'm.room.member' AND room_id = ? - AND membership IS NOT NULL - GROUP BY membership - """ + counts = self._get_member_counts_txn(txn, room_id) - txn.execute(sql, (room_id,)) res: Dict[str, MemberSummary] = {} - for count, membership in txn: + for membership, count in counts.items(): res.setdefault(membership, MemberSummary([], count)) # Order by membership (joins -> invites -> leave (former insiders) -> @@ -364,6 +365,31 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): ) @cached() + async def get_member_counts(self, room_id: str) -> Mapping[str, int]: + """Get a mapping of number of users by membership""" + + return await self.db_pool.runInteraction( + "get_member_counts", self._get_member_counts_txn, room_id + ) + + def _get_member_counts_txn( + self, txn: LoggingTransaction, room_id: str + ) -> Dict[str, int]: + """Get a mapping of number of users by membership""" + + # Note, rejected events will have a null membership field, so + # we we manually filter them out. + sql = """ + SELECT count(*), membership FROM current_state_events + WHERE type = 'm.room.member' AND room_id = ? + AND membership IS NOT NULL + GROUP BY membership + """ + + txn.execute(sql, (room_id,)) + return {membership: count for count, membership in txn} + + @cached() async def get_number_joined_users_in_room(self, room_id: str) -> int: return await self.db_pool.simple_select_one_onecol( table="current_state_events", @@ -524,9 +550,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): WHERE user_id = ? AND %s - """ % ( - clause, - ) + """ % (clause,) txn.execute(sql, (user_id, *args)) results = [ @@ -631,10 +655,8 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): """ # Paranoia check. if not self.hs.is_mine_id(user_id): - raise Exception( - "Cannot call 'get_local_current_membership_for_user_in_room' on " - "non-local user %s" % (user_id,), - ) + message = f"Provided user_id {user_id} is a non-local user" + raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.BAD_JSON) results = cast( Optional[Tuple[str, str]], @@ -692,6 +714,27 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): return {row[0] for row in txn} + async def get_rooms_user_currently_banned_from( + self, user_id: str + ) -> FrozenSet[str]: + """Returns a set of room_ids the user is currently banned from. + + If a remote user only returns rooms this server is currently + participating in. + """ + room_ids = await self.db_pool.simple_select_onecol( + table="current_state_events", + keyvalues={ + "type": EventTypes.Member, + "membership": Membership.BAN, + "state_key": user_id, + }, + retcol="room_id", + desc="get_rooms_user_currently_banned_from", + ) + + return frozenset(room_ids) + @cached(max_entries=500000, iterable=True) async def get_rooms_for_user(self, user_id: str) -> FrozenSet[str]: """Returns a set of room_ids the user is currently joined to. @@ -808,7 +851,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): """ txn.execute(sql, (user_id, *args)) - return {u: True for u, in txn} + return {u: True for (u,) in txn} to_return = {} for batch_user_ids in batch_iter(other_user_ids, 1000): @@ -828,6 +871,73 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): return {u for u, share_room in user_dict.items() if share_room} + @cached(max_entries=10000) + async def does_pair_of_users_share_a_room_joined_or_invited( + self, user_id: str, other_user_id: str + ) -> bool: + raise NotImplementedError() + + @cachedList( + cached_method_name="does_pair_of_users_share_a_room_joined_or_invited", + list_name="other_user_ids", + ) + async def _do_users_share_a_room_joined_or_invited( + self, user_id: str, other_user_ids: Collection[str] + ) -> Mapping[str, Optional[bool]]: + """Return mapping from user ID to whether they share a room with the + given user via being either joined or invited. + + Note: `None` and `False` are equivalent and mean they don't share a + room. + """ + + def do_users_share_a_room_joined_or_invited_txn( + txn: LoggingTransaction, user_ids: Collection[str] + ) -> Dict[str, bool]: + clause, args = make_in_list_sql_clause( + self.database_engine, "state_key", user_ids + ) + + # This query works by fetching both the list of rooms for the target + # user and the set of other users, and then checking if there is any + # overlap. + sql = f""" + SELECT DISTINCT b.state_key + FROM ( + SELECT room_id FROM current_state_events + WHERE type = 'm.room.member' AND (membership = 'join' OR membership = 'invite') AND state_key = ? + ) AS a + INNER JOIN ( + SELECT room_id, state_key FROM current_state_events + WHERE type = 'm.room.member' AND (membership = 'join' OR membership = 'invite') AND {clause} + ) AS b using (room_id) + """ + + txn.execute(sql, (user_id, *args)) + return {u: True for (u,) in txn} + + to_return = {} + for batch_user_ids in batch_iter(other_user_ids, 1000): + res = await self.db_pool.runInteraction( + "do_users_share_a_room_joined_or_invited", + do_users_share_a_room_joined_or_invited_txn, + batch_user_ids, + ) + to_return.update(res) + + return to_return + + async def do_users_share_a_room_joined_or_invited( + self, user_id: str, other_user_ids: Collection[str] + ) -> Set[str]: + """Return the set of users who share a room with the first users via being either joined or invited""" + + user_dict = await self._do_users_share_a_room_joined_or_invited( + user_id, other_user_ids + ) + + return {u for u, share_room in user_dict.items() if share_room} + async def get_users_who_share_room_with_user(self, user_id: str) -> Set[str]: """Returns the set of users who share a room with `user_id`""" room_ids = await self.get_rooms_for_user(user_id) @@ -1026,7 +1136,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): AND room_id = ? """ txn.execute(sql, (room_id,)) - return {d for d, in txn} + return {d for (d,) in txn} return await self.db_pool.runInteraction( "get_current_hosts_in_room", get_current_hosts_in_room_txn @@ -1094,7 +1204,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): """ txn.execute(sql, (room_id,)) # `server_domain` will be `NULL` for malformed MXIDs with no colons. - return tuple(d for d, in txn if d is not None) + return tuple(d for (d,) in txn if d is not None) return await self.db_pool.runInteraction( "get_current_hosts_in_room_ordered", get_current_hosts_in_room_ordered_txn @@ -1311,9 +1421,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): room_id = ? AND membership = ? AND NOT (%s) LIMIT 1 - """ % ( - clause, - ) + """ % (clause,) def _is_local_host_in_room_ignoring_users_txn( txn: LoggingTransaction, @@ -1337,11 +1445,23 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): keyvalues={"user_id": user_id, "room_id": room_id}, updatevalues={"forgotten": 1}, ) + # Handle updating the `sliding_sync_membership_snapshots` table + self.db_pool.simple_update_txn( + txn, + table="sliding_sync_membership_snapshots", + keyvalues={"user_id": user_id, "room_id": room_id}, + updatevalues={"forgotten": 1}, + ) self._invalidate_cache_and_stream(txn, self.did_forget, (user_id, room_id)) self._invalidate_cache_and_stream( txn, self.get_forgotten_rooms_for_user, (user_id,) ) + self._invalidate_cache_and_stream( + txn, + self.get_sliding_sync_rooms_for_user_from_membership_snapshots, + (user_id,), + ) await self.db_pool.runInteraction("forget_membership", f) @@ -1371,6 +1491,360 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): desc="room_forgetter_stream_pos", ) + @cached(iterable=True, max_entries=10000) + async def get_sliding_sync_rooms_for_user_from_membership_snapshots( + self, user_id: str + ) -> Mapping[str, RoomsForUserSlidingSync]: + """ + Get all the rooms for a user to handle a sliding sync request from the + `sliding_sync_membership_snapshots` table. These will be current memberships and + need to be rewound to the token range. + + Ignores forgotten rooms and rooms that the user has left themselves. + + Args: + user_id: The user ID to get the rooms for. + + Returns: + Map from room ID to membership info + """ + + def _txn( + txn: LoggingTransaction, + ) -> Dict[str, RoomsForUserSlidingSync]: + # XXX: If you use any new columns that can change (like from + # `sliding_sync_joined_rooms` or `forgotten`), make sure to bust the + # `get_sliding_sync_rooms_for_user_from_membership_snapshots` cache in the + # appropriate places (and add tests). + sql = """ + SELECT m.room_id, m.sender, m.membership, m.membership_event_id, + r.room_version, + m.event_instance_name, m.event_stream_ordering, + m.has_known_state, + COALESCE(j.room_type, m.room_type), + COALESCE(j.is_encrypted, m.is_encrypted) + FROM sliding_sync_membership_snapshots AS m + INNER JOIN rooms AS r USING (room_id) + LEFT JOIN sliding_sync_joined_rooms AS j ON (j.room_id = m.room_id AND m.membership = 'join') + WHERE user_id = ? + AND m.forgotten = 0 + AND (m.membership != 'leave' OR m.user_id != m.sender) + """ + txn.execute(sql, (user_id,)) + + return { + row[0]: RoomsForUserSlidingSync( + room_id=row[0], + sender=row[1], + membership=row[2], + event_id=row[3], + room_version_id=row[4], + event_pos=PersistedEventPosition(row[5], row[6]), + has_known_state=bool(row[7]), + room_type=row[8], + is_encrypted=bool(row[9]), + ) + for row in txn + # We filter out unknown room versions proactively. They + # shouldn't go down sync and their metadata may be in a broken + # state (causing errors). + if row[4] in KNOWN_ROOM_VERSIONS + } + + return await self.db_pool.runInteraction( + "get_sliding_sync_rooms_for_user_from_membership_snapshots", + _txn, + ) + + async def get_sliding_sync_self_leave_rooms_after_to_token( + self, + user_id: str, + to_token: StreamToken, + ) -> Dict[str, RoomsForUserSlidingSync]: + """ + Get all the self-leave rooms for a user after the `to_token` (outside the token + range) that are potentially relevant[1] and needed to handle a sliding sync + request. The results are from the `sliding_sync_membership_snapshots` table and + will be current memberships and need to be rewound to the token range. + + [1] If a leave happens after the token range, we may have still been joined (or + any non-self-leave which is relevant to sync) to the room before so we need to + include it in the list of potentially relevant rooms and apply + our rewind logic (outside of this function) to see if it's actually relevant. + + This is basically a sister-function to + `get_sliding_sync_rooms_for_user_from_membership_snapshots`. We could + alternatively incorporate this logic into + `get_sliding_sync_rooms_for_user_from_membership_snapshots` but those results + are cached and the `to_token` isn't very cache friendly (people are constantly + requesting with new tokens) so we separate it out here. + + Args: + user_id: The user ID to get the rooms for. + to_token: Any self-leave memberships after this position will be returned. + + Returns: + Map from room ID to membership info + """ + # TODO: Potential to check + # `self._membership_stream_cache.has_entity_changed(...)` as an early-return + # shortcut. + + def _txn( + txn: LoggingTransaction, + ) -> Dict[str, RoomsForUserSlidingSync]: + sql = """ + SELECT m.room_id, m.sender, m.membership, m.membership_event_id, + r.room_version, + m.event_instance_name, m.event_stream_ordering, + m.has_known_state, + m.room_type, + m.is_encrypted + FROM sliding_sync_membership_snapshots AS m + INNER JOIN rooms AS r USING (room_id) + WHERE user_id = ? + AND m.forgotten = 0 + AND m.membership = 'leave' + AND m.user_id = m.sender + AND (m.event_stream_ordering > ?) + """ + # If a leave happens after the token range, we may have still been joined + # (or any non-self-leave which is relevant to sync) to the room before so we + # need to include it in the list of potentially relevant rooms and apply our + # rewind logic (outside of this function). + # + # To handle tokens with a non-empty instance_map we fetch more + # results than necessary and then filter down + min_to_token_position = to_token.room_key.stream + txn.execute(sql, (user_id, min_to_token_position)) + + # Map from room_id to membership info + room_membership_for_user_map: Dict[str, RoomsForUserSlidingSync] = {} + for row in txn: + room_for_user = RoomsForUserSlidingSync( + room_id=row[0], + sender=row[1], + membership=row[2], + event_id=row[3], + room_version_id=row[4], + event_pos=PersistedEventPosition(row[5], row[6]), + has_known_state=bool(row[7]), + room_type=row[8], + is_encrypted=bool(row[9]), + ) + + # We filter out unknown room versions proactively. They shouldn't go + # down sync and their metadata may be in a broken state (causing + # errors). + if row[4] not in KNOWN_ROOM_VERSIONS: + continue + + # We only want to include the self-leave membership if it happened after + # the token range. + # + # Since the database pulls out more than necessary, we need to filter it + # down here. + if _filter_results_by_stream( + lower_token=None, + upper_token=to_token.room_key, + instance_name=room_for_user.event_pos.instance_name, + stream_ordering=room_for_user.event_pos.stream, + ): + continue + + room_membership_for_user_map[room_for_user.room_id] = room_for_user + + return room_membership_for_user_map + + return await self.db_pool.runInteraction( + "get_sliding_sync_self_leave_rooms_after_to_token", + _txn, + ) + + async def get_sliding_sync_room_for_user( + self, user_id: str, room_id: str + ) -> Optional[RoomsForUserSlidingSync]: + """Get the sliding sync room entry for the given user and room.""" + + def get_sliding_sync_room_for_user_txn( + txn: LoggingTransaction, + ) -> Optional[RoomsForUserSlidingSync]: + sql = """ + SELECT m.room_id, m.sender, m.membership, m.membership_event_id, + r.room_version, + m.event_instance_name, m.event_stream_ordering, + m.has_known_state, + COALESCE(j.room_type, m.room_type), + COALESCE(j.is_encrypted, m.is_encrypted) + FROM sliding_sync_membership_snapshots AS m + INNER JOIN rooms AS r USING (room_id) + LEFT JOIN sliding_sync_joined_rooms AS j ON (j.room_id = m.room_id AND m.membership = 'join') + WHERE user_id = ? + AND m.forgotten = 0 + AND m.room_id = ? + """ + txn.execute(sql, (user_id, room_id)) + row = txn.fetchone() + if not row: + return None + + return RoomsForUserSlidingSync( + room_id=row[0], + sender=row[1], + membership=row[2], + event_id=row[3], + room_version_id=row[4], + event_pos=PersistedEventPosition(row[5], row[6]), + has_known_state=bool(row[7]), + room_type=row[8], + is_encrypted=row[9], + ) + + return await self.db_pool.runInteraction( + "get_sliding_sync_room_for_user", get_sliding_sync_room_for_user_txn + ) + + async def get_sliding_sync_room_for_user_batch( + self, user_id: str, room_ids: StrCollection + ) -> Dict[str, RoomsForUserSlidingSync]: + """Get the sliding sync room entry for the given user and rooms.""" + + if not room_ids: + return {} + + def get_sliding_sync_room_for_user_batch_txn( + txn: LoggingTransaction, + ) -> Dict[str, RoomsForUserSlidingSync]: + clause, args = make_in_list_sql_clause( + self.database_engine, "m.room_id", room_ids + ) + sql = f""" + SELECT m.room_id, m.sender, m.membership, m.membership_event_id, + r.room_version, + m.event_instance_name, m.event_stream_ordering, + m.has_known_state, + COALESCE(j.room_type, m.room_type), + COALESCE(j.is_encrypted, m.is_encrypted) + FROM sliding_sync_membership_snapshots AS m + INNER JOIN rooms AS r USING (room_id) + LEFT JOIN sliding_sync_joined_rooms AS j ON (j.room_id = m.room_id AND m.membership = 'join') + WHERE m.forgotten = 0 + AND {clause} + AND user_id = ? + """ + args.append(user_id) + txn.execute(sql, args) + + return { + row[0]: RoomsForUserSlidingSync( + room_id=row[0], + sender=row[1], + membership=row[2], + event_id=row[3], + room_version_id=row[4], + event_pos=PersistedEventPosition(row[5], row[6]), + has_known_state=bool(row[7]), + room_type=row[8], + is_encrypted=row[9], + ) + for row in txn + } + + return await self.db_pool.runInteraction( + "get_sliding_sync_room_for_user_batch", + get_sliding_sync_room_for_user_batch_txn, + ) + + async def get_rooms_for_user_by_date( + self, user_id: str, from_ts: int + ) -> FrozenSet[str]: + """ + Fetch a list of rooms that the user has joined at or after the given timestamp, including + those they subsequently have left/been banned from. + + Args: + user_id: user ID of the user to search for + from_ts: a timestamp in ms from the unix epoch at which to begin the search at + """ + + def _get_rooms_for_user_by_join_date_txn( + txn: LoggingTransaction, user_id: str, timestamp: int + ) -> frozenset: + sql = """ + SELECT rm.room_id + FROM room_memberships AS rm + INNER JOIN events AS e USING (event_id) + WHERE rm.user_id = ? + AND rm.membership = 'join' + AND e.type = 'm.room.member' + AND e.received_ts >= ? + """ + txn.execute(sql, (user_id, timestamp)) + return frozenset([r[0] for r in txn]) + + return await self.db_pool.runInteraction( + "_get_rooms_for_user_by_join_date_txn", + _get_rooms_for_user_by_join_date_txn, + user_id, + from_ts, + ) + + async def set_room_participation(self, user_id: str, room_id: str) -> None: + """ + Record the provided user as participating in the given room + + Args: + user_id: the user ID of the user + room_id: ID of the room to set the participant in + """ + + def _set_room_participation_txn( + txn: LoggingTransaction, user_id: str, room_id: str + ) -> None: + sql = """ + UPDATE room_memberships + SET participant = true + WHERE event_id IN ( + SELECT event_id FROM local_current_membership + WHERE user_id = ? AND room_id = ? + ) + AND NOT participant + """ + txn.execute(sql, (user_id, room_id)) + + await self.db_pool.runInteraction( + "_set_room_participation_txn", _set_room_participation_txn, user_id, room_id + ) + + async def get_room_participation(self, user_id: str, room_id: str) -> bool: + """ + Check whether a user is listed as a participant in a room + + Args: + user_id: user ID of the user + room_id: ID of the room to check in + """ + + def _get_room_participation_txn( + txn: LoggingTransaction, user_id: str, room_id: str + ) -> bool: + sql = """ + SELECT participant + FROM local_current_membership AS l + INNER JOIN room_memberships AS r USING (event_id) + WHERE l.user_id = ? + AND l.room_id = ? + """ + txn.execute(sql, (user_id, room_id)) + res = txn.fetchone() + if res: + return res[0] + return False + + return await self.db_pool.runInteraction( + "_get_room_participation_txn", _get_room_participation_txn, user_id, room_id + ) + class RoomMemberBackgroundUpdateStore(SQLBaseStore): def __init__( @@ -1405,10 +1879,12 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore): self, progress: JsonDict, batch_size: int ) -> int: target_min_stream_id = progress.get( - "target_min_stream_id_inclusive", self._min_stream_order_on_start # type: ignore[attr-defined] + "target_min_stream_id_inclusive", + self._min_stream_order_on_start, # type: ignore[attr-defined] ) max_stream_id = progress.get( - "max_stream_id_exclusive", self._stream_order_on_start + 1 # type: ignore[attr-defined] + "max_stream_id_exclusive", + self._stream_order_on_start + 1, # type: ignore[attr-defined] ) def add_membership_profile_txn(txn: LoggingTransaction) -> int: diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 20fcfd3122..1d5c5e72ff 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py
@@ -94,7 +94,7 @@ class SearchWorkerStore(SQLBaseStore): VALUES (?,?,?,to_tsvector('english', ?),?,?) """ - args1 = ( + args1 = [ ( entry.event_id, entry.room_id, @@ -104,7 +104,7 @@ class SearchWorkerStore(SQLBaseStore): entry.origin_server_ts, ) for entry in entries - ) + ] txn.execute_batch(sql, args1) @@ -177,9 +177,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): AND (%s) ORDER BY stream_ordering DESC LIMIT ? - """ % ( - " OR ".join("type = '%s'" % (t,) for t in TYPES), - ) + """ % (" OR ".join("type = '%s'" % (t,) for t in TYPES),) txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) diff --git a/synapse/storage/databases/main/sliding_sync.py b/synapse/storage/databases/main/sliding_sync.py new file mode 100644
index 0000000000..6a62b11d1e --- /dev/null +++ b/synapse/storage/databases/main/sliding_sync.py
@@ -0,0 +1,603 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2023, 2025 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# <https://www.gnu.org/licenses/agpl-3.0.html>. +# + + +import logging +from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Set, cast + +import attr + +from synapse.api.errors import SlidingSyncUnknownPosition +from synapse.logging.opentracing import log_kv +from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) +from synapse.types import MultiWriterStreamToken, RoomStreamToken +from synapse.types.handlers.sliding_sync import ( + HaveSentRoom, + HaveSentRoomFlag, + MutablePerConnectionState, + PerConnectionState, + RoomStatusMap, + RoomSyncConfig, +) +from synapse.util import json_encoder +from synapse.util.caches.descriptors import cached + +if TYPE_CHECKING: + from synapse.server import HomeServer + from synapse.storage.databases.main import DataStore + +logger = logging.getLogger(__name__) + + +class SlidingSyncStore(SQLBaseStore): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): + super().__init__(database, db_conn, hs) + + self.db_pool.updates.register_background_index_update( + update_name="sliding_sync_connection_room_configs_required_state_id_idx", + index_name="sliding_sync_connection_room_configs_required_state_id_idx", + table="sliding_sync_connection_room_configs", + columns=("required_state_id",), + ) + + self.db_pool.updates.register_background_index_update( + update_name="sliding_sync_membership_snapshots_membership_event_id_idx", + index_name="sliding_sync_membership_snapshots_membership_event_id_idx", + table="sliding_sync_membership_snapshots", + columns=("membership_event_id",), + ) + + self.db_pool.updates.register_background_index_update( + update_name="sliding_sync_membership_snapshots_user_id_stream_ordering", + index_name="sliding_sync_membership_snapshots_user_id_stream_ordering", + table="sliding_sync_membership_snapshots", + columns=("user_id", "event_stream_ordering"), + replaces_index="sliding_sync_membership_snapshots_user_id", + ) + + async def get_latest_bump_stamp_for_room( + self, + room_id: str, + ) -> Optional[int]: + """ + Get the `bump_stamp` for the room. + + The `bump_stamp` is the `stream_ordering` of the last event according to the + `bump_event_types`. This helps clients sort more readily without them needing to + pull in a bunch of the timeline to determine the last activity. + `bump_event_types` is a thing because for example, we don't want display name + changes to mark the room as unread and bump it to the top. For encrypted rooms, + we just have to consider any activity as a bump because we can't see the content + and the client has to figure it out for themselves. + + This should only be called where the server is participating + in the room (someone local is joined). + + Returns: + The `bump_stamp` for the room (which can be `None`). + """ + + return cast( + Optional[int], + await self.db_pool.simple_select_one_onecol( + table="sliding_sync_joined_rooms", + keyvalues={"room_id": room_id}, + retcol="bump_stamp", + # FIXME: This should be `False` once we bump `SCHEMA_COMPAT_VERSION` and run the + # foreground update for + # `sliding_sync_joined_rooms`/`sliding_sync_membership_snapshots` (tracked + # by https://github.com/element-hq/synapse/issues/17623) + # + # The should be `allow_none=False` in the future because event though + # `bump_stamp` itself can be `None`, we should have a row in the + # `sliding_sync_joined_rooms` table for any joined room. + allow_none=True, + ), + ) + + async def persist_per_connection_state( + self, + user_id: str, + device_id: str, + conn_id: str, + previous_connection_position: Optional[int], + per_connection_state: "MutablePerConnectionState", + ) -> int: + """Persist updates to the per-connection state for a sliding sync + connection. + + Returns: + The connection position of the newly persisted state. + """ + + # This cast is safe because the downstream code only cares about + # `store.get_id_for_instance(...)` and `StreamWorkerStore` is mixed + # alongside `SlidingSyncStore` wherever we create a store. + store = cast("DataStore", self) + + return await self.db_pool.runInteraction( + "persist_per_connection_state", + self.persist_per_connection_state_txn, + user_id=user_id, + device_id=device_id, + conn_id=conn_id, + previous_connection_position=previous_connection_position, + per_connection_state=await PerConnectionStateDB.from_state( + per_connection_state, store + ), + ) + + def persist_per_connection_state_txn( + self, + txn: LoggingTransaction, + user_id: str, + device_id: str, + conn_id: str, + previous_connection_position: Optional[int], + per_connection_state: "PerConnectionStateDB", + ) -> int: + # First we fetch (or create) the connection key associated with the + # previous connection position. + if previous_connection_position is not None: + # The `previous_connection_position` is a user-supplied value, so we + # need to make sure that the one they supplied is actually theirs. + sql = """ + SELECT connection_key + FROM sliding_sync_connection_positions + INNER JOIN sliding_sync_connections USING (connection_key) + WHERE + connection_position = ? + AND user_id = ? AND effective_device_id = ? AND conn_id = ? + """ + txn.execute( + sql, (previous_connection_position, user_id, device_id, conn_id) + ) + row = txn.fetchone() + if row is None: + raise SlidingSyncUnknownPosition() + + (connection_key,) = row + else: + # We're restarting the connection, so we clear the previous existing data we + # used to track it. We do this here to ensure that if we get lots of + # one-shot requests we don't stack up lots of entries. We have `ON DELETE + # CASCADE` setup on the dependent tables so this will clear out all the + # associated data. + self.db_pool.simple_delete_txn( + txn, + table="sliding_sync_connections", + keyvalues={ + "user_id": user_id, + "effective_device_id": device_id, + "conn_id": conn_id, + }, + ) + + (connection_key,) = self.db_pool.simple_insert_returning_txn( + txn, + table="sliding_sync_connections", + values={ + "user_id": user_id, + "effective_device_id": device_id, + "conn_id": conn_id, + "created_ts": self._clock.time_msec(), + }, + returning=("connection_key",), + ) + + # Define a new connection position for the updates + (connection_position,) = self.db_pool.simple_insert_returning_txn( + txn, + table="sliding_sync_connection_positions", + values={ + "connection_key": connection_key, + "created_ts": self._clock.time_msec(), + }, + returning=("connection_position",), + ) + + # We need to deduplicate the `required_state` JSON. We do this by + # fetching all JSON associated with the connection and comparing that + # with the updates to `required_state` + + # Dict from required state json -> required state ID + required_state_to_id: Dict[str, int] = {} + if previous_connection_position is not None: + rows = self.db_pool.simple_select_list_txn( + txn, + table="sliding_sync_connection_required_state", + keyvalues={"connection_key": connection_key}, + retcols=("required_state_id", "required_state"), + ) + for required_state_id, required_state in rows: + required_state_to_id[required_state] = required_state_id + + room_to_state_ids: Dict[str, int] = {} + unique_required_state: Dict[str, List[str]] = {} + for room_id, room_state in per_connection_state.room_configs.items(): + serialized_state = json_encoder.encode( + # We store the required state as a sorted list of event type / + # state key tuples. + sorted( + (event_type, state_key) + for event_type, state_keys in room_state.required_state_map.items() + for state_key in state_keys + ) + ) + + existing_state_id = required_state_to_id.get(serialized_state) + if existing_state_id is not None: + room_to_state_ids[room_id] = existing_state_id + else: + unique_required_state.setdefault(serialized_state, []).append(room_id) + + # Insert any new `required_state` json we haven't previously seen. + for serialized_required_state, room_ids in unique_required_state.items(): + (required_state_id,) = self.db_pool.simple_insert_returning_txn( + txn, + table="sliding_sync_connection_required_state", + values={ + "connection_key": connection_key, + "required_state": serialized_required_state, + }, + returning=("required_state_id",), + ) + for room_id in room_ids: + room_to_state_ids[room_id] = required_state_id + + # Copy over state from the previous connection position (we'll overwrite + # these rows with any changes). + if previous_connection_position is not None: + sql = """ + INSERT INTO sliding_sync_connection_streams + (connection_position, stream, room_id, room_status, last_token) + SELECT ?, stream, room_id, room_status, last_token + FROM sliding_sync_connection_streams + WHERE connection_position = ? + """ + txn.execute(sql, (connection_position, previous_connection_position)) + + sql = """ + INSERT INTO sliding_sync_connection_room_configs + (connection_position, room_id, timeline_limit, required_state_id) + SELECT ?, room_id, timeline_limit, required_state_id + FROM sliding_sync_connection_room_configs + WHERE connection_position = ? + """ + txn.execute(sql, (connection_position, previous_connection_position)) + + # We now upsert the changes to the various streams. + key_values = [] + value_values = [] + for room_id, have_sent_room in per_connection_state.rooms._statuses.items(): + key_values.append((connection_position, "rooms", room_id)) + value_values.append( + (have_sent_room.status.value, have_sent_room.last_token) + ) + + for room_id, have_sent_room in per_connection_state.receipts._statuses.items(): + key_values.append((connection_position, "receipts", room_id)) + value_values.append( + (have_sent_room.status.value, have_sent_room.last_token) + ) + + for ( + room_id, + have_sent_room, + ) in per_connection_state.account_data._statuses.items(): + key_values.append((connection_position, "account_data", room_id)) + value_values.append( + (have_sent_room.status.value, have_sent_room.last_token) + ) + + self.db_pool.simple_upsert_many_txn( + txn, + table="sliding_sync_connection_streams", + key_names=( + "connection_position", + "stream", + "room_id", + ), + key_values=key_values, + value_names=( + "room_status", + "last_token", + ), + value_values=value_values, + ) + + # ... and upsert changes to the room configs. + keys = [] + values = [] + for room_id, room_config in per_connection_state.room_configs.items(): + keys.append((connection_position, room_id)) + values.append((room_config.timeline_limit, room_to_state_ids[room_id])) + + self.db_pool.simple_upsert_many_txn( + txn, + table="sliding_sync_connection_room_configs", + key_names=( + "connection_position", + "room_id", + ), + key_values=keys, + value_names=( + "timeline_limit", + "required_state_id", + ), + value_values=values, + ) + + return connection_position + + @cached(iterable=True, max_entries=100000) + async def get_and_clear_connection_positions( + self, user_id: str, device_id: str, conn_id: str, connection_position: int + ) -> "PerConnectionState": + """Get the per-connection state for the given connection position.""" + + per_connection_state_db = await self.db_pool.runInteraction( + "get_and_clear_connection_positions", + self._get_and_clear_connection_positions_txn, + user_id=user_id, + device_id=device_id, + conn_id=conn_id, + connection_position=connection_position, + ) + + # This cast is safe because the downstream code only cares about + # `store.get_id_for_instance(...)` and `StreamWorkerStore` is mixed + # alongside `SlidingSyncStore` wherever we create a store. + store = cast("DataStore", self) + + return await per_connection_state_db.to_state(store) + + def _get_and_clear_connection_positions_txn( + self, + txn: LoggingTransaction, + user_id: str, + device_id: str, + conn_id: str, + connection_position: int, + ) -> "PerConnectionStateDB": + # The `previous_connection_position` is a user-supplied value, so we + # need to make sure that the one they supplied is actually theirs. + sql = """ + SELECT connection_key + FROM sliding_sync_connection_positions + INNER JOIN sliding_sync_connections USING (connection_key) + WHERE + connection_position = ? + AND user_id = ? AND effective_device_id = ? AND conn_id = ? + """ + txn.execute(sql, (connection_position, user_id, device_id, conn_id)) + row = txn.fetchone() + if row is None: + raise SlidingSyncUnknownPosition() + + (connection_key,) = row + + # Now that we have seen the client has received and used the connection + # position, we can delete all the other connection positions. + sql = """ + DELETE FROM sliding_sync_connection_positions + WHERE connection_key = ? AND connection_position != ? + """ + txn.execute(sql, (connection_key, connection_position)) + + # Fetch and create a mapping from required state ID to the actual + # required state for the connection. + rows = self.db_pool.simple_select_list_txn( + txn, + table="sliding_sync_connection_required_state", + keyvalues={"connection_key": connection_key}, + retcols=( + "required_state_id", + "required_state", + ), + ) + + required_state_map: Dict[int, Dict[str, Set[str]]] = {} + for row in rows: + state = required_state_map[row[0]] = {} + for event_type, state_key in db_to_json(row[1]): + state.setdefault(event_type, set()).add(state_key) + + # Get all the room configs, looking up the required state from the map + # above. + room_config_rows = self.db_pool.simple_select_list_txn( + txn, + table="sliding_sync_connection_room_configs", + keyvalues={"connection_position": connection_position}, + retcols=( + "room_id", + "timeline_limit", + "required_state_id", + ), + ) + + room_configs: Dict[str, RoomSyncConfig] = {} + for ( + room_id, + timeline_limit, + required_state_id, + ) in room_config_rows: + room_configs[room_id] = RoomSyncConfig( + timeline_limit=timeline_limit, + required_state_map=required_state_map[required_state_id], + ) + + # Now look up the per-room stream data. + rooms: Dict[str, HaveSentRoom[str]] = {} + receipts: Dict[str, HaveSentRoom[str]] = {} + account_data: Dict[str, HaveSentRoom[str]] = {} + + receipt_rows = self.db_pool.simple_select_list_txn( + txn, + table="sliding_sync_connection_streams", + keyvalues={"connection_position": connection_position}, + retcols=( + "stream", + "room_id", + "room_status", + "last_token", + ), + ) + for stream, room_id, room_status, last_token in receipt_rows: + have_sent_room: HaveSentRoom[str] = HaveSentRoom( + status=HaveSentRoomFlag(room_status), last_token=last_token + ) + if stream == "rooms": + rooms[room_id] = have_sent_room + elif stream == "receipts": + receipts[room_id] = have_sent_room + elif stream == "account_data": + account_data[room_id] = have_sent_room + else: + # For forwards compatibility we ignore unknown streams, as in + # future we want to be able to easily add more stream types. + logger.warning("Unrecognized sliding sync stream in DB %r", stream) + + return PerConnectionStateDB( + rooms=RoomStatusMap(rooms), + receipts=RoomStatusMap(receipts), + account_data=RoomStatusMap(account_data), + room_configs=room_configs, + ) + + +@attr.s(auto_attribs=True, frozen=True) +class PerConnectionStateDB: + """An equivalent to `PerConnectionState` that holds data in a format stored + in the DB. + + The principle difference is that the tokens for the different streams are + serialized to strings. + + When persisting this *only* contains updates to the state. + """ + + rooms: "RoomStatusMap[str]" + receipts: "RoomStatusMap[str]" + account_data: "RoomStatusMap[str]" + + room_configs: Mapping[str, "RoomSyncConfig"] + + @staticmethod + async def from_state( + per_connection_state: "MutablePerConnectionState", store: "DataStore" + ) -> "PerConnectionStateDB": + """Convert from a standard `PerConnectionState`""" + rooms = { + room_id: HaveSentRoom( + status=status.status, + last_token=( + await status.last_token.to_string(store) + if status.last_token is not None + else None + ), + ) + for room_id, status in per_connection_state.rooms.get_updates().items() + } + + receipts = { + room_id: HaveSentRoom( + status=status.status, + last_token=( + await status.last_token.to_string(store) + if status.last_token is not None + else None + ), + ) + for room_id, status in per_connection_state.receipts.get_updates().items() + } + + account_data = { + room_id: HaveSentRoom( + status=status.status, + last_token=( + str(status.last_token) if status.last_token is not None else None + ), + ) + for room_id, status in per_connection_state.account_data.get_updates().items() + } + + log_kv( + { + "rooms": rooms, + "receipts": receipts, + "account_data": account_data, + "room_configs": per_connection_state.room_configs.maps[0], + } + ) + + return PerConnectionStateDB( + rooms=RoomStatusMap(rooms), + receipts=RoomStatusMap(receipts), + account_data=RoomStatusMap(account_data), + room_configs=per_connection_state.room_configs.maps[0], + ) + + async def to_state(self, store: "DataStore") -> "PerConnectionState": + """Convert into a standard `PerConnectionState`""" + rooms = { + room_id: HaveSentRoom( + status=status.status, + last_token=( + await RoomStreamToken.parse(store, status.last_token) + if status.last_token is not None + else None + ), + ) + for room_id, status in self.rooms._statuses.items() + } + + receipts = { + room_id: HaveSentRoom( + status=status.status, + last_token=( + await MultiWriterStreamToken.parse(store, status.last_token) + if status.last_token is not None + else None + ), + ) + for room_id, status in self.receipts._statuses.items() + } + + account_data = { + room_id: HaveSentRoom( + status=status.status, + last_token=( + int(status.last_token) if status.last_token is not None else None + ), + ) + for room_id, status in self.account_data._statuses.items() + } + + return PerConnectionState( + rooms=RoomStatusMap(rooms), + receipts=RoomStatusMap(receipts), + account_data=RoomStatusMap(account_data), + room_configs=self.room_configs, + ) diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 62bc4600fb..788f7d1e32 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py
@@ -308,8 +308,24 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return create_event @cached(max_entries=10000) - async def get_room_type(self, room_id: str) -> Optional[str]: - raise NotImplementedError() + async def get_room_type(self, room_id: str) -> Union[Optional[str], Sentinel]: + """Fetch room type for given room. + + Since this function is cached, any missing values would be cached as + `None`. In order to distinguish between an unencrypted room that has + `None` encryption and a room that is unknown to the server where we + might want to omit the value (which would make it cached as `None`), + instead we use the sentinel value `ROOM_UNKNOWN_SENTINEL`. + """ + + try: + create_event = await self.get_create_event_for_room(room_id) + return create_event.content.get(EventContentFields.ROOM_TYPE) + except NotFoundError: + # We use the sentinel value to distinguish between `None` which is a + # valid room type and a room that is unknown to the server so the value + # is just unset. + return ROOM_UNKNOWN_SENTINEL @cachedList(cached_method_name="get_room_type", list_name="room_ids") async def bulk_get_room_type( @@ -535,7 +551,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): desc="check_if_events_in_current_state", ) - return frozenset(event_id for event_id, in rows) + return frozenset(event_id for (event_id,) in rows) # FIXME: how should this be cached? @cancellable @@ -556,10 +572,10 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): Returns: Map from type/state_key to event ID. """ + if state_filter is None: + state_filter = StateFilter.all() - where_clause, where_args = ( - state_filter or StateFilter.all() - ).make_sql_filter_clause() + where_clause, where_args = (state_filter).make_sql_filter_clause() if not where_clause: # We delegate to the cached version @@ -568,7 +584,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): def _get_filtered_current_state_ids_txn( txn: LoggingTransaction, ) -> StateMap[str]: - results = StateMapWrapper(state_filter=state_filter or StateFilter.all()) + results = StateMapWrapper(state_filter=state_filter) sql = """ SELECT type, state_key, event_id FROM current_state_events @@ -665,7 +681,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): context: EventContext, ) -> None: """Update the state group for a partial state event""" - async with self._un_partial_stated_events_stream_id_gen.get_next() as un_partial_state_event_stream_id: + async with ( + self._un_partial_stated_events_stream_id_gen.get_next() as un_partial_state_event_stream_id + ): await self.db_pool.runInteraction( "update_state_for_partial_state_event", self._update_state_for_partial_state_event_txn, @@ -736,6 +754,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index" DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events" + MEMBERS_CURRENT_STATE_UPDATE_NAME = "current_state_events_members_room_index" def __init__( self, @@ -764,6 +783,13 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): self.DELETE_CURRENT_STATE_UPDATE_NAME, self._background_remove_left_rooms, ) + self.db_pool.updates.register_background_index_update( + self.MEMBERS_CURRENT_STATE_UPDATE_NAME, + index_name="current_state_events_members_room_index", + table="current_state_events", + columns=["room_id", "membership"], + where_clause="type='m.room.member'", + ) async def _background_remove_left_rooms( self, progress: JsonDict, batch_size: int diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py
index 9ed39e688a..00f87cc3a1 100644 --- a/synapse/storage/databases/main/state_deltas.py +++ b/synapse/storage/databases/main/state_deltas.py
@@ -20,16 +20,25 @@ # import logging -from typing import List, Optional, Tuple +from typing import TYPE_CHECKING, List, Optional, Tuple import attr from synapse.logging.opentracing import trace from synapse.storage._base import SQLBaseStore -from synapse.storage.database import LoggingTransaction +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, + make_in_list_sql_clause, +) from synapse.storage.databases.main.stream import _filter_results_by_stream -from synapse.types import RoomStreamToken +from synapse.types import RoomStreamToken, StrCollection from synapse.util.caches.stream_change_cache import StreamChangeCache +from synapse.util.iterutils import batch_iter + +if TYPE_CHECKING: + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -53,6 +62,21 @@ class StateDeltasStore(SQLBaseStore): # attribute. TODO: can we get static analysis to enforce this? _curr_state_delta_stream_cache: StreamChangeCache + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): + super().__init__(database, db_conn, hs) + + self.db_pool.updates.register_background_index_update( + update_name="current_state_delta_stream_room_index", + index_name="current_state_delta_stream_room_idx", + table="current_state_delta_stream", + columns=("room_id", "stream_id"), + ) + async def get_partial_current_state_deltas( self, prev_stream_id: int, max_stream_id: int ) -> Tuple[int, List[StateDelta]]: @@ -74,9 +98,9 @@ class StateDeltasStore(SQLBaseStore): prev_stream_id = int(prev_stream_id) # check we're not going backwards - assert ( - prev_stream_id <= max_stream_id - ), f"New stream id {max_stream_id} is smaller than prev stream id {prev_stream_id}" + assert prev_stream_id <= max_stream_id, ( + f"New stream id {max_stream_id} is smaller than prev stream id {prev_stream_id}" + ) if not self._curr_state_delta_stream_cache.has_any_entity_changed( prev_stream_id @@ -160,38 +184,144 @@ class StateDeltasStore(SQLBaseStore): self._get_max_stream_id_in_current_state_deltas_txn, ) + def get_current_state_deltas_for_room_txn( + self, + txn: LoggingTransaction, + room_id: str, + *, + from_token: Optional[RoomStreamToken], + to_token: Optional[RoomStreamToken], + ) -> List[StateDelta]: + """ + Get the state deltas between two tokens. + + (> `from_token` and <= `to_token`) + """ + from_clause = "" + from_args = [] + if from_token is not None: + from_clause = "AND ? < stream_id" + from_args = [from_token.stream] + + to_clause = "" + to_args = [] + if to_token is not None: + to_clause = "AND stream_id <= ?" + to_args = [to_token.get_max_stream_pos()] + + sql = f""" + SELECT instance_name, stream_id, type, state_key, event_id, prev_event_id + FROM current_state_delta_stream + WHERE room_id = ? {from_clause} {to_clause} + ORDER BY stream_id ASC + """ + txn.execute(sql, [room_id] + from_args + to_args) + + return [ + StateDelta( + stream_id=row[1], + room_id=room_id, + event_type=row[2], + state_key=row[3], + event_id=row[4], + prev_event_id=row[5], + ) + for row in txn + if _filter_results_by_stream(from_token, to_token, row[0], row[1]) + ] + @trace async def get_current_state_deltas_for_room( - self, room_id: str, from_token: RoomStreamToken, to_token: RoomStreamToken + self, + room_id: str, + *, + from_token: Optional[RoomStreamToken], + to_token: Optional[RoomStreamToken], ) -> List[StateDelta]: - """Get the state deltas between two tokens.""" + """ + Get the state deltas between two tokens. + + (> `from_token` and <= `to_token`) + """ + # We can bail early if the `from_token` is after the `to_token` + if ( + to_token is not None + and from_token is not None + and to_token.is_before_or_eq(from_token) + ): + return [] - def get_current_state_deltas_for_room_txn( + if ( + from_token is not None + and not self._curr_state_delta_stream_cache.has_entity_changed( + room_id, from_token.stream + ) + ): + return [] + + return await self.db_pool.runInteraction( + "get_current_state_deltas_for_room", + self.get_current_state_deltas_for_room_txn, + room_id, + from_token=from_token, + to_token=to_token, + ) + + @trace + async def get_current_state_deltas_for_rooms( + self, + room_ids: StrCollection, + from_token: RoomStreamToken, + to_token: RoomStreamToken, + ) -> List[StateDelta]: + """Get the state deltas between two tokens for the set of rooms.""" + + room_ids = self._curr_state_delta_stream_cache.get_entities_changed( + room_ids, from_token.stream + ) + if not room_ids: + return [] + + def get_current_state_deltas_for_rooms_txn( txn: LoggingTransaction, + room_ids: StrCollection, ) -> List[StateDelta]: - sql = """ - SELECT instance_name, stream_id, type, state_key, event_id, prev_event_id + clause, args = make_in_list_sql_clause( + self.database_engine, "room_id", room_ids + ) + + sql = f""" + SELECT instance_name, stream_id, room_id, type, state_key, event_id, prev_event_id FROM current_state_delta_stream - WHERE room_id = ? AND ? < stream_id AND stream_id <= ? + WHERE {clause} AND ? < stream_id AND stream_id <= ? ORDER BY stream_id ASC """ - txn.execute( - sql, (room_id, from_token.stream, to_token.get_max_stream_pos()) - ) + args.append(from_token.stream) + args.append(to_token.get_max_stream_pos()) + + txn.execute(sql, args) return [ StateDelta( stream_id=row[1], - room_id=room_id, - event_type=row[2], - state_key=row[3], - event_id=row[4], - prev_event_id=row[5], + room_id=row[2], + event_type=row[3], + state_key=row[4], + event_id=row[5], + prev_event_id=row[6], ) for row in txn if _filter_results_by_stream(from_token, to_token, row[0], row[1]) ] - return await self.db_pool.runInteraction( - "get_current_state_deltas_for_room", get_current_state_deltas_for_room_txn - ) + results = [] + for batch in batch_iter(room_ids, 1000): + deltas = await self.db_pool.runInteraction( + "get_current_state_deltas_for_rooms", + get_current_state_deltas_for_rooms_txn, + batch, + ) + + results.extend(deltas) + + return results diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index e9f6a918c7..79c49e7fd9 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py
@@ -161,7 +161,7 @@ class StatsStore(StateDeltasStore): LIMIT ? """ txn.execute(sql, (last_user_id, batch_size)) - return [r for r, in txn] + return [r for (r,) in txn] users_to_work_on = await self.db_pool.runInteraction( "_populate_stats_process_users", _get_next_batch @@ -207,7 +207,7 @@ class StatsStore(StateDeltasStore): LIMIT ? """ txn.execute(sql, (last_room_id, batch_size)) - return [r for r, in txn] + return [r for (r,) in txn] rooms_to_work_on = await self.db_pool.runInteraction( "populate_stats_rooms_get_batch", _get_next_batch @@ -751,9 +751,7 @@ class StatsStore(StateDeltasStore): LEFT JOIN profiles AS p ON lmr.user_id = p.full_user_id {} GROUP BY lmr.user_id, displayname - """.format( - where_clause - ) + """.format(where_clause) # SQLite does not support SELECT COUNT(*) OVER() sql = """ diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 4989c960a6..3fda49f31f 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py
@@ -21,7 +21,7 @@ # # -""" This module is responsible for getting events from the DB for pagination +"""This module is responsible for getting events from the DB for pagination and event streaming. The order it returns events in depend on whether we are streaming forwards or @@ -50,6 +50,8 @@ from typing import ( Dict, Iterable, List, + Literal, + Mapping, Optional, Protocol, Set, @@ -60,7 +62,7 @@ from typing import ( import attr from immutabledict import immutabledict -from typing_extensions import Literal, assert_never +from typing_extensions import assert_never from twisted.internet import defer @@ -78,9 +80,10 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine +from synapse.storage.roommember import RoomsForUserStateReset from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.types import PersistedEventPosition, RoomStreamToken, StrCollection -from synapse.util.caches.descriptors import cached +from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.cancellation import cancellable from synapse.util.iterutils import batch_iter @@ -107,7 +110,7 @@ class PaginateFunction(Protocol): to_key: Optional[RoomStreamToken] = None, direction: Direction = Direction.BACKWARDS, limit: int = 0, - ) -> Tuple[List[EventBase], RoomStreamToken]: ... + ) -> Tuple[List[EventBase], RoomStreamToken, bool]: ... # Used as return values for pagination APIs @@ -451,6 +454,8 @@ def _filter_results_by_stream( stream_ordering falls between the two tokens (taking a None token to mean unbounded). + The token range is defined by > `lower_token` and <= `upper_token`. + Used to filter results from fetching events in the DB against the given tokens. This is necessary to handle the case where the tokens include position maps, which we handle by fetching more than necessary from the DB @@ -678,7 +683,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): to_key: Optional[RoomStreamToken] = None, direction: Direction = Direction.BACKWARDS, limit: int = 0, - ) -> Dict[str, Tuple[List[EventBase], RoomStreamToken]]: + ) -> Dict[str, Tuple[List[EventBase], RoomStreamToken, bool]]: """Get new room events in stream ordering since `from_key`. Args: @@ -694,6 +699,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): A map from room id to a tuple containing: - list of recent events in the room - stream ordering key for the start of the chunk of events returned. + - a boolean to indicate if there were more events but we hit the limit When Direction.FORWARDS: from_key < x <= to_key, (ascending order) When Direction.BACKWARDS: from_key >= x > to_key, (descending order) @@ -749,6 +755,48 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): if self._events_stream_cache.has_entity_changed(room_id, from_id) } + async def get_rooms_that_have_updates_since_sliding_sync_table( + self, + room_ids: StrCollection, + from_key: RoomStreamToken, + ) -> StrCollection: + """Return the rooms that probably have had updates since the given + token (changes that are > `from_key`).""" + # If the stream change cache is valid for the stream token, we can just + # use the result of that. + if from_key.stream >= self._events_stream_cache.get_earliest_known_position(): + return self._events_stream_cache.get_entities_changed( + room_ids, from_key.stream + ) + + def get_rooms_that_have_updates_since_sliding_sync_table_txn( + txn: LoggingTransaction, + ) -> StrCollection: + sql = """ + SELECT room_id + FROM sliding_sync_joined_rooms + WHERE {clause} + AND event_stream_ordering > ? + """ + + results: Set[str] = set() + for batch in batch_iter(room_ids, 1000): + clause, args = make_in_list_sql_clause( + self.database_engine, "room_id", batch + ) + + args.append(from_key.stream) + txn.execute(sql.format(clause=clause), args) + + results.update(row[0] for row in txn) + + return results + + return await self.db_pool.runInteraction( + "get_rooms_that_have_updates_since_sliding_sync_table", + get_rooms_that_have_updates_since_sliding_sync_table_txn, + ) + async def paginate_room_events_by_stream_ordering( self, *, @@ -757,7 +805,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): to_key: Optional[RoomStreamToken] = None, direction: Direction = Direction.BACKWARDS, limit: int = 0, - ) -> Tuple[List[EventBase], RoomStreamToken]: + ) -> Tuple[List[EventBase], RoomStreamToken, bool]: """ Paginate events by `stream_ordering` in the room from the `from_key` in the given `direction` to the `to_key` or `limit`. @@ -772,8 +820,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): limit: Maximum number of events to return Returns: - The results as a list of events and a token that points to the end - of the result set. If no events are returned then the end of the + The results as a list of events, a token that points to the end of + the result set, and a boolean to indicate if there were more events + but we hit the limit. If no events are returned then the end of the stream has been reached (i.e. there are no events between `from_key` and `to_key`). @@ -797,7 +846,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): and to_key.is_before_or_eq(from_key) ): # Token selection matches what we do below if there are no rows - return [], to_key if to_key else from_key + return [], to_key if to_key else from_key, False # Or vice-versa, if we're looking backwards and our `from_key` is already before # our `to_key`. elif ( @@ -806,7 +855,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): and from_key.is_before_or_eq(to_key) ): # Token selection matches what we do below if there are no rows - return [], to_key if to_key else from_key + return [], to_key if to_key else from_key, False # We can do a quick sanity check to see if any events have been sent in the room # since the earlier token. @@ -825,7 +874,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): if not has_changed: # Token selection matches what we do below if there are no rows - return [], to_key if to_key else from_key + return [], to_key if to_key else from_key, False order, from_bound, to_bound = generate_pagination_bounds( direction, from_key, to_key @@ -841,7 +890,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): engine=self.database_engine, ) - def f(txn: LoggingTransaction) -> List[_EventDictReturn]: + def f(txn: LoggingTransaction) -> Tuple[List[_EventDictReturn], bool]: sql = f""" SELECT event_id, instance_name, stream_ordering FROM events @@ -853,9 +902,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): """ txn.execute(sql, (room_id, 2 * limit)) + # Get all the rows and check if we hit the limit. + fetched_rows = txn.fetchall() + limited = len(fetched_rows) >= 2 * limit + rows = [ _EventDictReturn(event_id, None, stream_ordering) - for event_id, instance_name, stream_ordering in txn + for event_id, instance_name, stream_ordering in fetched_rows if _filter_results_by_stream( lower_token=( to_key if direction == Direction.BACKWARDS else from_key @@ -866,10 +919,17 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): instance_name=instance_name, stream_ordering=stream_ordering, ) - ][:limit] - return rows + ] + + if len(rows) > limit: + limited = True - rows = await self.db_pool.runInteraction("get_room_events_stream_for_room", f) + rows = rows[:limit] + return rows, limited + + rows, limited = await self.db_pool.runInteraction( + "get_room_events_stream_for_room", f + ) ret = await self.get_events_as_list( [r.event_id for r in rows], get_prev_content=True @@ -886,7 +946,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): # `_paginate_room_events_by_topological_ordering_txn(...)`) next_key = to_key if to_key else from_key - return ret, next_key + return ret, next_key, limited @trace async def get_current_state_delta_membership_changes_for_user( @@ -927,7 +987,17 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): Returns: All membership changes to the current state in the token range. Events are sorted by `stream_ordering` ascending. + + `event_id`/`sender` can be `None` when the server leaves a room (meaning + everyone locally left) or a state reset which removed the person from the + room. We can't tell the difference between the two cases with what's + available in the `current_state_delta_stream` table. To actually check for a + state reset, you need to check if a membership still exists in the room. """ + + assert from_key.topological is None + assert to_key.topological is None + # Start by ruling out cases where a DB query is not necessary. if from_key == to_key: return [] @@ -1038,6 +1108,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): membership=( membership if membership is not None else Membership.LEAVE ), + # This will also be null for the same reasons if `s.event_id = null` sender=sender, # Prev event prev_event_id=prev_event_id, @@ -1072,6 +1143,203 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): if membership_change.room_id not in room_ids_to_exclude ] + @trace + async def get_sliding_sync_membership_changes( + self, + user_id: str, + from_key: RoomStreamToken, + to_key: RoomStreamToken, + excluded_room_ids: Optional[AbstractSet[str]] = None, + ) -> Dict[str, RoomsForUserStateReset]: + """ + Fetch membership events that result in a meaningful membership change for a + given user. + + A meaningful membership changes is one where the `membership` value actually + changes. This means memberships changes from `join` to `join` (like a display + name change) will be filtered out since they result in no meaningful change. + + Note: This function only works with "live" tokens with `stream_ordering` only. + + We're looking for membership changes in the token range (> `from_key` and <= + `to_key`). + + Args: + user_id: The user ID to fetch membership events for. + from_key: The point in the stream to sync from (fetching events > this point). + to_key: The token to fetch rooms up to (fetching events <= this point). + excluded_room_ids: Optional list of room IDs to exclude from the results. + + Returns: + All meaningful membership changes to the current state in the token range. + Events are sorted by `stream_ordering` ascending. + + `event_id`/`sender` can be `None` when the server leaves a room (meaning + everyone locally left) or a state reset which removed the person from the + room. We can't tell the difference between the two cases with what's + available in the `current_state_delta_stream` table. To actually check for a + state reset, you need to check if a membership still exists in the room. + """ + + assert from_key.topological is None + assert to_key.topological is None + + # Start by ruling out cases where a DB query is not necessary. + if from_key == to_key: + return {} + + if from_key: + has_changed = self._membership_stream_cache.has_entity_changed( + user_id, int(from_key.stream) + ) + if not has_changed: + return {} + + room_ids_to_exclude: AbstractSet[str] = set() + if excluded_room_ids is not None: + room_ids_to_exclude = excluded_room_ids + + def f(txn: LoggingTransaction) -> Dict[str, RoomsForUserStateReset]: + # To handle tokens with a non-empty instance_map we fetch more + # results than necessary and then filter down + min_from_id = from_key.stream + max_to_id = to_key.get_max_stream_pos() + + # This query looks at membership changes in + # `sliding_sync_membership_snapshots` which will not include users + # that were state reset out of rooms; so we need to look for that + # case in `current_state_delta_stream`. + sql = """ + SELECT + room_id, + membership_event_id, + event_instance_name, + event_stream_ordering, + membership, + sender, + prev_membership, + room_version + FROM + ( + SELECT + s.room_id, + s.membership_event_id, + s.event_instance_name, + s.event_stream_ordering, + s.membership, + s.sender, + m_prev.membership AS prev_membership + FROM sliding_sync_membership_snapshots as s + LEFT JOIN event_edges AS e ON e.event_id = s.membership_event_id + LEFT JOIN room_memberships AS m_prev ON m_prev.event_id = e.prev_event_id + WHERE s.user_id = ? + + UNION ALL + + SELECT + s.room_id, + e.event_id, + s.instance_name, + s.stream_id, + m.membership, + e.sender, + m_prev.membership AS prev_membership + FROM current_state_delta_stream AS s + LEFT JOIN events AS e ON e.event_id = s.event_id + LEFT JOIN room_memberships AS m ON m.event_id = s.event_id + LEFT JOIN room_memberships AS m_prev ON m_prev.event_id = s.prev_event_id + WHERE + s.type = ? + AND s.state_key = ? + ) AS c + INNER JOIN rooms USING (room_id) + WHERE event_stream_ordering > ? AND event_stream_ordering <= ? + ORDER BY event_stream_ordering ASC + """ + + txn.execute( + sql, + (user_id, EventTypes.Member, user_id, min_from_id, max_to_id), + ) + + membership_changes: Dict[str, RoomsForUserStateReset] = {} + for ( + room_id, + membership_event_id, + event_instance_name, + event_stream_ordering, + membership, + sender, + prev_membership, + room_version_id, + ) in txn: + assert room_id is not None + assert event_stream_ordering is not None + + if room_id in room_ids_to_exclude: + continue + + if _filter_results_by_stream( + from_key, + to_key, + event_instance_name, + event_stream_ordering, + ): + # When the server leaves a room, it will insert new rows into the + # `current_state_delta_stream` table with `event_id = null` for all + # current state. This means we might already have a row for the + # leave event and then another for the same leave where the + # `event_id=null` but the `prev_event_id` is pointing back at the + # earlier leave event. We don't want to report the leave, if we + # already have a leave event. + if ( + membership_event_id is None + and prev_membership == Membership.LEAVE + ): + continue + + if membership_event_id is None and room_id in membership_changes: + # SUSPICIOUS: if we join a room and get state reset out of it + # in the same queried window, + # won't this ignore the 'state reset out of it' part? + continue + + # When `s.event_id = null`, we won't be able to get respective + # `room_membership` but can assume the user has left the room + # because this only happens when the server leaves a room + # (meaning everyone locally left) or a state reset which removed + # the person from the room. + membership = ( + membership if membership is not None else Membership.LEAVE + ) + + if membership == prev_membership: + # If `membership` and `prev_membership` are the same then this + # is not a meaningful change so we can skip it. + # An example of this happening is when the user changes their display name. + continue + + membership_change = RoomsForUserStateReset( + room_id=room_id, + sender=sender, + membership=membership, + event_id=membership_event_id, + event_pos=PersistedEventPosition( + event_instance_name, event_stream_ordering + ), + room_version_id=room_version_id, + ) + + membership_changes[room_id] = membership_change + + return membership_changes + + membership_changes = await self.db_pool.runInteraction( + "get_sliding_sync_membership_changes", f + ) + + return membership_changes + @cancellable async def get_membership_changes_for_user( self, @@ -1121,9 +1389,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): AND e.stream_ordering > ? AND e.stream_ordering <= ? %s ORDER BY e.stream_ordering ASC - """ % ( - ignore_room_clause, - ) + """ % (ignore_room_clause,) txn.execute(sql, args) @@ -1192,7 +1458,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): if limit == 0: return [], end_token - rows, token = await self.db_pool.runInteraction( + rows, token, _ = await self.db_pool.runInteraction( "get_recent_event_ids_for_room", self._paginate_room_events_by_topological_ordering_txn, room_id, @@ -1263,12 +1529,76 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return None + async def get_last_event_pos_in_room( + self, + room_id: str, + event_types: Optional[StrCollection] = None, + ) -> Optional[Tuple[str, PersistedEventPosition]]: + """ + Returns the ID and event position of the last event in a room. + + Based on `get_last_event_pos_in_room_before_stream_ordering(...)` + + Args: + room_id + event_types: Optional allowlist of event types to filter by + + Returns: + The ID of the most recent event and it's position, or None if there are no + events in the room that match the given event types. + """ + + def _get_last_event_pos_in_room_txn( + txn: LoggingTransaction, + ) -> Optional[Tuple[str, PersistedEventPosition]]: + event_type_clause = "" + event_type_args: List[str] = [] + if event_types is not None and len(event_types) > 0: + event_type_clause, event_type_args = make_in_list_sql_clause( + txn.database_engine, "type", event_types + ) + event_type_clause = f"AND {event_type_clause}" + + sql = f""" + SELECT event_id, stream_ordering, instance_name + FROM events + LEFT JOIN rejections USING (event_id) + WHERE room_id = ? + {event_type_clause} + AND NOT outlier + AND rejections.event_id IS NULL + ORDER BY stream_ordering DESC + LIMIT 1 + """ + + txn.execute( + sql, + [room_id] + event_type_args, + ) + + row = cast(Optional[Tuple[str, int, str]], txn.fetchone()) + if row is not None: + event_id, stream_ordering, instance_name = row + + return event_id, PersistedEventPosition( + # If instance_name is null we default to "master" + instance_name or "master", + stream_ordering, + ) + + return None + + return await self.db_pool.runInteraction( + "get_last_event_pos_in_room", + _get_last_event_pos_in_room_txn, + ) + @trace async def get_last_event_pos_in_room_before_stream_ordering( self, room_id: str, end_token: RoomStreamToken, - event_types: Optional[Collection[str]] = None, + event_types: Optional[StrCollection] = None, ) -> Optional[Tuple[str, PersistedEventPosition]]: """ Returns the ID and event position of the last event in a room at or before a @@ -1381,8 +1711,56 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): rooms """ + # First we just get the latest positions for the room, as the vast + # majority of them will be before the given end token anyway. By doing + # this we can cache most rooms. + uncapped_results = await self._bulk_get_max_event_pos(room_ids) + + # Check that the stream position for the rooms are from before the + # minimum position of the token. If not then we need to fetch more + # rows. + results: Dict[str, int] = {} + recheck_rooms: Set[str] = set() min_token = end_token.stream - max_token = end_token.get_max_stream_pos() + for room_id, stream in uncapped_results.items(): + if stream is None: + # Despite the function not directly setting None, the cache can! + # See: https://github.com/element-hq/synapse/issues/17726 + continue + if stream <= min_token: + results[room_id] = stream + else: + recheck_rooms.add(room_id) + + if not recheck_rooms: + return results + + # There shouldn't be many rooms that we need to recheck, so we do them + # one-by-one. + for room_id in recheck_rooms: + result = await self.get_last_event_pos_in_room_before_stream_ordering( + room_id, end_token + ) + if result is not None: + results[room_id] = result[1].stream + + return results + + @cached() + async def _get_max_event_pos(self, room_id: str) -> int: + raise NotImplementedError() + + @cachedList(cached_method_name="_get_max_event_pos", list_name="room_ids") + async def _bulk_get_max_event_pos( + self, room_ids: StrCollection + ) -> Mapping[str, Optional[int]]: + """Fetch the max position of a persisted event in the room.""" + + # We need to be careful not to return positions ahead of the current + # positions, so we get the current token now and cap our queries to it. + now_token = self.get_room_max_token() + max_pos = now_token.get_max_stream_pos() + results: Dict[str, int] = {} # First, we check for the rooms in the stream change cache to see if we @@ -1390,31 +1768,32 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): missing_room_ids: Set[str] = set() for room_id in room_ids: stream_pos = self._events_stream_cache.get_max_pos_of_last_change(room_id) - if stream_pos and stream_pos <= min_token: + if stream_pos is not None: results[room_id] = stream_pos else: missing_room_ids.add(room_id) + if not missing_room_ids: + return results + # Next, we query the stream position from the DB. At first we fetch all # positions less than the *max* stream pos in the token, then filter # them down. We do this as a) this is a cheaper query, and b) the vast # majority of rooms will have a latest token from before the min stream # pos. - def bulk_get_last_event_pos_txn( - txn: LoggingTransaction, batch_room_ids: StrCollection + def bulk_get_max_event_pos_fallback_txn( + txn: LoggingTransaction, batched_room_ids: StrCollection ) -> Dict[str, int]: - # This query fetches the latest stream position in the rooms before - # the given max position. clause, args = make_in_list_sql_clause( - self.database_engine, "room_id", batch_room_ids + self.database_engine, "room_id", batched_room_ids ) sql = f""" SELECT room_id, ( SELECT stream_ordering FROM events AS e LEFT JOIN rejections USING (event_id) WHERE e.room_id = r.room_id - AND stream_ordering <= ? + AND e.stream_ordering <= ? AND NOT outlier AND rejection_reason IS NULL ORDER BY stream_ordering DESC @@ -1423,72 +1802,55 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): FROM rooms AS r WHERE {clause} """ - txn.execute(sql, [max_token] + args) + txn.execute(sql, [max_pos] + args) return {row[0]: row[1] for row in txn} - recheck_rooms: Set[str] = set() - for batched in batch_iter(missing_room_ids, 1000): - result = await self.db_pool.runInteraction( - "bulk_get_last_event_pos_in_room_before_stream_ordering", - bulk_get_last_event_pos_txn, - batched, - ) - - # Check that the stream position for the rooms are from before the - # minimum position of the token. If not then we need to fetch more - # rows. - for room_id, stream in result.items(): - if stream <= min_token: - results[room_id] = stream - else: - recheck_rooms.add(room_id) - - if not recheck_rooms: - return results - - # For the remaining rooms we need to fetch all rows between the min and - # max stream positions in the end token, and filter out the rows that - # are after the end token. - # - # This query should be fast as the range between the min and max should - # be small. - - def bulk_get_last_event_pos_recheck_txn( - txn: LoggingTransaction, batch_room_ids: StrCollection + # It's easier to look at the `sliding_sync_joined_rooms` table and avoid all of + # the joins and sub-queries. + def bulk_get_max_event_pos_from_sliding_sync_tables_txn( + txn: LoggingTransaction, batched_room_ids: StrCollection ) -> Dict[str, int]: clause, args = make_in_list_sql_clause( - self.database_engine, "room_id", batch_room_ids + self.database_engine, "room_id", batched_room_ids ) sql = f""" - SELECT room_id, instance_name, stream_ordering - FROM events - WHERE ? < stream_ordering AND stream_ordering <= ? - AND NOT outlier - AND rejection_reason IS NULL - AND {clause} - ORDER BY stream_ordering ASC + SELECT room_id, event_stream_ordering + FROM sliding_sync_joined_rooms + WHERE {clause} + ORDER BY event_stream_ordering DESC """ - txn.execute(sql, [min_token, max_token] + args) - - # We take the max stream ordering that is less than the token. Since - # we ordered by stream ordering we just need to iterate through and - # take the last matching stream ordering. - txn_results: Dict[str, int] = {} - for row in txn: - room_id = row[0] - event_pos = PersistedEventPosition(row[1], row[2]) - if not event_pos.persisted_after(end_token): - txn_results[room_id] = event_pos.stream - - return txn_results - - for batched in batch_iter(recheck_rooms, 1000): - recheck_result = await self.db_pool.runInteraction( - "bulk_get_last_event_pos_in_room_before_stream_ordering_recheck", - bulk_get_last_event_pos_recheck_txn, - batched, + txn.execute(sql, args) + return {row[0]: row[1] for row in txn} + + recheck_rooms: Set[str] = set() + for batched in batch_iter(room_ids, 1000): + if await self.have_finished_sliding_sync_background_jobs(): + batch_results = await self.db_pool.runInteraction( + "bulk_get_max_event_pos_from_sliding_sync_tables_txn", + bulk_get_max_event_pos_from_sliding_sync_tables_txn, + batched, + ) + else: + batch_results = await self.db_pool.runInteraction( + "bulk_get_max_event_pos_fallback_txn", + bulk_get_max_event_pos_fallback_txn, + batched, + ) + for room_id, stream_ordering in batch_results.items(): + if stream_ordering <= now_token.stream: + results[room_id] = stream_ordering + else: + recheck_rooms.add(room_id) + + # We now need to handle rooms where the above query returned a stream + # position that was potentially too new. This should happen very rarely + # so we just query the rooms one-by-one + for room_id in recheck_rooms: + result = await self.get_last_event_pos_in_room_before_stream_ordering( + room_id, now_token ) - results.update(recheck_result) + if result is not None: + results[room_id] = result[1].stream return results @@ -1680,15 +2042,14 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): dict """ - stream_ordering, topological_ordering = cast( - Tuple[int, int], - self.db_pool.simple_select_one_txn( - txn, - "events", - keyvalues={"event_id": event_id, "room_id": room_id}, - retcols=["stream_ordering", "topological_ordering"], - ), + row = self.db_pool.simple_select_one_txn( + txn, + "events", + keyvalues={"event_id": event_id, "room_id": room_id}, + retcols=("stream_ordering", "topological_ordering"), ) + stream_ordering = int(row[0]) + topological_ordering = int(row[1]) # Paginating backwards includes the event at the token, but paginating # forward doesn't. @@ -1700,7 +2061,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): topological=topological_ordering, stream=stream_ordering ) - rows, start_token = self._paginate_room_events_by_topological_ordering_txn( + rows, start_token, _ = self._paginate_room_events_by_topological_ordering_txn( txn, room_id, before_token, @@ -1710,7 +2071,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): ) events_before = [r.event_id for r in rows] - rows, end_token = self._paginate_room_events_by_topological_ordering_txn( + rows, end_token, _ = self._paginate_room_events_by_topological_ordering_txn( txn, room_id, after_token, @@ -1882,7 +2243,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): direction: Direction = Direction.BACKWARDS, limit: int = 0, event_filter: Optional[Filter] = None, - ) -> Tuple[List[_EventDictReturn], RoomStreamToken]: + ) -> Tuple[List[_EventDictReturn], RoomStreamToken, bool]: """Returns list of events before or after a given token. Args: @@ -1897,10 +2258,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): those that match the filter. Returns: - A list of _EventDictReturn and a token that points to the end of the - result set. If no events are returned then the end of the stream has - been reached (i.e. there are no events between `from_token` and - `to_token`), or `limit` is zero. + A list of _EventDictReturn, a token that points to the end of the + result set, and a boolean to indicate if there were more events but + we hit the limit. If no events are returned then the end of the + stream has been reached (i.e. there are no events between + `from_token` and `to_token`), or `limit` is zero. """ # We can bail early if we're looking forwards, and our `to_key` is already # before our `from_token`. @@ -1910,7 +2272,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): and to_token.is_before_or_eq(from_token) ): # Token selection matches what we do below if there are no rows - return [], to_token if to_token else from_token + return [], to_token if to_token else from_token, False # Or vice-versa, if we're looking backwards and our `from_token` is already before # our `to_token`. elif ( @@ -1919,7 +2281,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): and from_token.is_before_or_eq(to_token) ): # Token selection matches what we do below if there are no rows - return [], to_token if to_token else from_token + return [], to_token if to_token else from_token, False args: List[Any] = [room_id] @@ -1942,6 +2304,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): args.extend(filter_args) # We fetch more events as we'll filter the result set + requested_limit = int(limit) * 2 args.append(int(limit) * 2) select_keywords = "SELECT" @@ -2006,10 +2369,14 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): } txn.execute(sql, args) + # Get all the rows and check if we hit the limit. + fetched_rows = txn.fetchall() + limited = len(fetched_rows) >= requested_limit + # Filter the result set. rows = [ _EventDictReturn(event_id, topological_ordering, stream_ordering) - for event_id, instance_name, topological_ordering, stream_ordering in txn + for event_id, instance_name, topological_ordering, stream_ordering in fetched_rows if _filter_results( lower_token=( to_token if direction == Direction.BACKWARDS else from_token @@ -2021,7 +2388,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): topological_ordering=topological_ordering, stream_ordering=stream_ordering, ) - ][:limit] + ] + + if len(rows) > limit: + limited = True + + rows = rows[:limit] if rows: assert rows[-1].topological_ordering is not None @@ -2032,7 +2404,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): # TODO (erikj): We should work out what to do here instead. next_token = to_token if to_token else from_token - return rows, next_token + return rows, next_token, limited @trace @tag_args @@ -2045,7 +2417,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): direction: Direction = Direction.BACKWARDS, limit: int = 0, event_filter: Optional[Filter] = None, - ) -> Tuple[List[EventBase], RoomStreamToken]: + ) -> Tuple[List[EventBase], RoomStreamToken, bool]: """ Paginate events by `topological_ordering` (tie-break with `stream_ordering`) in the room from the `from_key` in the given `direction` to the `to_key` or @@ -2062,8 +2434,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): event_filter: If provided filters the events to those that match the filter. Returns: - The results as a list of events and a token that points to the end - of the result set. If no events are returned then the end of the + The results as a list of events, a token that points to the end of + the result set, and a boolean to indicate if there were more events + but we hit the limit. If no events are returned then the end of the stream has been reached (i.e. there are no events between `from_key` and `to_key`). @@ -2087,7 +2460,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): ): # Token selection matches what we do in `_paginate_room_events_txn` if there # are no rows - return [], to_key if to_key else from_key + return [], to_key if to_key else from_key, False # Or vice-versa, if we're looking backwards and our `from_key` is already before # our `to_key`. elif ( @@ -2097,9 +2470,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): ): # Token selection matches what we do in `_paginate_room_events_txn` if there # are no rows - return [], to_key if to_key else from_key + return [], to_key if to_key else from_key, False - rows, token = await self.db_pool.runInteraction( + rows, token, limited = await self.db_pool.runInteraction( "paginate_room_events_by_topological_ordering", self._paginate_room_events_by_topological_ordering_txn, room_id, @@ -2114,7 +2487,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): [r.event_id for r in rows], get_prev_content=True ) - return events, token + return events, token, limited @cached() async def get_id_for_instance(self, instance_name: str) -> int: diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index b5af294384..97b190bccc 100644 --- a/synapse/storage/databases/main/tags.py +++ b/synapse/storage/databases/main/tags.py
@@ -158,9 +158,56 @@ class TagsWorkerStore(AccountDataWorkerStore): return results + async def has_tags_changed_for_room( + self, + # Since there are multiple arguments with the same type, force keyword arguments + # so people don't accidentally swap the order + *, + user_id: str, + room_id: str, + from_stream_id: int, + to_stream_id: int, + ) -> bool: + """Check if the users tags for a room have been updated in the token range + + (> `from_stream_id` and <= `to_stream_id`) + + Args: + user_id: The user to get tags for + room_id: The room to get tags for + from_stream_id: The point in the stream to fetch from + to_stream_id: The point in the stream to fetch to + + Returns: + A mapping of tags to tag content. + """ + + # Shortcut if no room has changed for the user + changed = self._account_data_stream_cache.has_entity_changed( + user_id, int(from_stream_id) + ) + if not changed: + return False + + last_change_position_for_room = await self.db_pool.simple_select_one_onecol( + table="room_tags_revisions", + keyvalues={"user_id": user_id, "room_id": room_id}, + retcol="stream_id", + allow_none=True, + ) + + if last_change_position_for_room is None: + return False + + return ( + last_change_position_for_room > from_stream_id + and last_change_position_for_room <= to_stream_id + ) + + @cached(num_args=2, tree=True) async def get_tags_for_room( self, user_id: str, room_id: str - ) -> Dict[str, JsonDict]: + ) -> Mapping[str, JsonMapping]: """Get all the tags for the given room Args: @@ -182,7 +229,7 @@ class TagsWorkerStore(AccountDataWorkerStore): return {tag: db_to_json(content) for tag, content in rows} async def add_tag_to_room( - self, user_id: str, room_id: str, tag: str, content: JsonDict + self, user_id: str, room_id: str, tag: str, content: JsonMapping ) -> int: """Add a tag to a room for a user. @@ -213,6 +260,7 @@ class TagsWorkerStore(AccountDataWorkerStore): await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) + self.get_tags_for_room.invalidate((user_id, room_id)) return self._account_data_id_gen.get_current_token() @@ -226,10 +274,7 @@ class TagsWorkerStore(AccountDataWorkerStore): assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator) def remove_tag_txn(txn: LoggingTransaction, next_id: int) -> None: - sql = ( - "DELETE FROM room_tags " - " WHERE user_id = ? AND room_id = ? AND tag = ?" - ) + sql = "DELETE FROM room_tags WHERE user_id = ? AND room_id = ? AND tag = ?" txn.execute(sql, (user_id, room_id, tag)) self._update_revision_txn(txn, user_id, room_id, next_id) @@ -237,6 +282,7 @@ class TagsWorkerStore(AccountDataWorkerStore): await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) + self.get_tags_for_room.invalidate((user_id, room_id)) return self._account_data_id_gen.get_current_token() @@ -290,9 +336,19 @@ class TagsWorkerStore(AccountDataWorkerStore): rows: Iterable[Any], ) -> None: if stream_name == AccountDataStream.NAME: - for row in rows: + # Cast is safe because the `AccountDataStream` should only be giving us + # `AccountDataStreamRow` + account_data_stream_rows: List[AccountDataStream.AccountDataStreamRow] = ( + cast(List[AccountDataStream.AccountDataStreamRow], rows) + ) + + for row in account_data_stream_rows: if row.data_type == AccountDataTypes.TAG: self.get_tags_for_user.invalidate((row.user_id,)) + if row.room_id: + self.get_tags_for_room.invalidate((row.user_id, row.room_id)) + else: + self.get_tags_for_room.invalidate((row.user_id,)) self._account_data_stream_cache.entity_has_changed( row.user_id, token ) diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 770802483c..bfc324b80d 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py
@@ -86,10 +86,10 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): @wrap_as_background_process("cleanup_transactions") async def _cleanup_transactions(self) -> None: now = self._clock.time_msec() - month_ago = now - 30 * 24 * 60 * 60 * 1000 + day_ago = now - 24 * 60 * 60 * 1000 def _cleanup_transactions_txn(txn: LoggingTransaction) -> None: - txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,)) + txn.execute("DELETE FROM received_transactions WHERE ts < ?", (day_ago,)) await self.db_pool.runInteraction( "_cleanup_transactions", _cleanup_transactions_txn diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 6e18f714d7..31a8ce6666 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py
@@ -31,6 +31,7 @@ from typing import ( Sequence, Set, Tuple, + TypedDict, cast, ) @@ -42,10 +43,9 @@ try: USE_ICU = True except ModuleNotFoundError: + # except ModuleNotFoundError: USE_ICU = False -from typing_extensions import TypedDict - from synapse.api.errors import StoreError from synapse.util.stringutils import non_null_str_or_none @@ -224,9 +224,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): SELECT room_id, events FROM %s ORDER BY events DESC LIMIT 250 - """ % ( - TEMP_TABLE + "_rooms", - ) + """ % (TEMP_TABLE + "_rooms",) txn.execute(sql) rooms_to_work_on = cast(List[Tuple[str, int]], txn.fetchall()) @@ -585,9 +583,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): retry_counter: number of failures in refreshing the profile so far. Used for exponential backoff calculations. """ - assert not self.hs.is_mine_id( - user_id - ), "Can't mark a local user as a stale remote user." + assert not self.hs.is_mine_id(user_id), ( + "Can't mark a local user as a stale remote user." + ) server_name = UserID.from_string(user_id).domain @@ -1040,11 +1038,11 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): } """ + join_args: Tuple[str, ...] = (user_id,) + if self.hs.config.userdirectory.user_directory_search_all_users: - join_args = (user_id,) where_clause = "user_id != ?" else: - join_args = (user_id,) where_clause = """ ( EXISTS (select 1 from users_in_public_rooms WHERE user_id = t.user_id) @@ -1058,6 +1056,14 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): if not show_locked_users: where_clause += " AND (u.locked IS NULL OR u.locked = FALSE)" + # Adjust the JOIN type based on the exclude_remote_users flag (the users + # table only contains local users so an inner join is a good way to + # to exclude remote users) + if self.hs.config.userdirectory.user_directory_exclude_remote_users: + join_type = "JOIN" + else: + join_type = "LEFT JOIN" + # We allow manipulating the ranking algorithm by injecting statements # based on config options. additional_ordering_statements = [] @@ -1089,7 +1095,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): SELECT d.user_id AS user_id, display_name, avatar_url FROM matching_users as t INNER JOIN user_directory AS d USING (user_id) - LEFT JOIN users AS u ON t.user_id = u.name + %(join_type)s users AS u ON t.user_id = u.name WHERE %(where_clause)s ORDER BY @@ -1118,6 +1124,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): """ % { "where_clause": where_clause, "order_case_statements": " ".join(additional_ordering_statements), + "join_type": join_type, } args = ( (full_query,) @@ -1145,7 +1152,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): SELECT d.user_id AS user_id, display_name, avatar_url FROM user_directory_search as t INNER JOIN user_directory AS d USING (user_id) - LEFT JOIN users AS u ON t.user_id = u.name + %(join_type)s users AS u ON t.user_id = u.name WHERE %(where_clause)s AND value MATCH ? @@ -1158,6 +1165,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): """ % { "where_clause": where_clause, "order_statements": " ".join(additional_ordering_statements), + "join_type": join_type, } args = join_args + (search_query,) + ordering_arguments + (limit + 1,) else: @@ -1240,7 +1248,13 @@ def _parse_query_postgres(search_term: str) -> Tuple[str, str, str]: search_term = _filter_text_for_index(search_term) escaped_words = [] - for word in _parse_words(search_term): + for index, word in enumerate(_parse_words(search_term)): + if index >= 10: + # We limit how many terms we include, as otherwise it can use + # excessive database time if people accidentally search for large + # strings. + break + # Postgres tsvector and tsquery quoting rules: # words potentially containing punctuation should be quoted # and then existing quotes and backslashes should be doubled diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index ea7d8199a7..5b594fe8dd 100644 --- a/synapse/storage/databases/state/bg_updates.py +++ b/synapse/storage/databases/state/bg_updates.py
@@ -20,7 +20,15 @@ # import logging -from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union +from typing import ( + TYPE_CHECKING, + Dict, + List, + Mapping, + Optional, + Tuple, + Union, +) from synapse.logging.opentracing import tag_args, trace from synapse.storage._base import SQLBaseStore @@ -112,8 +120,8 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore): Returns: Map from state_group to a StateMap at that point. """ - - state_filter = state_filter or StateFilter.all() + if state_filter is None: + state_filter = StateFilter.all() results: Dict[int, MutableStateMap[str]] = {group: {} for group in groups} @@ -388,8 +396,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): return True, count txn.execute( - "SELECT state_group FROM state_group_edges" - " WHERE state_group = ?", + "SELECT state_group FROM state_group_edges WHERE state_group = ?", (state_group,), ) diff --git a/synapse/storage/databases/state/deletion.py b/synapse/storage/databases/state/deletion.py new file mode 100644
index 0000000000..f77c46f6ae --- /dev/null +++ b/synapse/storage/databases/state/deletion.py
@@ -0,0 +1,561 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2025 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# <https://www.gnu.org/licenses/agpl-3.0.html>. +# + + +import contextlib +from typing import ( + TYPE_CHECKING, + AbstractSet, + AsyncIterator, + Collection, + Mapping, + Optional, + Set, + Tuple, +) + +from synapse.events import EventBase +from synapse.events.snapshot import EventContext +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, + make_in_list_sql_clause, +) +from synapse.storage.engines import PostgresEngine +from synapse.util.stringutils import shortstr + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +class StateDeletionDataStore: + """Manages deletion of state groups in a safe manner. + + Deleting state groups is challenging as before we actually delete them we + need to ensure that there are no in-flight events that refer to the state + groups that we want to delete. + + To handle this, we take two approaches. First, before we persist any event + we ensure that the state group still exists and mark in the + `state_groups_persisting` table that the state group is about to be used. + (Note that we have to have the extra table here as state groups and events + can be in different databases, and thus we can't check for the existence of + state groups in the persist event transaction). Once the event has been + persisted, we can remove the row from `state_groups_persisting`. So long as + we check that table before deleting state groups, we can ensure that we + never persist events that reference deleted state groups, maintaining + database integrity. + + However, we want to avoid throwing exceptions so deep in the process of + persisting events. So instead of deleting state groups immediately, we mark + them as pending/proposed for deletion and wait for a certain amount of time + before performing the deletion. When we come to handle new events that + reference state groups, we check if they are pending deletion and bump the + time for when they'll be deleted (to give a chance for the event to be + persisted, or not). + + When deleting, we need to check that state groups remain unreferenced. There + is a race here where we a) fetch state groups that are ready for deletion, + b) check they're unreferenced, c) the state group becomes referenced but + then gets marked as pending deletion again, d) during the deletion + transaction we recheck `state_groups_pending_deletion` table again and see + that it exists and so continue with the deletion. To prevent this from + happening we add a `sequence_number` column to + `state_groups_pending_deletion`, and during deletion we ensure that for a + state group we're about to delete that the sequence number doesn't change + between steps (a) and (d). So long as we always bump the sequence number + whenever an event may become used the race can never happen. + """ + + # How long to wait before we delete state groups. This should be long enough + # for any in-flight events to be persisted. If events take longer to persist + # and any of the state groups they reference have been deleted, then the + # event will fail to persist (as well as any event in the same batch). + DELAY_BEFORE_DELETION_MS = 10 * 60 * 1000 + + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): + self._clock = hs.get_clock() + self.db_pool = database + self._instance_name = hs.get_instance_name() + + with db_conn.cursor(txn_name="_clear_existing_persising") as txn: + self._clear_existing_persising(txn) + + def _clear_existing_persising(self, txn: LoggingTransaction) -> None: + """On startup we clear any entries in `state_groups_persisting` that + match our instance name, in case of a previous unclean shutdown""" + + self.db_pool.simple_delete_txn( + txn, + table="state_groups_persisting", + keyvalues={"instance_name": self._instance_name}, + ) + + async def check_state_groups_and_bump_deletion( + self, state_groups: AbstractSet[int] + ) -> Collection[int]: + """Checks to make sure that the state groups haven't been deleted, and + if they're pending deletion we delay it (allowing time for any event + that will use them to finish persisting). + + Returns: + The state groups that are missing, if any. + """ + + return await self.db_pool.runInteraction( + "check_state_groups_and_bump_deletion", + self._check_state_groups_and_bump_deletion_txn, + state_groups, + # We don't need to lock if we're just doing a quick check, as the + # lock doesn't prevent any races here. + lock=False, + ) + + def _check_state_groups_and_bump_deletion_txn( + self, txn: LoggingTransaction, state_groups: AbstractSet[int], lock: bool = True + ) -> Collection[int]: + """Checks to make sure that the state groups haven't been deleted, and + if they're pending deletion we delay it (allowing time for any event + that will use them to finish persisting). + + The `lock` flag sets if we should lock the `state_group` rows we're + checking, which we should do when storing new groups. + + Returns: + The state groups that are missing, if any. + """ + + existing_state_groups = self._get_existing_groups_with_lock( + txn, state_groups, lock=lock + ) + + self._bump_deletion_txn(txn, existing_state_groups) + + missing_state_groups = state_groups - existing_state_groups + if missing_state_groups: + return missing_state_groups + + return () + + def _bump_deletion_txn( + self, txn: LoggingTransaction, state_groups: Collection[int] + ) -> None: + """Update any pending deletions of the state group that they may now be + referenced.""" + + if not state_groups: + return + + now = self._clock.time_msec() + if isinstance(self.db_pool.engine, PostgresEngine): + clause, args = make_in_list_sql_clause( + self.db_pool.engine, "state_group", state_groups + ) + sql = f""" + UPDATE state_groups_pending_deletion + SET sequence_number = DEFAULT, insertion_ts = ? + WHERE {clause} + """ + args.insert(0, now) + txn.execute(sql, args) + else: + rows = self.db_pool.simple_select_many_txn( + txn, + table="state_groups_pending_deletion", + column="state_group", + iterable=state_groups, + keyvalues={}, + retcols=("state_group",), + ) + if not rows: + return + + state_groups_to_update = [state_group for (state_group,) in rows] + + self.db_pool.simple_delete_many_txn( + txn, + table="state_groups_pending_deletion", + column="state_group", + values=state_groups_to_update, + keyvalues={}, + ) + self.db_pool.simple_insert_many_txn( + txn, + table="state_groups_pending_deletion", + keys=("state_group", "insertion_ts"), + values=[(state_group, now) for state_group in state_groups_to_update], + ) + + def _get_existing_groups_with_lock( + self, txn: LoggingTransaction, state_groups: Collection[int], lock: bool = True + ) -> AbstractSet[int]: + """Return which of the given state groups are in the database, and locks + those rows with `KEY SHARE` to ensure they don't get concurrently + deleted (if `lock` is true).""" + clause, args = make_in_list_sql_clause(self.db_pool.engine, "id", state_groups) + + sql = f""" + SELECT id FROM state_groups + WHERE {clause} + """ + if lock and isinstance(self.db_pool.engine, PostgresEngine): + # On postgres we add a row level lock to the rows to ensure that we + # conflict with any concurrent DELETEs. `FOR KEY SHARE` lock will + # not conflict with other read + sql += """ + FOR KEY SHARE + """ + + txn.execute(sql, args) + return {state_group for (state_group,) in txn} + + @contextlib.asynccontextmanager + async def persisting_state_group_references( + self, event_and_contexts: Collection[Tuple[EventBase, EventContext]] + ) -> AsyncIterator[None]: + """Wraps the persistence of the given events and contexts, ensuring that + any state groups referenced still exist and that they don't get deleted + during this.""" + + referenced_state_groups: Set[int] = set() + for event, ctx in event_and_contexts: + if ctx.rejected or event.internal_metadata.is_outlier(): + continue + + assert ctx.state_group is not None + + referenced_state_groups.add(ctx.state_group) + + if ctx.state_group_before_event: + referenced_state_groups.add(ctx.state_group_before_event) + + if not referenced_state_groups: + # We don't reference any state groups, so nothing to do + yield + return + + await self.db_pool.runInteraction( + "mark_state_groups_as_persisting", + self._mark_state_groups_as_persisting_txn, + referenced_state_groups, + ) + + error = True + try: + yield None + error = False + finally: + await self.db_pool.runInteraction( + "finish_persisting", + self._finish_persisting_txn, + referenced_state_groups, + error=error, + ) + + def _mark_state_groups_as_persisting_txn( + self, txn: LoggingTransaction, state_groups: Set[int] + ) -> None: + """Marks the given state groups as being persisted.""" + + existing_state_groups = self._get_existing_groups_with_lock(txn, state_groups) + missing_state_groups = state_groups - existing_state_groups + if missing_state_groups: + raise Exception( + f"state groups have been deleted: {shortstr(missing_state_groups)}" + ) + + self.db_pool.simple_insert_many_txn( + txn, + table="state_groups_persisting", + keys=("state_group", "instance_name"), + values=[(state_group, self._instance_name) for state_group in state_groups], + ) + + def _finish_persisting_txn( + self, txn: LoggingTransaction, state_groups: Collection[int], error: bool + ) -> None: + """Mark the state groups as having finished persistence. + + If `error` is true then we assume the state groups were not persisted, + and so we do not clear them from the pending deletion table. + """ + self.db_pool.simple_delete_many_txn( + txn, + table="state_groups_persisting", + column="state_group", + values=state_groups, + keyvalues={"instance_name": self._instance_name}, + ) + + if error: + # The state groups may or may not have been persisted, so we need to + # bump the deletion to ensure we recheck if they have become + # referenced. + self._bump_deletion_txn(txn, state_groups) + return + + self.db_pool.simple_delete_many_batch_txn( + txn, + table="state_groups_pending_deletion", + keys=("state_group",), + values=[(state_group,) for state_group in state_groups], + ) + + async def mark_state_groups_as_pending_deletion( + self, state_groups: Collection[int] + ) -> None: + """Mark the given state groups as pending deletion. + + If any of the state groups are already pending deletion, then those records are + left as is. + """ + + await self.db_pool.runInteraction( + "mark_state_groups_as_pending_deletion", + self._mark_state_groups_as_pending_deletion_txn, + state_groups, + ) + + def _mark_state_groups_as_pending_deletion_txn( + self, + txn: LoggingTransaction, + state_groups: Collection[int], + ) -> None: + sql = """ + INSERT INTO state_groups_pending_deletion (state_group, insertion_ts) + VALUES %s + ON CONFLICT (state_group) + DO NOTHING + """ + + now = self._clock.time_msec() + rows = [ + ( + state_group, + now, + ) + for state_group in state_groups + ] + if isinstance(txn.database_engine, PostgresEngine): + txn.execute_values(sql % ("?",), rows, fetch=False) + else: + txn.execute_batch(sql % ("(?, ?)",), rows) + + async def mark_state_groups_as_used(self, state_groups: Collection[int]) -> None: + """Mark the given state groups as now being referenced""" + + await self.db_pool.simple_delete_many( + table="state_groups_pending_deletion", + column="state_group", + iterable=state_groups, + keyvalues={}, + desc="mark_state_groups_as_used", + ) + + async def get_pending_deletions( + self, state_groups: Collection[int] + ) -> Mapping[int, int]: + """Get which state groups are pending deletion. + + Returns: + a mapping from state groups that are pending deletion to their + sequence number + """ + + rows = await self.db_pool.simple_select_many_batch( + table="state_groups_pending_deletion", + column="state_group", + iterable=state_groups, + retcols=("state_group", "sequence_number"), + keyvalues={}, + desc="get_pending_deletions", + ) + + return dict(rows) + + def get_state_groups_ready_for_potential_deletion_txn( + self, + txn: LoggingTransaction, + state_groups_to_sequence_numbers: Mapping[int, int], + ) -> Collection[int]: + """Given a set of state groups, return which state groups can + potentially be deleted. + + The state groups must have been checked to see if they remain + unreferenced before calling this function. + + Note: This must be called within the same transaction that the state + groups are deleted. + + Args: + state_groups_to_sequence_numbers: The state groups, and the sequence + numbers from before the state groups were checked to see if they + were unreferenced. + + Returns: + The subset of state groups that can safely be deleted + + """ + + if not state_groups_to_sequence_numbers: + return state_groups_to_sequence_numbers + + if isinstance(self.db_pool.engine, PostgresEngine): + # On postgres we want to lock the rows FOR UPDATE as early as + # possible to help conflicts. + clause, args = make_in_list_sql_clause( + self.db_pool.engine, "id", state_groups_to_sequence_numbers + ) + sql = f""" + SELECT id FROM state_groups + WHERE {clause} + FOR UPDATE + """ + txn.execute(sql, args) + + # Check the deletion status in the DB of the given state groups + clause, args = make_in_list_sql_clause( + self.db_pool.engine, + column="state_group", + iterable=state_groups_to_sequence_numbers, + ) + + sql = f""" + SELECT state_group, insertion_ts, sequence_number FROM ( + SELECT state_group, insertion_ts, sequence_number FROM state_groups_pending_deletion + UNION + SELECT state_group, null, null FROM state_groups_persisting + ) AS s + WHERE {clause} + """ + + txn.execute(sql, args) + + # The above query will return potentially two rows per state group (one + # for each table), so we track which state groups have enough time + # elapsed and which are not ready to be persisted. + ready_to_be_deleted = set() + not_ready_to_be_deleted = set() + + now = self._clock.time_msec() + for state_group, insertion_ts, sequence_number in txn: + if insertion_ts is None: + # A null insertion_ts means that we are currently persisting + # events that reference the state group, so we don't delete + # them. + not_ready_to_be_deleted.add(state_group) + continue + + # We know this can't be None if insertion_ts is not None + assert sequence_number is not None + + # Check if the sequence number has changed, if it has then it + # indicates that the state group may have become referenced since we + # checked. + if state_groups_to_sequence_numbers[state_group] != sequence_number: + not_ready_to_be_deleted.add(state_group) + continue + + if now - insertion_ts < self.DELAY_BEFORE_DELETION_MS: + # Not enough time has elapsed to allow us to delete. + not_ready_to_be_deleted.add(state_group) + continue + + ready_to_be_deleted.add(state_group) + + can_be_deleted = ready_to_be_deleted - not_ready_to_be_deleted + if not_ready_to_be_deleted: + # If there are any state groups that aren't ready to be deleted, + # then we also need to remove any state groups that are referenced + # by them. + clause, args = make_in_list_sql_clause( + self.db_pool.engine, + column="state_group", + iterable=state_groups_to_sequence_numbers, + ) + sql = f""" + WITH RECURSIVE ancestors(state_group) AS ( + SELECT DISTINCT prev_state_group + FROM state_group_edges WHERE {clause} + UNION + SELECT prev_state_group + FROM state_group_edges + INNER JOIN ancestors USING (state_group) + ) + SELECT state_group FROM ancestors + """ + txn.execute(sql, args) + + can_be_deleted.difference_update(state_group for (state_group,) in txn) + + return can_be_deleted + + async def get_next_state_group_collection_to_delete( + self, + ) -> Optional[Tuple[str, Mapping[int, int]]]: + """Get the next set of state groups to try and delete + + Returns: + 2-tuple of room_id and mapping of state groups to sequence number. + """ + return await self.db_pool.runInteraction( + "get_next_state_group_collection_to_delete", + self._get_next_state_group_collection_to_delete_txn, + ) + + def _get_next_state_group_collection_to_delete_txn( + self, + txn: LoggingTransaction, + ) -> Optional[Tuple[str, Mapping[int, int]]]: + """Implementation of `get_next_state_group_collection_to_delete`""" + + # We want to return chunks of state groups that were marked for deletion + # at the same time (this isn't necessary, just more efficient). We do + # this by looking for the oldest insertion_ts, and then pulling out all + # rows that have the same insertion_ts (and room ID). + now = self._clock.time_msec() + + sql = """ + SELECT room_id, insertion_ts + FROM state_groups_pending_deletion AS sd + INNER JOIN state_groups AS sg ON (id = sd.state_group) + LEFT JOIN state_groups_persisting AS sp USING (state_group) + WHERE insertion_ts < ? AND sp.state_group IS NULL + ORDER BY insertion_ts + LIMIT 1 + """ + txn.execute(sql, (now - self.DELAY_BEFORE_DELETION_MS,)) + row = txn.fetchone() + if not row: + return None + + (room_id, insertion_ts) = row + + sql = """ + SELECT state_group, sequence_number + FROM state_groups_pending_deletion AS sd + INNER JOIN state_groups AS sg ON (id = sd.state_group) + LEFT JOIN state_groups_persisting AS sp USING (state_group) + WHERE room_id = ? AND insertion_ts = ? AND sp.state_group IS NULL + ORDER BY insertion_ts + """ + txn.execute(sql, (room_id, insertion_ts)) + + return room_id, dict(txn) diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index d4ac74c1ee..c1a66dcba0 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py
@@ -22,10 +22,10 @@ import logging from typing import ( TYPE_CHECKING, - Collection, Dict, Iterable, List, + Mapping, Optional, Set, Tuple, @@ -36,7 +36,10 @@ import attr from synapse.api.constants import EventTypes from synapse.events import EventBase -from synapse.events.snapshot import UnpersistedEventContext, UnpersistedEventContextBase +from synapse.events.snapshot import ( + UnpersistedEventContext, + UnpersistedEventContextBase, +) from synapse.logging.opentracing import tag_args, trace from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( @@ -45,6 +48,7 @@ from synapse.storage.database import ( LoggingTransaction, ) from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStore +from synapse.storage.engines import PostgresEngine from synapse.storage.types import Cursor from synapse.storage.util.sequence import build_sequence_generator from synapse.types import MutableStateMap, StateKey, StateMap @@ -55,6 +59,7 @@ from synapse.util.cancellation import cancellable if TYPE_CHECKING: from synapse.server import HomeServer + from synapse.storage.databases.state.deletion import StateDeletionDataStore logger = logging.getLogger(__name__) @@ -83,8 +88,10 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): database: DatabasePool, db_conn: LoggingDatabaseConnection, hs: "HomeServer", + state_deletion_store: "StateDeletionDataStore", ): super().__init__(database, db_conn, hs) + self._state_deletion_store = state_deletion_store # Originally the state store used a single DictionaryCache to cache the # event IDs for the state types in a given state group to avoid hammering @@ -284,7 +291,8 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): Returns: Dict of state group to state map. """ - state_filter = state_filter or StateFilter.all() + if state_filter is None: + state_filter = StateFilter.all() member_filter, non_member_filter = state_filter.get_member_split() @@ -466,14 +474,15 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): Returns: A list of state groups """ - is_in_db = self.db_pool.simple_select_one_onecol_txn( - txn, - table="state_groups", - keyvalues={"id": prev_group}, - retcol="id", - allow_none=True, + + # We need to check that the prev group isn't about to be deleted + is_missing = ( + self._state_deletion_store._check_state_groups_and_bump_deletion_txn( + txn, + {prev_group}, + ) ) - if not is_in_db: + if is_missing: raise Exception( "Trying to persist state with unpersisted prev_group: %r" % (prev_group,) @@ -545,6 +554,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): for key, state_id in context.state_delta_due_to_event.items() ], ) + return events_and_context return await self.db_pool.runInteraction( @@ -600,14 +610,15 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): The state group if successfully created, or None if the state needs to be persisted as a full state. """ - is_in_db = self.db_pool.simple_select_one_onecol_txn( - txn, - table="state_groups", - keyvalues={"id": prev_group}, - retcol="id", - allow_none=True, + + # We need to check that the prev group isn't about to be deleted + is_missing = ( + self._state_deletion_store._check_state_groups_and_bump_deletion_txn( + txn, + {prev_group}, + ) ) - if not is_in_db: + if is_missing: raise Exception( "Trying to persist state with unpersisted prev_group: %r" % (prev_group,) @@ -725,8 +736,10 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): ) async def purge_unreferenced_state_groups( - self, room_id: str, state_groups_to_delete: Collection[int] - ) -> None: + self, + room_id: str, + state_groups_to_sequence_numbers: Mapping[int, int], + ) -> bool: """Deletes no longer referenced state groups and de-deltas any state groups that reference them. @@ -734,21 +747,31 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): room_id: The room the state groups belong to (must all be in the same room). state_groups_to_delete: Set of all state groups to delete. + + Returns: + Whether any state groups were actually deleted. """ - await self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "purge_unreferenced_state_groups", self._purge_unreferenced_state_groups, room_id, - state_groups_to_delete, + state_groups_to_sequence_numbers, ) def _purge_unreferenced_state_groups( self, txn: LoggingTransaction, room_id: str, - state_groups_to_delete: Collection[int], - ) -> None: + state_groups_to_sequence_numbers: Mapping[int, int], + ) -> bool: + state_groups_to_delete = self._state_deletion_store.get_state_groups_ready_for_potential_deletion_txn( + txn, state_groups_to_sequence_numbers + ) + + if not state_groups_to_delete: + return False + logger.info( "[purge] found %i state groups to delete", len(state_groups_to_delete) ) @@ -767,7 +790,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): remaining_state_groups = { state_group - for state_group, in rows + for (state_group,) in rows if state_group not in state_groups_to_delete } @@ -804,13 +827,23 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): logger.info("[purge] removing redundant state groups") txn.execute_batch( "DELETE FROM state_groups_state WHERE state_group = ?", - ((sg,) for sg in state_groups_to_delete), + [(sg,) for sg in state_groups_to_delete], + ) + txn.execute_batch( + "DELETE FROM state_group_edges WHERE state_group = ?", + [(sg,) for sg in state_groups_to_delete], ) txn.execute_batch( "DELETE FROM state_groups WHERE id = ?", - ((sg,) for sg in state_groups_to_delete), + [(sg,) for sg in state_groups_to_delete], + ) + txn.execute_batch( + "DELETE FROM state_groups_pending_deletion WHERE state_group = ?", + [(sg,) for sg in state_groups_to_delete], ) + return True + @trace @tag_args async def get_previous_state_groups( @@ -829,7 +862,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): List[Tuple[int, int]], await self.db_pool.simple_select_many_batch( table="state_group_edges", - column="prev_state_group", + column="state_group", iterable=state_groups, keyvalues={}, retcols=("state_group", "prev_state_group"), @@ -839,60 +872,77 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): return dict(rows) - async def purge_room_state( - self, room_id: str, state_groups_to_delete: Collection[int] - ) -> None: - """Deletes all record of a room from state tables + @trace + @tag_args + async def get_next_state_groups( + self, state_groups: Iterable[int] + ) -> Dict[int, int]: + """Fetch the groups that have the given state groups as their previous + state groups. Args: - room_id: - state_groups_to_delete: State groups to delete + state_groups + + Returns: + A mapping from state group to previous state group. """ - logger.info("[purge] Starting state purge") - await self.db_pool.runInteraction( + rows = cast( + List[Tuple[int, int]], + await self.db_pool.simple_select_many_batch( + table="state_group_edges", + column="prev_state_group", + iterable=state_groups, + keyvalues={}, + retcols=("state_group", "prev_state_group"), + desc="get_next_state_groups", + ), + ) + + return dict(rows) + + async def purge_room_state(self, room_id: str) -> None: + return await self.db_pool.runInteraction( "purge_room_state", self._purge_room_state_txn, room_id, - state_groups_to_delete, ) - logger.info("[purge] Done with state purge") def _purge_room_state_txn( self, txn: LoggingTransaction, room_id: str, - state_groups_to_delete: Collection[int], ) -> None: - # first we have to delete the state groups states - logger.info("[purge] removing %s from state_groups_state", room_id) + # Delete all edges that reference a state group linked to room_id + logger.info("[purge] removing %s from state_group_edges", room_id) - self.db_pool.simple_delete_many_txn( - txn, - table="state_groups_state", - column="state_group", - values=state_groups_to_delete, - keyvalues={}, - ) + if isinstance(self.database_engine, PostgresEngine): + # Disable statement timeouts for this transaction; purging rooms can + # take a while! + txn.execute("SET LOCAL statement_timeout = 0") - # ... and the state group edges - logger.info("[purge] removing %s from state_group_edges", room_id) + txn.execute( + """ + DELETE FROM state_group_edges AS sge WHERE sge.state_group IN ( + SELECT id FROM state_groups AS sg WHERE sg.room_id = ? + )""", + (room_id,), + ) - self.db_pool.simple_delete_many_txn( - txn, - table="state_group_edges", - column="state_group", - values=state_groups_to_delete, - keyvalues={}, + # state_groups_state table has a room_id column but no index on it, unlike state_groups, + # so we delete them by matching the room_id through the state_groups table. + logger.info("[purge] removing %s from state_groups_state", room_id) + txn.execute( + """ + DELETE FROM state_groups_state AS sgs WHERE sgs.state_group IN ( + SELECT id FROM state_groups AS sg WHERE sg.room_id = ? + )""", + (room_id,), ) - # ... and the state groups logger.info("[purge] removing %s from state_groups", room_id) - - self.db_pool.simple_delete_many_txn( + self.db_pool.simple_delete_txn( txn, table="state_groups", - column="id", - values=state_groups_to_delete, - keyvalues={}, + keyvalues={"room_id": room_id}, )