From c3f2f0f063e5d97e1ae396b85330ea249499f077 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 10 Jan 2024 12:29:42 +0000 Subject: Faster partial join to room with complex auth graph (#7) Instead of persisting outliers in a bunch of batches, let's just do them all at once. This is fine because all `_auth_and_persist_outliers_inner` is doing is checking the auth rules for each event, which requires the events to be topologically sorted by the auth graph. --- synapse/handlers/federation_event.py | 79 ++++++++++++++---------------------- 1 file changed, 30 insertions(+), 49 deletions(-) (limited to 'synapse') diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 882be905db..398f19eec0 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -94,7 +94,7 @@ from synapse.types import ( ) from synapse.types.state import StateFilter from synapse.util.async_helpers import Linearizer, concurrently_execute -from synapse.util.iterutils import batch_iter, partition, sorted_topologically_batched +from synapse.util.iterutils import batch_iter, partition, sorted_topologically from synapse.util.retryutils import NotRetryingDestination from synapse.util.stringutils import shortstr @@ -1678,57 +1678,36 @@ class FederationEventHandler: # We need to persist an event's auth events before the event. auth_graph = { - ev: [event_map[e_id] for e_id in ev.auth_event_ids() if e_id in event_map] + ev.event_id: [e_id for e_id in ev.auth_event_ids() if e_id in event_map] for ev in event_map.values() } - for roots in sorted_topologically_batched(event_map.values(), auth_graph): - if not roots: - # if *none* of the remaining events are ready, that means - # we have a loop. This either means a bug in our logic, or that - # somebody has managed to create a loop (which requires finding a - # hash collision in room v2 and later). - logger.warning( - "Loop found in auth events while fetching missing state/auth " - "events: %s", - shortstr(event_map.keys()), - ) - return - - logger.info( - "Persisting %i of %i remaining outliers: %s", - len(roots), - len(event_map), - shortstr(e.event_id for e in roots), - ) - - await self._auth_and_persist_outliers_inner(room_id, roots) - - async def _auth_and_persist_outliers_inner( - self, room_id: str, fetched_events: Collection[EventBase] - ) -> None: - """Helper for _auth_and_persist_outliers - - Persists a batch of events where we have (theoretically) already persisted all - of their auth events. - - Marks the events as outliers, auths them, persists them to the database, and, - where appropriate (eg, an invite), awakes the notifier. + sorted_auth_event_ids = sorted_topologically(event_map.keys(), auth_graph) + sorted_auth_events = [event_map[e_id] for e_id in sorted_auth_event_ids] + logger.info( + "Persisting %i remaining outliers: %s", + len(sorted_auth_events), + shortstr(e.event_id for e in sorted_auth_events), + ) - Params: - origin: where the events came from - room_id: the room that the events are meant to be in (though this has - not yet been checked) - fetched_events: the events to persist - """ # get all the auth events for all the events in this batch. By now, they should # have been persisted. - auth_events = { - aid for event in fetched_events for aid in event.auth_event_ids() + auth_event_ids = { + aid for event in sorted_auth_events for aid in event.auth_event_ids() + } + auth_map = { + ev.event_id: ev + for ev in sorted_auth_events + if ev.event_id in auth_event_ids } - persisted_events = await self._store.get_events( - auth_events, - allow_rejected=True, - ) + + missing_events = auth_event_ids.difference(auth_map) + if missing_events: + persisted_events = await self._store.get_events( + missing_events, + allow_rejected=True, + redact_behaviour=EventRedactBehaviour.as_is, + ) + auth_map.update(persisted_events) events_and_contexts_to_persist: List[Tuple[EventBase, EventContext]] = [] @@ -1736,7 +1715,7 @@ class FederationEventHandler: with nested_logging_context(suffix=event.event_id): auth = [] for auth_event_id in event.auth_event_ids(): - ae = persisted_events.get(auth_event_id) + ae = auth_map.get(auth_event_id) if not ae: # the fact we can't find the auth event doesn't mean it doesn't # exist, which means it is premature to reject `event`. Instead we @@ -1755,7 +1734,9 @@ class FederationEventHandler: context = EventContext.for_outlier(self._storage_controllers) try: validate_event_for_room_version(event) - await check_state_independent_auth_rules(self._store, event) + await check_state_independent_auth_rules( + self._store, event, batched_auth_events=auth_map + ) check_state_dependent_auth_rules(event, auth) except AuthError as e: logger.warning("Rejecting %r because %s", event, e) @@ -1772,7 +1753,7 @@ class FederationEventHandler: events_and_contexts_to_persist.append((event, context)) - for event in fetched_events: + for event in sorted_auth_events: await prep(event) await self.persist_events_and_notify( -- cgit 1.5.1 From 4c67f0391ba6001aea37642935466bbbb145da7a Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 10 Jan 2024 13:55:16 +0000 Subject: Split up deleting devices into batches (#16766) Otherwise for users with large numbers of devices this can cause a lot of woe. --- changelog.d/16766.misc | 1 + synapse/storage/databases/main/devices.py | 8 ++++++-- 2 files changed, 7 insertions(+), 2 deletions(-) create mode 100644 changelog.d/16766.misc (limited to 'synapse') diff --git a/changelog.d/16766.misc b/changelog.d/16766.misc new file mode 100644 index 0000000000..ded77a11c4 --- /dev/null +++ b/changelog.d/16766.misc @@ -0,0 +1 @@ +Split up deleting devices into batches. diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 66fabf4f76..ae914298fb 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -1794,7 +1794,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): device_ids: The IDs of the devices to delete """ - def _delete_devices_txn(txn: LoggingTransaction) -> None: + def _delete_devices_txn(txn: LoggingTransaction, device_ids: List[str]) -> None: self.db_pool.simple_delete_many_txn( txn, table="devices", @@ -1811,7 +1811,11 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): keyvalues={"user_id": user_id}, ) - await self.db_pool.runInteraction("delete_devices", _delete_devices_txn) + for batch in batch_iter(device_ids, 100): + await self.db_pool.runInteraction( + "delete_devices", _delete_devices_txn, batch + ) + for device_id in device_ids: self.device_id_exists_cache.invalidate((user_id, device_id)) -- cgit 1.5.1 From 578c5c736e4ca2479c8b72b5e9ac20cd7acde0e8 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 10 Jan 2024 14:31:35 +0000 Subject: Reduce amount of state pulled out when querying federation hierachy (#16785) There are two changes here: 1. Only pull out the required state when handling the request. 2. Change the get filtered state return type to check that we're only querying state that was requested --------- Co-authored-by: reivilibre --- changelog.d/16785.misc | 1 + synapse/handlers/room_summary.py | 12 ++++++++- synapse/storage/databases/main/state.py | 48 +++++++++++++++++++++++++++++++-- synapse/types/state.py | 24 +++++++++++++++++ 4 files changed, 82 insertions(+), 3 deletions(-) create mode 100644 changelog.d/16785.misc (limited to 'synapse') diff --git a/changelog.d/16785.misc b/changelog.d/16785.misc new file mode 100644 index 0000000000..4de185c5dd --- /dev/null +++ b/changelog.d/16785.misc @@ -0,0 +1 @@ +Reduce amount of state pulled out when querying federation hierachy. diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py index a534f5f280..78bcac1429 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py @@ -44,6 +44,7 @@ from synapse.api.ratelimiting import Ratelimiter from synapse.config.ratelimiting import RatelimitSettings from synapse.events import EventBase from synapse.types import JsonDict, Requester, StrCollection +from synapse.types.state import StateFilter from synapse.util.caches.response_cache import ResponseCache if TYPE_CHECKING: @@ -546,7 +547,16 @@ class RoomSummaryHandler: Returns: True if the room is accessible to the requesting user or server. """ - state_ids = await self._storage_controllers.state.get_current_state_ids(room_id) + event_types = [ + (EventTypes.JoinRules, ""), + (EventTypes.RoomHistoryVisibility, ""), + ] + if requester: + event_types.append((EventTypes.Member, requester)) + + state_ids = await self._storage_controllers.state.get_current_state_ids( + room_id, state_filter=StateFilter.from_types(event_types) + ) # If there's no state for the room, it isn't known. if not state_ids: diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index 4700e74ad2..06c44bb563 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -30,7 +30,10 @@ from typing import ( Optional, Set, Tuple, + TypeVar, + Union, cast, + overload, ) import attr @@ -52,7 +55,7 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore -from synapse.types import JsonDict, JsonMapping, StateMap +from synapse.types import JsonDict, JsonMapping, StateKey, StateMap from synapse.types.state import StateFilter from synapse.util.caches import intern_string from synapse.util.caches.descriptors import cached, cachedList @@ -64,6 +67,8 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +_T = TypeVar("_T") + MAX_STATE_DELTA_HOPS = 100 @@ -349,7 +354,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): def _get_filtered_current_state_ids_txn( txn: LoggingTransaction, ) -> StateMap[str]: - results = {} + results = StateMapWrapper(state_filter=state_filter or StateFilter.all()) + sql = """ SELECT type, state_key, event_id FROM current_state_events WHERE room_id = ? @@ -726,3 +732,41 @@ class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore): hs: "HomeServer", ): super().__init__(database, db_conn, hs) + + +@attr.s(auto_attribs=True, slots=True) +class StateMapWrapper(Dict[StateKey, str]): + """A wrapper around a StateMap[str] to ensure that we only query for items + that were not filtered out. + + This is to help prevent bugs where we filter out state but other bits of the + code expect the state to be there. + """ + + state_filter: StateFilter + + def __getitem__(self, key: StateKey) -> str: + if key not in self.state_filter: + raise Exception("State map was filtered and doesn't include: %s", key) + return super().__getitem__(key) + + @overload + def get(self, key: Tuple[str, str]) -> Optional[str]: + ... + + @overload + def get(self, key: Tuple[str, str], default: Union[str, _T]) -> Union[str, _T]: + ... + + def get( + self, key: StateKey, default: Union[str, _T, None] = None + ) -> Union[str, _T, None]: + if key not in self.state_filter: + raise Exception("State map was filtered and doesn't include: %s", key) + return super().get(key, default) + + def __contains__(self, key: Any) -> bool: + if key not in self.state_filter: + raise Exception("State map was filtered and doesn't include: %s", key) + + return super().__contains__(key) diff --git a/synapse/types/state.py b/synapse/types/state.py index 5ca3c94bce..53662372af 100644 --- a/synapse/types/state.py +++ b/synapse/types/state.py @@ -20,6 +20,7 @@ import logging from typing import ( TYPE_CHECKING, + Any, Callable, Collection, Dict, @@ -584,6 +585,29 @@ class StateFilter: # local users only return False + def __contains__(self, key: Any) -> bool: + if not isinstance(key, tuple) or len(key) != 2: + raise TypeError( + f"'in StateFilter' requires (str, str) as left operand, not {type(key).__name__}" + ) + + typ, state_key = key + + if not isinstance(typ, str) or not isinstance(state_key, str): + raise TypeError( + f"'in StateFilter' requires (str, str) as left operand, not ({type(typ).__name__}, {type(state_key).__name__})" + ) + + if typ in self.types: + state_keys = self.types[typ] + if state_keys is None or state_key in state_keys: + return True + + elif self.include_others: + return True + + return False + _ALL_STATE_FILTER = StateFilter(types=immutabledict(), include_others=True) _ALL_NON_MEMBER_STATE_FILTER = StateFilter( -- cgit 1.5.1 From 0a96fa52a214d0cc707dbdf9693118ceddbfc150 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 10 Jan 2024 14:42:13 +0000 Subject: Pull less state out if we fail to backfill (#16788) Sometimes we fail to fetch events during backfill due to missing state, and we often end up querying the same bad events periodically (as people backpaginate). In such cases its likely we will continue to fail to get the state, and therefore we should try *before* loading the state that we have from the DB (as otherwise it's wasted DB and memory). --------- Co-authored-by: reivilibre --- changelog.d/16788.misc | 1 + synapse/handlers/federation_event.py | 21 ++++++++++++--------- 2 files changed, 13 insertions(+), 9 deletions(-) create mode 100644 changelog.d/16788.misc (limited to 'synapse') diff --git a/changelog.d/16788.misc b/changelog.d/16788.misc new file mode 100644 index 0000000000..e58a5a7a32 --- /dev/null +++ b/changelog.d/16788.misc @@ -0,0 +1 @@ +Pull less state out of the DB when we retry fetching old events during backfill. diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 398f19eec0..12837429b9 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -1141,16 +1141,8 @@ class FederationEventHandler: partial_state_flags = await self._store.get_partial_state_events(seen) partial_state = any(partial_state_flags.values()) - # Get the state of the events we know about - ours = await self._state_storage_controller.get_state_groups_ids( - room_id, seen, await_full_state=False - ) - # state_maps is a list of mappings from (type, state_key) to event_id - state_maps: List[StateMap[str]] = list(ours.values()) - - # we don't need this any more, let's delete it. - del ours + state_maps: List[StateMap[str]] = [] # Ask the remote server for the states we don't # know about @@ -1169,6 +1161,17 @@ class FederationEventHandler: state_maps.append(remote_state_map) + # Get the state of the events we know about. We do this *after* + # trying to fetch missing state over federation as that might fail + # and then we can skip loading the local state. + ours = await self._state_storage_controller.get_state_groups_ids( + room_id, seen, await_full_state=False + ) + state_maps.extend(ours.values()) + + # we don't need this any more, let's delete it. + del ours + room_version = await self._store.get_room_version_id(room_id) state_map = await self._state_resolution_handler.resolve_events_with_store( room_id, -- cgit 1.5.1 From cbe8a80d108068c585f76bc18e14c61371644f84 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 10 Jan 2024 15:11:59 +0000 Subject: Faster load recents for sync (#16783) This hopefully reduces the amount of state we need to keep in memory --- changelog.d/16783.misc | 1 + synapse/handlers/sync.py | 14 ++++++++------ synapse/storage/databases/main/state.py | 17 ++++++++++++++++- 3 files changed, 25 insertions(+), 7 deletions(-) create mode 100644 changelog.d/16783.misc (limited to 'synapse') diff --git a/changelog.d/16783.misc b/changelog.d/16783.misc new file mode 100644 index 0000000000..9d3b96ffc6 --- /dev/null +++ b/changelog.d/16783.misc @@ -0,0 +1 @@ +Faster load recents for sync by reducing amount of state pulled out. diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 0385c04bc2..2e10035772 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -583,10 +583,11 @@ class SyncHandler: # `recents`, so partial state is only a problem when a membership # event turns up in `recents` but has not made it into the current # state. - current_state_ids_map = ( - await self.store.get_partial_current_state_ids(room_id) + current_state_ids = ( + await self.store.check_if_events_in_current_state( + {e.event_id for e in recents if e.is_state()} + ) ) - current_state_ids = frozenset(current_state_ids_map.values()) recents = await filter_events_for_client( self._storage_controllers, @@ -667,10 +668,11 @@ class SyncHandler: # `loaded_recents`, so partial state is only a problem when a # membership event turns up in `loaded_recents` but has not made it # into the current state. - current_state_ids_map = ( - await self.store.get_partial_current_state_ids(room_id) + current_state_ids = ( + await self.store.check_if_events_in_current_state( + {e.event_id for e in loaded_recents if e.is_state()} + ) ) - current_state_ids = frozenset(current_state_ids_map.values()) loaded_recents = await filter_events_for_client( self._storage_controllers, diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index 06c44bb563..8006046453 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -24,6 +24,7 @@ from typing import ( Any, Collection, Dict, + FrozenSet, Iterable, List, Mapping, @@ -55,7 +56,7 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore -from synapse.types import JsonDict, JsonMapping, StateKey, StateMap +from synapse.types import JsonDict, JsonMapping, StateKey, StateMap, StrCollection from synapse.types.state import StateFilter from synapse.util.caches import intern_string from synapse.util.caches.descriptors import cached, cachedList @@ -323,6 +324,20 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): "get_partial_current_state_ids", _get_current_state_ids_txn ) + async def check_if_events_in_current_state( + self, event_ids: StrCollection + ) -> FrozenSet[str]: + """Checks and returns which of the given events is part of the current state.""" + rows = await self.db_pool.simple_select_many_batch( + table="current_state_events", + column="event_id", + iterable=event_ids, + retcols=("event_id",), + desc="check_if_events_in_current_state", + ) + + return frozenset(event_id for event_id, in rows) + # FIXME: how should this be cached? @cancellable async def get_partial_filtered_current_state_ids( -- cgit 1.5.1 From a986f86c827879f26abd574da6c7d12ac9578dd8 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 10 Jan 2024 17:16:49 +0000 Subject: Correctly handle OIDC config with no `client_secret` set (#16806) In previous versions of authlib using `client_secret_basic` without a `client_secret` would result in an invalid auth header. Since authlib 1.3 it throws an exception. The configuration may be accepted in by very lax servers, so we don't want to deny it outright. Instead, let's default the `client_auth_method` to `none`, which does the right thing. If the config specifies `client_auth_method` and no `client_secret` then that is going to be bogus and we should reject it --- changelog.d/16806.misc | 1 + synapse/config/oidc.py | 15 ++++++++++++++- 2 files changed, 15 insertions(+), 1 deletion(-) create mode 100644 changelog.d/16806.misc (limited to 'synapse') diff --git a/changelog.d/16806.misc b/changelog.d/16806.misc new file mode 100644 index 0000000000..623338268b --- /dev/null +++ b/changelog.d/16806.misc @@ -0,0 +1 @@ +Reject OIDC config when `client_secret` isn't specified, but the auth method requires one. diff --git a/synapse/config/oidc.py b/synapse/config/oidc.py index 07ca16c94c..8f9cdbddbb 100644 --- a/synapse/config/oidc.py +++ b/synapse/config/oidc.py @@ -299,6 +299,19 @@ def _parse_oidc_config_dict( config_path + ("client_secret",), ) + # If no client secret is specified then the auth method must be None + client_auth_method = oidc_config.get("client_auth_method") + if client_secret is None and client_secret_jwt_key is None: + if client_auth_method is None: + client_auth_method = "none" + elif client_auth_method != "none": + raise ConfigError( + "No 'client_secret' is set in OIDC config, and 'client_auth_method' is not set to 'none'" + ) + + if client_auth_method is None: + client_auth_method = "client_secret_basic" + return OidcProviderConfig( idp_id=idp_id, idp_name=oidc_config.get("idp_name", "OIDC"), @@ -309,7 +322,7 @@ def _parse_oidc_config_dict( client_id=oidc_config["client_id"], client_secret=client_secret, client_secret_jwt_key=client_secret_jwt_key, - client_auth_method=oidc_config.get("client_auth_method", "client_secret_basic"), + client_auth_method=client_auth_method, pkce_method=oidc_config.get("pkce_method", "auto"), scopes=oidc_config.get("scopes", ["openid"]), authorization_endpoint=oidc_config.get("authorization_endpoint"), -- cgit 1.5.1 From b11f7b5122061d4908b3328689486bc16dc58445 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 11 Jan 2024 11:52:13 +0000 Subject: Improve DB performance of calculating badge counts for push. (#16756) The crux of the change is to try and make the queries simpler and pull out fewer rows. Before, there were quite a few joins against subqueries, which caused postgres to pull out more rows than necessary. Instead, let's simplify the query and do some of the filtering out in Python instead, letting Postgres do better optimizations now that it doesn't have to deal with joins against subqueries. Review note: this is a complete rewrite of the function, so not sure how useful the diff is. --------- Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> --- changelog.d/16756.misc | 1 + synapse/push/push_tools.py | 8 +- .../storage/databases/main/event_push_actions.py | 253 ++++++++++++--------- 3 files changed, 148 insertions(+), 114 deletions(-) create mode 100644 changelog.d/16756.misc (limited to 'synapse') diff --git a/changelog.d/16756.misc b/changelog.d/16756.misc new file mode 100644 index 0000000000..200e18fb7b --- /dev/null +++ b/changelog.d/16756.misc @@ -0,0 +1 @@ +Improve DB performance of calculating badge counts for push. diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index 03ce0b4dc6..cce9583fa7 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -28,17 +28,11 @@ from synapse.storage.databases.main import DataStore async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -> int: invites = await store.get_invited_rooms_for_local_user(user_id) - joins = await store.get_rooms_for_user(user_id) badge = len(invites) room_to_count = await store.get_unread_counts_by_room_for_user(user_id) - for room_id, notify_count in room_to_count.items(): - # room_to_count may include rooms which the user has left, - # ignore those. - if room_id not in joins: - continue - + for _room_id, notify_count in room_to_count.items(): if notify_count == 0: continue diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 650b8c8135..6d4e2942ea 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -357,10 +357,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas This function is intentionally not cached because it is called to calculate the unread badge for push notifications and thus the result is expected to change. - Note that this function assumes the user is a member of the room. Because - summary rows are not removed when a user leaves a room, the caller must - filter out those results from the result. - Returns: A map of room ID to notification counts for the given user. """ @@ -373,127 +369,170 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas def _get_unread_counts_by_room_for_user_txn( self, txn: LoggingTransaction, user_id: str ) -> Dict[str, int]: - receipt_types_clause, args = make_in_list_sql_clause( + # To get the badge count of all rooms we need to make three queries: + # 1. Fetch all counts from `event_push_summary`, discarding any stale + # rooms. + # 2. Fetch all notifications from `event_push_actions` that haven't + # been rotated yet. + # 3. Fetch all notifications from `event_push_actions` for the stale + # rooms. + # + # The "stale room" scenario generally happens when there is a new read + # receipt that hasn't yet been processed to update the + # `event_push_summary` table. When that happens we ignore the + # `event_push_summary` table for that room and calculate the count + # manually from `event_push_actions`. + + # We need to only take into account read receipts of these types. + receipt_types_clause, receipt_types_args = make_in_list_sql_clause( self.database_engine, "receipt_type", (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE), ) - args.extend([user_id, user_id]) - - receipts_cte = f""" - WITH all_receipts AS ( - SELECT room_id, thread_id, MAX(event_stream_ordering) AS max_receipt_stream_ordering - FROM receipts_linearized - LEFT JOIN events USING (room_id, event_id) - WHERE - {receipt_types_clause} - AND user_id = ? - GROUP BY room_id, thread_id - ) - """ - - receipts_joins = """ - LEFT JOIN ( - SELECT room_id, thread_id, - max_receipt_stream_ordering AS threaded_receipt_stream_ordering - FROM all_receipts - WHERE thread_id IS NOT NULL - ) AS threaded_receipts USING (room_id, thread_id) - LEFT JOIN ( - SELECT room_id, thread_id, - max_receipt_stream_ordering AS unthreaded_receipt_stream_ordering - FROM all_receipts - WHERE thread_id IS NULL - ) AS unthreaded_receipts USING (room_id) - """ - - # First get summary counts by room / thread for the user. We use the max receipt - # stream ordering of both threaded & unthreaded receipts to compare against the - # summary table. - # - # PostgreSQL and SQLite differ in comparing scalar numerics. - if isinstance(self.database_engine, PostgresEngine): - # GREATEST ignores NULLs. - max_clause = """GREATEST( - threaded_receipt_stream_ordering, - unthreaded_receipt_stream_ordering - )""" - else: - # MAX returns NULL if any are NULL, so COALESCE to 0 first. - max_clause = """MAX( - COALESCE(threaded_receipt_stream_ordering, 0), - COALESCE(unthreaded_receipt_stream_ordering, 0) - )""" + # Step 1, fetch all counts from `event_push_summary` for the user. This + # is slightly convoluted as we also need to pull out the stream ordering + # of the most recent receipt of the user in the room (either a thread + # aware receipt or thread unaware receipt) in order to determine + # whether the row in `event_push_summary` is stale. Hence the outer + # GROUP BY and odd join condition against `receipts_linearized`. sql = f""" - {receipts_cte} - SELECT eps.room_id, eps.thread_id, notif_count - FROM event_push_summary AS eps - {receipts_joins} - WHERE user_id = ? - AND notif_count != 0 - AND ( - (last_receipt_stream_ordering IS NULL AND stream_ordering > {max_clause}) - OR last_receipt_stream_ordering = {max_clause} + SELECT room_id, notif_count, stream_ordering, thread_id, last_receipt_stream_ordering, + MAX(receipt_stream_ordering) + FROM ( + SELECT e.room_id, notif_count, e.stream_ordering, e.thread_id, last_receipt_stream_ordering, + ev.stream_ordering AS receipt_stream_ordering + FROM event_push_summary AS e + INNER JOIN local_current_membership USING (user_id, room_id) + LEFT JOIN receipts_linearized AS r ON ( + e.user_id = r.user_id + AND e.room_id = r.room_id + AND (e.thread_id = r.thread_id OR r.thread_id IS NULL) + AND {receipt_types_clause} ) + LEFT JOIN events AS ev ON (r.event_id = ev.event_id) + WHERE e.user_id = ? and notif_count > 0 + ) AS es + GROUP BY room_id, notif_count, stream_ordering, thread_id, last_receipt_stream_ordering """ - txn.execute(sql, args) - - seen_thread_ids = set() - room_to_count: Dict[str, int] = defaultdict(int) - for room_id, thread_id, notif_count in txn: - room_to_count[room_id] += notif_count - seen_thread_ids.add(thread_id) + txn.execute( + sql, + receipt_types_args + + [ + user_id, + ], + ) - # Now get any event push actions that haven't been rotated using the same OR - # join and filter by receipt and event push summary rotated up to stream ordering. - sql = f""" - {receipts_cte} - SELECT epa.room_id, epa.thread_id, COUNT(CASE WHEN epa.notif = 1 THEN 1 END) AS notif_count - FROM event_push_actions AS epa - {receipts_joins} - WHERE user_id = ? - AND epa.notif = 1 - AND stream_ordering > (SELECT stream_ordering FROM event_push_summary_stream_ordering) - AND (threaded_receipt_stream_ordering IS NULL OR stream_ordering > threaded_receipt_stream_ordering) - AND (unthreaded_receipt_stream_ordering IS NULL OR stream_ordering > unthreaded_receipt_stream_ordering) - GROUP BY epa.room_id, epa.thread_id - """ - txn.execute(sql, args) + room_to_count: Dict[str, int] = defaultdict(int) + stale_room_ids = set() + for row in txn: + room_id = row[0] + notif_count = row[1] + stream_ordering = row[2] + _thread_id = row[3] + last_receipt_stream_ordering = row[4] + receipt_stream_ordering = row[5] + + if last_receipt_stream_ordering is None: + if receipt_stream_ordering is None: + room_to_count[room_id] += notif_count + elif stream_ordering > receipt_stream_ordering: + room_to_count[room_id] += notif_count + else: + # The latest read receipt from the user is after all the rows for + # this room in `event_push_summary`. We ignore them, and + # calculate the count from `event_push_actions` in step 3. + pass + elif last_receipt_stream_ordering == receipt_stream_ordering: + room_to_count[room_id] += notif_count + else: + # The row is stale if `last_receipt_stream_ordering` is set and + # *doesn't* match the latest receipt from the user. + stale_room_ids.add(room_id) - for room_id, thread_id, notif_count in txn: - # Note: only count push actions we have valid summaries for with up to date receipt. - if thread_id not in seen_thread_ids: - continue - room_to_count[room_id] += notif_count + # Discard any stale rooms from `room_to_count`, as we will recalculate + # them in step 3. + for room_id in stale_room_ids: + room_to_count.pop(room_id, None) - thread_id_clause, thread_ids_args = make_in_list_sql_clause( - self.database_engine, "epa.thread_id", seen_thread_ids + # Step 2, basically the same query, except against `event_push_actions` + # and only fetching rows inserted since the last rotation. + rotated_upto_stream_ordering = self.db_pool.simple_select_one_onecol_txn( + txn, + table="event_push_summary_stream_ordering", + keyvalues={}, + retcol="stream_ordering", ) - # Finally re-check event_push_actions for any rooms not in the summary, ignoring - # the rotated up-to position. This handles the case where a read receipt has arrived - # but not been rotated meaning the summary table is out of date, so we go back to - # the push actions table. sql = f""" - {receipts_cte} - SELECT epa.room_id, COUNT(CASE WHEN epa.notif = 1 THEN 1 END) AS notif_count - FROM event_push_actions AS epa - {receipts_joins} - WHERE user_id = ? - AND NOT {thread_id_clause} - AND epa.notif = 1 - AND (threaded_receipt_stream_ordering IS NULL OR stream_ordering > threaded_receipt_stream_ordering) - AND (unthreaded_receipt_stream_ordering IS NULL OR stream_ordering > unthreaded_receipt_stream_ordering) - GROUP BY epa.room_id + SELECT room_id, thread_id + FROM ( + SELECT e.room_id, e.stream_ordering, e.thread_id, + ev.stream_ordering AS receipt_stream_ordering + FROM event_push_actions AS e + INNER JOIN local_current_membership USING (user_id, room_id) + LEFT JOIN receipts_linearized AS r ON ( + e.user_id = r.user_id + AND e.room_id = r.room_id + AND (e.thread_id = r.thread_id OR r.thread_id IS NULL) + AND {receipt_types_clause} + ) + LEFT JOIN events AS ev ON (r.event_id = ev.event_id) + WHERE e.user_id = ? and notif > 0 + AND e.stream_ordering > ? + ) AS es + GROUP BY room_id, stream_ordering, thread_id + HAVING stream_ordering > COALESCE(MAX(receipt_stream_ordering), 0) """ - args.extend(thread_ids_args) - txn.execute(sql, args) + txn.execute( + sql, + receipt_types_args + [user_id, rotated_upto_stream_ordering], + ) + for room_id, _thread_id in txn: + # Again, we ignore any stale rooms. + if room_id not in stale_room_ids: + # For event push actions it is one notification per row. + room_to_count[room_id] += 1 + + # Step 3, if we have stale rooms then we need to recalculate the counts + # from `event_push_actions`. Again, this is basically the same query as + # above except without a lower bound on stream ordering and only against + # a specific set of rooms. + if stale_room_ids: + room_id_clause, room_id_args = make_in_list_sql_clause( + self.database_engine, + "e.room_id", + stale_room_ids, + ) - for room_id, notif_count in txn: - room_to_count[room_id] += notif_count + sql = f""" + SELECT room_id, thread_id + FROM ( + SELECT e.room_id, e.stream_ordering, e.thread_id, + ev.stream_ordering AS receipt_stream_ordering + FROM event_push_actions AS e + INNER JOIN local_current_membership USING (user_id, room_id) + LEFT JOIN receipts_linearized AS r ON ( + e.user_id = r.user_id + AND e.room_id = r.room_id + AND (e.thread_id = r.thread_id OR r.thread_id IS NULL) + AND {receipt_types_clause} + ) + LEFT JOIN events AS ev ON (r.event_id = ev.event_id) + WHERE e.user_id = ? and notif > 0 + AND {room_id_clause} + ) AS es + GROUP BY room_id, stream_ordering, thread_id + HAVING stream_ordering > COALESCE(MAX(receipt_stream_ordering), 0) + """ + txn.execute( + sql, + receipt_types_args + [user_id] + room_id_args, + ) + for room_id, _ in txn: + room_to_count[room_id] += 1 return room_to_count -- cgit 1.5.1 From c43f751013489c7a037312455d4ee2793065ed6c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 11 Jan 2024 13:37:57 +0000 Subject: Optimize query for fetching to-device messages in `/sync` (#16805) The current query supports passing in a list of users, which generates a query using `user_id = ANY(..)`. This is generates a less efficient query plan that is notably slower than a simple `user_id = ?` condition. Note: The new function is mostly a copy and paste and then a simplification of the existing function. --- changelog.d/16805.misc | 1 + synapse/storage/databases/main/deviceinbox.py | 149 +++++++++++++------------- 2 files changed, 73 insertions(+), 77 deletions(-) create mode 100644 changelog.d/16805.misc (limited to 'synapse') diff --git a/changelog.d/16805.misc b/changelog.d/16805.misc new file mode 100644 index 0000000000..0b54ab0f74 --- /dev/null +++ b/changelog.d/16805.misc @@ -0,0 +1 @@ +Optimize query for fetching to-device messages in `/sync`. diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 40477b9da0..fa47b471e8 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -245,33 +245,74 @@ class DeviceInboxWorkerStore(SQLBaseStore): * The last-processed stream ID. Subsequent calls of this function with the same device should pass this value as 'from_stream_id'. """ - ( - user_id_device_id_to_messages, - last_processed_stream_id, - ) = await self._get_device_messages( - user_ids=[user_id], - device_id=device_id, - from_stream_id=from_stream_id, - to_stream_id=to_stream_id, - limit=limit, - ) - - if not user_id_device_id_to_messages: + if not self._device_inbox_stream_cache.has_entity_changed( + user_id, from_stream_id + ): # There were no messages! return [], to_stream_id - # Extract the messages, no need to return the user and device ID again - to_device_messages = user_id_device_id_to_messages.get((user_id, device_id), []) + def get_device_messages_txn( + txn: LoggingTransaction, + ) -> Tuple[List[JsonDict], int]: + sql = """ + SELECT stream_id, message_json FROM device_inbox + WHERE user_id = ? AND device_id = ? + AND ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC + LIMIT ? + """ + txn.execute(sql, (user_id, device_id, from_stream_id, to_stream_id, limit)) + + # Create and fill a dictionary of (user ID, device ID) -> list of messages + # intended for each device. + last_processed_stream_pos = to_stream_id + to_device_messages: List[JsonDict] = [] + rowcount = 0 + for row in txn: + rowcount += 1 + + last_processed_stream_pos = row[0] + message_dict = db_to_json(row[1]) + + # Store the device details + to_device_messages.append(message_dict) - return to_device_messages, last_processed_stream_id + # start a new span for each message, so that we can tag each separately + with start_active_span("get_to_device_message"): + set_tag(SynapseTags.TO_DEVICE_TYPE, message_dict["type"]) + set_tag(SynapseTags.TO_DEVICE_SENDER, message_dict["sender"]) + set_tag(SynapseTags.TO_DEVICE_RECIPIENT, user_id) + set_tag(SynapseTags.TO_DEVICE_RECIPIENT_DEVICE, device_id) + set_tag( + SynapseTags.TO_DEVICE_MSGID, + message_dict["content"].get(EventContentFields.TO_DEVICE_MSGID), + ) + + if rowcount == limit: + # We ended up bumping up against the message limit. There may be more messages + # to retrieve. Return what we have, as well as the last stream position that + # was processed. + # + # The caller is expected to set this as the lower (exclusive) bound + # for the next query of this device. + return to_device_messages, last_processed_stream_pos + + # The limit was not reached, thus we know that recipient_device_to_messages + # contains all to-device messages for the given device and stream id range. + # + # We return to_stream_id, which the caller should then provide as the lower + # (exclusive) bound on the next query of this device. + return to_device_messages, to_stream_id + + return await self.db_pool.runInteraction( + "get_messages_for_device", get_device_messages_txn + ) async def _get_device_messages( self, user_ids: Collection[str], from_stream_id: int, to_stream_id: int, - device_id: Optional[str] = None, - limit: Optional[int] = None, ) -> Tuple[Dict[Tuple[str, str], List[JsonDict]], int]: """ Retrieve pending to-device messages for a collection of user devices. @@ -291,11 +332,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): user_ids: The user IDs to filter device messages by. from_stream_id: The lower boundary of stream id to filter with (exclusive). to_stream_id: The upper boundary of stream id to filter with (inclusive). - device_id: A device ID to query to-device messages for. If not provided, to-device - messages from all device IDs for the given user IDs will be queried. May not be - provided if `user_ids` contains more than one entry. - limit: The maximum number of to-device messages to return. Can only be used when - passing a single user ID / device ID tuple. + Returns: A tuple containing: @@ -308,30 +345,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): logger.warning("No users provided upon querying for device IDs") return {}, to_stream_id - # Prevent a query for one user's device also retrieving another user's device with - # the same device ID (device IDs are not unique across users). - if len(user_ids) > 1 and device_id is not None: - raise AssertionError( - "Programming error: 'device_id' cannot be supplied to " - "_get_device_messages when >1 user_id has been provided" - ) - - # A limit can only be applied when querying for a single user ID / device ID tuple. - # See the docstring of this function for more details. - if limit is not None and device_id is None: - raise AssertionError( - "Programming error: _get_device_messages was passed 'limit' " - "without a specific user_id/device_id" - ) - user_ids_to_query: Set[str] = set() - device_ids_to_query: Set[str] = set() - - # Note that a device ID could be an empty str - if device_id is not None: - # If a device ID was passed, use it to filter results. - # Otherwise, device IDs will be derived from the given collection of user IDs. - device_ids_to_query.add(device_id) # Determine which users have devices with pending messages for user_id in user_ids: @@ -355,20 +369,20 @@ class DeviceInboxWorkerStore(SQLBaseStore): # hidden devices should not receive to-device messages. # Note that this is more efficient than just dropping `device_id` from the query, # since device_inbox has an index on `(user_id, device_id, stream_id)` - if not device_ids_to_query: - user_device_dicts = cast( - List[Tuple[str]], - self.db_pool.simple_select_many_txn( - txn, - table="devices", - column="user_id", - iterable=user_ids_to_query, - keyvalues={"hidden": False}, - retcols=("device_id",), - ), - ) - device_ids_to_query.update({row[0] for row in user_device_dicts}) + user_device_dicts = cast( + List[Tuple[str]], + self.db_pool.simple_select_many_txn( + txn, + table="devices", + column="user_id", + iterable=user_ids_to_query, + keyvalues={"hidden": False}, + retcols=("device_id",), + ), + ) + + device_ids_to_query = {row[0] for row in user_device_dicts} if not device_ids_to_query: # We've ended up with no devices to query. @@ -400,22 +414,15 @@ class DeviceInboxWorkerStore(SQLBaseStore): to_stream_id, ) - # If a limit was provided, limit the data retrieved from the database - if limit is not None: - sql += "LIMIT ?" - sql_args += (limit,) - txn.execute(sql, sql_args) # Create and fill a dictionary of (user ID, device ID) -> list of messages # intended for each device. - last_processed_stream_pos = to_stream_id recipient_device_to_messages: Dict[Tuple[str, str], List[JsonDict]] = {} rowcount = 0 for row in txn: rowcount += 1 - last_processed_stream_pos = row[0] recipient_user_id = row[1] recipient_device_id = row[2] message_dict = db_to_json(row[3]) @@ -436,18 +443,6 @@ class DeviceInboxWorkerStore(SQLBaseStore): message_dict["content"].get(EventContentFields.TO_DEVICE_MSGID), ) - if limit is not None and rowcount == limit: - # We ended up bumping up against the message limit. There may be more messages - # to retrieve. Return what we have, as well as the last stream position that - # was processed. - # - # The caller is expected to set this as the lower (exclusive) bound - # for the next query of this device. - return recipient_device_to_messages, last_processed_stream_pos - - # The limit was not reached, thus we know that recipient_device_to_messages - # contains all to-device messages for the given device and stream id range. - # # We return to_stream_id, which the caller should then provide as the lower # (exclusive) bound on the next query of this device. return recipient_device_to_messages, to_stream_id -- cgit 1.5.1