summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/__init__.py38
-rw-r--r--synapse/storage/databases/main/account_data.py3
-rw-r--r--synapse/storage/databases/main/appservice.py58
-rw-r--r--synapse/storage/databases/main/cache.py33
-rw-r--r--synapse/storage/databases/main/censor_events.py2
-rw-r--r--synapse/storage/databases/main/deviceinbox.py13
-rw-r--r--synapse/storage/databases/main/devices.py125
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py53
-rw-r--r--synapse/storage/databases/main/event_federation.py9
-rw-r--r--synapse/storage/databases/main/event_push_actions.py947
-rw-r--r--synapse/storage/databases/main/events.py151
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py208
-rw-r--r--synapse/storage/databases/main/events_worker.py306
-rw-r--r--synapse/storage/databases/main/group_server.py34
-rw-r--r--synapse/storage/databases/main/media_repository.py10
-rw-r--r--synapse/storage/databases/main/monthly_active_users.py1
-rw-r--r--synapse/storage/databases/main/purge_events.py37
-rw-r--r--synapse/storage/databases/main/push_rule.py130
-rw-r--r--synapse/storage/databases/main/receipts.py92
-rw-r--r--synapse/storage/databases/main/registration.py13
-rw-r--r--synapse/storage/databases/main/relations.py6
-rw-r--r--synapse/storage/databases/main/room.py175
-rw-r--r--synapse/storage/databases/main/roommember.py502
-rw-r--r--synapse/storage/databases/main/search.py10
-rw-r--r--synapse/storage/databases/main/state.py14
-rw-r--r--synapse/storage/databases/main/stats.py16
-rw-r--r--synapse/storage/databases/main/stream.py54
-rw-r--r--synapse/storage/databases/state/bg_updates.py3
-rw-r--r--synapse/storage/databases/state/store.py165
29 files changed, 2193 insertions, 1015 deletions
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py

