From c3516e9decc355b75a297d72a13b98a43d312e66 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Tue, 16 Aug 2022 12:16:56 +0000 Subject: Faster room joins: make `/joined_members` block whilst the room is partial stated. (#13514) --- synapse/storage/controllers/state.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) (limited to 'synapse/storage/controllers/state.py') diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py index 0d480f1014..0c78eb735e 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py @@ -30,6 +30,7 @@ from typing import ( from synapse.api.constants import EventTypes from synapse.events import EventBase from synapse.logging.opentracing import trace +from synapse.storage.roommember import ProfileInfo from synapse.storage.state import StateFilter from synapse.storage.util.partial_state_events_tracker import ( PartialCurrentStateTracker, @@ -506,3 +507,15 @@ class StateStorageController: await self._partial_state_room_tracker.await_full_state(room_id) return await self.stores.main.get_current_hosts_in_room(room_id) + + async def get_users_in_room_with_profiles( + self, room_id: str + ) -> Dict[str, ProfileInfo]: + """ + Get the current users in the room with their profiles. + If the room is currently partial-stated, this will block until the room has + full state. + """ + await self._partial_state_room_tracker.await_full_state(room_id) + + return await self.stores.main.get_users_in_room_with_profiles(room_id) -- cgit 1.5.1 From 0a4efbc1ddc3a58a6d75ad5d4d960b9ed367481e Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Tue, 16 Aug 2022 12:39:40 -0500 Subject: Instrument the federation/backfill part of `/messages` (#13489) Instrument the federation/backfill part of `/messages` so it's easier to follow what's going on in Jaeger when viewing a trace. Split out from https://github.com/matrix-org/synapse/pull/13440 Follow-up from https://github.com/matrix-org/synapse/pull/13368 Part of https://github.com/matrix-org/synapse/issues/13356 --- changelog.d/13489.misc | 1 + synapse/federation/federation_client.py | 27 ++++- synapse/handlers/federation.py | 10 +- synapse/handlers/federation_event.py | 112 ++++++++++++++++++--- synapse/logging/opentracing.py | 19 +++- synapse/storage/controllers/persist_events.py | 30 ++++-- synapse/storage/controllers/state.py | 5 +- synapse/storage/databases/main/event_federation.py | 6 ++ synapse/storage/databases/main/events.py | 2 + synapse/storage/databases/main/events_worker.py | 38 +++++-- .../storage/util/partial_state_events_tracker.py | 3 + 11 files changed, 220 insertions(+), 33 deletions(-) create mode 100644 changelog.d/13489.misc (limited to 'synapse/storage/controllers/state.py') diff --git a/changelog.d/13489.misc b/changelog.d/13489.misc new file mode 100644 index 0000000000..5e4853860e --- /dev/null +++ b/changelog.d/13489.misc @@ -0,0 +1 @@ +Instrument the federation/backfill part of `/messages` for understandable traces in Jaeger. diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 54ffbd8170..987f6dad46 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -61,7 +61,7 @@ from synapse.federation.federation_base import ( ) from synapse.federation.transport.client import SendJoinResponse from synapse.http.types import QueryParams -from synapse.logging.opentracing import trace +from synapse.logging.opentracing import SynapseTags, set_tag, tag_args, trace from synapse.types import JsonDict, UserID, get_domain_from_id from synapse.util.async_helpers import concurrently_execute from synapse.util.caches.expiringcache import ExpiringCache @@ -235,6 +235,7 @@ class FederationClient(FederationBase): ) @trace + @tag_args async def backfill( self, dest: str, room_id: str, limit: int, extremities: Collection[str] ) -> Optional[List[EventBase]]: @@ -337,6 +338,8 @@ class FederationClient(FederationBase): return None + @trace + @tag_args async def get_pdu( self, destinations: Iterable[str], @@ -448,6 +451,8 @@ class FederationClient(FederationBase): return event_copy + @trace + @tag_args async def get_room_state_ids( self, destination: str, room_id: str, event_id: str ) -> Tuple[List[str], List[str]]: @@ -467,6 +472,23 @@ class FederationClient(FederationBase): state_event_ids = result["pdu_ids"] auth_event_ids = result.get("auth_chain_ids", []) + set_tag( + SynapseTags.RESULT_PREFIX + "state_event_ids", + str(state_event_ids), + ) + set_tag( + SynapseTags.RESULT_PREFIX + "state_event_ids.length", + str(len(state_event_ids)), + ) + set_tag( + SynapseTags.RESULT_PREFIX + "auth_event_ids", + str(auth_event_ids), + ) + set_tag( + SynapseTags.RESULT_PREFIX + "auth_event_ids.length", + str(len(auth_event_ids)), + ) + if not isinstance(state_event_ids, list) or not isinstance( auth_event_ids, list ): @@ -474,6 +496,8 @@ class FederationClient(FederationBase): return state_event_ids, auth_event_ids + @trace + @tag_args async def get_room_state( self, destination: str, @@ -533,6 +557,7 @@ class FederationClient(FederationBase): return valid_state_events, valid_auth_events + @trace async def _check_sigs_and_hash_and_fetch( self, origin: str, diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 6f5ab86ac4..d13011d138 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -59,7 +59,7 @@ from synapse.events.validator import EventValidator from synapse.federation.federation_client import InvalidResponseError from synapse.http.servlet import assert_params_in_dict from synapse.logging.context import nested_logging_context -from synapse.logging.opentracing import tag_args, trace +from synapse.logging.opentracing import SynapseTags, set_tag, tag_args, trace from synapse.metrics.background_process_metrics import run_as_background_process from synapse.module_api import NOT_SPAM from synapse.replication.http.federation import ( @@ -370,6 +370,14 @@ class FederationHandler: logger.debug( "_maybe_backfill_inner: extremities_to_request %s", extremities_to_request ) + set_tag( + SynapseTags.RESULT_PREFIX + "extremities_to_request", + str(extremities_to_request), + ) + set_tag( + SynapseTags.RESULT_PREFIX + "extremities_to_request.length", + str(len(extremities_to_request)), + ) # Now we need to decide which hosts to hit first. diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 8968b705d4..dd0d610fe9 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -59,7 +59,13 @@ from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.federation.federation_client import InvalidResponseError from synapse.logging.context import nested_logging_context -from synapse.logging.opentracing import trace +from synapse.logging.opentracing import ( + SynapseTags, + set_tag, + start_active_span, + tag_args, + trace, +) from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet from synapse.replication.http.federation import ( @@ -410,6 +416,7 @@ class FederationEventHandler: prev_member_event, ) + @trace async def process_remote_join( self, origin: str, @@ -715,7 +722,7 @@ class FederationEventHandler: @trace async def _process_pulled_events( - self, origin: str, events: Iterable[EventBase], backfilled: bool + self, origin: str, events: Collection[EventBase], backfilled: bool ) -> None: """Process a batch of events we have pulled from a remote server @@ -730,6 +737,15 @@ class FederationEventHandler: backfilled: True if this is part of a historical batch of events (inhibits notification to clients, and validation of device keys.) """ + set_tag( + SynapseTags.FUNC_ARG_PREFIX + "event_ids", + str([event.event_id for event in events]), + ) + set_tag( + SynapseTags.FUNC_ARG_PREFIX + "event_ids.length", + str(len(events)), + ) + set_tag(SynapseTags.FUNC_ARG_PREFIX + "backfilled", str(backfilled)) logger.debug( "processing pulled backfilled=%s events=%s", backfilled, @@ -753,6 +769,7 @@ class FederationEventHandler: await self._process_pulled_event(origin, ev, backfilled=backfilled) @trace + @tag_args async def _process_pulled_event( self, origin: str, event: EventBase, backfilled: bool ) -> None: @@ -854,6 +871,7 @@ class FederationEventHandler: else: raise + @trace async def _compute_event_context_with_maybe_missing_prevs( self, dest: str, event: EventBase ) -> EventContext: @@ -970,6 +988,8 @@ class FederationEventHandler: event, state_ids_before_event=state_map, partial_state=partial_state ) + @trace + @tag_args async def _get_state_ids_after_missing_prev_event( self, destination: str, @@ -1009,10 +1029,10 @@ class FederationEventHandler: logger.debug("Fetching %i events from cache/store", len(desired_events)) have_events = await self._store.have_seen_events(room_id, desired_events) - missing_desired_events = desired_events - have_events + missing_desired_event_ids = desired_events - have_events logger.debug( "We are missing %i events (got %i)", - len(missing_desired_events), + len(missing_desired_event_ids), len(have_events), ) @@ -1024,13 +1044,30 @@ class FederationEventHandler: # already have a bunch of the state events. It would be nice if the # federation api gave us a way of finding out which we actually need. - missing_auth_events = set(auth_event_ids) - have_events - missing_auth_events.difference_update( - await self._store.have_seen_events(room_id, missing_auth_events) + missing_auth_event_ids = set(auth_event_ids) - have_events + missing_auth_event_ids.difference_update( + await self._store.have_seen_events(room_id, missing_auth_event_ids) ) - logger.debug("We are also missing %i auth events", len(missing_auth_events)) + logger.debug("We are also missing %i auth events", len(missing_auth_event_ids)) - missing_events = missing_desired_events | missing_auth_events + missing_event_ids = missing_desired_event_ids | missing_auth_event_ids + + set_tag( + SynapseTags.RESULT_PREFIX + "missing_auth_event_ids", + str(missing_auth_event_ids), + ) + set_tag( + SynapseTags.RESULT_PREFIX + "missing_auth_event_ids.length", + str(len(missing_auth_event_ids)), + ) + set_tag( + SynapseTags.RESULT_PREFIX + "missing_desired_event_ids", + str(missing_desired_event_ids), + ) + set_tag( + SynapseTags.RESULT_PREFIX + "missing_desired_event_ids.length", + str(len(missing_desired_event_ids)), + ) # Making an individual request for each of 1000s of events has a lot of # overhead. On the other hand, we don't really want to fetch all of the events @@ -1041,13 +1078,13 @@ class FederationEventHandler: # # TODO: might it be better to have an API which lets us do an aggregate event # request - if (len(missing_events) * 10) >= len(auth_event_ids) + len(state_event_ids): + if (len(missing_event_ids) * 10) >= len(auth_event_ids) + len(state_event_ids): logger.debug("Requesting complete state from remote") await self._get_state_and_persist(destination, room_id, event_id) else: - logger.debug("Fetching %i events from remote", len(missing_events)) + logger.debug("Fetching %i events from remote", len(missing_event_ids)) await self._get_events_and_persist( - destination=destination, room_id=room_id, event_ids=missing_events + destination=destination, room_id=room_id, event_ids=missing_event_ids ) # We now need to fill out the state map, which involves fetching the @@ -1104,6 +1141,14 @@ class FederationEventHandler: event_id, failed_to_fetch, ) + set_tag( + SynapseTags.RESULT_PREFIX + "failed_to_fetch", + str(failed_to_fetch), + ) + set_tag( + SynapseTags.RESULT_PREFIX + "failed_to_fetch.length", + str(len(failed_to_fetch)), + ) if remote_event.is_state() and remote_event.rejected_reason is None: state_map[ @@ -1112,6 +1157,8 @@ class FederationEventHandler: return state_map + @trace + @tag_args async def _get_state_and_persist( self, destination: str, room_id: str, event_id: str ) -> None: @@ -1133,6 +1180,7 @@ class FederationEventHandler: destination=destination, room_id=room_id, event_ids=(event_id,) ) + @trace async def _process_received_pdu( self, origin: str, @@ -1283,6 +1331,7 @@ class FederationEventHandler: except Exception: logger.exception("Failed to resync device for %s", sender) + @trace async def _handle_marker_event(self, origin: str, marker_event: EventBase) -> None: """Handles backfilling the insertion event when we receive a marker event that points to one. @@ -1414,6 +1463,8 @@ class FederationEventHandler: return event_from_response + @trace + @tag_args async def _get_events_and_persist( self, destination: str, room_id: str, event_ids: Collection[str] ) -> None: @@ -1459,6 +1510,7 @@ class FederationEventHandler: logger.info("Fetched %i events of %i requested", len(events), len(event_ids)) await self._auth_and_persist_outliers(room_id, events) + @trace async def _auth_and_persist_outliers( self, room_id: str, events: Iterable[EventBase] ) -> None: @@ -1477,6 +1529,16 @@ class FederationEventHandler: """ event_map = {event.event_id: event for event in events} + event_ids = event_map.keys() + set_tag( + SynapseTags.FUNC_ARG_PREFIX + "event_ids", + str(event_ids), + ) + set_tag( + SynapseTags.FUNC_ARG_PREFIX + "event_ids.length", + str(len(event_ids)), + ) + # filter out any events we have already seen. This might happen because # the events were eagerly pushed to us (eg, during a room join), or because # another thread has raced against us since we decided to request the event. @@ -1593,6 +1655,7 @@ class FederationEventHandler: backfilled=True, ) + @trace async def _check_event_auth( self, origin: Optional[str], event: EventBase, context: EventContext ) -> None: @@ -1631,6 +1694,14 @@ class FederationEventHandler: claimed_auth_events = await self._load_or_fetch_auth_events_for_event( origin, event ) + set_tag( + SynapseTags.RESULT_PREFIX + "claimed_auth_events", + str([ev.event_id for ev in claimed_auth_events]), + ) + set_tag( + SynapseTags.RESULT_PREFIX + "claimed_auth_events.length", + str(len(claimed_auth_events)), + ) # ... and check that the event passes auth at those auth events. # https://spec.matrix.org/v1.3/server-server-api/#checks-performed-on-receipt-of-a-pdu: @@ -1728,6 +1799,7 @@ class FederationEventHandler: ) context.rejected = RejectedReason.AUTH_ERROR + @trace async def _maybe_kick_guest_users(self, event: EventBase) -> None: if event.type != EventTypes.GuestAccess: return @@ -1935,6 +2007,8 @@ class FederationEventHandler: # instead we raise an AuthError, which will make the caller ignore it. raise AuthError(code=HTTPStatus.FORBIDDEN, msg="Auth events could not be found") + @trace + @tag_args async def _get_remote_auth_chain_for_event( self, destination: str, room_id: str, event_id: str ) -> None: @@ -1963,6 +2037,7 @@ class FederationEventHandler: await self._auth_and_persist_outliers(room_id, remote_auth_events) + @trace async def _run_push_actions_and_persist_event( self, event: EventBase, context: EventContext, backfilled: bool = False ) -> None: @@ -2071,8 +2146,17 @@ class FederationEventHandler: self._message_handler.maybe_schedule_expiry(event) if not backfilled: # Never notify for backfilled events - for event in events: - await self._notify_persisted_event(event, max_stream_token) + with start_active_span("notify_persisted_events"): + set_tag( + SynapseTags.RESULT_PREFIX + "event_ids", + str([ev.event_id for ev in events]), + ) + set_tag( + SynapseTags.RESULT_PREFIX + "event_ids.length", + str(len(events)), + ) + for event in events: + await self._notify_persisted_event(event, max_stream_token) return max_stream_token.stream diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index d1fa2cf8ae..482316a1ff 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -310,6 +310,19 @@ class SynapseTags: # The name of the external cache CACHE_NAME = "cache.name" + # Used to tag function arguments + # + # Tag a named arg. The name of the argument should be appended to this prefix. + FUNC_ARG_PREFIX = "ARG." + # Tag extra variadic number of positional arguments (`def foo(first, second, *extras)`) + FUNC_ARGS = "args" + # Tag keyword args + FUNC_KWARGS = "kwargs" + + # Some intermediate result that's interesting to the function. The label for + # the result should be appended to this prefix. + RESULT_PREFIX = "RESULT." + class SynapseBaggage: FORCE_TRACING = "synapse-force-tracing" @@ -967,9 +980,9 @@ def tag_args(func: Callable[P, R]) -> Callable[P, R]: # first argument only if it's named `self` or `cls`. This isn't fool-proof # but handles the idiomatic cases. for i, arg in enumerate(args[1:], start=1): # type: ignore[index] - set_tag("ARG_" + argspec.args[i], str(arg)) - set_tag("args", str(args[len(argspec.args) :])) # type: ignore[index] - set_tag("kwargs", str(kwargs)) + set_tag(SynapseTags.FUNC_ARG_PREFIX + argspec.args[i], str(arg)) + set_tag(SynapseTags.FUNC_ARGS, str(args[len(argspec.args) :])) # type: ignore[index] + set_tag(SynapseTags.FUNC_KWARGS, str(kwargs)) yield return _custom_sync_async_decorator(func, _wrapping_logic) diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py index cf98b0ab48..dad3731b9b 100644 --- a/synapse/storage/controllers/persist_events.py +++ b/synapse/storage/controllers/persist_events.py @@ -45,8 +45,14 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.events import EventBase from synapse.events.snapshot import EventContext -from synapse.logging import opentracing from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable +from synapse.logging.opentracing import ( + SynapseTags, + active_span, + set_tag, + start_active_span_follows_from, + trace, +) from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.controllers.state import StateStorageController from synapse.storage.databases import Databases @@ -223,7 +229,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]): queue.append(end_item) # also add our active opentracing span to the item so that we get a link back - span = opentracing.active_span() + span = active_span() if span: end_item.parent_opentracing_span_contexts.append(span.context) @@ -234,7 +240,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]): res = await make_deferred_yieldable(end_item.deferred.observe()) # add another opentracing span which links to the persist trace. - with opentracing.start_active_span_follows_from( + with start_active_span_follows_from( f"{task.name}_complete", (end_item.opentracing_span_context,) ): pass @@ -266,7 +272,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]): queue = self._get_drainining_queue(room_id) for item in queue: try: - with opentracing.start_active_span_follows_from( + with start_active_span_follows_from( item.task.name, item.parent_opentracing_span_contexts, inherit_force_tracing=True, @@ -355,7 +361,7 @@ class EventsPersistenceStorageController: f"Found an unexpected task type in event persistence queue: {task}" ) - @opentracing.trace + @trace async def persist_events( self, events_and_contexts: Iterable[Tuple[EventBase, EventContext]], @@ -380,9 +386,21 @@ class EventsPersistenceStorageController: PartialStateConflictError: if attempting to persist a partial state event in a room that has been un-partial stated. """ + event_ids: List[str] = [] partitioned: Dict[str, List[Tuple[EventBase, EventContext]]] = {} for event, ctx in events_and_contexts: partitioned.setdefault(event.room_id, []).append((event, ctx)) + event_ids.append(event.event_id) + + set_tag( + SynapseTags.FUNC_ARG_PREFIX + "event_ids", + str(event_ids), + ) + set_tag( + SynapseTags.FUNC_ARG_PREFIX + "event_ids.length", + str(len(event_ids)), + ) + set_tag(SynapseTags.FUNC_ARG_PREFIX + "backfilled", str(backfilled)) async def enqueue( item: Tuple[str, List[Tuple[EventBase, EventContext]]] @@ -418,7 +436,7 @@ class EventsPersistenceStorageController: self.main_store.get_room_max_token(), ) - @opentracing.trace + @trace async def persist_event( self, event: EventBase, context: EventContext, backfilled: bool = False ) -> Tuple[EventBase, PersistedEventPosition, RoomStreamToken]: diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py index 0c78eb735e..1ad002f57b 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py @@ -29,7 +29,7 @@ from typing import ( from synapse.api.constants import EventTypes from synapse.events import EventBase -from synapse.logging.opentracing import trace +from synapse.logging.opentracing import tag_args, trace from synapse.storage.roommember import ProfileInfo from synapse.storage.state import StateFilter from synapse.storage.util.partial_state_events_tracker import ( @@ -229,6 +229,7 @@ class StateStorageController: return {event: event_to_state[event] for event in event_ids} @trace + @tag_args async def get_state_ids_for_events( self, event_ids: Collection[str], @@ -333,6 +334,7 @@ class StateStorageController: ) @trace + @tag_args async def get_state_group_for_events( self, event_ids: Collection[str], @@ -474,6 +476,7 @@ class StateStorageController: prev_stream_id, max_stream_id ) + @trace async def get_current_state( self, room_id: str, state_filter: Optional[StateFilter] = None ) -> StateMap[EventBase]: diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 0bc8401f2b..c836078da6 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -712,6 +712,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas # Return all events where not all sets can reach them. return {eid for eid, n in event_to_missing_sets.items() if n} + @trace + @tag_args async def get_oldest_event_ids_with_depth_in_room( self, room_id: str ) -> List[Tuple[str, int]]: @@ -770,6 +772,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas room_id, ) + @trace async def get_insertion_event_backward_extremities_in_room( self, room_id: str ) -> List[Tuple[str, int]]: @@ -1342,6 +1345,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas event_results.reverse() return event_results + @trace + @tag_args async def get_successor_events(self, event_id: str) -> List[str]: """Fetch all events that have the given event as a prev event @@ -1378,6 +1383,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas _delete_old_forward_extrem_cache_txn, ) + @trace async def insert_insertion_extremity(self, event_id: str, room_id: str) -> None: await self.db_pool.simple_upsert( table="insertion_event_extremities", diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 5560b38a48..a4010ee28d 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -40,6 +40,7 @@ from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import RoomVersions from synapse.events import EventBase, relation_from_event from synapse.events.snapshot import EventContext +from synapse.logging.opentracing import trace from synapse.storage._base import db_to_json, make_in_list_sql_clause from synapse.storage.database import ( DatabasePool, @@ -145,6 +146,7 @@ class PersistEventsStore: self._backfill_id_gen: AbstractStreamIdGenerator = self.store._backfill_id_gen self._stream_id_gen: AbstractStreamIdGenerator = self.store._stream_id_gen + @trace async def _persist_events_and_state_updates( self, events_and_contexts: List[Tuple[EventBase, EventContext]], diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index b07d812ae2..8a7cdb024d 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -54,6 +54,7 @@ from synapse.logging.context import ( current_context, make_deferred_yieldable, ) +from synapse.logging.opentracing import start_active_span, tag_args, trace from synapse.metrics.background_process_metrics import ( run_as_background_process, wrap_as_background_process, @@ -430,6 +431,8 @@ class EventsWorkerStore(SQLBaseStore): return {e.event_id: e for e in events} + @trace + @tag_args async def get_events_as_list( self, event_ids: Collection[str], @@ -1090,23 +1093,42 @@ class EventsWorkerStore(SQLBaseStore): """ fetched_event_ids: Set[str] = set() fetched_events: Dict[str, _EventRow] = {} - events_to_fetch = event_ids - while events_to_fetch: - row_map = await self._enqueue_events(events_to_fetch) + async def _fetch_event_ids_and_get_outstanding_redactions( + event_ids_to_fetch: Collection[str], + ) -> Collection[str]: + """ + Fetch all of the given event_ids and return any associated redaction event_ids + that we still need to fetch in the next iteration. + """ + row_map = await self._enqueue_events(event_ids_to_fetch) # we need to recursively fetch any redactions of those events redaction_ids: Set[str] = set() - for event_id in events_to_fetch: + for event_id in event_ids_to_fetch: row = row_map.get(event_id) fetched_event_ids.add(event_id) if row: fetched_events[event_id] = row redaction_ids.update(row.redactions) - events_to_fetch = redaction_ids.difference(fetched_event_ids) - if events_to_fetch: - logger.debug("Also fetching redaction events %s", events_to_fetch) + event_ids_to_fetch = redaction_ids.difference(fetched_event_ids) + return event_ids_to_fetch + + # Grab the initial list of events requested + event_ids_to_fetch = await _fetch_event_ids_and_get_outstanding_redactions( + event_ids + ) + # Then go and recursively find all of the associated redactions + with start_active_span("recursively fetching redactions"): + while event_ids_to_fetch: + logger.debug("Also fetching redaction events %s", event_ids_to_fetch) + + event_ids_to_fetch = ( + await _fetch_event_ids_and_get_outstanding_redactions( + event_ids_to_fetch + ) + ) # build a map from event_id to EventBase event_map: Dict[str, EventBase] = {} @@ -1424,6 +1446,8 @@ class EventsWorkerStore(SQLBaseStore): return {r["event_id"] for r in rows} + @trace + @tag_args async def have_seen_events( self, room_id: str, event_ids: Iterable[str] ) -> Set[str]: diff --git a/synapse/storage/util/partial_state_events_tracker.py b/synapse/storage/util/partial_state_events_tracker.py index 466e5137f2..b4bf49dace 100644 --- a/synapse/storage/util/partial_state_events_tracker.py +++ b/synapse/storage/util/partial_state_events_tracker.py @@ -20,6 +20,7 @@ from twisted.internet import defer from twisted.internet.defer import Deferred from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable +from synapse.logging.opentracing import trace_with_opname from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.room import RoomWorkerStore from synapse.util import unwrapFirstError @@ -58,6 +59,7 @@ class PartialStateEventsTracker: for o in observers: o.callback(None) + @trace_with_opname("PartialStateEventsTracker.await_full_state") async def await_full_state(self, event_ids: Collection[str]) -> None: """Wait for all the given events to have full state. @@ -151,6 +153,7 @@ class PartialCurrentStateTracker: for o in observers: o.callback(None) + @trace_with_opname("PartialCurrentStateTracker.await_full_state") async def await_full_state(self, room_id: str) -> None: # We add the deferred immediately so that the DB call to check for # partial state doesn't race when we unpartial the room. -- cgit 1.5.1 From 84169a82dcf7dfb6eb7d307ea7f5e33cb57f6e3f Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Thu, 18 Aug 2022 11:53:02 +0100 Subject: Avoid blocking lazy-loading `/sync`s during partial joins (#13477) Use a state filter or accept partial state in a few places where we request state, to avoid blocking. To make lazy-loading `/sync`s work, we need to provide the memberships of event senders, which are not guaranteed to be in the room state. Instead we dig through auth events for memberships to present to clients. The auth events of an event are guaranteed to contain a passable membership event, otherwise the event would have been rejected. Note that this only covers the common code paths encountered during testing. There has been no exhaustive checking of all sync code paths. Fixes #13146. Signed-off-by: Sean Quah --- changelog.d/13477.misc | 1 + synapse/handlers/sync.py | 253 ++++++++++++++++++++++++++++++----- synapse/storage/controllers/state.py | 24 +++- 3 files changed, 244 insertions(+), 34 deletions(-) create mode 100644 changelog.d/13477.misc (limited to 'synapse/storage/controllers/state.py') diff --git a/changelog.d/13477.misc b/changelog.d/13477.misc new file mode 100644 index 0000000000..5d21ae9d7a --- /dev/null +++ b/changelog.d/13477.misc @@ -0,0 +1 @@ +Faster room joins: Avoid blocking lazy-loading `/sync`s during partial joins due to remote memberships. Pull remote memberships from auth events instead of the room state. diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 3ca01391c9..b4d3f3958c 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -16,9 +16,11 @@ import logging from typing import ( TYPE_CHECKING, Any, + Collection, Dict, FrozenSet, List, + Mapping, Optional, Sequence, Set, @@ -517,10 +519,17 @@ class SyncHandler: # ensure that we always include current state in the timeline current_state_ids: FrozenSet[str] = frozenset() if any(e.is_state() for e in recents): + # FIXME(faster_joins): We use the partial state here as + # we don't want to block `/sync` on finishing a lazy join. + # Which should be fine once + # https://github.com/matrix-org/synapse/issues/12989 is resolved, + # since we shouldn't reach here anymore? + # Note that we use the current state as a whitelist for filtering + # `recents`, so partial state is only a problem when a membership + # event turns up in `recents` but has not made it into the current + # state. current_state_ids_map = ( - await self._state_storage_controller.get_current_state_ids( - room_id - ) + await self.store.get_partial_current_state_ids(room_id) ) current_state_ids = frozenset(current_state_ids_map.values()) @@ -589,7 +598,13 @@ class SyncHandler: if any(e.is_state() for e in loaded_recents): # FIXME(faster_joins): We use the partial state here as # we don't want to block `/sync` on finishing a lazy join. - # Is this the correct way of doing it? + # Which should be fine once + # https://github.com/matrix-org/synapse/issues/12989 is resolved, + # since we shouldn't reach here anymore? + # Note that we use the current state as a whitelist for filtering + # `loaded_recents`, so partial state is only a problem when a + # membership event turns up in `loaded_recents` but has not made it + # into the current state. current_state_ids_map = ( await self.store.get_partial_current_state_ids(room_id) ) @@ -637,7 +652,10 @@ class SyncHandler: ) async def get_state_after_event( - self, event_id: str, state_filter: Optional[StateFilter] = None + self, + event_id: str, + state_filter: Optional[StateFilter] = None, + await_full_state: bool = True, ) -> StateMap[str]: """ Get the room state after the given event @@ -645,9 +663,14 @@ class SyncHandler: Args: event_id: event of interest state_filter: The state filter used to fetch state from the database. + await_full_state: if `True`, will block if we do not yet have complete state + at the event and `state_filter` is not satisfied by partial state. + Defaults to `True`. """ state_ids = await self._state_storage_controller.get_state_ids_for_event( - event_id, state_filter=state_filter or StateFilter.all() + event_id, + state_filter=state_filter or StateFilter.all(), + await_full_state=await_full_state, ) # using get_metadata_for_events here (instead of get_event) sidesteps an issue @@ -670,6 +693,7 @@ class SyncHandler: room_id: str, stream_position: StreamToken, state_filter: Optional[StateFilter] = None, + await_full_state: bool = True, ) -> StateMap[str]: """Get the room state at a particular stream position @@ -677,6 +701,9 @@ class SyncHandler: room_id: room for which to get state stream_position: point at which to get state state_filter: The state filter used to fetch state from the database. + await_full_state: if `True`, will block if we do not yet have complete state + at the last event in the room before `stream_position` and + `state_filter` is not satisfied by partial state. Defaults to `True`. """ # FIXME: This gets the state at the latest event before the stream ordering, # which might not be the same as the "current state" of the room at the time @@ -688,7 +715,9 @@ class SyncHandler: if last_event_id: state = await self.get_state_after_event( - last_event_id, state_filter=state_filter or StateFilter.all() + last_event_id, + state_filter=state_filter or StateFilter.all(), + await_full_state=await_full_state, ) else: @@ -891,7 +920,15 @@ class SyncHandler: with Measure(self.clock, "compute_state_delta"): # The memberships needed for events in the timeline. # Only calculated when `lazy_load_members` is on. - members_to_fetch = None + members_to_fetch: Optional[Set[str]] = None + + # A dictionary mapping user IDs to the first event in the timeline sent by + # them. Only calculated when `lazy_load_members` is on. + first_event_by_sender_map: Optional[Dict[str, EventBase]] = None + + # The contribution to the room state from state events in the timeline. + # Only contains the last event for any given state key. + timeline_state: StateMap[str] lazy_load_members = sync_config.filter_collection.lazy_load_members() include_redundant_members = ( @@ -902,10 +939,23 @@ class SyncHandler: # We only request state for the members needed to display the # timeline: - members_to_fetch = { - event.sender # FIXME: we also care about invite targets etc. - for event in batch.events - } + timeline_state = {} + + members_to_fetch = set() + first_event_by_sender_map = {} + for event in batch.events: + # Build the map from user IDs to the first timeline event they sent. + if event.sender not in first_event_by_sender_map: + first_event_by_sender_map[event.sender] = event + + # We need the event's sender, unless their membership was in a + # previous timeline event. + if (EventTypes.Member, event.sender) not in timeline_state: + members_to_fetch.add(event.sender) + # FIXME: we also care about invite targets etc. + + if event.is_state(): + timeline_state[(event.type, event.state_key)] = event.event_id if full_state: # always make sure we LL ourselves so we know we're in the room @@ -915,16 +965,21 @@ class SyncHandler: members_to_fetch.add(sync_config.user.to_string()) state_filter = StateFilter.from_lazy_load_member_list(members_to_fetch) + + # We are happy to use partial state to compute the `/sync` response. + # Since partial state may not include the lazy-loaded memberships we + # require, we fix up the state response afterwards with memberships from + # auth events. + await_full_state = False else: - state_filter = StateFilter.all() + timeline_state = { + (event.type, event.state_key): event.event_id + for event in batch.events + if event.is_state() + } - # The contribution to the room state from state events in the timeline. - # Only contains the last event for any given state key. - timeline_state = { - (event.type, event.state_key): event.event_id - for event in batch.events - if event.is_state() - } + state_filter = StateFilter.all() + await_full_state = True # Now calculate the state to return in the sync response for the room. # This is more or less the change in state between the end of the previous @@ -936,19 +991,26 @@ class SyncHandler: if batch: state_at_timeline_end = ( await self._state_storage_controller.get_state_ids_for_event( - batch.events[-1].event_id, state_filter=state_filter + batch.events[-1].event_id, + state_filter=state_filter, + await_full_state=await_full_state, ) ) state_at_timeline_start = ( await self._state_storage_controller.get_state_ids_for_event( - batch.events[0].event_id, state_filter=state_filter + batch.events[0].event_id, + state_filter=state_filter, + await_full_state=await_full_state, ) ) else: state_at_timeline_end = await self.get_state_at( - room_id, stream_position=now_token, state_filter=state_filter + room_id, + stream_position=now_token, + state_filter=state_filter, + await_full_state=await_full_state, ) state_at_timeline_start = state_at_timeline_end @@ -964,14 +1026,19 @@ class SyncHandler: if batch: state_at_timeline_start = ( await self._state_storage_controller.get_state_ids_for_event( - batch.events[0].event_id, state_filter=state_filter + batch.events[0].event_id, + state_filter=state_filter, + await_full_state=await_full_state, ) ) else: # We can get here if the user has ignored the senders of all # the recent events. state_at_timeline_start = await self.get_state_at( - room_id, stream_position=now_token, state_filter=state_filter + room_id, + stream_position=now_token, + state_filter=state_filter, + await_full_state=await_full_state, ) # for now, we disable LL for gappy syncs - see @@ -993,20 +1060,28 @@ class SyncHandler: # is indeed the case. assert since_token is not None state_at_previous_sync = await self.get_state_at( - room_id, stream_position=since_token, state_filter=state_filter + room_id, + stream_position=since_token, + state_filter=state_filter, + await_full_state=await_full_state, ) if batch: state_at_timeline_end = ( await self._state_storage_controller.get_state_ids_for_event( - batch.events[-1].event_id, state_filter=state_filter + batch.events[-1].event_id, + state_filter=state_filter, + await_full_state=await_full_state, ) ) else: # We can get here if the user has ignored the senders of all # the recent events. state_at_timeline_end = await self.get_state_at( - room_id, stream_position=now_token, state_filter=state_filter + room_id, + stream_position=now_token, + state_filter=state_filter, + await_full_state=await_full_state, ) state_ids = _calculate_state( @@ -1036,8 +1111,23 @@ class SyncHandler: (EventTypes.Member, member) for member in members_to_fetch ), + await_full_state=False, ) + # If we only have partial state for the room, `state_ids` may be missing the + # memberships we wanted. We attempt to find some by digging through the auth + # events of timeline events. + if lazy_load_members and await self.store.is_partial_state_room(room_id): + assert members_to_fetch is not None + assert first_event_by_sender_map is not None + + additional_state_ids = ( + await self._find_missing_partial_state_memberships( + room_id, members_to_fetch, first_event_by_sender_map, state_ids + ) + ) + state_ids = {**state_ids, **additional_state_ids} + # At this point, if `lazy_load_members` is enabled, `state_ids` includes # the memberships of all event senders in the timeline. This is because we # may not have sent the memberships in a previous sync. @@ -1086,6 +1176,99 @@ class SyncHandler: if e.type != EventTypes.Aliases # until MSC2261 or alternative solution } + async def _find_missing_partial_state_memberships( + self, + room_id: str, + members_to_fetch: Collection[str], + events_with_membership_auth: Mapping[str, EventBase], + found_state_ids: StateMap[str], + ) -> StateMap[str]: + """Finds missing memberships from a set of auth events and returns them as a + state map. + + Args: + room_id: The partial state room to find the remaining memberships for. + members_to_fetch: The memberships to find. + events_with_membership_auth: A mapping from user IDs to events whose auth + events are known to contain their membership. + found_state_ids: A dict from (type, state_key) -> state_event_id, containing + memberships that have been previously found. Entries in + `members_to_fetch` that have a membership in `found_state_ids` are + ignored. + + Returns: + A dict from ("m.room.member", state_key) -> state_event_id, containing the + memberships missing from `found_state_ids`. + + Raises: + KeyError: if `events_with_membership_auth` does not have an entry for a + missing membership. Memberships in `found_state_ids` do not need an + entry in `events_with_membership_auth`. + """ + additional_state_ids: MutableStateMap[str] = {} + + # Tracks the missing members for logging purposes. + missing_members = set() + + # Identify memberships missing from `found_state_ids` and pick out the auth + # events in which to look for them. + auth_event_ids: Set[str] = set() + for member in members_to_fetch: + if (EventTypes.Member, member) in found_state_ids: + continue + + missing_members.add(member) + event_with_membership_auth = events_with_membership_auth[member] + auth_event_ids.update(event_with_membership_auth.auth_event_ids()) + + auth_events = await self.store.get_events(auth_event_ids) + + # Run through the missing memberships once more, picking out the memberships + # from the pile of auth events we have just fetched. + for member in members_to_fetch: + if (EventTypes.Member, member) in found_state_ids: + continue + + event_with_membership_auth = events_with_membership_auth[member] + + # Dig through the auth events to find the desired membership. + for auth_event_id in event_with_membership_auth.auth_event_ids(): + # We only store events once we have all their auth events, + # so the auth event must be in the pile we have just + # fetched. + auth_event = auth_events[auth_event_id] + + if ( + auth_event.type == EventTypes.Member + and auth_event.state_key == member + ): + missing_members.remove(member) + additional_state_ids[ + (EventTypes.Member, member) + ] = auth_event.event_id + break + + if missing_members: + # There really shouldn't be any missing memberships now. Either: + # * we couldn't find an auth event, which shouldn't happen because we do + # not persist events with persisting their auth events first, or + # * the set of auth events did not contain a membership we wanted, which + # means our caller didn't compute the events in `members_to_fetch` + # correctly, or we somehow accepted an event whose auth events were + # dodgy. + logger.error( + "Failed to find memberships for %s in partial state room " + "%s in the auth events of %s.", + missing_members, + room_id, + [ + events_with_membership_auth[member].event_id + for member in missing_members + ], + ) + + return additional_state_ids + async def unread_notifs_for_room_id( self, room_id: str, sync_config: SyncConfig ) -> NotifCounts: @@ -1730,7 +1913,11 @@ class SyncHandler: continue if room_id in sync_result_builder.joined_room_ids or has_join: - old_state_ids = await self.get_state_at(room_id, since_token) + old_state_ids = await self.get_state_at( + room_id, + since_token, + state_filter=StateFilter.from_types([(EventTypes.Member, user_id)]), + ) old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None) old_mem_ev = None if old_mem_ev_id: @@ -1756,7 +1943,13 @@ class SyncHandler: newly_left_rooms.append(room_id) else: if not old_state_ids: - old_state_ids = await self.get_state_at(room_id, since_token) + old_state_ids = await self.get_state_at( + room_id, + since_token, + state_filter=StateFilter.from_types( + [(EventTypes.Member, user_id)] + ), + ) old_mem_ev_id = old_state_ids.get( (EventTypes.Member, user_id), None ) diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py index 1ad002f57b..f9ffd0e29e 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py @@ -234,6 +234,7 @@ class StateStorageController: self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None, + await_full_state: bool = True, ) -> Dict[str, StateMap[str]]: """ Get the state dicts corresponding to a list of events, containing the event_ids @@ -242,6 +243,9 @@ class StateStorageController: Args: event_ids: events whose state should be returned state_filter: The state filter used to fetch state from the database. + await_full_state: if `True`, will block if we do not yet have complete state + at these events and `state_filter` is not satisfied by partial state. + Defaults to `True`. Returns: A dict from event_id -> (type, state_key) -> event_id @@ -250,8 +254,12 @@ class StateStorageController: RuntimeError if we don't have a state group for one or more of the events (ie they are outliers or unknown) """ - await_full_state = True - if state_filter and not state_filter.must_await_full_state(self._is_mine_id): + if ( + await_full_state + and state_filter + and not state_filter.must_await_full_state(self._is_mine_id) + ): + # Full state is not required if the state filter is restrictive enough. await_full_state = False event_to_groups = await self.get_state_group_for_events( @@ -294,7 +302,10 @@ class StateStorageController: @trace async def get_state_ids_for_event( - self, event_id: str, state_filter: Optional[StateFilter] = None + self, + event_id: str, + state_filter: Optional[StateFilter] = None, + await_full_state: bool = True, ) -> StateMap[str]: """ Get the state dict corresponding to a particular event @@ -302,6 +313,9 @@ class StateStorageController: Args: event_id: event whose state should be returned state_filter: The state filter used to fetch state from the database. + await_full_state: if `True`, will block if we do not yet have complete state + at the event and `state_filter` is not satisfied by partial state. + Defaults to `True`. Returns: A dict from (type, state_key) -> state_event_id @@ -311,7 +325,9 @@ class StateStorageController: outlier or is unknown) """ state_map = await self.get_state_ids_for_events( - [event_id], state_filter or StateFilter.all() + [event_id], + state_filter or StateFilter.all(), + await_full_state=await_full_state, ) return state_map[event_id] -- cgit 1.5.1