index 11d9d16c19..4dccbb732a 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py
@@ -24,9 +24,9 @@ from synapse.storage.database import ( LoggingTransaction, ) from synapse.storage.databases.main.stats import UserSortOrder -from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine +from synapse.storage.engines import BaseDatabaseEngine from synapse.storage.types import Cursor -from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator +from synapse.storage.util.id_generators import StreamIdGenerator from synapse.types import JsonDict, get_domain_from_id from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -45,7 +45,6 @@ from .event_push_actions import EventPushActionsStore from .events_bg_updates import EventsBackgroundUpdatesStore from .events_forward_extremities import EventForwardExtremitiesStore from .filtering import FilteringStore -from .group_server import GroupServerStore from .keys import KeyStore from .lock import LockStore from .media_repository import MediaRepositoryStore @@ -88,7 +87,6 @@ class DataStore( RoomStore, RoomBatchStore, RegistrationStore, - StreamWorkerStore, ProfileStore, PresenceStore, TransactionWorkerStore, @@ -105,19 +103,20 @@ class DataStore( PusherStore, PushRuleStore, ApplicationServiceTransactionStore, + EventPushActionsStore, + ServerMetricsStore, ReceiptsStore, EndToEndKeyStore, EndToEndRoomKeyStore, SearchStore, TagsStore, AccountDataStore, - EventPushActionsStore, + StreamWorkerStore, OpenIdStore, ClientIpWorkerStore, DeviceStore, DeviceInboxStore, UserDirectoryStore, - GroupServerStore, UserErasureStore, MonthlyActiveUsersWorkerStore, StatsStore, @@ -126,7 +125,6 @@ class DataStore( UIAuthStore, EventForwardExtremitiesStore, CacheInvalidationWorkerStore, - ServerMetricsStore, LockStore, SessionStore, ): @@ -151,31 +149,6 @@ class DataStore( ], ) - self._cache_id_gen: Optional[MultiWriterIdGenerator] - if isinstance(self.database_engine, PostgresEngine): - # We set the `writers` to an empty list here as we don't care about - # missing updates over restarts, as we'll not have anything in our - # caches to invalidate. (This reduces the amount of writes to the DB - # that happen). - self._cache_id_gen = MultiWriterIdGenerator( - db_conn, - database, - stream_name="caches", - instance_name=hs.get_instance_name(), - tables=[ - ( - "cache_invalidation_stream_by_instance", - "instance_name", - "stream_id", - ) - ], - sequence_name="cache_invalidation_stream_seq", - writers=[], - ) - - else: - self._cache_id_gen = None - super().__init__(database, db_conn, hs) events_max = self._stream_id_gen.get_current_token() @@ -197,6 +170,7 @@ class DataStore( self._min_stream_order_on_start = self.get_room_min_stream_ordering() def get_device_stream_token(self) -> int: + # TODO: shouldn't this be moved to `DeviceWorkerStore`? return self._device_list_id_gen.get_current_token() async def get_users(self) -> List[JsonDict]: diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 9af9f4f18e..c38b8a9e5a 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py
@@ -650,9 +650,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) txn, self.get_account_data_for_room, (user_id,) ) self._invalidate_cache_and_stream(txn, self.get_push_rules_for_user, (user_id,)) - self._invalidate_cache_and_stream( - txn, self.get_push_rules_enabled_for_user, (user_id,) - ) # This user might be contained in the ignored_by cache for other users, # so we have to invalidate it all. self._invalidate_all_cache_and_stream(txn, self.ignored_by) diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index e284454b66..64b70a7b28 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py
@@ -371,52 +371,30 @@ class ApplicationServiceTransactionWorkerStore( device_list_summary=DeviceListUpdates(), ) - async def set_appservice_last_pos(self, pos: int) -> None: - def set_appservice_last_pos_txn(txn: LoggingTransaction) -> None: - txn.execute( - "UPDATE appservice_stream_position SET stream_ordering = ?", (pos,) - ) + async def get_appservice_last_pos(self) -> int: + """ + Get the last stream ordering position for the appservice process. + """ - await self.db_pool.runInteraction( - "set_appservice_last_pos", set_appservice_last_pos_txn + return await self.db_pool.simple_select_one_onecol( + table="appservice_stream_position", + retcol="stream_ordering", + keyvalues={}, + desc="get_appservice_last_pos", ) - async def get_new_events_for_appservice( - self, current_id: int, limit: int - ) -> Tuple[int, List[EventBase]]: - """Get all new events for an appservice""" - - def get_new_events_for_appservice_txn( - txn: LoggingTransaction, - ) -> Tuple[int, List[str]]: - sql = ( - "SELECT e.stream_ordering, e.event_id" - " FROM events AS e" - " WHERE" - " (SELECT stream_ordering FROM appservice_stream_position)" - " < e.stream_ordering" - " AND e.stream_ordering <= ?" - " ORDER BY e.stream_ordering ASC" - " LIMIT ?" - ) - - txn.execute(sql, (current_id, limit)) - rows = txn.fetchall() - - upper_bound = current_id - if len(rows) == limit: - upper_bound = rows[-1][0] - - return upper_bound, [row[1] for row in rows] + async def set_appservice_last_pos(self, pos: int) -> None: + """ + Set the last stream ordering position for the appservice process. + """ - upper_bound, event_ids = await self.db_pool.runInteraction( - "get_new_events_for_appservice", get_new_events_for_appservice_txn + await self.db_pool.simple_update_one( + table="appservice_stream_position", + keyvalues={}, + updatevalues={"stream_ordering": pos}, + desc="set_appservice_last_pos", ) - events = await self.get_events_as_list(event_ids, get_prev_content=True) - - return upper_bound, events - async def get_type_stream_id_for_appservice( self, service: ApplicationService, type: str ) -> int: diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 1653a6a9b6..12e9a42382 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py
@@ -32,6 +32,7 @@ from synapse.storage.database import ( LoggingTransaction, ) from synapse.storage.engines import PostgresEngine +from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.util.caches.descriptors import _CachedFunction from synapse.util.iterutils import batch_iter @@ -65,6 +66,31 @@ class CacheInvalidationWorkerStore(SQLBaseStore): psql_only=True, # The table is only on postgres DBs. ) + self._cache_id_gen: Optional[MultiWriterIdGenerator] + if isinstance(self.database_engine, PostgresEngine): + # We set the `writers` to an empty list here as we don't care about + # missing updates over restarts, as we'll not have anything in our + # caches to invalidate. (This reduces the amount of writes to the DB + # that happen). + self._cache_id_gen = MultiWriterIdGenerator( + db_conn, + database, + stream_name="caches", + instance_name=hs.get_instance_name(), + tables=[ + ( + "cache_invalidation_stream_by_instance", + "instance_name", + "stream_id", + ) + ], + sequence_name="cache_invalidation_stream_seq", + writers=[], + ) + + else: + self._cache_id_gen = None + async def get_all_updated_caches( self, instance_name: str, last_id: int, current_id: int, limit: int ) -> Tuple[List[Tuple[int, tuple]], int, bool]: @@ -193,7 +219,10 @@ class CacheInvalidationWorkerStore(SQLBaseStore): relates_to: Optional[str], backfilled: bool, ) -> None: - self._invalidate_get_event_cache(event_id) + # This invalidates any local in-memory cached event objects, the original + # process triggering the invalidation is responsible for clearing any external + # cached objects. + self._invalidate_local_get_event_cache(event_id) self.have_seen_event.invalidate((room_id, event_id)) self.get_latest_event_ids_in_room.invalidate((room_id,)) @@ -208,7 +237,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self._events_stream_cache.entity_has_changed(room_id, stream_ordering) if redacts: - self._invalidate_get_event_cache(redacts) + self._invalidate_local_get_event_cache(redacts) # Caches which might leak edits must be invalidated for the event being # redacted. self.get_relations_for_event.invalidate((redacts,)) diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py
index fd3fc298b3..58177ecec1 100644 --- a/synapse/storage/databases/main/censor_events.py +++ b/synapse/storage/databases/main/censor_events.py
@@ -194,7 +194,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase # changed its content in the database. We can't call # self._invalidate_cache_and_stream because self.get_event_cache isn't of the # right type. - txn.call_after(self._get_event_cache.invalidate, (event.event_id,)) + self.invalidate_get_event_cache_after_txn(txn, event.event_id) # Send that invalidation to replication so that other workers also invalidate # the event cache. self._send_invalidation_to_replication( diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 599b418383..73c95ffb6f 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py
@@ -436,7 +436,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): (user_id, device_id), None ) - set_tag("last_deleted_stream_id", last_deleted_stream_id) + set_tag("last_deleted_stream_id", str(last_deleted_stream_id)) if last_deleted_stream_id: has_changed = self._device_inbox_stream_cache.has_entity_changed( @@ -834,8 +834,6 @@ class DeviceInboxWorkerStore(SQLBaseStore): class DeviceInboxBackgroundUpdateStore(SQLBaseStore): DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" - REMOVE_DELETED_DEVICES = "remove_deleted_devices_from_device_inbox" - REMOVE_HIDDEN_DEVICES = "remove_hidden_devices_from_device_inbox" REMOVE_DEAD_DEVICES_FROM_INBOX = "remove_dead_devices_from_device_inbox" def __init__( @@ -857,15 +855,6 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore): self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox ) - # Used to be a background update that deletes all device_inboxes for deleted - # devices. - self.db_pool.updates.register_noop_background_update( - self.REMOVE_DELETED_DEVICES - ) - # Used to be a background update that deletes all device_inboxes for hidden - # devices. - self.db_pool.updates.register_noop_background_update(self.REMOVE_HIDDEN_DEVICES) - self.db_pool.updates.register_background_update_handler( self.REMOVE_DEAD_DEVICES_FROM_INBOX, self._remove_dead_devices_from_device_inbox, diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 71e7863dd8..ca0fe8c4be 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py
@@ -28,6 +28,8 @@ from typing import ( cast, ) +from typing_extensions import Literal + from synapse.api.constants import EduTypes from synapse.api.errors import Codes, StoreError from synapse.logging.opentracing import ( @@ -44,6 +46,8 @@ from synapse.storage.database import ( LoggingTransaction, make_tuple_comparison_clause, ) +from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore +from synapse.storage.types import Cursor from synapse.types import JsonDict, get_verify_key_from_cross_signing_key from synapse.util import json_decoder, json_encoder from synapse.util.caches.descriptors import cached, cachedList @@ -65,7 +69,7 @@ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = ( BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes" -class DeviceWorkerStore(SQLBaseStore): +class DeviceWorkerStore(EndToEndKeyWorkerStore): def __init__( self, database: DatabasePool, @@ -74,7 +78,9 @@ class DeviceWorkerStore(SQLBaseStore): ): super().__init__(database, db_conn, hs) - device_list_max = self._device_list_id_gen.get_current_token() + # Type-ignore: _device_list_id_gen is mixed in from either DataStore (as a + # StreamIdGenerator) or SlavedDataStore (as a SlavedIdTracker). + device_list_max = self._device_list_id_gen.get_current_token() # type: ignore[attr-defined] device_list_prefill, min_device_list_id = self.db_pool.get_cache_dict( db_conn, "device_lists_stream", @@ -339,8 +345,9 @@ class DeviceWorkerStore(SQLBaseStore): # following this stream later. last_processed_stream_id = from_stream_id - query_map = {} - cross_signing_keys_by_user = {} + # A map of (user ID, device ID) to (stream ID, context). + query_map: Dict[Tuple[str, str], Tuple[int, Optional[str]]] = {} + cross_signing_keys_by_user: Dict[str, Dict[str, object]] = {} for user_id, device_id, update_stream_id, update_context in updates: # Calculate the remaining length budget. # Note that, for now, each entry in `cross_signing_keys_by_user` @@ -596,7 +603,7 @@ class DeviceWorkerStore(SQLBaseStore): txn=txn, table="device_lists_outbound_last_success", key_names=("destination", "user_id"), - key_values=((destination, user_id) for user_id, _ in rows), + key_values=[(destination, user_id) for user_id, _ in rows], value_names=("stream_id",), value_values=((stream_id,) for _, stream_id in rows), ) @@ -621,7 +628,9 @@ class DeviceWorkerStore(SQLBaseStore): The new stream ID. """ - async with self._device_list_id_gen.get_next() as stream_id: + # TODO: this looks like it's _writing_. Should this be on DeviceStore rather + # than DeviceWorkerStore? + async with self._device_list_id_gen.get_next() as stream_id: # type: ignore[attr-defined] await self.db_pool.runInteraction( "add_user_sig_change_to_streams", self._add_user_signature_change_txn, @@ -660,7 +669,7 @@ class DeviceWorkerStore(SQLBaseStore): @trace async def get_user_devices_from_cache( - self, query_list: List[Tuple[str, str]] + self, query_list: List[Tuple[str, Optional[str]]] ) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]: """Get the devices (and keys if any) for remote users from the cache. @@ -686,7 +695,7 @@ class DeviceWorkerStore(SQLBaseStore): } - users_needing_resync user_ids_not_in_cache = user_ids - user_ids_in_cache - results = {} + results: Dict[str, Dict[str, JsonDict]] = {} for user_id, device_id in query_list: if user_id not in user_ids_in_cache: continue @@ -697,8 +706,8 @@ class DeviceWorkerStore(SQLBaseStore): else: results[user_id] = await self.get_cached_devices_for_user(user_id) - set_tag("in_cache", results) - set_tag("not_in_cache", user_ids_not_in_cache) + set_tag("in_cache", str(results)) + set_tag("not_in_cache", str(user_ids_not_in_cache)) return user_ids_not_in_cache, results @@ -727,7 +736,7 @@ class DeviceWorkerStore(SQLBaseStore): def get_cached_device_list_changes( self, from_key: int, - ) -> Optional[Set[str]]: + ) -> Optional[List[str]]: """Get set of users whose devices have changed since `from_key`, or None if that information is not in our cache. """ @@ -737,7 +746,7 @@ class DeviceWorkerStore(SQLBaseStore): async def get_users_whose_devices_changed( self, from_key: int, - user_ids: Optional[Iterable[str]] = None, + user_ids: Optional[Collection[str]] = None, to_key: Optional[int] = None, ) -> Set[str]: """Get set of users whose devices have changed since `from_key` that @@ -757,6 +766,7 @@ class DeviceWorkerStore(SQLBaseStore): """ # Get set of users who *may* have changed. Users not in the returned # list have definitely not changed. + user_ids_to_check: Optional[Collection[str]] if user_ids is None: # Get set of all users that have had device list changes since 'from_key' user_ids_to_check = self._device_list_stream_cache.get_all_entities_changed( @@ -772,7 +782,7 @@ class DeviceWorkerStore(SQLBaseStore): return set() def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]: - changes = set() + changes: Set[str] = set() stream_id_where_clause = "stream_id > ?" sql_args = [from_key] @@ -788,6 +798,9 @@ class DeviceWorkerStore(SQLBaseStore): """ # Query device changes with a batch of users at a time + # Assertion for mypy's benefit; see also + # https://mypy.readthedocs.io/en/stable/common_issues.html#narrowing-and-inner-functions + assert user_ids_to_check is not None for chunk in batch_iter(user_ids_to_check, 100): clause, args = make_in_list_sql_clause( txn.database_engine, "user_id", chunk @@ -854,7 +867,9 @@ class DeviceWorkerStore(SQLBaseStore): if last_id == current_id: return [], current_id, False - def _get_all_device_list_changes_for_remotes(txn): + def _get_all_device_list_changes_for_remotes( + txn: Cursor, + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: # This query Does The Right Thing where it'll correctly apply the # bounds to the inner queries. sql = """ @@ -913,7 +928,7 @@ class DeviceWorkerStore(SQLBaseStore): desc="get_device_list_last_stream_id_for_remotes", ) - results = {user_id: None for user_id in user_ids} + results: Dict[str, Optional[str]] = {user_id: None for user_id in user_ids} results.update({row["user_id"]: row["stream_id"] for row in rows}) return results @@ -1193,6 +1208,65 @@ class DeviceWorkerStore(SQLBaseStore): return devices + @cached() + async def _get_min_device_lists_changes_in_room(self) -> int: + """Returns the minimum stream ID that we have entries for + `device_lists_changes_in_room` + """ + + return await self.db_pool.simple_select_one_onecol( + table="device_lists_changes_in_room", + keyvalues={}, + retcol="COALESCE(MIN(stream_id), 0)", + desc="get_min_device_lists_changes_in_room", + ) + + async def get_device_list_changes_in_rooms( + self, room_ids: Collection[str], from_id: int + ) -> Optional[Set[str]]: + """Return the set of users whose devices have changed in the given rooms + since the given stream ID. + + Returns None if the given stream ID is too old. + """ + + if not room_ids: + return set() + + min_stream_id = await self._get_min_device_lists_changes_in_room() + + if min_stream_id > from_id: + return None + + sql = """ + SELECT DISTINCT user_id FROM device_lists_changes_in_room + WHERE {clause} AND stream_id >= ? + """ + + def _get_device_list_changes_in_rooms_txn( + txn: LoggingTransaction, + clause: str, + args: List[Any], + ) -> Set[str]: + txn.execute(sql.format(clause=clause), args) + return {user_id for user_id, in txn} + + changes = set() + for chunk in batch_iter(room_ids, 1000): + clause, args = make_in_list_sql_clause( + self.database_engine, "room_id", chunk + ) + args.append(from_id) + + changes |= await self.db_pool.runInteraction( + "get_device_list_changes_in_rooms", + _get_device_list_changes_in_rooms_txn, + clause, + args, + ) + + return changes + class DeviceBackgroundUpdateStore(SQLBaseStore): def __init__( @@ -1240,15 +1314,6 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): self._remove_duplicate_outbound_pokes, ) - # a pair of background updates that were added during the 1.14 release cycle, - # but replaced with 58/06dlols_unique_idx.py - self.db_pool.updates.register_noop_background_update( - "device_lists_outbound_last_success_unique_idx", - ) - self.db_pool.updates.register_noop_background_update( - "drop_device_lists_outbound_last_success_non_unique_idx", - ) - async def _drop_device_list_streams_non_unique_indexes( self, progress: JsonDict, batch_size: int ) -> int: @@ -1346,9 +1411,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): # Map of (user_id, device_id) -> bool. If there is an entry that implies # the device exists. - self.device_id_exists_cache = LruCache( - cache_name="device_id_exists", max_size=10000 - ) + self.device_id_exists_cache: LruCache[ + Tuple[str, str], Literal[True] + ] = LruCache(cache_name="device_id_exists", max_size=10000) async def store_device( self, @@ -1660,7 +1725,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): context, ) - async with self._device_list_id_gen.get_next_mult( + async with self._device_list_id_gen.get_next_mult( # type: ignore[attr-defined] len(device_ids) ) as stream_ids: await self.db_pool.runInteraction( @@ -1713,7 +1778,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): device_ids: Iterable[str], hosts: Collection[str], stream_ids: List[int], - context: Dict[str, str], + context: Optional[Dict[str, str]], ) -> None: for host in hosts: txn.call_after( @@ -1884,7 +1949,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): [], ) - async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids: + async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids: # type: ignore[attr-defined] return await self.db_pool.runInteraction( "add_device_list_outbound_pokes", add_device_list_outbound_pokes_txn, diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 9b293475c8..46c0d06157 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -22,11 +22,14 @@ from typing import ( List, Optional, Tuple, + Union, cast, + overload, ) import attr from canonicaljson import encode_canonical_json +from typing_extensions import Literal from synapse.api.constants import DeviceKeyAlgorithms from synapse.appservice import ( @@ -113,7 +116,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker user_devices = devices[user_id] results = [] for device_id, device in user_devices.items(): - result = {"device_id": device_id} + result: JsonDict = {"device_id": device_id} keys = device.keys if keys: @@ -143,7 +146,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker key data. The key data will be a dict in the same format as the DeviceKeys type returned by POST /_matrix/client/r0/keys/query. """ - set_tag("query_list", query_list) + set_tag("query_list", str(query_list)) if not query_list: return {} @@ -156,6 +159,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker rv[user_id] = {} for device_id, device_info in device_keys.items(): r = device_info.keys + if r is None: + continue + r["unsigned"] = {} display_name = device_info.display_name if display_name is not None: @@ -164,13 +170,42 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker return rv + @overload + async def get_e2e_device_keys_and_signatures( + self, + query_list: Collection[Tuple[str, Optional[str]]], + include_all_devices: Literal[False] = False, + ) -> Dict[str, Dict[str, DeviceKeyLookupResult]]: + ... + + @overload + async def get_e2e_device_keys_and_signatures( + self, + query_list: Collection[Tuple[str, Optional[str]]], + include_all_devices: bool = False, + include_deleted_devices: Literal[False] = False, + ) -> Dict[str, Dict[str, DeviceKeyLookupResult]]: + ... + + @overload + async def get_e2e_device_keys_and_signatures( + self, + query_list: Collection[Tuple[str, Optional[str]]], + include_all_devices: Literal[True], + include_deleted_devices: Literal[True], + ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]: + ... + @trace async def get_e2e_device_keys_and_signatures( self, - query_list: List[Tuple[str, Optional[str]]], + query_list: Collection[Tuple[str, Optional[str]]], include_all_devices: bool = False, include_deleted_devices: bool = False, - ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]: + ) -> Union[ + Dict[str, Dict[str, DeviceKeyLookupResult]], + Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]], + ]: """Fetch a list of device keys Any cross-signatures made on the keys by the owner of the device are also @@ -383,7 +418,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker def _add_e2e_one_time_keys(txn: LoggingTransaction) -> None: set_tag("user_id", user_id) set_tag("device_id", device_id) - set_tag("new_keys", new_keys) + set_tag("new_keys", str(new_keys)) # We are protected from race between lookup and insertion due to # a unique constraint. If there is a race of two calls to # `add_e2e_one_time_keys` then they'll conflict and we will only @@ -1044,7 +1079,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker _claim_e2e_one_time_key = _claim_e2e_one_time_key_simple db_autocommit = False - row = await self.db_pool.runInteraction( + claim_row = await self.db_pool.runInteraction( "claim_e2e_one_time_keys", _claim_e2e_one_time_key, user_id, @@ -1052,11 +1087,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker algorithm, db_autocommit=db_autocommit, ) - if row: + if claim_row: device_results = results.setdefault(user_id, {}).setdefault( device_id, {} ) - device_results[row[0]] = row[1] + device_results[claim_row[0]] = claim_row[1] continue # No one-time key available, so see if there's a fallback @@ -1126,7 +1161,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): set_tag("user_id", user_id) set_tag("device_id", device_id) set_tag("time_now", time_now) - set_tag("device_keys", device_keys) + set_tag("device_keys", str(device_keys)) old_key_json = self.db_pool.simple_select_one_onecol_txn( txn, diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index eec55b6478..c836078da6 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py
@@ -33,6 +33,7 @@ from synapse.api.constants import MAX_DEPTH, EventTypes from synapse.api.errors import StoreError from synapse.api.room_versions import EventFormatVersions, RoomVersion from synapse.events import EventBase, make_event_from_dict +from synapse.logging.opentracing import tag_args, trace from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( @@ -126,6 +127,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas ) return await self.get_events_as_list(event_ids) + @trace + @tag_args async def get_auth_chain_ids( self, room_id: str, @@ -709,6 +712,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas # Return all events where not all sets can reach them. return {eid for eid, n in event_to_missing_sets.items() if n} + @trace + @tag_args async def get_oldest_event_ids_with_depth_in_room( self, room_id: str ) -> List[Tuple[str, int]]: @@ -767,6 +772,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas room_id, ) + @trace async def get_insertion_event_backward_extremities_in_room( self, room_id: str ) -> List[Tuple[str, int]]: @@ -1339,6 +1345,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas event_results.reverse() return event_results + @trace + @tag_args async def get_successor_events(self, event_id: str) -> List[str]: """Fetch all events that have the given event as a prev event @@ -1375,6 +1383,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas _delete_old_forward_extrem_cache_txn, ) + @trace async def insert_insertion_extremity(self, event_id: str, room_id: str) -> None: await self.db_pool.simple_upsert( table="insertion_event_extremities", diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index b019979350..f4a07de2a3 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py
@@ -12,18 +12,92 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +"""Responsible for storing and fetching push actions / notifications. + +There are two main uses for push actions: + 1. Sending out push to a user's device; and + 2. Tracking per-room per-user notification counts (used in sync requests). + +For the former we simply use the `event_push_actions` table, which contains all +the calculated actions for a given user (which were calculated by the +`BulkPushRuleEvaluator`). + +For the latter we could simply count the number of rows in `event_push_actions` +table for a given room/user, but in practice this is *very* heavyweight when +there were a large number of notifications (due to e.g. the user never reading a +room). Plus, keeping all push actions indefinitely uses a lot of disk space. + +To fix these issues, we add a new table `event_push_summary` that tracks +per-user per-room counts of all notifications that happened before a stream +ordering S. Thus, to get the notification count for a user / room we can simply +query a single row in `event_push_summary` and count the number of rows in +`event_push_actions` with a stream ordering larger than S (and as long as S is +"recent", the number of rows needing to be scanned will be small). + +The `event_push_summary` table is updated via a background job that periodically +chooses a new stream ordering S' (usually the latest stream ordering), counts +all notifications in `event_push_actions` between the existing S and S', and +adds them to the existing counts in `event_push_summary`. + +This allows us to delete old rows from `event_push_actions` once those rows have +been counted and added to `event_push_summary` (we call this process +"rotation"). + + +We need to handle when a user sends a read receipt to the room. Again this is +done as a background process. For each receipt we clear the row in +`event_push_summary` and count the number of notifications in +`event_push_actions` that happened after the receipt but before S, and insert +that count into `event_push_summary` (If the receipt happened *after* S then we +simply clear the `event_push_summary`.) + +Note that its possible that if the read receipt is for an old event the relevant +`event_push_actions` rows will have been rotated and we get the wrong count +(it'll be too low). We accept this as a rare edge case that is unlikely to +impact the user much (since the vast majority of read receipts will be for the +latest event). + +The last complication is to handle the race where we request the notifications +counts after a user sends a read receipt into the room, but *before* the +background update handles the receipt (without any special handling the counts +would be outdated). We fix this by including in `event_push_summary` the read +receipt we used when updating `event_push_summary`, and every time we query the +table we check if that matches the most recent read receipt in the room. If yes, +continue as above, if not we simply query the `event_push_actions` table +directly. + +Since read receipts are almost always for recent events, scanning the +`event_push_actions` table in this case is unlikely to be a problem. Even if it +is a problem, it is temporary until the background job handles the new read +receipt. +""" + import logging -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast +from typing import ( + TYPE_CHECKING, + Collection, + Dict, + List, + Mapping, + Optional, + Tuple, + Union, + cast, +) import attr +from synapse.api.constants import ReceiptTypes from synapse.metrics.background_process_metrics import wrap_as_background_process -from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( DatabasePool, LoggingDatabaseConnection, LoggingTransaction, ) +from synapse.storage.databases.main.receipts import ReceiptsWorkerStore +from synapse.storage.databases.main.stream import StreamWorkerStore from synapse.util import json_encoder from synapse.util.caches.descriptors import cached @@ -79,18 +153,20 @@ class UserPushAction(EmailPushAction): profile_tag: str -@attr.s(slots=True, frozen=True, auto_attribs=True) +@attr.s(slots=True, auto_attribs=True) class NotifCounts: """ The per-user, per-room count of notifications. Used by sync and push. """ - notify_count: int - unread_count: int - highlight_count: int + notify_count: int = 0 + unread_count: int = 0 + highlight_count: int = 0 -def _serialize_action(actions: List[Union[dict, str]], is_highlight: bool) -> str: +def _serialize_action( + actions: Collection[Union[Mapping, str]], is_highlight: bool +) -> str: """Custom serializer for actions. This allows us to "compress" common actions. We use the fact that most users have the same actions for notifs (and for @@ -119,7 +195,7 @@ def _deserialize_action(actions: str, is_highlight: bool) -> List[Union[dict, st return DEFAULT_NOTIF_ACTION -class EventPushActionsWorkerStore(SQLBaseStore): +class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBaseStore): def __init__( self, database: DatabasePool, @@ -140,23 +216,30 @@ class EventPushActionsWorkerStore(SQLBaseStore): self._find_stream_orderings_for_times, 10 * 60 * 1000 ) - self._rotate_delay = 3 self._rotate_count = 10000 self._doing_notif_rotation = False if hs.config.worker.run_background_tasks: self._rotate_notif_loop = self._clock.looping_call( - self._rotate_notifs, 30 * 60 * 1000 + self._rotate_notifs, 30 * 1000 ) - @cached(num_args=3, tree=True, max_entries=5000) + self.db_pool.updates.register_background_index_update( + "event_push_summary_unique_index", + index_name="event_push_summary_unique_index", + table="event_push_summary", + columns=["user_id", "room_id"], + unique=True, + replaces_index="event_push_summary_user_rm", + ) + + @cached(tree=True, max_entries=5000) async def get_unread_event_push_actions_by_room_for_user( self, room_id: str, user_id: str, - last_read_event_id: Optional[str], ) -> NotifCounts: """Get the notification count, the highlight count and the unread message count - for a given user in a given room after the given read receipt. + for a given user in a given room after their latest read receipt. Note that this function assumes the user to be a current member of the room, since it's either called by the sync handler to handle joined room entries, or by @@ -165,20 +248,16 @@ class EventPushActionsWorkerStore(SQLBaseStore): Args: room_id: The room to retrieve the counts in. user_id: The user to retrieve the counts for. - last_read_event_id: The event associated with the latest read receipt for - this user in this room. None if no receipt for this user in this room. Returns - A dict containing the counts mentioned earlier in this docstring, - respectively under the keys "notify_count", "highlight_count" and - "unread_count". + A NotifCounts object containing the notification count, the highlight count + and the unread message count. """ return await self.db_pool.runInteraction( "get_unread_event_push_actions_by_room", self._get_unread_counts_by_receipt_txn, room_id, user_id, - last_read_event_id, ) def _get_unread_counts_by_receipt_txn( @@ -186,20 +265,23 @@ class EventPushActionsWorkerStore(SQLBaseStore): txn: LoggingTransaction, room_id: str, user_id: str, - last_read_event_id: Optional[str], ) -> NotifCounts: - stream_ordering = None + # Get the stream ordering of the user's latest receipt in the room. + result = self.get_last_receipt_for_user_txn( + txn, + user_id, + room_id, + receipt_types=( + ReceiptTypes.READ, + ReceiptTypes.READ_PRIVATE, + ), + ) - if last_read_event_id is not None: - stream_ordering = self.get_stream_id_for_event_txn( # type: ignore[attr-defined] - txn, - last_read_event_id, - allow_none=True, - ) + if result: + _, stream_ordering = result - if stream_ordering is None: - # Either last_read_event_id is None, or it's an event we don't have (e.g. - # because it's been purged), in which case retrieve the stream ordering for + else: + # If the user has no receipts in the room, retrieve the stream ordering for # the latest membership event from this user in this room (which we assume is # a join). event_id = self.db_pool.simple_select_one_onecol_txn( @@ -209,57 +291,159 @@ class EventPushActionsWorkerStore(SQLBaseStore): retcol="event_id", ) - stream_ordering = self.get_stream_id_for_event_txn(txn, event_id) # type: ignore[attr-defined] + stream_ordering = self.get_stream_id_for_event_txn(txn, event_id) return self._get_unread_counts_by_pos_txn( txn, room_id, user_id, stream_ordering ) def _get_unread_counts_by_pos_txn( - self, txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int + self, + txn: LoggingTransaction, + room_id: str, + user_id: str, + receipt_stream_ordering: int, ) -> NotifCounts: - sql = ( - "SELECT" - " COUNT(CASE WHEN notif = 1 THEN 1 END)," - " COUNT(CASE WHEN highlight = 1 THEN 1 END)," - " COUNT(CASE WHEN unread = 1 THEN 1 END)" - " FROM event_push_actions ea" - " WHERE user_id = ?" - " AND room_id = ?" - " AND stream_ordering > ?" - ) + """Get the number of unread messages for a user/room that have happened + since the given stream ordering. - txn.execute(sql, (user_id, room_id, stream_ordering)) - row = txn.fetchone() + Args: + txn: The database transaction. + room_id: The room ID to get unread counts for. + user_id: The user ID to get unread counts for. + receipt_stream_ordering: The stream ordering of the user's latest + receipt in the room. If there are no receipts, the stream ordering + of the user's join event. - (notif_count, highlight_count, unread_count) = (0, 0, 0) + Returns + A NotifCounts object containing the notification count, the highlight count + and the unread message count. + """ - if row: - (notif_count, highlight_count, unread_count) = row + counts = NotifCounts() + # First we pull the counts from the summary table. + # + # We check that `last_receipt_stream_ordering` matches the stream + # ordering given. If it doesn't match then a new read receipt has arrived and + # we haven't yet updated the counts in `event_push_summary` to reflect + # that; in that case we simply ignore `event_push_summary` counts + # and do a manual count of all of the rows in the `event_push_actions` table + # for this user/room. + # + # If `last_receipt_stream_ordering` is null then that means it's up to + # date (as the row was written by an older version of Synapse that + # updated `event_push_summary` synchronously when persisting a new read + # receipt). txn.execute( """ - SELECT notif_count, unread_count FROM event_push_summary - WHERE room_id = ? AND user_id = ? AND stream_ordering > ? + SELECT stream_ordering, notif_count, COALESCE(unread_count, 0) + FROM event_push_summary + WHERE room_id = ? AND user_id = ? + AND ( + (last_receipt_stream_ordering IS NULL AND stream_ordering > ?) + OR last_receipt_stream_ordering = ? + ) """, - (room_id, user_id, stream_ordering), + (room_id, user_id, receipt_stream_ordering, receipt_stream_ordering), ) row = txn.fetchone() + summary_stream_ordering = 0 + if row: + summary_stream_ordering = row[0] + counts.notify_count += row[1] + counts.unread_count += row[2] + + # Next we need to count highlights, which aren't summarised + sql = """ + SELECT COUNT(*) FROM event_push_actions + WHERE user_id = ? + AND room_id = ? + AND stream_ordering > ? + AND highlight = 1 + """ + txn.execute(sql, (user_id, room_id, receipt_stream_ordering)) + row = txn.fetchone() if row: - notif_count += row[0] - - if row[1] is not None: - # The unread_count column of event_push_summary is NULLable, so we need - # to make sure we don't try increasing the unread counts if it's NULL - # for this row. - unread_count += row[1] - - return NotifCounts( - notify_count=notif_count, - unread_count=unread_count, - highlight_count=highlight_count, + counts.highlight_count += row[0] + + # Finally we need to count push actions that aren't included in the + # summary returned above. This might be due to recent events that haven't + # been summarised yet or the summary is out of date due to a recent read + # receipt. + start_unread_stream_ordering = max( + receipt_stream_ordering, summary_stream_ordering ) + notify_count, unread_count = self._get_notif_unread_count_for_user_room( + txn, room_id, user_id, start_unread_stream_ordering + ) + + counts.notify_count += notify_count + counts.unread_count += unread_count + + return counts + + def _get_notif_unread_count_for_user_room( + self, + txn: LoggingTransaction, + room_id: str, + user_id: str, + stream_ordering: int, + max_stream_ordering: Optional[int] = None, + ) -> Tuple[int, int]: + """Returns the notify and unread counts from `event_push_actions` for + the given user/room in the given range. + + Does not consult `event_push_summary` table, which may include push + actions that have been deleted from `event_push_actions` table. + + Args: + txn: The database transaction. + room_id: The room ID to get unread counts for. + user_id: The user ID to get unread counts for. + stream_ordering: The (exclusive) minimum stream ordering to consider. + max_stream_ordering: The (inclusive) maximum stream ordering to consider. + If this is not given, then no maximum is applied. + + Return: + A tuple of the notif count and unread count in the given range. + """ + + # If there have been no events in the room since the stream ordering, + # there can't be any push actions either. + if not self._events_stream_cache.has_entity_changed(room_id, stream_ordering): + return 0, 0 + + clause = "" + args = [user_id, room_id, stream_ordering] + if max_stream_ordering is not None: + clause = "AND ea.stream_ordering <= ?" + args.append(max_stream_ordering) + + # If the max stream ordering is less than the min stream ordering, + # then obviously there are zero push actions in that range. + if max_stream_ordering <= stream_ordering: + return 0, 0 + + sql = f""" + SELECT + COUNT(CASE WHEN notif = 1 THEN 1 END), + COUNT(CASE WHEN unread = 1 THEN 1 END) + FROM event_push_actions ea + WHERE user_id = ? + AND room_id = ? + AND ea.stream_ordering > ? + {clause} + """ + + txn.execute(sql, args) + row = txn.fetchone() + + if row: + return cast(Tuple[int, int], row) + + return 0, 0 async def get_push_action_users_in_range( self, min_stream_ordering: int, max_stream_ordering: int @@ -274,6 +458,31 @@ class EventPushActionsWorkerStore(SQLBaseStore): return await self.db_pool.runInteraction("get_push_action_users_in_range", f) + def _get_receipts_by_room_txn( + self, txn: LoggingTransaction, user_id: str + ) -> List[Tuple[str, int]]: + receipt_types_clause, args = make_in_list_sql_clause( + self.database_engine, + "receipt_type", + ( + ReceiptTypes.READ, + ReceiptTypes.READ_PRIVATE, + ), + ) + + sql = f""" + SELECT room_id, MAX(stream_ordering) + FROM receipts_linearized + INNER JOIN events USING (room_id, event_id) + WHERE {receipt_types_clause} + AND user_id = ? + GROUP BY room_id + """ + + args.extend((user_id,)) + txn.execute(sql, args) + return cast(List[Tuple[str, int]], txn.fetchall()) + async def get_unread_push_actions_for_user_in_range_for_http( self, user_id: str, @@ -296,81 +505,46 @@ class EventPushActionsWorkerStore(SQLBaseStore): The list will be ordered by ascending stream_ordering. The list will have between 0~limit entries. """ - # find rooms that have a read receipt in them and return the next - # push actions - def get_after_receipt( - txn: LoggingTransaction, - ) -> List[Tuple[str, str, int, str, bool]]: - # find rooms that have a read receipt in them and return the next - # push actions - sql = ( - "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," - " ep.highlight " - " FROM (" - " SELECT room_id," - " MAX(stream_ordering) as stream_ordering" - " FROM events" - " INNER JOIN receipts_linearized USING (room_id, event_id)" - " WHERE receipt_type = 'm.read' AND user_id = ?" - " GROUP BY room_id" - ") AS rl," - " event_push_actions AS ep" - " WHERE" - " ep.room_id = rl.room_id" - " AND ep.stream_ordering > rl.stream_ordering" - " AND ep.user_id = ?" - " AND ep.stream_ordering > ?" - " AND ep.stream_ordering <= ?" - " AND ep.notif = 1" - " ORDER BY ep.stream_ordering ASC LIMIT ?" - ) - args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] - txn.execute(sql, args) - return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall()) - after_read_receipt = await self.db_pool.runInteraction( - "get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt + receipts_by_room = dict( + await self.db_pool.runInteraction( + "get_unread_push_actions_for_user_in_range_http_receipts", + self._get_receipts_by_room_txn, + user_id=user_id, + ), ) - # There are rooms with push actions in them but you don't have a read receipt in - # them e.g. rooms you've been invited to, so get push actions for rooms which do - # not have read receipts in them too. - def get_no_receipt( + def get_push_actions_txn( txn: LoggingTransaction, ) -> List[Tuple[str, str, int, str, bool]]: - sql = ( - "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," - " ep.highlight " - " FROM event_push_actions AS ep" - " INNER JOIN events AS e USING (room_id, event_id)" - " WHERE" - " ep.room_id NOT IN (" - " SELECT room_id FROM receipts_linearized" - " WHERE receipt_type = 'm.read' AND user_id = ?" - " GROUP BY room_id" - " )" - " AND ep.user_id = ?" - " AND ep.stream_ordering > ?" - " AND ep.stream_ordering <= ?" - " AND ep.notif = 1" - " ORDER BY ep.stream_ordering ASC LIMIT ?" - ) - args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] - txn.execute(sql, args) + sql = """ + SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, ep.highlight + FROM event_push_actions AS ep + WHERE + ep.user_id = ? + AND ep.stream_ordering > ? + AND ep.stream_ordering <= ? + AND ep.notif = 1 + ORDER BY ep.stream_ordering ASC LIMIT ? + """ + txn.execute(sql, (user_id, min_stream_ordering, max_stream_ordering, limit)) return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall()) - no_read_receipt = await self.db_pool.runInteraction( - "get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt + push_actions = await self.db_pool.runInteraction( + "get_unread_push_actions_for_user_in_range_http", get_push_actions_txn ) notifs = [ HttpPushAction( - event_id=row[0], - room_id=row[1], - stream_ordering=row[2], - actions=_deserialize_action(row[3], row[4]), + event_id=event_id, + room_id=room_id, + stream_ordering=stream_ordering, + actions=_deserialize_action(actions, highlight), ) - for row in after_read_receipt + no_read_receipt + for event_id, room_id, stream_ordering, actions, highlight in push_actions + # Only include push actions with a stream ordering after any receipt, or without any + # receipt present (invited to but never read rooms). + if stream_ordering > receipts_by_room.get(room_id, 0) ] # Now sort it so it's ordered correctly, since currently it will @@ -405,82 +579,50 @@ class EventPushActionsWorkerStore(SQLBaseStore): The list will be ordered by descending received_ts. The list will have between 0~limit entries. """ - # find rooms that have a read receipt in them and return the most recent - # push actions - def get_after_receipt( - txn: LoggingTransaction, - ) -> List[Tuple[str, str, int, str, bool, int]]: - sql = ( - "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," - " ep.highlight, e.received_ts" - " FROM (" - " SELECT room_id," - " MAX(stream_ordering) as stream_ordering" - " FROM events" - " INNER JOIN receipts_linearized USING (room_id, event_id)" - " WHERE receipt_type = 'm.read' AND user_id = ?" - " GROUP BY room_id" - ") AS rl," - " event_push_actions AS ep" - " INNER JOIN events AS e USING (room_id, event_id)" - " WHERE" - " ep.room_id = rl.room_id" - " AND ep.stream_ordering > rl.stream_ordering" - " AND ep.user_id = ?" - " AND ep.stream_ordering > ?" - " AND ep.stream_ordering <= ?" - " AND ep.notif = 1" - " ORDER BY ep.stream_ordering DESC LIMIT ?" - ) - args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] - txn.execute(sql, args) - return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall()) - after_read_receipt = await self.db_pool.runInteraction( - "get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt + receipts_by_room = dict( + await self.db_pool.runInteraction( + "get_unread_push_actions_for_user_in_range_email_receipts", + self._get_receipts_by_room_txn, + user_id=user_id, + ), ) - # There are rooms with push actions in them but you don't have a read receipt in - # them e.g. rooms you've been invited to, so get push actions for rooms which do - # not have read receipts in them too. - def get_no_receipt( + def get_push_actions_txn( txn: LoggingTransaction, ) -> List[Tuple[str, str, int, str, bool, int]]: - sql = ( - "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," - " ep.highlight, e.received_ts" - " FROM event_push_actions AS ep" - " INNER JOIN events AS e USING (room_id, event_id)" - " WHERE" - " ep.room_id NOT IN (" - " SELECT room_id FROM receipts_linearized" - " WHERE receipt_type = 'm.read' AND user_id = ?" - " GROUP BY room_id" - " )" - " AND ep.user_id = ?" - " AND ep.stream_ordering > ?" - " AND ep.stream_ordering <= ?" - " AND ep.notif = 1" - " ORDER BY ep.stream_ordering DESC LIMIT ?" - ) - args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] - txn.execute(sql, args) + sql = """ + SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, + ep.highlight, e.received_ts + FROM event_push_actions AS ep + INNER JOIN events AS e USING (room_id, event_id) + WHERE + ep.user_id = ? + AND ep.stream_ordering > ? + AND ep.stream_ordering <= ? + AND ep.notif = 1 + ORDER BY ep.stream_ordering DESC LIMIT ? + """ + txn.execute(sql, (user_id, min_stream_ordering, max_stream_ordering, limit)) return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall()) - no_read_receipt = await self.db_pool.runInteraction( - "get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt + push_actions = await self.db_pool.runInteraction( + "get_unread_push_actions_for_user_in_range_email", get_push_actions_txn ) # Make a list of dicts from the two sets of results. notifs = [ EmailPushAction( - event_id=row[0], - room_id=row[1], - stream_ordering=row[2], - actions=_deserialize_action(row[3], row[4]), - received_ts=row[5], + event_id=event_id, + room_id=room_id, + stream_ordering=stream_ordering, + actions=_deserialize_action(actions, highlight), + received_ts=received_ts, ) - for row in after_read_receipt + no_read_receipt + for event_id, room_id, stream_ordering, actions, highlight, received_ts in push_actions + # Only include push actions with a stream ordering after any receipt, or without any + # receipt present (invited to but never read rooms). + if stream_ordering > receipts_by_room.get(room_id, 0) ] # Now sort it so it's ordered correctly, since currently it will @@ -526,7 +668,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): async def add_push_actions_to_staging( self, event_id: str, - user_id_actions: Dict[str, List[Union[dict, str]]], + user_id_actions: Dict[str, Collection[Union[Mapping, str]]], count_as_unread: bool, ) -> None: """Add the push actions for the event to the push action staging area. @@ -543,7 +685,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): # This is a helper function for generating the necessary tuple that # can be used to insert into the `event_push_actions_staging` table. def _gen_entry( - user_id: str, actions: List[Union[dict, str]] + user_id: str, actions: Collection[Union[Mapping, str]] ) -> Tuple[str, str, str, int, int, int]: is_highlight = 1 if _action_has_highlight(actions) else 0 notif = 1 if "notify" in actions else 0 @@ -556,26 +698,14 @@ class EventPushActionsWorkerStore(SQLBaseStore): int(count_as_unread), # unread column ) - def _add_push_actions_to_staging_txn(txn: LoggingTransaction) -> None: - # We don't use simple_insert_many here to avoid the overhead - # of generating lists of dicts. - - sql = """ - INSERT INTO event_push_actions_staging - (event_id, user_id, actions, notif, highlight, unread) - VALUES (?, ?, ?, ?, ?, ?) - """ - - txn.execute_batch( - sql, - ( - _gen_entry(user_id, actions) - for user_id, actions in user_id_actions.items() - ), - ) - - return await self.db_pool.runInteraction( - "add_push_actions_to_staging", _add_push_actions_to_staging_txn + await self.db_pool.simple_insert_many( + "event_push_actions_staging", + keys=("event_id", "user_id", "actions", "notif", "highlight", "unread"), + values=[ + _gen_entry(user_id, actions) + for user_id, actions in user_id_actions.items() + ], + desc="add_push_actions_to_staging", ) async def remove_push_actions_from_staging(self, event_id: str) -> None: @@ -689,12 +819,12 @@ class EventPushActionsWorkerStore(SQLBaseStore): # [10, <none>, 20], we should treat this as being equivalent to # [10, 10, 20]. # - sql = ( - "SELECT received_ts FROM events" - " WHERE stream_ordering <= ?" - " ORDER BY stream_ordering DESC" - " LIMIT 1" - ) + sql = """ + SELECT received_ts FROM events + WHERE stream_ordering <= ? + ORDER BY stream_ordering DESC + LIMIT 1 + """ while range_end - range_start > 0: middle = (range_end + range_start) // 2 @@ -722,14 +852,14 @@ class EventPushActionsWorkerStore(SQLBaseStore): self, stream_ordering: int ) -> Optional[int]: def f(txn: LoggingTransaction) -> Optional[Tuple[int]]: - sql = ( - "SELECT e.received_ts" - " FROM event_push_actions AS ep" - " JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id" - " WHERE ep.stream_ordering > ? AND notif = 1" - " ORDER BY ep.stream_ordering ASC" - " LIMIT 1" - ) + sql = """ + SELECT e.received_ts + FROM event_push_actions AS ep + JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id + WHERE ep.stream_ordering > ? AND notif = 1 + ORDER BY ep.stream_ordering ASC + LIMIT 1 + """ txn.execute(sql, (stream_ordering,)) return cast(Optional[Tuple[int]], txn.fetchone()) @@ -745,6 +875,19 @@ class EventPushActionsWorkerStore(SQLBaseStore): self._doing_notif_rotation = True try: + # First we recalculate push summaries and delete stale push actions + # for rooms/users with new receipts. + while True: + logger.debug("Handling new receipts") + + caught_up = await self.db_pool.runInteraction( + "_handle_new_receipts_for_notifs_txn", + self._handle_new_receipts_for_notifs_txn, + ) + if caught_up: + break + + # Then we update the event push summaries for any new events while True: logger.info("Rotating notifications") @@ -753,15 +896,129 @@ class EventPushActionsWorkerStore(SQLBaseStore): ) if caught_up: break - await self.hs.get_clock().sleep(self._rotate_delay) + + # Finally we clear out old event push actions. + await self._remove_old_push_actions_that_have_rotated() finally: self._doing_notif_rotation = False + def _handle_new_receipts_for_notifs_txn(self, txn: LoggingTransaction) -> bool: + """Check for new read receipts and delete from event push actions. + + Any push actions which predate the user's most recent read receipt are + now redundant, so we can remove them from `event_push_actions` and + update `event_push_summary`. + + Returns true if all new receipts have been processed. + """ + + limit = 100 + + # The (inclusive) receipt stream ID that was previously processed.. + min_receipts_stream_id = self.db_pool.simple_select_one_onecol_txn( + txn, + table="event_push_summary_last_receipt_stream_id", + keyvalues={}, + retcol="stream_id", + ) + + max_receipts_stream_id = self._receipts_id_gen.get_current_token() + + # The (inclusive) event stream ordering that was previously summarised. + old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn( + txn, + table="event_push_summary_stream_ordering", + keyvalues={}, + retcol="stream_ordering", + ) + + sql = """ + SELECT r.stream_id, r.room_id, r.user_id, e.stream_ordering + FROM receipts_linearized AS r + INNER JOIN events AS e USING (event_id) + WHERE ? < r.stream_id AND r.stream_id <= ? AND user_id LIKE ? + ORDER BY r.stream_id ASC + LIMIT ? + """ + + # We only want local users, so we add a dodgy filter to the above query + # and recheck it below. + user_filter = "%:" + self.hs.hostname + + txn.execute( + sql, + ( + min_receipts_stream_id, + max_receipts_stream_id, + user_filter, + limit, + ), + ) + rows = txn.fetchall() + + # For each new read receipt we delete push actions from before it and + # recalculate the summary. + for _, room_id, user_id, stream_ordering in rows: + # Only handle our own read receipts. + if not self.hs.is_mine_id(user_id): + continue + + txn.execute( + """ + DELETE FROM event_push_actions + WHERE room_id = ? + AND user_id = ? + AND stream_ordering <= ? + AND highlight = 0 + """, + (room_id, user_id, stream_ordering), + ) + + # Fetch the notification counts between the stream ordering of the + # latest receipt and what was previously summarised. + notif_count, unread_count = self._get_notif_unread_count_for_user_room( + txn, room_id, user_id, stream_ordering, old_rotate_stream_ordering + ) + + # Replace the previous summary with the new counts. + self.db_pool.simple_upsert_txn( + txn, + table="event_push_summary", + keyvalues={"room_id": room_id, "user_id": user_id}, + values={ + "notif_count": notif_count, + "unread_count": unread_count, + "stream_ordering": old_rotate_stream_ordering, + "last_receipt_stream_ordering": stream_ordering, + }, + ) + + # We always update `event_push_summary_last_receipt_stream_id` to + # ensure that we don't rescan the same receipts for remote users. + + upper_limit = max_receipts_stream_id + if len(rows) >= limit: + # If we pulled out a limited number of rows we only update the + # position to the last receipt we processed, so we continue + # processing the rest next iteration. + upper_limit = rows[-1][0] + + self.db_pool.simple_update_txn( + txn, + table="event_push_summary_last_receipt_stream_id", + keyvalues={}, + updatevalues={"stream_id": upper_limit}, + ) + + return len(rows) < limit + def _rotate_notifs_txn(self, txn: LoggingTransaction) -> bool: - """Archives older notifications into event_push_summary. Returns whether - the archiving process has caught up or not. + """Archives older notifications (from event_push_actions) into event_push_summary. + + Returns whether the archiving process has caught up or not. """ + # The (inclusive) event stream ordering that was previously summarised. old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn( txn, table="event_push_summary_stream_ordering", @@ -776,50 +1033,64 @@ class EventPushActionsWorkerStore(SQLBaseStore): SELECT stream_ordering FROM event_push_actions WHERE stream_ordering > ? ORDER BY stream_ordering ASC LIMIT 1 OFFSET ? - """, + """, (old_rotate_stream_ordering, self._rotate_count), ) stream_row = txn.fetchone() if stream_row: (offset_stream_ordering,) = stream_row - assert self.stream_ordering_day_ago is not None + + # We need to bound by the current token to ensure that we handle + # out-of-order writes correctly. rotate_to_stream_ordering = min( - self.stream_ordering_day_ago, offset_stream_ordering + offset_stream_ordering, self._stream_id_gen.get_current_token() ) - caught_up = offset_stream_ordering >= self.stream_ordering_day_ago + caught_up = False else: - rotate_to_stream_ordering = self.stream_ordering_day_ago + rotate_to_stream_ordering = self._stream_id_gen.get_current_token() caught_up = True logger.info("Rotating notifications up to: %s", rotate_to_stream_ordering) - self._rotate_notifs_before_txn(txn, rotate_to_stream_ordering) + self._rotate_notifs_before_txn( + txn, old_rotate_stream_ordering, rotate_to_stream_ordering + ) - # We have caught up iff we were limited by `stream_ordering_day_ago` return caught_up def _rotate_notifs_before_txn( - self, txn: LoggingTransaction, rotate_to_stream_ordering: int + self, + txn: LoggingTransaction, + old_rotate_stream_ordering: int, + rotate_to_stream_ordering: int, ) -> None: - old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn( - txn, - table="event_push_summary_stream_ordering", - keyvalues={}, - retcol="stream_ordering", - ) + """Archives older notifications (from event_push_actions) into event_push_summary. + + Any event_push_actions between old_rotate_stream_ordering (exclusive) and + rotate_to_stream_ordering (inclusive) will be added to the event_push_summary + table. + + Args: + txn: The database transaction. + old_rotate_stream_ordering: The previous maximum event stream ordering. + rotate_to_stream_ordering: The new maximum event stream ordering to summarise. + """ # Calculate the new counts that should be upserted into event_push_summary sql = """ SELECT user_id, room_id, coalesce(old.%s, 0) + upd.cnt, - upd.stream_ordering, - old.user_id + upd.stream_ordering FROM ( SELECT user_id, room_id, count(*) as cnt, - max(stream_ordering) as stream_ordering - FROM event_push_actions - WHERE ? <= stream_ordering AND stream_ordering < ? - AND highlight = 0 + max(ea.stream_ordering) as stream_ordering + FROM event_push_actions AS ea + LEFT JOIN event_push_summary AS old USING (user_id, room_id) + WHERE ? < ea.stream_ordering AND ea.stream_ordering <= ? + AND ( + old.last_receipt_stream_ordering IS NULL + OR old.last_receipt_stream_ordering < ea.stream_ordering + ) AND %s = 1 GROUP BY user_id, room_id ) AS upd @@ -842,7 +1113,6 @@ class EventPushActionsWorkerStore(SQLBaseStore): summaries[(row[0], row[1])] = _EventPushSummary( unread_count=row[2], stream_ordering=row[3], - old_user_id=row[4], notif_count=0, ) @@ -863,115 +1133,93 @@ class EventPushActionsWorkerStore(SQLBaseStore): summaries[(row[0], row[1])] = _EventPushSummary( unread_count=0, stream_ordering=row[3], - old_user_id=row[4], notif_count=row[2], ) logger.info("Rotating notifications, handling %d rows", len(summaries)) - # If the `old.user_id` above is NULL then we know there isn't already an - # entry in the table, so we simply insert it. Otherwise we update the - # existing table. - self.db_pool.simple_insert_many_txn( + self.db_pool.simple_upsert_many_txn( txn, table="event_push_summary", - keys=( - "user_id", - "room_id", - "notif_count", - "unread_count", - "stream_ordering", - ), - values=[ + key_names=("user_id", "room_id"), + key_values=[(user_id, room_id) for user_id, room_id in summaries], + value_names=("notif_count", "unread_count", "stream_ordering"), + value_values=[ ( - user_id, - room_id, summary.notif_count, summary.unread_count, summary.stream_ordering, ) - for ((user_id, room_id), summary) in summaries.items() - if summary.old_user_id is None + for summary in summaries.values() ], ) - txn.execute_batch( - """ - UPDATE event_push_summary - SET notif_count = ?, unread_count = ?, stream_ordering = ? - WHERE user_id = ? AND room_id = ? - """, - ( - ( - summary.notif_count, - summary.unread_count, - summary.stream_ordering, - user_id, - room_id, - ) - for ((user_id, room_id), summary) in summaries.items() - if summary.old_user_id is not None - ), - ) - - txn.execute( - "DELETE FROM event_push_actions" - " WHERE ? <= stream_ordering AND stream_ordering < ? AND highlight = 0", - (old_rotate_stream_ordering, rotate_to_stream_ordering), - ) - - logger.info("Rotating notifications, deleted %s push actions", txn.rowcount) - txn.execute( "UPDATE event_push_summary_stream_ordering SET stream_ordering = ?", (rotate_to_stream_ordering,), ) - def _remove_old_push_actions_before_txn( - self, txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int - ) -> None: - """ - Purges old push actions for a user and room before a given - stream_ordering. - - We however keep a months worth of highlighted notifications, so that - users can still get a list of recent highlights. + async def _remove_old_push_actions_that_have_rotated(self) -> None: + """Clear out old push actions that have been summarised.""" - Args: - txn: The transaction - room_id: Room ID to delete from - user_id: user ID to delete for - stream_ordering: The lowest stream ordering which will - not be deleted. - """ - txn.call_after( - self.get_unread_event_push_actions_by_room_for_user.invalidate, - (room_id, user_id), + # We want to clear out anything that is older than a day that *has* already + # been rotated. + rotated_upto_stream_ordering = await self.db_pool.simple_select_one_onecol( + table="event_push_summary_stream_ordering", + keyvalues={}, + retcol="stream_ordering", ) - # We need to join on the events table to get the received_ts for - # event_push_actions and sqlite won't let us use a join in a delete so - # we can't just delete where received_ts < x. Furthermore we can - # only identify event_push_actions by a tuple of room_id, event_id - # we we can't use a subquery. - # Instead, we look up the stream ordering for the last event in that - # room received before the threshold time and delete event_push_actions - # in the room with a stream_odering before that. - txn.execute( - "DELETE FROM event_push_actions " - " WHERE user_id = ? AND room_id = ? AND " - " stream_ordering <= ?" - " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)", - (user_id, room_id, stream_ordering, self.stream_ordering_month_ago), + max_stream_ordering_to_delete = min( + rotated_upto_stream_ordering, self.stream_ordering_day_ago ) - txn.execute( - """ - DELETE FROM event_push_summary - WHERE room_id = ? AND user_id = ? AND stream_ordering <= ? - """, - (room_id, user_id, stream_ordering), - ) + def remove_old_push_actions_that_have_rotated_txn( + txn: LoggingTransaction, + ) -> bool: + # We don't want to clear out too much at a time, so we bound our + # deletes. + batch_size = self._rotate_count + + txn.execute( + """ + SELECT stream_ordering FROM event_push_actions + WHERE stream_ordering <= ? AND highlight = 0 + ORDER BY stream_ordering ASC LIMIT 1 OFFSET ? + """, + ( + max_stream_ordering_to_delete, + batch_size, + ), + ) + stream_row = txn.fetchone() + + if stream_row: + (stream_ordering,) = stream_row + else: + stream_ordering = max_stream_ordering_to_delete + + # We need to use a inclusive bound here to handle the case where a + # single stream ordering has more than `batch_size` rows. + txn.execute( + """ + DELETE FROM event_push_actions + WHERE stream_ordering <= ? AND highlight = 0 + """, + (stream_ordering,), + ) + + logger.info("Rotating notifications, deleted %s push actions", txn.rowcount) + + return txn.rowcount < batch_size + + while True: + done = await self.db_pool.runInteraction( + "_remove_old_push_actions_that_have_rotated", + remove_old_push_actions_that_have_rotated_txn, + ) + if done: + break class EventPushActionsStore(EventPushActionsWorkerStore): @@ -1000,6 +1248,16 @@ class EventPushActionsStore(EventPushActionsWorkerStore): where_clause="highlight=1", ) + # Add index to make deleting old push actions faster. + self.db_pool.updates.register_background_index_update( + "event_push_actions_stream_highlight_index", + index_name="event_push_actions_stream_highlight_index", + table="event_push_actions", + columns=["highlight", "stream_ordering"], + where_clause="highlight=0", + psql_only=True, + ) + async def get_push_actions_for_user( self, user_id: str, @@ -1024,16 +1282,18 @@ class EventPushActionsStore(EventPushActionsWorkerStore): # NB. This assumes event_ids are globally unique since # it makes the query easier to index - sql = ( - "SELECT epa.event_id, epa.room_id," - " epa.stream_ordering, epa.topological_ordering," - " epa.actions, epa.highlight, epa.profile_tag, e.received_ts" - " FROM event_push_actions epa, events e" - " WHERE epa.event_id = e.event_id" - " AND epa.user_id = ? %s" - " AND epa.notif = 1" - " ORDER BY epa.stream_ordering DESC" - " LIMIT ?" % (before_clause,) + sql = """ + SELECT epa.event_id, epa.room_id, + epa.stream_ordering, epa.topological_ordering, + epa.actions, epa.highlight, epa.profile_tag, e.received_ts + FROM event_push_actions epa, events e + WHERE epa.event_id = e.event_id + AND epa.user_id = ? %s + AND epa.notif = 1 + ORDER BY epa.stream_ordering DESC + LIMIT ? + """ % ( + before_clause, ) txn.execute(sql, args) return cast( @@ -1056,7 +1316,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): ] -def _action_has_highlight(actions: List[Union[dict, str]]) -> bool: +def _action_has_highlight(actions: Collection[Union[Mapping, str]]) -> bool: for action in actions: if not isinstance(action, dict): continue @@ -1075,5 +1335,4 @@ class _EventPushSummary: unread_count: int stream_ordering: int - old_user_id: str notif_count: int diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 17e35cf63e..a4010ee28d 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py
@@ -16,6 +16,7 @@ import itertools import logging from collections import OrderedDict +from http import HTTPStatus from typing import ( TYPE_CHECKING, Any, @@ -35,9 +36,11 @@ from prometheus_client import Counter import synapse.metrics from synapse.api.constants import EventContentFields, EventTypes, RelationTypes +from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import RoomVersions from synapse.events import EventBase, relation_from_event from synapse.events.snapshot import EventContext +from synapse.logging.opentracing import trace from synapse.storage._base import db_to_json, make_in_list_sql_clause from synapse.storage.database import ( DatabasePool, @@ -46,7 +49,7 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.events_worker import EventCacheEntry from synapse.storage.databases.main.search import SearchEntry -from synapse.storage.engines.postgres import PostgresEngine +from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import AbstractStreamIdGenerator from synapse.storage.util.sequence import SequenceGenerator from synapse.types import JsonDict, StateMap, get_domain_from_id @@ -69,6 +72,24 @@ event_counter = Counter( ) +class PartialStateConflictError(SynapseError): + """An internal error raised when attempting to persist an event with partial state + after the room containing the event has been un-partial stated. + + This error should be handled by recomputing the event context and trying again. + + This error has an HTTP status code so that it can be transported over replication. + It should not be exposed to clients. + """ + + def __init__(self) -> None: + super().__init__( + HTTPStatus.CONFLICT, + msg="Cannot persist partial state event in un-partial stated room", + errcode=Codes.UNKNOWN, + ) + + @attr.s(slots=True, auto_attribs=True) class DeltaState: """Deltas to use to update the `current_state_events` table. @@ -125,6 +146,7 @@ class PersistEventsStore: self._backfill_id_gen: AbstractStreamIdGenerator = self.store._backfill_id_gen self._stream_id_gen: AbstractStreamIdGenerator = self.store._stream_id_gen + @trace async def _persist_events_and_state_updates( self, events_and_contexts: List[Tuple[EventBase, EventContext]], @@ -154,6 +176,10 @@ class PersistEventsStore: Returns: Resolves when the events have been persisted + + Raises: + PartialStateConflictError: if attempting to persist a partial state event in + a room that has been un-partial stated. """ # We want to calculate the stream orderings as late as possible, as @@ -354,6 +380,9 @@ class PersistEventsStore: For each room, a list of the event ids which are the forward extremities. + Raises: + PartialStateConflictError: if attempting to persist a partial state event in + a room that has been un-partial stated. """ state_delta_for_room = state_delta_for_room or {} new_forward_extremities = new_forward_extremities or {} @@ -980,16 +1009,16 @@ class PersistEventsStore: self, room_id: str, state_delta: DeltaState, - stream_id: int, ) -> None: """Update the current state stored in the datatabase for the given room""" - await self.db_pool.runInteraction( - "update_current_state", - self._update_current_state_txn, - state_delta_by_room={room_id: state_delta}, - stream_id=stream_id, - ) + async with self._stream_id_gen.get_next() as stream_ordering: + await self.db_pool.runInteraction( + "update_current_state", + self._update_current_state_txn, + state_delta_by_room={room_id: state_delta}, + stream_id=stream_ordering, + ) def _update_current_state_txn( self, @@ -1266,7 +1295,7 @@ class PersistEventsStore: depth_updates: Dict[str, int] = {} for event, context in events_and_contexts: # Remove the any existing cache entries for the event_ids - txn.call_after(self.store._invalidate_get_event_cache, event.event_id) + self.store.invalidate_get_event_cache_after_txn(txn, event.event_id) # Then update the `stream_ordering` position to mark the latest # event as the front of the room. This should not be done for # backfilled events because backfilled events have negative @@ -1304,6 +1333,10 @@ class PersistEventsStore: Returns: new list, without events which are already in the events table. + + Raises: + PartialStateConflictError: if attempting to persist a partial state event in + a room that has been un-partial stated. """ txn.execute( "SELECT event_id, outlier FROM events WHERE event_id in (%s)" @@ -1315,9 +1348,24 @@ class PersistEventsStore: event_id: outlier for event_id, outlier in txn } + logger.debug( + "_update_outliers_txn: events=%s have_persisted=%s", + [ev.event_id for ev, _ in events_and_contexts], + have_persisted, + ) + to_remove = set() for event, context in events_and_contexts: - if event.event_id not in have_persisted: + outlier_persisted = have_persisted.get(event.event_id) + logger.debug( + "_update_outliers_txn: event=%s outlier=%s outlier_persisted=%s", + event.event_id, + event.internal_metadata.is_outlier(), + outlier_persisted, + ) + + # Ignore events which we haven't persisted at all + if outlier_persisted is None: continue to_remove.add(event) @@ -1327,7 +1375,6 @@ class PersistEventsStore: # was an outlier or not - what we have is at least as good. continue - outlier_persisted = have_persisted[event.event_id] if not event.internal_metadata.is_outlier() and outlier_persisted: # We received a copy of an event that we had already stored as # an outlier in the database. We now have some state at that event @@ -1338,7 +1385,10 @@ class PersistEventsStore: # events down /sync. In general they will be historical events, so that # doesn't matter too much, but that is not always the case. - logger.info("Updating state for ex-outlier event %s", event.event_id) + logger.info( + "_update_outliers_txn: Updating state for ex-outlier event %s", + event.event_id, + ) # insert into event_to_state_groups. try: @@ -1442,7 +1492,7 @@ class PersistEventsStore: event.sender, "url" in event.content and isinstance(event.content["url"], str), event.get_state_key(), - context.rejected or None, + context.rejected, ) for event, context in events_and_contexts ), @@ -1638,13 +1688,13 @@ class PersistEventsStore: if not row["rejects"] and not row["redacts"]: to_prefill.append(EventCacheEntry(event=event, redacted_event=None)) - def prefill() -> None: + async def prefill() -> None: for cache_entry in to_prefill: - self.store._get_event_cache.set( + await self.store._get_event_cache.set( (cache_entry.event.event_id,), cache_entry ) - txn.call_after(prefill) + txn.async_call_after(prefill) def _store_redaction(self, txn: LoggingTransaction, event: EventBase) -> None: """Invalidate the caches for the redacted event. @@ -1653,7 +1703,7 @@ class PersistEventsStore: _invalidate_caches_for_event. """ assert event.redacts is not None - txn.call_after(self.store._invalidate_get_event_cache, event.redacts) + self.store.invalidate_get_event_cache_after_txn(txn, event.redacts) txn.call_after(self.store.get_relations_for_event.invalidate, (event.redacts,)) txn.call_after(self.store.get_applicable_edit.invalidate, (event.redacts,)) @@ -1766,6 +1816,18 @@ class PersistEventsStore: self.store.get_invited_rooms_for_local_user.invalidate, (event.state_key,), ) + txn.call_after( + self.store.get_local_users_in_room.invalidate, + (event.room_id,), + ) + txn.call_after( + self.store.get_number_joined_users_in_room.invalidate, + (event.room_id,), + ) + txn.call_after( + self.store.get_user_in_room_with_profile.invalidate, + (event.room_id, event.state_key), + ) # The `_get_membership_from_event_id` is immutable, except for the # case where we look up an event *before* persisting it. @@ -2215,6 +2277,11 @@ class PersistEventsStore: txn: LoggingTransaction, events_and_contexts: Collection[Tuple[EventBase, EventContext]], ) -> None: + """ + Raises: + PartialStateConflictError: if attempting to persist a partial state event in + a room that has been un-partial stated. + """ state_groups = {} for event, context in events_and_contexts: if event.internal_metadata.is_outlier(): @@ -2239,19 +2306,37 @@ class PersistEventsStore: # if we have partial state for these events, record the fact. (This happens # here rather than in _store_event_txn because it also needs to happen when # we de-outlier an event.) - self.db_pool.simple_insert_many_txn( - txn, - table="partial_state_events", - keys=("room_id", "event_id"), - values=[ - ( - event.room_id, - event.event_id, - ) - for event, ctx in events_and_contexts - if ctx.partial_state - ], - ) + try: + self.db_pool.simple_insert_many_txn( + txn, + table="partial_state_events", + keys=("room_id", "event_id"), + values=[ + ( + event.room_id, + event.event_id, + ) + for event, ctx in events_and_contexts + if ctx.partial_state + ], + ) + except self.db_pool.engine.module.IntegrityError: + logger.info( + "Cannot persist events %s in rooms %s: room has been un-partial stated", + [ + event.event_id + for event, ctx in events_and_contexts + if ctx.partial_state + ], + list( + { + event.room_id + for event, ctx in events_and_contexts + if ctx.partial_state + } + ), + ) + raise PartialStateConflictError() self.db_pool.simple_upsert_many_txn( txn, @@ -2296,11 +2381,9 @@ class PersistEventsStore: self.db_pool.simple_insert_many_txn( txn, table="event_edges", - keys=("event_id", "prev_event_id", "room_id", "is_state"), + keys=("event_id", "prev_event_id"), values=[ - (ev.event_id, e_id, ev.room_id, False) - for ev in events - for e_id in ev.prev_event_ids() + (ev.event_id, e_id) for ev in events for e_id in ev.prev_event_ids() ], ) diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index d5f0059665..6e8aeed7b4 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -1,4 +1,4 @@ -# Copyright 2019-2021 The Matrix.org Foundation C.I.C. +# Copyright 2019-2022 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -64,6 +64,11 @@ class _BackgroundUpdates: INDEX_STREAM_ORDERING2_TS = "index_stream_ordering2_ts" REPLACE_STREAM_ORDERING_COLUMN = "replace_stream_ordering_column" + EVENT_EDGES_DROP_INVALID_ROWS = "event_edges_drop_invalid_rows" + EVENT_EDGES_REPLACE_INDEX = "event_edges_replace_index" + + EVENTS_POPULATE_STATE_KEY_REJECTIONS = "events_populate_state_key_rejections" + @attr.s(slots=True, frozen=True, auto_attribs=True) class _CalculateChainCover: @@ -177,11 +182,6 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): self._purged_chain_cover_index, ) - # The event_thread_relation background update was replaced with the - # event_arbitrary_relations one, which handles any relation to avoid - # needed to potentially crawl the entire events table in the future. - self.db_pool.updates.register_noop_background_update("event_thread_relation") - self.db_pool.updates.register_background_update_handler( "event_arbitrary_relations", self._event_arbitrary_relations, @@ -240,6 +240,26 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): ################################################################################ + self.db_pool.updates.register_background_update_handler( + _BackgroundUpdates.EVENT_EDGES_DROP_INVALID_ROWS, + self._background_drop_invalid_event_edges_rows, + ) + + self.db_pool.updates.register_background_index_update( + _BackgroundUpdates.EVENT_EDGES_REPLACE_INDEX, + index_name="event_edges_event_id_prev_event_id_idx", + table="event_edges", + columns=["event_id", "prev_event_id"], + unique=True, + # the old index which just covered event_id is now redundant. + replaces_index="ev_edges_id", + ) + + self.db_pool.updates.register_background_update_handler( + _BackgroundUpdates.EVENTS_POPULATE_STATE_KEY_REJECTIONS, + self._background_events_populate_state_key_rejections, + ) + async def _background_reindex_fields_sender( self, progress: JsonDict, batch_size: int ) -> int: @@ -1290,3 +1310,179 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): ) return 0 + + async def _background_drop_invalid_event_edges_rows( + self, progress: JsonDict, batch_size: int + ) -> int: + """Drop invalid rows from event_edges + + This only runs for postgres. For SQLite, it all happens synchronously. + + Firstly, drop any rows with is_state=True. These may have been added a long time + ago, but they are no longer used. + + We also drop rows that do not correspond to entries in `events`, and add a + foreign key. + """ + + last_event_id = progress.get("last_event_id", "") + + def drop_invalid_event_edges_txn(txn: LoggingTransaction) -> bool: + """Returns True if we're done.""" + + # first we need to find an endpoint. + txn.execute( + """ + SELECT event_id FROM event_edges + WHERE event_id > ? + ORDER BY event_id + LIMIT 1 OFFSET ? + """, + (last_event_id, batch_size), + ) + + endpoint = None + row = txn.fetchone() + + if row: + endpoint = row[0] + + where_clause = "ee.event_id > ?" + args = [last_event_id] + if endpoint: + where_clause += " AND ee.event_id <= ?" + args.append(endpoint) + + # now delete any that: + # - have is_state=TRUE, or + # - do not correspond to a row in `events` + txn.execute( + f""" + DELETE FROM event_edges + WHERE event_id IN ( + SELECT ee.event_id + FROM event_edges ee + LEFT JOIN events ev USING (event_id) + WHERE ({where_clause}) AND + (is_state OR ev.event_id IS NULL) + )""", + args, + ) + + logger.info( + "cleaned up event_edges up to %s: removed %i/%i rows", + endpoint, + txn.rowcount, + batch_size, + ) + + if endpoint is not None: + self.db_pool.updates._background_update_progress_txn( + txn, + _BackgroundUpdates.EVENT_EDGES_DROP_INVALID_ROWS, + {"last_event_id": endpoint}, + ) + return False + + # if that was the final batch, we validate the foreign key. + # + # The constraint should have been in place and enforced for new rows since + # before we started deleting invalid rows, so there's no chance for any + # invalid rows to have snuck in the meantime. In other words, this really + # ought to succeed. + logger.info("cleaned up event_edges; enabling foreign key") + txn.execute( + "ALTER TABLE event_edges VALIDATE CONSTRAINT event_edges_event_id_fkey" + ) + return True + + done = await self.db_pool.runInteraction( + desc="drop_invalid_event_edges", func=drop_invalid_event_edges_txn + ) + + if done: + await self.db_pool.updates._end_background_update( + _BackgroundUpdates.EVENT_EDGES_DROP_INVALID_ROWS + ) + + return batch_size + + async def _background_events_populate_state_key_rejections( + self, progress: JsonDict, batch_size: int + ) -> int: + """Back-populate `events.state_key` and `events.rejection_reason""" + + min_stream_ordering_exclusive = progress["min_stream_ordering_exclusive"] + max_stream_ordering_inclusive = progress["max_stream_ordering_inclusive"] + + def _populate_txn(txn: LoggingTransaction) -> bool: + """Returns True if we're done.""" + + # first we need to find an endpoint. + # we need to find the final row in the batch of batch_size, which means + # we need to skip over (batch_size-1) rows and get the next row. + txn.execute( + """ + SELECT stream_ordering FROM events + WHERE stream_ordering > ? AND stream_ordering <= ? + ORDER BY stream_ordering + LIMIT 1 OFFSET ? + """, + ( + min_stream_ordering_exclusive, + max_stream_ordering_inclusive, + batch_size - 1, + ), + ) + + endpoint = None + row = txn.fetchone() + if row: + endpoint = row[0] + + where_clause = "stream_ordering > ?" + args = [min_stream_ordering_exclusive] + if endpoint: + where_clause += " AND stream_ordering <= ?" + args.append(endpoint) + + # now do the updates. + txn.execute( + f""" + UPDATE events + SET state_key = (SELECT state_key FROM state_events se WHERE se.event_id = events.event_id), + rejection_reason = (SELECT reason FROM rejections rej WHERE rej.event_id = events.event_id) + WHERE ({where_clause}) + """, + args, + ) + + logger.info( + "populated new `events` columns up to %s/%i: updated %i rows", + endpoint, + max_stream_ordering_inclusive, + txn.rowcount, + ) + + if endpoint is None: + # we're done + return True + + progress["min_stream_ordering_exclusive"] = endpoint + self.db_pool.updates._background_update_progress_txn( + txn, + _BackgroundUpdates.EVENTS_POPULATE_STATE_KEY_REJECTIONS, + progress, + ) + return False + + done = await self.db_pool.runInteraction( + desc="events_populate_state_key_rejections", func=_populate_txn + ) + + if done: + await self.db_pool.updates._end_background_update( + _BackgroundUpdates.EVENTS_POPULATE_STATE_KEY_REJECTIONS + ) + + return batch_size diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index b99b107784..9b997c304d 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py
@@ -54,6 +54,7 @@ from synapse.logging.context import ( current_context, make_deferred_yieldable, ) +from synapse.logging.opentracing import start_active_span, tag_args, trace from synapse.metrics.background_process_metrics import ( run_as_background_process, wrap_as_background_process, @@ -79,7 +80,7 @@ from synapse.types import JsonDict, get_domain_from_id from synapse.util import unwrapFirstError from synapse.util.async_helpers import ObservableDeferred, delay_cancellation from synapse.util.caches.descriptors import cached, cachedList -from synapse.util.caches.lrucache import LruCache +from synapse.util.caches.lrucache import AsyncLruCache from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure @@ -238,7 +239,9 @@ class EventsWorkerStore(SQLBaseStore): 5 * 60 * 1000, ) - self._get_event_cache: LruCache[Tuple[str], EventCacheEntry] = LruCache( + self._get_event_cache: AsyncLruCache[ + Tuple[str], EventCacheEntry + ] = AsyncLruCache( cache_name="*getEvent*", max_size=hs.config.caches.event_cache_size, ) @@ -292,25 +295,6 @@ class EventsWorkerStore(SQLBaseStore): super().process_replication_rows(stream_name, instance_name, token, rows) - async def get_received_ts(self, event_id: str) -> Optional[int]: - """Get received_ts (when it was persisted) for the event. - - Raises an exception for unknown events. - - Args: - event_id: The event ID to query. - - Returns: - Timestamp in milliseconds, or None for events that were persisted - before received_ts was implemented. - """ - return await self.db_pool.simple_select_one_onecol( - table="events", - keyvalues={"event_id": event_id}, - retcol="received_ts", - desc="get_received_ts", - ) - async def have_censored_event(self, event_id: str) -> bool: """Check if an event has been censored, i.e. if the content of the event has been erased from the database due to a redaction. @@ -447,6 +431,8 @@ class EventsWorkerStore(SQLBaseStore): return {e.event_id: e for e in events} + @trace + @tag_args async def get_events_as_list( self, event_ids: Collection[str], @@ -617,7 +603,11 @@ class EventsWorkerStore(SQLBaseStore): Returns: map from event id to result """ - event_entry_map = self._get_events_from_cache( + # Shortcut: check if we have any events in the *in memory* cache - this function + # may be called repeatedly for the same event so at this point we cannot reach + # out to any external cache for performance reasons. The external cache is + # checked later on in the `get_missing_events_from_cache_or_db` function below. + event_entry_map = self._get_events_from_local_cache( event_ids, ) @@ -649,7 +639,9 @@ class EventsWorkerStore(SQLBaseStore): if missing_events_ids: - async def get_missing_events_from_db() -> Dict[str, EventCacheEntry]: + async def get_missing_events_from_cache_or_db() -> Dict[ + str, EventCacheEntry + ]: """Fetches the events in `missing_event_ids` from the database. Also creates entries in `self._current_event_fetches` to allow @@ -674,10 +666,18 @@ class EventsWorkerStore(SQLBaseStore): # the events have been redacted, and if so pulling the redaction event # out of the database to check it. # + missing_events = {} try: - missing_events = await self._get_events_from_db( + # Try to fetch from any external cache. We already checked the + # in-memory cache above. + missing_events = await self._get_events_from_external_cache( missing_events_ids, ) + # Now actually fetch any remaining events from the DB + db_missing_events = await self._get_events_from_db( + missing_events_ids - missing_events.keys(), + ) + missing_events.update(db_missing_events) except Exception as e: with PreserveLoggingContext(): fetching_deferred.errback(e) @@ -696,7 +696,7 @@ class EventsWorkerStore(SQLBaseStore): # cancellations, since multiple `_get_events_from_cache_or_db` calls can # reuse the same fetch. missing_events: Dict[str, EventCacheEntry] = await delay_cancellation( - get_missing_events_from_db() + get_missing_events_from_cache_or_db() ) event_entry_map.update(missing_events) @@ -729,15 +729,96 @@ class EventsWorkerStore(SQLBaseStore): return event_entry_map - def _invalidate_get_event_cache(self, event_id: str) -> None: - self._get_event_cache.invalidate((event_id,)) + def invalidate_get_event_cache_after_txn( + self, txn: LoggingTransaction, event_id: str + ) -> None: + """ + Prepares a database transaction to invalidate the get event cache for a given + event ID when executed successfully. This is achieved by attaching two callbacks + to the transaction, one to invalidate the async cache and one for the in memory + sync cache (importantly called in that order). + + Arguments: + txn: the database transaction to attach the callbacks to + event_id: the event ID to be invalidated from caches + """ + + txn.async_call_after(self._invalidate_async_get_event_cache, event_id) + txn.call_after(self._invalidate_local_get_event_cache, event_id) + + async def _invalidate_async_get_event_cache(self, event_id: str) -> None: + """ + Invalidates an event in the asyncronous get event cache, which may be remote. + + Arguments: + event_id: the event ID to invalidate + """ + + await self._get_event_cache.invalidate((event_id,)) + + def _invalidate_local_get_event_cache(self, event_id: str) -> None: + """ + Invalidates an event in local in-memory get event caches. + + Arguments: + event_id: the event ID to invalidate + """ + + self._get_event_cache.invalidate_local((event_id,)) self._event_ref.pop(event_id, None) self._current_event_fetches.pop(event_id, None) - def _get_events_from_cache( + async def _get_events_from_cache( + self, events: Iterable[str], update_metrics: bool = True + ) -> Dict[str, EventCacheEntry]: + """Fetch events from the caches, both in memory and any external. + + May return rejected events. + + Args: + events: list of event_ids to fetch + update_metrics: Whether to update the cache hit ratio metrics + """ + event_map = self._get_events_from_local_cache( + events, update_metrics=update_metrics + ) + + missing_event_ids = (e for e in events if e not in event_map) + event_map.update( + await self._get_events_from_external_cache( + events=missing_event_ids, + update_metrics=update_metrics, + ) + ) + + return event_map + + async def _get_events_from_external_cache( + self, events: Iterable[str], update_metrics: bool = True + ) -> Dict[str, EventCacheEntry]: + """Fetch events from any configured external cache. + + May return rejected events. + + Args: + events: list of event_ids to fetch + update_metrics: Whether to update the cache hit ratio metrics + """ + event_map = {} + + for event_id in events: + ret = await self._get_event_cache.get_external( + (event_id,), None, update_metrics=update_metrics + ) + if ret: + event_map[event_id] = ret + + return event_map + + def _get_events_from_local_cache( self, events: Iterable[str], update_metrics: bool = True ) -> Dict[str, EventCacheEntry]: - """Fetch events from the caches. + """Fetch events from the local, in memory, caches. May return rejected events. @@ -749,7 +830,7 @@ class EventsWorkerStore(SQLBaseStore): for event_id in events: # First check if it's in the event cache - ret = self._get_event_cache.get( + ret = self._get_event_cache.get_local( (event_id,), None, update_metrics=update_metrics ) if ret: @@ -771,7 +852,7 @@ class EventsWorkerStore(SQLBaseStore): # We add the entry back into the cache as we want to keep # recently queried events in the cache. - self._get_event_cache.set((event_id,), cache_entry) + self._get_event_cache.set_local((event_id,), cache_entry) return event_map @@ -965,7 +1046,13 @@ class EventsWorkerStore(SQLBaseStore): } row_dict = self.db_pool.new_transaction( - conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch + conn, + "do_fetch", + [], + [], + [], + self._fetch_event_rows, + events_to_fetch, ) # We only want to resolve deferreds from the main thread @@ -1006,23 +1093,42 @@ class EventsWorkerStore(SQLBaseStore): """ fetched_event_ids: Set[str] = set() fetched_events: Dict[str, _EventRow] = {} - events_to_fetch = event_ids - while events_to_fetch: - row_map = await self._enqueue_events(events_to_fetch) + async def _fetch_event_ids_and_get_outstanding_redactions( + event_ids_to_fetch: Collection[str], + ) -> Collection[str]: + """ + Fetch all of the given event_ids and return any associated redaction event_ids + that we still need to fetch in the next iteration. + """ + row_map = await self._enqueue_events(event_ids_to_fetch) # we need to recursively fetch any redactions of those events redaction_ids: Set[str] = set() - for event_id in events_to_fetch: + for event_id in event_ids_to_fetch: row = row_map.get(event_id) fetched_event_ids.add(event_id) if row: fetched_events[event_id] = row redaction_ids.update(row.redactions) - events_to_fetch = redaction_ids.difference(fetched_event_ids) - if events_to_fetch: - logger.debug("Also fetching redaction events %s", events_to_fetch) + event_ids_to_fetch = redaction_ids.difference(fetched_event_ids) + return event_ids_to_fetch + + # Grab the initial list of events requested + event_ids_to_fetch = await _fetch_event_ids_and_get_outstanding_redactions( + event_ids + ) + # Then go and recursively find all of the associated redactions + with start_active_span("recursively fetching redactions"): + while event_ids_to_fetch: + logger.debug("Also fetching redaction events %s", event_ids_to_fetch) + + event_ids_to_fetch = ( + await _fetch_event_ids_and_get_outstanding_redactions( + event_ids_to_fetch + ) + ) # build a map from event_id to EventBase event_map: Dict[str, EventBase] = {} @@ -1148,7 +1254,7 @@ class EventsWorkerStore(SQLBaseStore): event=original_ev, redacted_event=redacted_event ) - self._get_event_cache.set((event_id,), cache_entry) + await self._get_event_cache.set((event_id,), cache_entry) result_map[event_id] = cache_entry if not redacted_event: @@ -1340,6 +1446,8 @@ class EventsWorkerStore(SQLBaseStore): return {r["event_id"] for r in rows} + @trace + @tag_args async def have_seen_events( self, room_id: str, event_ids: Iterable[str] ) -> Set[str]: @@ -1382,7 +1490,9 @@ class EventsWorkerStore(SQLBaseStore): # if the event cache contains the event, obviously we've seen it. cache_results = { - (rid, eid) for (rid, eid) in keys if self._get_event_cache.contains((eid,)) + (rid, eid) + for (rid, eid) in keys + if await self._get_event_cache.contains((eid,)) } results = dict.fromkeys(cache_results, True) remaining = [k for k in keys if k not in cache_results] @@ -1465,7 +1575,7 @@ class EventsWorkerStore(SQLBaseStore): async def get_all_new_forward_event_rows( self, instance_name: str, last_id: int, current_id: int, limit: int - ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]: + ) -> List[Tuple[int, str, str, str, str, str, str, str, bool, bool]]: """Returns new events, for the Events replication stream Args: @@ -1481,10 +1591,11 @@ class EventsWorkerStore(SQLBaseStore): def get_all_new_forward_event_rows( txn: LoggingTransaction, - ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]: + ) -> List[Tuple[int, str, str, str, str, str, str, str, bool, bool]]: sql = ( "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," - " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL" + " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL," + " e.outlier" " FROM events AS e" " LEFT JOIN redactions USING (event_id)" " LEFT JOIN state_events AS se USING (event_id)" @@ -1498,7 +1609,8 @@ class EventsWorkerStore(SQLBaseStore): ) txn.execute(sql, (last_id, current_id, instance_name, limit)) return cast( - List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall() + List[Tuple[int, str, str, str, str, str, str, str, bool, bool]], + txn.fetchall(), ) return await self.db_pool.runInteraction( @@ -1507,7 +1619,7 @@ class EventsWorkerStore(SQLBaseStore): async def get_ex_outlier_stream_rows( self, instance_name: str, last_id: int, current_id: int - ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]: + ) -> List[Tuple[int, str, str, str, str, str, str, str, bool, bool]]: """Returns de-outliered events, for the Events replication stream Args: @@ -1522,11 +1634,14 @@ class EventsWorkerStore(SQLBaseStore): def get_ex_outlier_stream_rows_txn( txn: LoggingTransaction, - ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]: + ) -> List[Tuple[int, str, str, str, str, str, str, str, bool, bool]]: sql = ( "SELECT event_stream_ordering, e.event_id, e.room_id, e.type," - " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL" + " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL," + " e.outlier" " FROM events AS e" + # NB: the next line (inner join) is what makes this query different from + # get_all_new_forward_event_rows. " INNER JOIN ex_outlier_stream AS out USING (event_id)" " LEFT JOIN redactions USING (event_id)" " LEFT JOIN state_events AS se USING (event_id)" @@ -1541,7 +1656,8 @@ class EventsWorkerStore(SQLBaseStore): txn.execute(sql, (last_id, current_id, instance_name)) return cast( - List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall() + List[Tuple[int, str, str, str, str, str, str, str, bool, bool]], + txn.fetchall(), ) return await self.db_pool.runInteraction( @@ -1995,7 +2111,14 @@ class EventsWorkerStore(SQLBaseStore): AND room_id = ? /* Make sure event is not rejected */ AND rejections.event_id IS NULL - ORDER BY origin_server_ts %s + /** + * First sort by the message timestamp. If the message timestamps are the + * same, we want the message that logically comes "next" (before/after + * the given timestamp) based on the DAG and its topological order (`depth`). + * Finally, we can tie-break based on when it was received on the server + * (`stream_ordering`). + */ + ORDER BY origin_server_ts %s, depth %s, stream_ordering %s LIMIT 1; """ @@ -2014,7 +2137,8 @@ class EventsWorkerStore(SQLBaseStore): order = "ASC" txn.execute( - sql_template % (comparison_operator, order), (timestamp, room_id) + sql_template % (comparison_operator, order, order, order), + (timestamp, room_id), ) row = txn.fetchone() if row: @@ -2079,14 +2203,92 @@ class EventsWorkerStore(SQLBaseStore): def _get_partial_state_events_batch_txn( txn: LoggingTransaction, room_id: str ) -> List[str]: + # we want to work through the events from oldest to newest, so + # we only want events whose prev_events do *not* have partial state - hence + # the 'NOT EXISTS' clause in the below. + # + # This is necessary because ordering by stream ordering isn't quite enough + # to ensure that we work from oldest to newest event (in particular, + # if an event is initially persisted as an outlier and later de-outliered, + # it can end up with a lower stream_ordering than its prev_events). + # + # Typically this means we'll only return one event per batch, but that's + # hard to do much about. + # + # See also: https://github.com/matrix-org/synapse/issues/13001 txn.execute( """ SELECT event_id FROM partial_state_events AS pse JOIN events USING (event_id) - WHERE pse.room_id = ? + WHERE pse.room_id = ? AND + NOT EXISTS( + SELECT 1 FROM event_edges AS ee + JOIN partial_state_events AS prev_pse ON (prev_pse.event_id=ee.prev_event_id) + WHERE ee.event_id=pse.event_id + ) ORDER BY events.stream_ordering LIMIT 100 """, (room_id,), ) return [row[0] for row in txn] + + def mark_event_rejected_txn( + self, + txn: LoggingTransaction, + event_id: str, + rejection_reason: Optional[str], + ) -> None: + """Mark an event that was previously accepted as rejected, or vice versa + + This can happen, for example, when resyncing state during a faster join. + + Args: + txn: + event_id: ID of event to update + rejection_reason: reason it has been rejected, or None if it is now accepted + """ + if rejection_reason is None: + logger.info( + "Marking previously-processed event %s as accepted", + event_id, + ) + self.db_pool.simple_delete_txn( + txn, + "rejections", + keyvalues={"event_id": event_id}, + ) + else: + logger.info( + "Marking previously-processed event %s as rejected(%s)", + event_id, + rejection_reason, + ) + self.db_pool.simple_upsert_txn( + txn, + table="rejections", + keyvalues={"event_id": event_id}, + values={ + "reason": rejection_reason, + "last_check": self._clock.time_msec(), + }, + ) + self.db_pool.simple_update_txn( + txn, + table="events", + keyvalues={"event_id": event_id}, + updatevalues={"rejection_reason": rejection_reason}, + ) + + self.invalidate_get_event_cache_after_txn(txn, event_id) + + # TODO(faster_joins): invalidate the cache on workers. Ideally we'd just + # call '_send_invalidation_to_replication', but we actually need the other + # end to call _invalidate_local_get_event_cache() rather than (just) + # _get_event_cache.invalidate(). + # + # One solution might be to (somehow) get the workers to call + # _invalidate_caches_for_event() (though that will invalidate more than + # strictly necessary). + # + # https://github.com/matrix-org/synapse/issues/12994 diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py deleted file mode 100644
index c15a7136b6..0000000000 --- a/synapse/storage/databases/main/group_server.py +++ /dev/null
@@ -1,34 +0,0 @@ -# Copyright 2017 Vector Creations Ltd -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import TYPE_CHECKING - -from synapse.storage._base import SQLBaseStore -from synapse.storage.database import DatabasePool, LoggingDatabaseConnection - -if TYPE_CHECKING: - from synapse.server import HomeServer - - -class GroupServerStore(SQLBaseStore): - def __init__( - self, - database: DatabasePool, - db_conn: LoggingDatabaseConnection, - hs: "HomeServer", - ): - # Register a legacy groups background update as a no-op. - database.updates.register_noop_background_update("local_group_updates_index") - super().__init__(database, db_conn, hs) diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index d028be16de..9b172a64d8 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py
@@ -37,9 +37,6 @@ from synapse.types import JsonDict, UserID if TYPE_CHECKING: from synapse.server import HomeServer -BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD = ( - "media_repository_drop_index_wo_method" -) BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2 = ( "media_repository_drop_index_wo_method_2" ) @@ -111,13 +108,6 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): unique=True, ) - # the original impl of _drop_media_index_without_method was broken (see - # https://github.com/matrix-org/synapse/issues/8649), so we replace the original - # impl with a no-op and run the fixed migration as - # media_repository_drop_index_wo_method_2. - self.db_pool.updates.register_noop_background_update( - BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD - ) self.db_pool.updates.register_background_update_handler( BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2, self._drop_media_index_without_method, diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index 9a63f953fb..efd136a864 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -66,6 +66,7 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore): "initialise_mau_threepids", [], [], + [], self._initialise_reserved_users, hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value], ) diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index ba385f9fc4..f6822707e4 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py
@@ -19,6 +19,8 @@ from synapse.api.errors import SynapseError from synapse.storage.database import LoggingTransaction from synapse.storage.databases.main import CacheInvalidationWorkerStore from synapse.storage.databases.main.state import StateGroupWorkerStore +from synapse.storage.engines import PostgresEngine +from synapse.storage.engines._base import IsolationLevel from synapse.types import RoomStreamToken logger = logging.getLogger(__name__) @@ -214,10 +216,10 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): # Delete all remote non-state events for table in ( + "event_edges", "events", "event_json", "event_auth", - "event_edges", "event_forward_extremities", "event_relations", "event_search", @@ -302,7 +304,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): self._invalidate_cache_and_stream( txn, self.have_seen_event, (room_id, event_id) ) - self._invalidate_get_event_cache(event_id) + self.invalidate_get_event_cache_after_txn(txn, event_id) logger.info("[purge] done") @@ -317,11 +319,38 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): Returns: The list of state groups to delete. """ - return await self.db_pool.runInteraction( - "purge_room", self._purge_room_txn, room_id + + # This first runs the purge transaction with READ_COMMITTED isolation level, + # meaning any new rows in the tables will not trigger a serialization error. + # We then run the same purge a second time without this isolation level to + # purge any of those rows which were added during the first. + + state_groups_to_delete = await self.db_pool.runInteraction( + "purge_room", + self._purge_room_txn, + room_id=room_id, + isolation_level=IsolationLevel.READ_COMMITTED, + ) + + state_groups_to_delete.extend( + await self.db_pool.runInteraction( + "purge_room", + self._purge_room_txn, + room_id=room_id, + ), ) + return state_groups_to_delete + def _purge_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[int]: + # This collides with event persistence so we cannot write new events and metadata into + # a room while deleting it or this transaction will fail. + if isinstance(self.database_engine, PostgresEngine): + txn.execute( + "SELECT room_version FROM rooms WHERE room_id = ? FOR UPDATE", + (room_id,), + ) + # First, fetch all the state groups that should be deleted, before # we delete that information. txn.execute( diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index d5aefe02b6..5079edd1e0 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py
@@ -14,11 +14,23 @@ # limitations under the License. import abc import logging -from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Tuple, Union, cast +from typing import ( + TYPE_CHECKING, + Any, + Collection, + Dict, + List, + Mapping, + Optional, + Sequence, + Tuple, + Union, + cast, +) from synapse.api.errors import StoreError from synapse.config.homeserver import ExperimentalConfig -from synapse.push.baserules import list_with_base_rules +from synapse.push.baserules import FilteredPushRules, PushRule, compile_push_rules from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import ( @@ -50,69 +62,39 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -def _is_experimental_rule_enabled( - rule_id: str, experimental_config: ExperimentalConfig -) -> bool: - """Used by `_load_rules` to filter out experimental rules when they - have not been enabled. - """ - if ( - rule_id == "global/override/.org.matrix.msc3786.rule.room.server_acl" - and not experimental_config.msc3786_enabled - ): - return False - if ( - rule_id == "global/underride/.org.matrix.msc3772.thread_reply" - and not experimental_config.msc3772_enabled - ): - return False - return True - - def _load_rules( rawrules: List[JsonDict], enabled_map: Dict[str, bool], experimental_config: ExperimentalConfig, -) -> List[JsonDict]: - ruleslist = [] - for rawrule in rawrules: - rule = dict(rawrule) - rule["conditions"] = db_to_json(rawrule["conditions"]) - rule["actions"] = db_to_json(rawrule["actions"]) - rule["default"] = False - ruleslist.append(rule) - - # We're going to be mutating this a lot, so copy it. We also filter out - # any experimental default push rules that aren't enabled. - rules = [ - rule - for rule in list_with_base_rules(ruleslist) - if _is_experimental_rule_enabled(rule["rule_id"], experimental_config) - ] +) -> FilteredPushRules: + """Take the DB rows returned from the DB and convert them into a full + `FilteredPushRules` object. + """ - for i, rule in enumerate(rules): - rule_id = rule["rule_id"] + ruleslist = [ + PushRule( + rule_id=rawrule["rule_id"], + priority_class=rawrule["priority_class"], + conditions=db_to_json(rawrule["conditions"]), + actions=db_to_json(rawrule["actions"]), + ) + for rawrule in rawrules + ] - if rule_id not in enabled_map: - continue - if rule.get("enabled", True) == bool(enabled_map[rule_id]): - continue + push_rules = compile_push_rules(ruleslist) - # Rules are cached across users. - rule = dict(rule) - rule["enabled"] = bool(enabled_map[rule_id]) - rules[i] = rule + filtered_rules = FilteredPushRules(push_rules, enabled_map, experimental_config) - return rules + return filtered_rules # The ABCMeta metaclass ensures that it cannot be instantiated without # the abstract methods being implemented. class PushRulesWorkerStore( ApplicationServiceWorkerStore, - ReceiptsWorkerStore, PusherWorkerStore, RoomMemberWorkerStore, + ReceiptsWorkerStore, EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta, @@ -162,7 +144,7 @@ class PushRulesWorkerStore( raise NotImplementedError() @cached(max_entries=5000) - async def get_push_rules_for_user(self, user_id: str) -> List[JsonDict]: + async def get_push_rules_for_user(self, user_id: str) -> FilteredPushRules: rows = await self.db_pool.simple_select_list( table="push_rules", keyvalues={"user_name": user_id}, @@ -183,7 +165,6 @@ class PushRulesWorkerStore( return _load_rules(rows, enabled_map, self.hs.config.experimental) - @cached(max_entries=5000) async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]: results = await self.db_pool.simple_select_list( table="push_rules_enable", @@ -216,11 +197,11 @@ class PushRulesWorkerStore( @cachedList(cached_method_name="get_push_rules_for_user", list_name="user_ids") async def bulk_get_push_rules( self, user_ids: Collection[str] - ) -> Dict[str, List[JsonDict]]: + ) -> Dict[str, FilteredPushRules]: if not user_ids: return {} - results: Dict[str, List[JsonDict]] = {user_id: [] for user_id in user_ids} + raw_rules: Dict[str, List[JsonDict]] = {user_id: [] for user_id in user_ids} rows = await self.db_pool.simple_select_many_batch( table="push_rules", @@ -228,25 +209,25 @@ class PushRulesWorkerStore( iterable=user_ids, retcols=("*",), desc="bulk_get_push_rules", + batch_size=1000, ) rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))) for row in rows: - results.setdefault(row["user_name"], []).append(row) + raw_rules.setdefault(row["user_name"], []).append(row) enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids) - for user_id, rules in results.items(): + results: Dict[str, FilteredPushRules] = {} + + for user_id, rules in raw_rules.items(): results[user_id] = _load_rules( rules, enabled_map_by_user.get(user_id, {}), self.hs.config.experimental ) return results - @cachedList( - cached_method_name="get_push_rules_enabled_for_user", list_name="user_ids" - ) async def bulk_get_push_rules_enabled( self, user_ids: Collection[str] ) -> Dict[str, Dict[str, bool]]: @@ -261,6 +242,7 @@ class PushRulesWorkerStore( iterable=user_ids, retcols=("user_name", "rule_id", "enabled"), desc="bulk_get_push_rules_enabled", + batch_size=1000, ) for row in rows: enabled = bool(row["enabled"]) @@ -344,8 +326,8 @@ class PushRuleStore(PushRulesWorkerStore): user_id: str, rule_id: str, priority_class: int, - conditions: List[Dict[str, str]], - actions: List[Union[JsonDict, str]], + conditions: Sequence[Mapping[str, str]], + actions: Sequence[Union[Mapping[str, Any], str]], before: Optional[str] = None, after: Optional[str] = None, ) -> None: @@ -807,7 +789,6 @@ class PushRuleStore(PushRulesWorkerStore): self.db_pool.simple_insert_txn(txn, "push_rules_stream", values=values) txn.call_after(self.get_push_rules_for_user.invalidate, (user_id,)) - txn.call_after(self.get_push_rules_enabled_for_user.invalidate, (user_id,)) txn.call_after( self.push_rules_stream_cache.entity_has_changed, user_id, stream_id ) @@ -816,7 +797,7 @@ class PushRuleStore(PushRulesWorkerStore): return self._push_rules_stream_id_gen.get_current_token() async def copy_push_rule_from_room_to_room( - self, new_room_id: str, user_id: str, rule: dict + self, new_room_id: str, user_id: str, rule: PushRule ) -> None: """Copy a single push rule from one room to another for a specific user. @@ -826,21 +807,27 @@ class PushRuleStore(PushRulesWorkerStore): rule: A push rule. """ # Create new rule id - rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1]) + rule_id_scope = "/".join(rule.rule_id.split("/")[:-1]) new_rule_id = rule_id_scope + "/" + new_room_id + new_conditions = [] + # Change room id in each condition - for condition in rule.get("conditions", []): + for condition in rule.conditions: + new_condition = condition if condition.get("key") == "room_id": - condition["pattern"] = new_room_id + new_condition = dict(condition) + new_condition["pattern"] = new_room_id + + new_conditions.append(new_condition) # Add the rule for the new room await self.add_push_rule( user_id=user_id, rule_id=new_rule_id, - priority_class=rule["priority_class"], - conditions=rule["conditions"], - actions=rule["actions"], + priority_class=rule.priority_class, + conditions=new_conditions, + actions=rule.actions, ) async def copy_push_rules_from_room_to_room_for_user( @@ -858,8 +845,11 @@ class PushRuleStore(PushRulesWorkerStore): user_push_rules = await self.get_push_rules_for_user(user_id) # Get rules relating to the old room and copy them to the new room - for rule in user_push_rules: - conditions = rule.get("conditions", []) + for rule, enabled in user_push_rules: + if not enabled: + continue + + conditions = rule.conditions if any( (c.get("key") == "room_id" and c.get("pattern") == old_room_id) for c in conditions diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 21e954ccc1..124c70ad37 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py
@@ -26,7 +26,7 @@ from typing import ( cast, ) -from synapse.api.constants import EduTypes, ReceiptTypes +from synapse.api.constants import EduTypes from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import ReceiptsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause @@ -36,6 +36,7 @@ from synapse.storage.database import ( LoggingTransaction, ) from synapse.storage.engines import PostgresEngine +from synapse.storage.engines._base import IsolationLevel from synapse.storage.util.id_generators import ( AbstractStreamIdTracker, MultiWriterIdGenerator, @@ -117,7 +118,7 @@ class ReceiptsWorkerStore(SQLBaseStore): return self._receipts_id_gen.get_current_token() async def get_last_receipt_event_id_for_user( - self, user_id: str, room_id: str, receipt_types: Iterable[str] + self, user_id: str, room_id: str, receipt_types: Collection[str] ) -> Optional[str]: """ Fetch the event ID for the latest receipt in a room with one of the given receipt types. @@ -125,58 +126,63 @@ class ReceiptsWorkerStore(SQLBaseStore): Args: user_id: The user to fetch receipts for. room_id: The room ID to fetch the receipt for. - receipt_type: The receipt types to fetch. Earlier receipt types - are given priority if multiple receipts point to the same event. + receipt_type: The receipt types to fetch. Returns: The latest receipt, if one exists. """ - latest_event_id: Optional[str] = None - latest_stream_ordering = 0 - for receipt_type in receipt_types: - result = await self._get_last_receipt_event_id_for_user( - user_id, room_id, receipt_type - ) - if result is None: - continue - event_id, stream_ordering = result - - if latest_event_id is None or latest_stream_ordering < stream_ordering: - latest_event_id = event_id - latest_stream_ordering = stream_ordering + result = await self.db_pool.runInteraction( + "get_last_receipt_event_id_for_user", + self.get_last_receipt_for_user_txn, + user_id, + room_id, + receipt_types, + ) + if not result: + return None - return latest_event_id + event_id, _ = result + return event_id - @cached() - async def _get_last_receipt_event_id_for_user( - self, user_id: str, room_id: str, receipt_type: str + def get_last_receipt_for_user_txn( + self, + txn: LoggingTransaction, + user_id: str, + room_id: str, + receipt_types: Collection[str], ) -> Optional[Tuple[str, int]]: """ - Fetch the event ID and stream ordering for the latest receipt. + Fetch the event ID and stream_ordering for the latest receipt in a room + with one of the given receipt types. Args: user_id: The user to fetch receipts for. room_id: The room ID to fetch the receipt for. - receipt_type: The receipt type to fetch. + receipt_type: The receipt types to fetch. Returns: - The event ID and stream ordering of the latest receipt, if one exists; - otherwise `None`. + The event ID and stream ordering of the latest receipt, if one exists. """ - sql = """ + + clause, args = make_in_list_sql_clause( + self.database_engine, "receipt_type", receipt_types + ) + + sql = f""" SELECT event_id, stream_ordering FROM receipts_linearized INNER JOIN events USING (room_id, event_id) - WHERE user_id = ? + WHERE {clause} + AND user_id = ? AND room_id = ? - AND receipt_type = ? + ORDER BY stream_ordering DESC + LIMIT 1 """ - def f(txn: LoggingTransaction) -> Optional[Tuple[str, int]]: - txn.execute(sql, (user_id, room_id, receipt_type)) - return cast(Optional[Tuple[str, int]], txn.fetchone()) + args.extend((user_id, room_id)) + txn.execute(sql, args) - return await self.db_pool.runInteraction("get_own_receipt_for_user", f) + return cast(Optional[Tuple[str, int]], txn.fetchone()) async def get_receipts_for_user( self, user_id: str, receipt_types: Iterable[str] @@ -576,8 +582,11 @@ class ReceiptsWorkerStore(SQLBaseStore): ) -> None: self._get_receipts_for_user_with_orderings.invalidate((user_id, receipt_type)) self._get_linearized_receipts_for_room.invalidate((room_id,)) - self._get_last_receipt_event_id_for_user.invalidate( - (user_id, room_id, receipt_type) + + # We use this method to invalidate so that we don't end up with circular + # dependencies between the receipts and push action stores. + self._attempt_to_invalidate_cache( + "get_unread_event_push_actions_by_room_for_user", (room_id,) ) def process_replication_rows( @@ -673,17 +682,6 @@ class ReceiptsWorkerStore(SQLBaseStore): lock=False, ) - # When updating a local users read receipt, remove any push actions - # which resulted from the receipt's event and all earlier events. - if ( - self.hs.is_mine_id(user_id) - and receipt_type in (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE) - and stream_ordering is not None - ): - self._remove_old_push_actions_before_txn( # type: ignore[attr-defined] - txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering - ) - return rx_ts def _graph_to_linear( @@ -764,6 +762,10 @@ class ReceiptsWorkerStore(SQLBaseStore): linearized_event_id, data, stream_id=stream_id, + # Read committed is actually beneficial here because we check for a receipt with + # greater stream order, and checking the very latest data at select time is better + # than the data at transaction start time. + isolation_level=IsolationLevel.READ_COMMITTED, ) # If the receipt was older than the currently persisted one, nothing to do. diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 4991360b70..7fb9c801da 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py
@@ -69,9 +69,9 @@ class TokenLookupResult: """ user_id: str + token_id: int is_guest: bool = False shadow_banned: bool = False - token_id: Optional[int] = None device_id: Optional[str] = None valid_until_ms: Optional[int] = None token_owner: str = attr.ib() @@ -1805,21 +1805,10 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): columns=["creation_ts"], ) - # we no longer use refresh tokens, but it's possible that some people - # might have a background update queued to build this index. Just - # clear the background update. - self.db_pool.updates.register_noop_background_update( - "refresh_tokens_device_index" - ) - self.db_pool.updates.register_background_update_handler( "users_set_deactivated_flag", self._background_update_set_deactivated_flag ) - self.db_pool.updates.register_noop_background_update( - "user_threepids_grandfather" - ) - self.db_pool.updates.register_background_index_update( "user_external_ids_user_id_idx", index_name="user_external_ids_user_id_idx", diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index b457bc189e..7bd27790eb 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py
@@ -62,7 +62,6 @@ class RelationsWorkerStore(SQLBaseStore): room_id: str, relation_type: Optional[str] = None, event_type: Optional[str] = None, - aggregation_key: Optional[str] = None, limit: int = 5, direction: str = "b", from_token: Optional[StreamToken] = None, @@ -76,7 +75,6 @@ class RelationsWorkerStore(SQLBaseStore): room_id: The room the event belongs to. relation_type: Only fetch events with this relation type, if given. event_type: Only fetch events with this event type, if given. - aggregation_key: Only fetch events with this aggregation key, if given. limit: Only fetch the most recent `limit` events. direction: Whether to fetch the most recent first (`"b"`) or the oldest first (`"f"`). @@ -105,10 +103,6 @@ class RelationsWorkerStore(SQLBaseStore): where_clause.append("type = ?") where_args.append(event_type) - if aggregation_key: - where_clause.append("aggregation_key = ?") - where_args.append(aggregation_key) - pagination_clause = generate_pagination_where_clause( direction=direction, column_names=("topological_ordering", "stream_ordering"), diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 68d4fc2e64..bef66f1992 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py
@@ -32,12 +32,17 @@ from typing import ( import attr -from synapse.api.constants import EventContentFields, EventTypes, JoinRules +from synapse.api.constants import ( + EventContentFields, + EventTypes, + JoinRules, + PublicRoomsFilterFields, +) from synapse.api.errors import StoreError from synapse.api.room_versions import RoomVersion, RoomVersions from synapse.config.homeserver import HomeServerConfig from synapse.events import EventBase -from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( DatabasePool, LoggingDatabaseConnection, @@ -170,7 +175,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): rooms.creator, state.encryption, state.is_federatable AS federatable, rooms.is_public AS public, state.join_rules, state.guest_access, state.history_visibility, curr.current_state_events AS state_events, - state.avatar, state.topic + state.avatar, state.topic, state.room_type FROM rooms LEFT JOIN room_stats_state state USING (room_id) LEFT JOIN room_stats_current curr USING (room_id) @@ -199,10 +204,29 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): desc="get_public_room_ids", ) + def _construct_room_type_where_clause( + self, room_types: Union[List[Union[str, None]], None] + ) -> Tuple[Union[str, None], List[str]]: + if not room_types: + return None, [] + else: + # We use None when we want get rooms without a type + is_null_clause = "" + if None in room_types: + is_null_clause = "OR room_type IS NULL" + room_types = [value for value in room_types if value is not None] + + list_clause, args = make_in_list_sql_clause( + self.database_engine, "room_type", room_types + ) + + return f"({list_clause} {is_null_clause})", args + async def count_public_rooms( self, network_tuple: Optional[ThirdPartyInstanceID], ignore_non_federatable: bool, + search_filter: Optional[dict], ) -> int: """Counts the number of public rooms as tracked in the room_stats_current and room_stats_state table. @@ -210,11 +234,20 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): Args: network_tuple ignore_non_federatable: If true filters out non-federatable rooms + search_filter """ def _count_public_rooms_txn(txn: LoggingTransaction) -> int: query_args = [] + room_type_clause, args = self._construct_room_type_where_clause( + search_filter.get(PublicRoomsFilterFields.ROOM_TYPES, None) + if search_filter + else None + ) + room_type_clause = f" AND {room_type_clause}" if room_type_clause else "" + query_args += args + if network_tuple: if network_tuple.appservice_id: published_sql = """ @@ -249,6 +282,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): OR join_rules = '{JoinRules.KNOCK_RESTRICTED}' OR history_visibility = 'world_readable' ) + {room_type_clause} AND joined_members > 0 """ @@ -347,8 +381,12 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): if ignore_non_federatable: where_clauses.append("is_federatable") - if search_filter and search_filter.get("generic_search_term", None): - search_term = "%" + search_filter["generic_search_term"] + "%" + if search_filter and search_filter.get( + PublicRoomsFilterFields.GENERIC_SEARCH_TERM, None + ): + search_term = ( + "%" + search_filter[PublicRoomsFilterFields.GENERIC_SEARCH_TERM] + "%" + ) where_clauses.append( """ @@ -365,6 +403,15 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): search_term.lower(), ] + room_type_clause, args = self._construct_room_type_where_clause( + search_filter.get(PublicRoomsFilterFields.ROOM_TYPES, None) + if search_filter + else None + ) + if room_type_clause: + where_clauses.append(room_type_clause) + query_args += args + where_clause = "" if where_clauses: where_clause = " AND " + " AND ".join(where_clauses) @@ -373,7 +420,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): sql = f""" SELECT room_id, name, topic, canonical_alias, joined_members, - avatar, history_visibility, guest_access, join_rules + avatar, history_visibility, guest_access, join_rules, room_type FROM ( {published_sql} ) published @@ -549,7 +596,8 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): SELECT state.room_id, state.name, state.canonical_alias, curr.joined_members, curr.local_users_in_room, rooms.room_version, rooms.creator, state.encryption, state.is_federatable, rooms.is_public, state.join_rules, - state.guest_access, state.history_visibility, curr.current_state_events + state.guest_access, state.history_visibility, curr.current_state_events, + state.room_type FROM room_stats_state state INNER JOIN room_stats_current curr USING (room_id) INNER JOIN rooms USING (room_id) @@ -593,12 +641,15 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): "version": room[5], "creator": room[6], "encryption": room[7], - "federatable": room[8], - "public": room[9], + # room_stats_state.federatable is an integer on sqlite. + "federatable": bool(room[8]), + # rooms.is_public is an integer on sqlite. + "public": bool(room[9]), "join_rules": room[10], "guest_access": room[11], "history_visibility": room[12], "state_events": room[13], + "room_type": room[14], } ) @@ -1109,25 +1160,34 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): return room_servers async def clear_partial_state_room(self, room_id: str) -> bool: - # this can race with incoming events, so we watch out for FK errors. - # TODO(faster_joins): this still doesn't completely fix the race, since the persist process - # is not atomic. I fear we need an application-level lock. + """Clears the partial state flag for a room. + + Args: + room_id: The room whose partial state flag is to be cleared. + + Returns: + `True` if the partial state flag has been cleared successfully. + + `False` if the partial state flag could not be cleared because the room + still contains events with partial state. + """ try: await self.db_pool.runInteraction( "clear_partial_state_room", self._clear_partial_state_room_txn, room_id ) return True - except self.db_pool.engine.module.DatabaseError as e: - # TODO(faster_joins): how do we distinguish between FK errors and other errors? - logger.warning( + except self.db_pool.engine.module.IntegrityError as e: + # Assume that any `IntegrityError`s are due to partial state events. + logger.info( "Exception while clearing lazy partial-state-room %s, retrying: %s", room_id, e, ) return False - @staticmethod - def _clear_partial_state_room_txn(txn: LoggingTransaction, room_id: str) -> None: + def _clear_partial_state_room_txn( + self, txn: LoggingTransaction, room_id: str + ) -> None: DatabasePool.simple_delete_txn( txn, table="partial_state_rooms_servers", @@ -1138,7 +1198,9 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): table="partial_state_rooms", keyvalues={"room_id": room_id}, ) + self._invalidate_cache_and_stream(txn, self.is_partial_state_room, (room_id,)) + @cached() async def is_partial_state_room(self, room_id: str) -> bool: """Checks if this room has partial state. @@ -1164,6 +1226,7 @@ class _BackgroundUpdates: POPULATE_ROOM_DEPTH_MIN_DEPTH2 = "populate_room_depth_min_depth2" REPLACE_ROOM_DEPTH_MIN_DEPTH = "replace_room_depth_min_depth" POPULATE_ROOMS_CREATOR_COLUMN = "populate_rooms_creator_column" + ADD_ROOM_TYPE_COLUMN = "add_room_type_column" _REPLACE_ROOM_DEPTH_SQL_COMMANDS = ( @@ -1198,6 +1261,11 @@ class RoomBackgroundUpdateStore(SQLBaseStore): self._background_add_rooms_room_version_column, ) + self.db_pool.updates.register_background_update_handler( + _BackgroundUpdates.ADD_ROOM_TYPE_COLUMN, + self._background_add_room_type_column, + ) + # BG updates to change the type of room_depth.min_depth self.db_pool.updates.register_background_update_handler( _BackgroundUpdates.POPULATE_ROOM_DEPTH_MIN_DEPTH2, @@ -1567,6 +1635,69 @@ class RoomBackgroundUpdateStore(SQLBaseStore): return batch_size + async def _background_add_room_type_column( + self, progress: JsonDict, batch_size: int + ) -> int: + """Background update to go and add room_type information to `room_stats_state` + table from `event_json` table. + """ + + last_room_id = progress.get("room_id", "") + + def _background_add_room_type_column_txn( + txn: LoggingTransaction, + ) -> bool: + sql = """ + SELECT state.room_id, json FROM event_json + INNER JOIN current_state_events AS state USING (event_id) + WHERE state.room_id > ? AND type = 'm.room.create' + ORDER BY state.room_id + LIMIT ? + """ + + txn.execute(sql, (last_room_id, batch_size)) + room_id_to_create_event_results = txn.fetchall() + + new_last_room_id = None + for room_id, event_json in room_id_to_create_event_results: + event_dict = db_to_json(event_json) + + room_type = event_dict.get("content", {}).get( + EventContentFields.ROOM_TYPE, None + ) + if isinstance(room_type, str): + self.db_pool.simple_update_txn( + txn, + table="room_stats_state", + keyvalues={"room_id": room_id}, + updatevalues={"room_type": room_type}, + ) + + new_last_room_id = room_id + + if new_last_room_id is None: + return True + + self.db_pool.updates._background_update_progress_txn( + txn, + _BackgroundUpdates.ADD_ROOM_TYPE_COLUMN, + {"room_id": new_last_room_id}, + ) + + return False + + end = await self.db_pool.runInteraction( + "_background_add_room_type_column", + _background_add_room_type_column_txn, + ) + + if end: + await self.db_pool.updates._end_background_update( + _BackgroundUpdates.ADD_ROOM_TYPE_COLUMN + ) + + return batch_size + class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): def __init__( @@ -1643,9 +1774,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): servers, ) - @staticmethod def _store_partial_state_room_txn( - txn: LoggingTransaction, room_id: str, servers: Collection[str] + self, txn: LoggingTransaction, room_id: str, servers: Collection[str] ) -> None: DatabasePool.simple_insert_txn( txn, @@ -1660,6 +1790,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): keys=("room_id", "server_name"), values=((room_id, s) for s in servers), ) + self._invalidate_cache_and_stream(txn, self.is_partial_state_room, (room_id,)) async def maybe_store_room_on_outlier_membership( self, room_id: str, room_version: RoomVersion @@ -1875,9 +2006,15 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else "" + # We join on room_stats_state despite not using any columns from it + # because the join can influence the number of rows returned; + # e.g. a room that doesn't have state, maybe because it was deleted. + # The query returning the total count should be consistent with + # the query returning the results. sql = """ SELECT COUNT(*) as total_event_reports FROM event_reports AS er + JOIN room_stats_state ON room_stats_state.room_id = er.room_id {} """.format( where_clause diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 31bc8c5601..4f0adb136a 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py
@@ -21,6 +21,7 @@ from typing import ( FrozenSet, Iterable, List, + Mapping, Optional, Set, Tuple, @@ -30,8 +31,6 @@ from typing import ( import attr from synapse.api.constants import EventTypes, Membership -from synapse.events import EventBase -from synapse.events.snapshot import EventContext from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import ( run_as_background_process, @@ -56,6 +55,7 @@ from synapse.types import JsonDict, PersistedEventPosition, StateMap, get_domain from synapse.util.async_helpers import Linearizer from synapse.util.caches import intern_string from synapse.util.caches.descriptors import _CacheContext, cached, cachedList +from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure if TYPE_CHECKING: @@ -184,34 +184,109 @@ class RoomMemberWorkerStore(EventsWorkerStore): self._check_safe_current_state_events_membership_updated_txn, ) - @cached(max_entries=100000, iterable=True, prune_unread_entries=False) + @cached(max_entries=100000, iterable=True) async def get_users_in_room(self, room_id: str) -> List[str]: + """ + Returns a list of users in the room sorted by longest in the room first + (aka. with the lowest depth). This is done to match the sort in + `get_current_hosts_in_room()` and so we can re-use the cache but it's + not horrible to have here either. + """ + return await self.db_pool.runInteraction( "get_users_in_room", self.get_users_in_room_txn, room_id ) def get_users_in_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[str]: + """ + Returns a list of users in the room sorted by longest in the room first + (aka. with the lowest depth). This is done to match the sort in + `get_current_hosts_in_room()` and so we can re-use the cache but it's + not horrible to have here either. + """ # If we can assume current_state_events.membership is up to date # then we can avoid a join, which is a Very Good Thing given how # frequently this function gets called. if self._current_state_events_membership_up_to_date: sql = """ - SELECT state_key FROM current_state_events - WHERE type = 'm.room.member' AND room_id = ? AND membership = ? + SELECT c.state_key FROM current_state_events as c + /* Get the depth of the event from the events table */ + INNER JOIN events AS e USING (event_id) + WHERE c.type = 'm.room.member' AND c.room_id = ? AND membership = ? + /* Sorted by lowest depth first */ + ORDER BY e.depth ASC; """ else: sql = """ - SELECT state_key FROM room_memberships as m + SELECT c.state_key FROM room_memberships as m + /* Get the depth of the event from the events table */ + INNER JOIN events AS e USING (event_id) INNER JOIN current_state_events as c ON m.event_id = c.event_id AND m.room_id = c.room_id AND m.user_id = c.state_key WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ? + /* Sorted by lowest depth first */ + ORDER BY e.depth ASC; """ txn.execute(sql, (room_id, Membership.JOIN)) return [r[0] for r in txn] + @cached() + def get_user_in_room_with_profile( + self, room_id: str, user_id: str + ) -> Dict[str, ProfileInfo]: + raise NotImplementedError() + + @cachedList( + cached_method_name="get_user_in_room_with_profile", list_name="user_ids" + ) + async def get_subset_users_in_room_with_profiles( + self, room_id: str, user_ids: Collection[str] + ) -> Dict[str, ProfileInfo]: + """Get a mapping from user ID to profile information for a list of users + in a given room. + + The profile information comes directly from this room's `m.room.member` + events, and so may be specific to this room rather than part of a user's + global profile. To avoid privacy leaks, the profile data should only be + revealed to users who are already in this room. + + Args: + room_id: The ID of the room to retrieve the users of. + user_ids: a list of users in the room to run the query for + + Returns: + A mapping from user ID to ProfileInfo. + """ + + def _get_subset_users_in_room_with_profiles( + txn: LoggingTransaction, + ) -> Dict[str, ProfileInfo]: + clause, ids = make_in_list_sql_clause( + self.database_engine, "c.state_key", user_ids + ) + + sql = """ + SELECT state_key, display_name, avatar_url FROM room_memberships as m + INNER JOIN current_state_events as c + ON m.event_id = c.event_id + AND m.room_id = c.room_id + AND m.user_id = c.state_key + WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ? AND %s + """ % ( + clause, + ) + txn.execute(sql, (room_id, Membership.JOIN, *ids)) + + return {r[0]: ProfileInfo(display_name=r[1], avatar_url=r[2]) for r in txn} + + return await self.db_pool.runInteraction( + "get_subset_users_in_room_with_profiles", + _get_subset_users_in_room_with_profiles, + ) + @cached(max_entries=100000, iterable=True) async def get_users_in_room_with_profiles( self, room_id: str @@ -228,6 +303,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): Returns: A mapping from user ID to ProfileInfo. + + Preconditions: + - There is full state available for the room (it is not partial-stated). """ def _get_users_in_room_with_profiles( @@ -338,6 +416,15 @@ class RoomMemberWorkerStore(EventsWorkerStore): ) @cached() + async def get_number_joined_users_in_room(self, room_id: str) -> int: + return await self.db_pool.simple_select_one_onecol( + table="current_state_events", + keyvalues={"room_id": room_id, "membership": Membership.JOIN}, + retcol="COUNT(*)", + desc="get_number_joined_users_in_room", + ) + + @cached() async def get_invited_rooms_for_local_user( self, user_id: str ) -> List[RoomsForUser]: @@ -416,6 +503,17 @@ class RoomMemberWorkerStore(EventsWorkerStore): user_id: str, membership_list: List[str], ) -> List[RoomsForUser]: + """Get all the rooms for this *local* user where the membership for this user + matches one in the membership list. + + Args: + user_id: The user ID. + membership_list: A list of synapse.api.constants.Membership + values which the user must be in. + + Returns: + The RoomsForUser that the user matches the membership types. + """ # Paranoia check. if not self.hs.is_mine_id(user_id): raise Exception( @@ -444,6 +542,44 @@ class RoomMemberWorkerStore(EventsWorkerStore): return results + @cached(iterable=True) + async def get_local_users_in_room(self, room_id: str) -> List[str]: + """ + Retrieves a list of the current roommembers who are local to the server. + """ + return await self.db_pool.simple_select_onecol( + table="local_current_membership", + keyvalues={"room_id": room_id, "membership": Membership.JOIN}, + retcol="user_id", + desc="get_local_users_in_room", + ) + + async def check_local_user_in_room(self, user_id: str, room_id: str) -> bool: + """ + Check whether a given local user is currently joined to the given room. + + Returns: + A boolean indicating whether the user is currently joined to the room + + Raises: + Exeption when called with a non-local user to this homeserver + """ + if not self.hs.is_mine_id(user_id): + raise Exception( + "Cannot call 'check_local_user_in_room' on " + "non-local user %s" % (user_id,), + ) + + ( + membership, + member_event_id, + ) = await self.get_local_current_membership_for_user_in_room( + user_id=user_id, + room_id=room_id, + ) + + return membership == Membership.JOIN + async def get_local_current_membership_for_user_in_room( self, user_id: str, room_id: str ) -> Tuple[Optional[str], Optional[str]]: @@ -476,7 +612,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): return results_dict.get("membership"), results_dict.get("event_id") - @cached(max_entries=500000, iterable=True, prune_unread_entries=False) + @cached(max_entries=500000, iterable=True) async def get_rooms_for_user_with_stream_ordering( self, user_id: str ) -> FrozenSet[GetRoomsForUserWithStreamOrdering]: @@ -647,25 +783,76 @@ class RoomMemberWorkerStore(EventsWorkerStore): ) return frozenset(r.room_id for r in rooms) - @cached( - max_entries=500000, - cache_context=True, - iterable=True, - prune_unread_entries=False, + @cached(max_entries=10000) + async def does_pair_of_users_share_a_room( + self, user_id: str, other_user_id: str + ) -> bool: + raise NotImplementedError() + + @cachedList( + cached_method_name="does_pair_of_users_share_a_room", list_name="other_user_ids" ) - async def get_users_who_share_room_with_user( - self, user_id: str, cache_context: _CacheContext + async def _do_users_share_a_room( + self, user_id: str, other_user_ids: Collection[str] + ) -> Mapping[str, Optional[bool]]: + """Return mapping from user ID to whether they share a room with the + given user. + + Note: `None` and `False` are equivalent and mean they don't share a + room. + """ + + def do_users_share_a_room_txn( + txn: LoggingTransaction, user_ids: Collection[str] + ) -> Dict[str, bool]: + clause, args = make_in_list_sql_clause( + self.database_engine, "state_key", user_ids + ) + + # This query works by fetching both the list of rooms for the target + # user and the set of other users, and then checking if there is any + # overlap. + sql = f""" + SELECT b.state_key + FROM ( + SELECT room_id FROM current_state_events + WHERE type = 'm.room.member' AND membership = 'join' AND state_key = ? + ) AS a + INNER JOIN ( + SELECT room_id, state_key FROM current_state_events + WHERE type = 'm.room.member' AND membership = 'join' AND {clause} + ) AS b using (room_id) + LIMIT 1 + """ + + txn.execute(sql, (user_id, *args)) + return {u: True for u, in txn} + + to_return = {} + for batch_user_ids in batch_iter(other_user_ids, 1000): + res = await self.db_pool.runInteraction( + "do_users_share_a_room", do_users_share_a_room_txn, batch_user_ids + ) + to_return.update(res) + + return to_return + + async def do_users_share_a_room( + self, user_id: str, other_user_ids: Collection[str] ) -> Set[str]: + """Return the set of users who share a room with the first users""" + + user_dict = await self._do_users_share_a_room(user_id, other_user_ids) + + return {u for u, share_room in user_dict.items() if share_room} + + async def get_users_who_share_room_with_user(self, user_id: str) -> Set[str]: """Returns the set of users who share a room with `user_id`""" - room_ids = await self.get_rooms_for_user( - user_id, on_invalidate=cache_context.invalidate - ) + room_ids = await self.get_rooms_for_user(user_id) user_who_share_room = set() for room_id in room_ids: - user_ids = await self.get_users_in_room( - room_id, on_invalidate=cache_context.invalidate - ) + user_ids = await self.get_users_in_room(room_id) user_who_share_room.update(user_ids) return user_who_share_room @@ -694,161 +881,92 @@ class RoomMemberWorkerStore(EventsWorkerStore): return shared_room_ids or frozenset() - async def get_joined_users_from_context( - self, event: EventBase, context: EventContext - ) -> Dict[str, ProfileInfo]: - state_group: Union[object, int] = context.state_group - if not state_group: - # If state_group is None it means it has yet to be assigned a - # state group, i.e. we need to make sure that calls with a state_group - # of None don't hit previous cached calls with a None state_group. - # To do this we set the state_group to a new object as object() != object() - state_group = object() + async def get_joined_user_ids_from_state( + self, room_id: str, state: StateMap[str] + ) -> Set[str]: + """ + For a given set of state IDs, get a set of user IDs in the room. - current_state_ids = await context.get_current_state_ids() - assert current_state_ids is not None - assert state_group is not None - return await self._get_joined_users_from_context( - event.room_id, state_group, current_state_ids, event=event, context=context - ) + This method checks the local event cache, before calling + `_get_user_ids_from_membership_event_ids` for any uncached events. + """ - async def get_joined_users_from_state( - self, room_id: str, state_entry: "_StateCacheEntry" - ) -> Dict[str, ProfileInfo]: - state_group: Union[object, int] = state_entry.state_group - if not state_group: - # If state_group is None it means it has yet to be assigned a - # state group, i.e. we need to make sure that calls with a state_group - # of None don't hit previous cached calls with a None state_group. - # To do this we set the state_group to a new object as object() != object() - state_group = object() + with Measure(self._clock, "get_joined_user_ids_from_state"): + users_in_room = set() + member_event_ids = [ + e_id for key, e_id in state.items() if key[0] == EventTypes.Member + ] - assert state_group is not None - with Measure(self._clock, "get_joined_users_from_state"): - return await self._get_joined_users_from_context( - room_id, state_group, state_entry.state, context=state_entry - ) + # We check if we have any of the member event ids in the event cache + # before we ask the DB - @cached(num_args=2, cache_context=True, iterable=True, max_entries=100000) - async def _get_joined_users_from_context( - self, - room_id: str, - state_group: Union[object, int], - current_state_ids: StateMap[str], - cache_context: _CacheContext, - event: Optional[EventBase] = None, - context: Optional[Union[EventContext, "_StateCacheEntry"]] = None, - ) -> Dict[str, ProfileInfo]: - # We don't use `state_group`, it's there so that we can cache based - # on it. However, it's important that it's never None, since two current_states - # with a state_group of None are likely to be different. - assert state_group is not None - - users_in_room = {} - member_event_ids = [ - e_id - for key, e_id in current_state_ids.items() - if key[0] == EventTypes.Member - ] - - if context is not None: - # If we have a context with a delta from a previous state group, - # check if we also have the result from the previous group in cache. - # If we do then we can reuse that result and simply update it with - # any membership changes in `delta_ids` - if context.prev_group and context.delta_ids: - prev_res = self._get_joined_users_from_context.cache.get_immediate( - (room_id, context.prev_group), None - ) - if prev_res and isinstance(prev_res, dict): - users_in_room = dict(prev_res) - member_event_ids = [ - e_id - for key, e_id in context.delta_ids.items() - if key[0] == EventTypes.Member - ] - for etype, state_key in context.delta_ids: - if etype == EventTypes.Member: - users_in_room.pop(state_key, None) - - # We check if we have any of the member event ids in the event cache - # before we ask the DB - - # We don't update the event cache hit ratio as it completely throws off - # the hit ratio counts. After all, we don't populate the cache if we - # miss it here - event_map = self._get_events_from_cache(member_event_ids, update_metrics=False) - - missing_member_event_ids = [] - for event_id in member_event_ids: - ev_entry = event_map.get(event_id) - if ev_entry and not ev_entry.event.rejected_reason: - if ev_entry.event.membership == Membership.JOIN: - users_in_room[ev_entry.event.state_key] = ProfileInfo( - display_name=ev_entry.event.content.get("displayname", None), - avatar_url=ev_entry.event.content.get("avatar_url", None), - ) - else: - missing_member_event_ids.append(event_id) - - if missing_member_event_ids: - event_to_memberships = await self._get_joined_profiles_from_event_ids( - missing_member_event_ids + # We don't update the event cache hit ratio as it completely throws off + # the hit ratio counts. After all, we don't populate the cache if we + # miss it here + event_map = self._get_events_from_local_cache( + member_event_ids, update_metrics=False ) - users_in_room.update(row for row in event_to_memberships.values() if row) - - if event is not None and event.type == EventTypes.Member: - if event.membership == Membership.JOIN: - if event.event_id in member_event_ids: - users_in_room[event.state_key] = ProfileInfo( - display_name=event.content.get("displayname", None), - avatar_url=event.content.get("avatar_url", None), + + missing_member_event_ids = [] + for event_id in member_event_ids: + ev_entry = event_map.get(event_id) + if ev_entry and not ev_entry.event.rejected_reason: + if ev_entry.event.membership == Membership.JOIN: + users_in_room.add(ev_entry.event.state_key) + else: + missing_member_event_ids.append(event_id) + + if missing_member_event_ids: + event_to_memberships = ( + await self._get_user_ids_from_membership_event_ids( + missing_member_event_ids ) + ) + users_in_room.update( + user_id for user_id in event_to_memberships.values() if user_id + ) - return users_in_room + return users_in_room - @cached(max_entries=10000) - def _get_joined_profile_from_event_id( + @cached( + max_entries=10000, + # This name matches the old function that has been replaced - the cache name + # is kept here to maintain backwards compatibility. + name="_get_joined_profile_from_event_id", + ) + def _get_user_id_from_membership_event_id( self, event_id: str ) -> Optional[Tuple[str, ProfileInfo]]: raise NotImplementedError() @cachedList( - cached_method_name="_get_joined_profile_from_event_id", + cached_method_name="_get_user_id_from_membership_event_id", list_name="event_ids", ) - async def _get_joined_profiles_from_event_ids( + async def _get_user_ids_from_membership_event_ids( self, event_ids: Iterable[str] - ) -> Dict[str, Optional[Tuple[str, ProfileInfo]]]: + ) -> Dict[str, Optional[str]]: """For given set of member event_ids check if they point to a join - event and if so return the associated user and profile info. + event. Args: event_ids: The member event IDs to lookup Returns: - Map from event ID to `user_id` and ProfileInfo (or None if not join event). + Map from event ID to `user_id`, or None if event is not a join. """ rows = await self.db_pool.simple_select_many_batch( table="room_memberships", column="event_id", iterable=event_ids, - retcols=("user_id", "display_name", "avatar_url", "event_id"), + retcols=("user_id", "event_id"), keyvalues={"membership": Membership.JOIN}, - batch_size=500, - desc="_get_joined_profiles_from_event_ids", + batch_size=1000, + desc="_get_user_ids_from_membership_event_ids", ) - return { - row["event_id"]: ( - row["user_id"], - ProfileInfo( - avatar_url=row["avatar_url"], display_name=row["display_name"] - ), - ) - for row in rows - } + return {row["event_id"]: row["user_id"] for row in rows} @cached(max_entries=10000) async def is_host_joined(self, room_id: str, host: str) -> bool: @@ -894,44 +1012,77 @@ class RoomMemberWorkerStore(EventsWorkerStore): return True @cached(iterable=True, max_entries=10000) - async def get_current_hosts_in_room(self, room_id: str) -> Set[str]: - """Get current hosts in room based on current state.""" + async def get_current_hosts_in_room(self, room_id: str) -> List[str]: + """ + Get current hosts in room based on current state. + + The heuristic of sorting by servers who have been in the room the + longest is good because they're most likely to have anything we ask + about. + + Returns: + Returns a list of servers sorted by longest in the room first. (aka. + sorted by join with the lowest depth first). + """ # First we check if we already have `get_users_in_room` in the cache, as # we can just calculate result from that users = self.get_users_in_room.cache.get_immediate( (room_id,), None, update_metrics=False ) - if users is not None: - return {get_domain_from_id(u) for u in users} - - if isinstance(self.database_engine, Sqlite3Engine): + if users is None and isinstance(self.database_engine, Sqlite3Engine): # If we're using SQLite then let's just always use # `get_users_in_room` rather than funky SQL. users = await self.get_users_in_room(room_id) - return {get_domain_from_id(u) for u in users} + + if users is not None: + # Because `users` is sorted from lowest -> highest depth, the list + # of domains will also be sorted that way. + domains: List[str] = [] + # We use a `Set` just for fast lookups + domain_set: Set[str] = set() + for u in users: + domain = get_domain_from_id(u) + if domain not in domain_set: + domain_set.add(domain) + domains.append(domain) + return domains # For PostgreSQL we can use a regex to pull out the domains from the # joined users in `current_state_events` via regex. - def get_current_hosts_in_room_txn(txn: LoggingTransaction) -> Set[str]: + def get_current_hosts_in_room_txn(txn: LoggingTransaction) -> List[str]: + # Returns a list of servers currently joined in the room sorted by + # longest in the room first (aka. with the lowest depth). The + # heuristic of sorting by servers who have been in the room the + # longest is good because they're most likely to have anything we + # ask about. sql = """ - SELECT DISTINCT substring(state_key FROM '@[^:]*:(.*)$') - FROM current_state_events + SELECT + /* Match the domain part of the MXID */ + substring(c.state_key FROM '@[^:]*:(.*)$') as server_domain + FROM current_state_events c + /* Get the depth of the event from the events table */ + INNER JOIN events AS e USING (event_id) WHERE - type = 'm.room.member' - AND membership = 'join' - AND room_id = ? + /* Find any join state events in the room */ + c.type = 'm.room.member' + AND c.membership = 'join' + AND c.room_id = ? + /* Group all state events from the same domain into their own buckets (groups) */ + GROUP BY server_domain + /* Sorted by lowest depth first */ + ORDER BY min(e.depth) ASC; """ txn.execute(sql, (room_id,)) - return {d for d, in txn} + return [d for d, in txn] return await self.db_pool.runInteraction( "get_current_hosts_in_room", get_current_hosts_in_room_txn ) async def get_joined_hosts( - self, room_id: str, state_entry: "_StateCacheEntry" + self, room_id: str, state: StateMap[str], state_entry: "_StateCacheEntry" ) -> FrozenSet[str]: state_group: Union[object, int] = state_entry.state_group if not state_group: @@ -944,7 +1095,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): assert state_group is not None with Measure(self._clock, "get_joined_hosts"): return await self._get_joined_hosts( - room_id, state_group, state_entry=state_entry + room_id, state_group, state, state_entry=state_entry ) @cached(num_args=2, max_entries=10000, iterable=True) @@ -952,6 +1103,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): self, room_id: str, state_group: Union[object, int], + state: StateMap[str], state_entry: "_StateCacheEntry", ) -> FrozenSet[str]: # We don't use `state_group`, it's there so that we can cache based on @@ -1006,12 +1158,12 @@ class RoomMemberWorkerStore(EventsWorkerStore): else: # The cache doesn't match the state group or prev state group, # so we calculate the result from first principles. - joined_users = await self.get_joined_users_from_state( - room_id, state_entry + joined_user_ids = await self.get_joined_user_ids_from_state( + room_id, state ) cache.hosts_to_joined_users = {} - for user_id in joined_users: + for user_id in joined_user_ids: host = intern_string(get_domain_from_id(user_id)) cache.hosts_to_joined_users.setdefault(host, set()).add(user_id) @@ -1090,6 +1242,30 @@ class RoomMemberWorkerStore(EventsWorkerStore): "get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn ) + async def is_locally_forgotten_room(self, room_id: str) -> bool: + """Returns whether all local users have forgotten this room_id. + + Args: + room_id: The room ID to query. + + Returns: + Whether the room is forgotten. + """ + + sql = """ + SELECT count(*) > 0 FROM local_current_membership + INNER JOIN room_memberships USING (room_id, event_id) + WHERE + room_id = ? + AND forgotten = 0; + """ + + rows = await self.db_pool.execute("is_forgotten_room", None, sql, room_id) + + # `count(*)` returns always an integer + # If any rows still exist it means someone has not forgotten this room yet + return not rows[0][0] + async def get_rooms_user_has_been_in(self, user_id: str) -> Set[str]: """Get all rooms that the user has ever been in. diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 78e0773b2a..f6e24b68d2 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py
@@ -113,7 +113,6 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): EVENT_SEARCH_UPDATE_NAME = "event_search" EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order" - EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist" EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin" EVENT_SEARCH_DELETE_NON_STRINGS = "event_search_sqlite_delete_non_strings" @@ -132,15 +131,6 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): self.EVENT_SEARCH_ORDER_UPDATE_NAME, self._background_reindex_search_order ) - # we used to have a background update to turn the GIN index into a - # GIST one; we no longer do that (obviously) because we actually want - # a GIN index. However, it's possible that some people might still have - # the background update queued, so we register a handler to clear the - # background update. - self.db_pool.updates.register_noop_background_update( - self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME - ) - self.db_pool.updates.register_background_update_handler( self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME, self._background_reindex_gin_search ) diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 5e6efbd0fc..0b10af0e58 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py
@@ -419,15 +419,22 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): # anything that was rejected should have the same state as its # predecessor. if context.rejected: - assert context.state_group == context.state_group_before_event + state_group = context.state_group_before_event + else: + state_group = context.state_group self.db_pool.simple_update_txn( txn, table="event_to_state_groups", keyvalues={"event_id": event.event_id}, - updatevalues={"state_group": context.state_group}, + updatevalues={"state_group": state_group}, ) + # the event may now be rejected where it was not before, or vice versa, + # in which case we need to update the rejected flags. + if bool(context.rejected) != (event.rejected_reason is not None): + self.mark_event_rejected_txn(txn, event.event_id, context.rejected) + self.db_pool.simple_delete_one_txn( txn, table="partial_state_events", @@ -435,11 +442,12 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): ) # TODO(faster_joins): need to do something about workers here + # https://github.com/matrix-org/synapse/issues/12994 txn.call_after(self.is_partial_state_event.invalidate, (event.event_id,)) txn.call_after( self._get_state_group_for_event.prefill, (event.event_id,), - context.state_group, + state_group, ) diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index b95dbef678..b4c652acf3 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py
@@ -16,7 +16,7 @@ import logging from enum import Enum from itertools import chain -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast from typing_extensions import Counter @@ -120,11 +120,6 @@ class StatsStore(StateDeltasStore): self.db_pool.updates.register_background_update_handler( "populate_stats_process_users", self._populate_stats_process_users ) - # we no longer need to perform clean-up, but we will give ourselves - # the potential to reintroduce it in the future – so documentation - # will still encourage the use of this no-op handler. - self.db_pool.updates.register_noop_background_update("populate_stats_cleanup") - self.db_pool.updates.register_noop_background_update("populate_stats_prepare") async def _populate_stats_process_users( self, progress: JsonDict, batch_size: int @@ -243,6 +238,7 @@ class StatsStore(StateDeltasStore): * avatar * canonical_alias * guest_access + * room_type A is_federatable key can also be included with a boolean value. @@ -268,6 +264,7 @@ class StatsStore(StateDeltasStore): "avatar", "canonical_alias", "guest_access", + "room_type", ): field = fields.get(col, sentinel) if field is not sentinel and (not isinstance(field, str) or "\0" in field): @@ -300,6 +297,7 @@ class StatsStore(StateDeltasStore): keyvalues={id_col: id}, retcol="completed_delta_stream_id", allow_none=True, + desc="get_earliest_token_for_stats", ) async def bulk_update_stats_delta( @@ -576,7 +574,7 @@ class StatsStore(StateDeltasStore): state_event_map = await self.get_events(event_ids, get_prev_content=False) # type: ignore[attr-defined] - room_state = { + room_state: Dict[str, Union[None, bool, str]] = { "join_rules": None, "history_visibility": None, "encryption": None, @@ -585,6 +583,7 @@ class StatsStore(StateDeltasStore): "avatar": None, "canonical_alias": None, "is_federatable": True, + "room_type": None, } for event in state_event_map.values(): @@ -608,6 +607,9 @@ class StatsStore(StateDeltasStore): room_state["is_federatable"] = ( event.content.get(EventContentFields.FEDERATE, True) is True ) + room_type = event.content.get(EventContentFields.ROOM_TYPE) + if isinstance(room_type, str): + room_state["room_type"] = room_type await self.update_room_state(room_id, room_state) diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 8e88784d3c..a347430aa7 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py
@@ -46,16 +46,19 @@ from typing import ( Set, Tuple, cast, + overload, ) import attr from frozendict import frozendict +from typing_extensions import Literal from twisted.internet import defer from synapse.api.filtering import Filter from synapse.events import EventBase from synapse.logging.context import make_deferred_yieldable, run_in_background +from synapse.logging.opentracing import trace from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( DatabasePool, @@ -795,6 +798,24 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): ) return RoomStreamToken(topo, stream_ordering) + @overload + def get_stream_id_for_event_txn( + self, + txn: LoggingTransaction, + event_id: str, + allow_none: Literal[False] = False, + ) -> int: + ... + + @overload + def get_stream_id_for_event_txn( + self, + txn: LoggingTransaction, + event_id: str, + allow_none: bool = False, + ) -> Optional[int]: + ... + def get_stream_id_for_event_txn( self, txn: LoggingTransaction, @@ -1002,8 +1023,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): } async def get_all_new_events_stream( - self, from_id: int, current_id: int, limit: int - ) -> Tuple[int, List[EventBase]]: + self, from_id: int, current_id: int, limit: int, get_prev_content: bool = False + ) -> Tuple[int, List[EventBase], Dict[str, Optional[int]]]: """Get all new events Returns all events with from_id < stream_ordering <= current_id. @@ -1012,19 +1033,21 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): from_id: the stream_ordering of the last event we processed current_id: the stream_ordering of the most recently processed event limit: the maximum number of events to return + get_prev_content: whether to fetch previous event content Returns: - A tuple of (next_id, events), where `next_id` is the next value to - pass as `from_id` (it will either be the stream_ordering of the - last returned event, or, if fewer than `limit` events were found, - the `current_id`). + A tuple of (next_id, events, event_to_received_ts), where `next_id` + is the next value to pass as `from_id` (it will either be the + stream_ordering of the last returned event, or, if fewer than `limit` + events were found, the `current_id`). The `event_to_received_ts` is + a dictionary mapping event ID to the event `received_ts`. """ def get_all_new_events_stream_txn( txn: LoggingTransaction, - ) -> Tuple[int, List[str]]: + ) -> Tuple[int, Dict[str, Optional[int]]]: sql = ( - "SELECT e.stream_ordering, e.event_id" + "SELECT e.stream_ordering, e.event_id, e.received_ts" " FROM events AS e" " WHERE" " ? < e.stream_ordering AND e.stream_ordering <= ?" @@ -1039,15 +1062,21 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): if len(rows) == limit: upper_bound = rows[-1][0] - return upper_bound, [row[1] for row in rows] + event_to_received_ts: Dict[str, Optional[int]] = { + row[1]: row[2] for row in rows + } + return upper_bound, event_to_received_ts - upper_bound, event_ids = await self.db_pool.runInteraction( + upper_bound, event_to_received_ts = await self.db_pool.runInteraction( "get_all_new_events_stream", get_all_new_events_stream_txn ) - events = await self.get_events_as_list(event_ids) + events = await self.get_events_as_list( + event_to_received_ts.keys(), + get_prev_content=get_prev_content, + ) - return upper_bound, events + return upper_bound, events, event_to_received_ts async def get_federation_out_pos(self, typ: str) -> int: if self._need_to_reset_federation_stream_positions: @@ -1318,6 +1347,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return rows, next_token + @trace async def paginate_room_events( self, room_id: str, diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index fa9eadaca7..a7fcc564a9 100644 --- a/synapse/storage/databases/state/bg_updates.py +++ b/synapse/storage/databases/state/bg_updates.py
@@ -24,6 +24,7 @@ from synapse.storage.database import ( from synapse.storage.engines import PostgresEngine from synapse.storage.state import StateFilter from synapse.types import MutableStateMap, StateMap +from synapse.util.caches import intern_string if TYPE_CHECKING: from synapse.server import HomeServer @@ -136,7 +137,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore): txn.execute(sql % (where_clause,), args) for row in txn: typ, state_key, event_id = row - key = (typ, state_key) + key = (intern_string(typ), intern_string(state_key)) results[group][key] = event_id else: max_entries_returned = state_filter.max_entries_returned() diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 609a2b88bf..bb64543c1f 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py
@@ -202,7 +202,14 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): requests state from the cache, if False we need to query the DB for the missing state. """ - cache_entry = cache.get(group) + # If we are asked explicitly for a subset of keys, we only ask for those + # from the cache. This ensures that the `DictionaryCache` can make + # better decisions about what to cache and what to expire. + dict_keys = None + if not state_filter.has_wildcards(): + dict_keys = state_filter.concrete_types() + + cache_entry = cache.get(group, dict_keys=dict_keys) state_dict_ids = cache_entry.value if cache_entry.full or state_filter.is_full(): @@ -400,14 +407,17 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): room_id: str, prev_group: Optional[int], delta_ids: Optional[StateMap[str]], - current_state_ids: StateMap[str], + current_state_ids: Optional[StateMap[str]], ) -> int: """Store a new set of state, returning a newly assigned state group. + At least one of `current_state_ids` and `prev_group` must be provided. Whenever + `prev_group` is not None, `delta_ids` must also not be None. + Args: event_id: The event ID for which the state was calculated room_id - prev_group: A previous state group for the room, optional. + prev_group: A previous state group for the room. delta_ids: The delta between state at `prev_group` and `current_state_ids`, if `prev_group` was given. Same format as `current_state_ids`. @@ -418,10 +428,41 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): The state group ID """ - def _store_state_group_txn(txn: LoggingTransaction) -> int: - if current_state_ids is None: - # AFAIK, this can never happen - raise Exception("current_state_ids cannot be None") + if prev_group is None and current_state_ids is None: + raise Exception("current_state_ids and prev_group can't both be None") + + if prev_group is not None and delta_ids is None: + raise Exception("delta_ids is None when prev_group is not None") + + def insert_delta_group_txn( + txn: LoggingTransaction, prev_group: int, delta_ids: StateMap[str] + ) -> Optional[int]: + """Try and persist the new group as a delta. + + Requires that we have the state as a delta from a previous state group. + + Returns: + The state group if successfully created, or None if the state + needs to be persisted as a full state. + """ + is_in_db = self.db_pool.simple_select_one_onecol_txn( + txn, + table="state_groups", + keyvalues={"id": prev_group}, + retcol="id", + allow_none=True, + ) + if not is_in_db: + raise Exception( + "Trying to persist state with unpersisted prev_group: %r" + % (prev_group,) + ) + + # if the chain of state group deltas is going too long, we fall back to + # persisting a complete state group. + potential_hops = self._count_state_group_hops_txn(txn, prev_group) + if potential_hops >= MAX_STATE_DELTA_HOPS: + return None state_group = self._state_group_seq_gen.get_next_id_txn(txn) @@ -431,51 +472,45 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): values={"id": state_group, "room_id": room_id, "event_id": event_id}, ) - # We persist as a delta if we can, while also ensuring the chain - # of deltas isn't tooo long, as otherwise read performance degrades. - if prev_group: - is_in_db = self.db_pool.simple_select_one_onecol_txn( - txn, - table="state_groups", - keyvalues={"id": prev_group}, - retcol="id", - allow_none=True, - ) - if not is_in_db: - raise Exception( - "Trying to persist state with unpersisted prev_group: %r" - % (prev_group,) - ) - - potential_hops = self._count_state_group_hops_txn(txn, prev_group) - if prev_group and potential_hops < MAX_STATE_DELTA_HOPS: - assert delta_ids is not None - - self.db_pool.simple_insert_txn( - txn, - table="state_group_edges", - values={"state_group": state_group, "prev_state_group": prev_group}, - ) + self.db_pool.simple_insert_txn( + txn, + table="state_group_edges", + values={"state_group": state_group, "prev_state_group": prev_group}, + ) - self.db_pool.simple_insert_many_txn( - txn, - table="state_groups_state", - keys=("state_group", "room_id", "type", "state_key", "event_id"), - values=[ - (state_group, room_id, key[0], key[1], state_id) - for key, state_id in delta_ids.items() - ], - ) - else: - self.db_pool.simple_insert_many_txn( - txn, - table="state_groups_state", - keys=("state_group", "room_id", "type", "state_key", "event_id"), - values=[ - (state_group, room_id, key[0], key[1], state_id) - for key, state_id in current_state_ids.items() - ], - ) + self.db_pool.simple_insert_many_txn( + txn, + table="state_groups_state", + keys=("state_group", "room_id", "type", "state_key", "event_id"), + values=[ + (state_group, room_id, key[0], key[1], state_id) + for key, state_id in delta_ids.items() + ], + ) + + return state_group + + def insert_full_state_txn( + txn: LoggingTransaction, current_state_ids: StateMap[str] + ) -> int: + """Persist the full state, returning the new state group.""" + state_group = self._state_group_seq_gen.get_next_id_txn(txn) + + self.db_pool.simple_insert_txn( + txn, + table="state_groups", + values={"id": state_group, "room_id": room_id, "event_id": event_id}, + ) + + self.db_pool.simple_insert_many_txn( + txn, + table="state_groups_state", + keys=("state_group", "room_id", "type", "state_key", "event_id"), + values=[ + (state_group, room_id, key[0], key[1], state_id) + for key, state_id in current_state_ids.items() + ], + ) # Prefill the state group caches with this group. # It's fine to use the sequence like this as the state group map @@ -491,7 +526,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): self._state_group_members_cache.update, self._state_group_members_cache.sequence, key=state_group, - value=dict(current_member_state_ids), + value=current_member_state_ids, ) current_non_member_state_ids = { @@ -503,13 +538,35 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): self._state_group_cache.update, self._state_group_cache.sequence, key=state_group, - value=dict(current_non_member_state_ids), + value=current_non_member_state_ids, ) return state_group + if prev_group is not None: + state_group = await self.db_pool.runInteraction( + "store_state_group.insert_delta_group", + insert_delta_group_txn, + prev_group, + delta_ids, + ) + if state_group is not None: + return state_group + + # We're going to persist the state as a complete group rather than + # a delta, so first we need to ensure we have loaded the state map + # from the database. + if current_state_ids is None: + assert prev_group is not None + assert delta_ids is not None + groups = await self._get_state_for_groups([prev_group]) + current_state_ids = dict(groups[prev_group]) + current_state_ids.update(delta_ids) + return await self.db_pool.runInteraction( - "store_state_group", _store_state_group_txn + "store_state_group.insert_full_state", + insert_full_state_txn, + current_state_ids, ) async def purge_unreferenced_state_groups(