summary refs log tree commit diff
diff options
context:
space:
mode:
authorEric Eastwood <erice@element.io>2022-08-16 12:39:40 -0500
committerGitHub <noreply@github.com>2022-08-16 12:39:40 -0500
commit0a4efbc1ddc3a58a6d75ad5d4d960b9ed367481e (patch)
tree323829e05d63712953bf4c901d07519612cc3e51
parentMerge branch 'master' into develop (diff)
downloadsynapse-0a4efbc1ddc3a58a6d75ad5d4d960b9ed367481e.tar.xz
Instrument the federation/backfill part of `/messages` (#13489)
Instrument the federation/backfill part of `/messages` so it's easier to follow what's going on in Jaeger when viewing a trace.

Split out from https://github.com/matrix-org/synapse/pull/13440

Follow-up from https://github.com/matrix-org/synapse/pull/13368

Part of https://github.com/matrix-org/synapse/issues/13356
-rw-r--r--changelog.d/13489.misc1
-rw-r--r--synapse/federation/federation_client.py27
-rw-r--r--synapse/handlers/federation.py10
-rw-r--r--synapse/handlers/federation_event.py112
-rw-r--r--synapse/logging/opentracing.py19
-rw-r--r--synapse/storage/controllers/persist_events.py30
-rw-r--r--synapse/storage/controllers/state.py5
-rw-r--r--synapse/storage/databases/main/event_federation.py6
-rw-r--r--synapse/storage/databases/main/events.py2
-rw-r--r--synapse/storage/databases/main/events_worker.py38
-rw-r--r--synapse/storage/util/partial_state_events_tracker.py3
11 files changed, 220 insertions, 33 deletions
diff --git a/changelog.d/13489.misc b/changelog.d/13489.misc
new file mode 100644
index 0000000000..5e4853860e
--- /dev/null
+++ b/changelog.d/13489.misc
@@ -0,0 +1 @@
+Instrument the federation/backfill part of `/messages` for understandable traces in Jaeger.
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 54ffbd8170..987f6dad46 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -61,7 +61,7 @@ from synapse.federation.federation_base import (
 )
 from synapse.federation.transport.client import SendJoinResponse
 from synapse.http.types import QueryParams
-from synapse.logging.opentracing import trace
+from synapse.logging.opentracing import SynapseTags, set_tag, tag_args, trace
 from synapse.types import JsonDict, UserID, get_domain_from_id
 from synapse.util.async_helpers import concurrently_execute
 from synapse.util.caches.expiringcache import ExpiringCache
@@ -235,6 +235,7 @@ class FederationClient(FederationBase):
         )
 
     @trace
+    @tag_args
     async def backfill(
         self, dest: str, room_id: str, limit: int, extremities: Collection[str]
     ) -> Optional[List[EventBase]]:
@@ -337,6 +338,8 @@ class FederationClient(FederationBase):
 
         return None
 
+    @trace
+    @tag_args
     async def get_pdu(
         self,
         destinations: Iterable[str],
@@ -448,6 +451,8 @@ class FederationClient(FederationBase):
 
         return event_copy
 
+    @trace
+    @tag_args
     async def get_room_state_ids(
         self, destination: str, room_id: str, event_id: str
     ) -> Tuple[List[str], List[str]]:
@@ -467,6 +472,23 @@ class FederationClient(FederationBase):
         state_event_ids = result["pdu_ids"]
         auth_event_ids = result.get("auth_chain_ids", [])
 
+        set_tag(
+            SynapseTags.RESULT_PREFIX + "state_event_ids",
+            str(state_event_ids),
+        )
+        set_tag(
+            SynapseTags.RESULT_PREFIX + "state_event_ids.length",
+            str(len(state_event_ids)),
+        )
+        set_tag(
+            SynapseTags.RESULT_PREFIX + "auth_event_ids",
+            str(auth_event_ids),
+        )
+        set_tag(
+            SynapseTags.RESULT_PREFIX + "auth_event_ids.length",
+            str(len(auth_event_ids)),
+        )
+
         if not isinstance(state_event_ids, list) or not isinstance(
             auth_event_ids, list
         ):
@@ -474,6 +496,8 @@ class FederationClient(FederationBase):
 
         return state_event_ids, auth_event_ids
 
+    @trace
+    @tag_args
     async def get_room_state(
         self,
         destination: str,
@@ -533,6 +557,7 @@ class FederationClient(FederationBase):
 
         return valid_state_events, valid_auth_events
 
+    @trace
     async def _check_sigs_and_hash_and_fetch(
         self,
         origin: str,
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 6f5ab86ac4..d13011d138 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -59,7 +59,7 @@ from synapse.events.validator import EventValidator
 from synapse.federation.federation_client import InvalidResponseError
 from synapse.http.servlet import assert_params_in_dict
 from synapse.logging.context import nested_logging_context
-from synapse.logging.opentracing import tag_args, trace
+from synapse.logging.opentracing import SynapseTags, set_tag, tag_args, trace
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.module_api import NOT_SPAM
 from synapse.replication.http.federation import (
@@ -370,6 +370,14 @@ class FederationHandler:
         logger.debug(
             "_maybe_backfill_inner: extremities_to_request %s", extremities_to_request
         )
+        set_tag(
+            SynapseTags.RESULT_PREFIX + "extremities_to_request",
+            str(extremities_to_request),
+        )
+        set_tag(
+            SynapseTags.RESULT_PREFIX + "extremities_to_request.length",
+            str(len(extremities_to_request)),
+        )
 
         # Now we need to decide which hosts to hit first.
 
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 8968b705d4..dd0d610fe9 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -59,7 +59,13 @@ from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
 from synapse.federation.federation_client import InvalidResponseError
 from synapse.logging.context import nested_logging_context
-from synapse.logging.opentracing import trace
+from synapse.logging.opentracing import (
+    SynapseTags,
+    set_tag,
+    start_active_span,
+    tag_args,
+    trace,
+)
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
 from synapse.replication.http.federation import (
@@ -410,6 +416,7 @@ class FederationEventHandler:
             prev_member_event,
         )
 
+    @trace
     async def process_remote_join(
         self,
         origin: str,
@@ -715,7 +722,7 @@ class FederationEventHandler:
 
     @trace
     async def _process_pulled_events(
-        self, origin: str, events: Iterable[EventBase], backfilled: bool
+        self, origin: str, events: Collection[EventBase], backfilled: bool
     ) -> None:
         """Process a batch of events we have pulled from a remote server
 
@@ -730,6 +737,15 @@ class FederationEventHandler:
             backfilled: True if this is part of a historical batch of events (inhibits
                 notification to clients, and validation of device keys.)
         """
+        set_tag(
+            SynapseTags.FUNC_ARG_PREFIX + "event_ids",
+            str([event.event_id for event in events]),
+        )
+        set_tag(
+            SynapseTags.FUNC_ARG_PREFIX + "event_ids.length",
+            str(len(events)),
+        )
+        set_tag(SynapseTags.FUNC_ARG_PREFIX + "backfilled", str(backfilled))
         logger.debug(
             "processing pulled backfilled=%s events=%s",
             backfilled,
@@ -753,6 +769,7 @@ class FederationEventHandler:
                 await self._process_pulled_event(origin, ev, backfilled=backfilled)
 
     @trace
+    @tag_args
     async def _process_pulled_event(
         self, origin: str, event: EventBase, backfilled: bool
     ) -> None:
@@ -854,6 +871,7 @@ class FederationEventHandler:
             else:
                 raise
 
+    @trace
     async def _compute_event_context_with_maybe_missing_prevs(
         self, dest: str, event: EventBase
     ) -> EventContext:
@@ -970,6 +988,8 @@ class FederationEventHandler:
             event, state_ids_before_event=state_map, partial_state=partial_state
         )
 
+    @trace
+    @tag_args
     async def _get_state_ids_after_missing_prev_event(
         self,
         destination: str,
@@ -1009,10 +1029,10 @@ class FederationEventHandler:
         logger.debug("Fetching %i events from cache/store", len(desired_events))
         have_events = await self._store.have_seen_events(room_id, desired_events)
 
-        missing_desired_events = desired_events - have_events
+        missing_desired_event_ids = desired_events - have_events
         logger.debug(
             "We are missing %i events (got %i)",
-            len(missing_desired_events),
+            len(missing_desired_event_ids),
             len(have_events),
         )
 
@@ -1024,13 +1044,30 @@ class FederationEventHandler:
         #   already have a bunch of the state events. It would be nice if the
         #   federation api gave us a way of finding out which we actually need.
 
-        missing_auth_events = set(auth_event_ids) - have_events
-        missing_auth_events.difference_update(
-            await self._store.have_seen_events(room_id, missing_auth_events)
+        missing_auth_event_ids = set(auth_event_ids) - have_events
+        missing_auth_event_ids.difference_update(
+            await self._store.have_seen_events(room_id, missing_auth_event_ids)
         )
-        logger.debug("We are also missing %i auth events", len(missing_auth_events))
+        logger.debug("We are also missing %i auth events", len(missing_auth_event_ids))
 
-        missing_events = missing_desired_events | missing_auth_events
+        missing_event_ids = missing_desired_event_ids | missing_auth_event_ids
+
+        set_tag(
+            SynapseTags.RESULT_PREFIX + "missing_auth_event_ids",
+            str(missing_auth_event_ids),
+        )
+        set_tag(
+            SynapseTags.RESULT_PREFIX + "missing_auth_event_ids.length",
+            str(len(missing_auth_event_ids)),
+        )
+        set_tag(
+            SynapseTags.RESULT_PREFIX + "missing_desired_event_ids",
+            str(missing_desired_event_ids),
+        )
+        set_tag(
+            SynapseTags.RESULT_PREFIX + "missing_desired_event_ids.length",
+            str(len(missing_desired_event_ids)),
+        )
 
         # Making an individual request for each of 1000s of events has a lot of
         # overhead. On the other hand, we don't really want to fetch all of the events
@@ -1041,13 +1078,13 @@ class FederationEventHandler:
         #
         # TODO: might it be better to have an API which lets us do an aggregate event
         #   request
-        if (len(missing_events) * 10) >= len(auth_event_ids) + len(state_event_ids):
+        if (len(missing_event_ids) * 10) >= len(auth_event_ids) + len(state_event_ids):
             logger.debug("Requesting complete state from remote")
             await self._get_state_and_persist(destination, room_id, event_id)
         else:
-            logger.debug("Fetching %i events from remote", len(missing_events))
+            logger.debug("Fetching %i events from remote", len(missing_event_ids))
             await self._get_events_and_persist(
-                destination=destination, room_id=room_id, event_ids=missing_events
+                destination=destination, room_id=room_id, event_ids=missing_event_ids
             )
 
         # We now need to fill out the state map, which involves fetching the
@@ -1104,6 +1141,14 @@ class FederationEventHandler:
                 event_id,
                 failed_to_fetch,
             )
+            set_tag(
+                SynapseTags.RESULT_PREFIX + "failed_to_fetch",
+                str(failed_to_fetch),
+            )
+            set_tag(
+                SynapseTags.RESULT_PREFIX + "failed_to_fetch.length",
+                str(len(failed_to_fetch)),
+            )
 
         if remote_event.is_state() and remote_event.rejected_reason is None:
             state_map[
@@ -1112,6 +1157,8 @@ class FederationEventHandler:
 
         return state_map
 
+    @trace
+    @tag_args
     async def _get_state_and_persist(
         self, destination: str, room_id: str, event_id: str
     ) -> None:
@@ -1133,6 +1180,7 @@ class FederationEventHandler:
                 destination=destination, room_id=room_id, event_ids=(event_id,)
             )
 
+    @trace
     async def _process_received_pdu(
         self,
         origin: str,
@@ -1283,6 +1331,7 @@ class FederationEventHandler:
         except Exception:
             logger.exception("Failed to resync device for %s", sender)
 
+    @trace
     async def _handle_marker_event(self, origin: str, marker_event: EventBase) -> None:
         """Handles backfilling the insertion event when we receive a marker
         event that points to one.
@@ -1414,6 +1463,8 @@ class FederationEventHandler:
 
         return event_from_response
 
+    @trace
+    @tag_args
     async def _get_events_and_persist(
         self, destination: str, room_id: str, event_ids: Collection[str]
     ) -> None:
@@ -1459,6 +1510,7 @@ class FederationEventHandler:
         logger.info("Fetched %i events of %i requested", len(events), len(event_ids))
         await self._auth_and_persist_outliers(room_id, events)
 
+    @trace
     async def _auth_and_persist_outliers(
         self, room_id: str, events: Iterable[EventBase]
     ) -> None:
@@ -1477,6 +1529,16 @@ class FederationEventHandler:
         """
         event_map = {event.event_id: event for event in events}
 
+        event_ids = event_map.keys()
+        set_tag(
+            SynapseTags.FUNC_ARG_PREFIX + "event_ids",
+            str(event_ids),
+        )
+        set_tag(
+            SynapseTags.FUNC_ARG_PREFIX + "event_ids.length",
+            str(len(event_ids)),
+        )
+
         # filter out any events we have already seen. This might happen because
         # the events were eagerly pushed to us (eg, during a room join), or because
         # another thread has raced against us since we decided to request the event.
@@ -1593,6 +1655,7 @@ class FederationEventHandler:
             backfilled=True,
         )
 
+    @trace
     async def _check_event_auth(
         self, origin: Optional[str], event: EventBase, context: EventContext
     ) -> None:
@@ -1631,6 +1694,14 @@ class FederationEventHandler:
         claimed_auth_events = await self._load_or_fetch_auth_events_for_event(
             origin, event
         )
+        set_tag(
+            SynapseTags.RESULT_PREFIX + "claimed_auth_events",
+            str([ev.event_id for ev in claimed_auth_events]),
+        )
+        set_tag(
+            SynapseTags.RESULT_PREFIX + "claimed_auth_events.length",
+            str(len(claimed_auth_events)),
+        )
 
         # ... and check that the event passes auth at those auth events.
         # https://spec.matrix.org/v1.3/server-server-api/#checks-performed-on-receipt-of-a-pdu:
@@ -1728,6 +1799,7 @@ class FederationEventHandler:
             )
             context.rejected = RejectedReason.AUTH_ERROR
 
+    @trace
     async def _maybe_kick_guest_users(self, event: EventBase) -> None:
         if event.type != EventTypes.GuestAccess:
             return
@@ -1935,6 +2007,8 @@ class FederationEventHandler:
         # instead we raise an AuthError, which will make the caller ignore it.
         raise AuthError(code=HTTPStatus.FORBIDDEN, msg="Auth events could not be found")
 
+    @trace
+    @tag_args
     async def _get_remote_auth_chain_for_event(
         self, destination: str, room_id: str, event_id: str
     ) -> None:
@@ -1963,6 +2037,7 @@ class FederationEventHandler:
 
         await self._auth_and_persist_outliers(room_id, remote_auth_events)
 
+    @trace
     async def _run_push_actions_and_persist_event(
         self, event: EventBase, context: EventContext, backfilled: bool = False
     ) -> None:
@@ -2071,8 +2146,17 @@ class FederationEventHandler:
                     self._message_handler.maybe_schedule_expiry(event)
 
             if not backfilled:  # Never notify for backfilled events
-                for event in events:
-                    await self._notify_persisted_event(event, max_stream_token)
+                with start_active_span("notify_persisted_events"):
+                    set_tag(
+                        SynapseTags.RESULT_PREFIX + "event_ids",
+                        str([ev.event_id for ev in events]),
+                    )
+                    set_tag(
+                        SynapseTags.RESULT_PREFIX + "event_ids.length",
+                        str(len(events)),
+                    )
+                    for event in events:
+                        await self._notify_persisted_event(event, max_stream_token)
 
             return max_stream_token.stream
 
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index d1fa2cf8ae..482316a1ff 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -310,6 +310,19 @@ class SynapseTags:
     # The name of the external cache
     CACHE_NAME = "cache.name"
 
+    # Used to tag function arguments
+    #
+    # Tag a named arg. The name of the argument should be appended to this prefix.
+    FUNC_ARG_PREFIX = "ARG."
+    # Tag extra variadic number of positional arguments (`def foo(first, second, *extras)`)
+    FUNC_ARGS = "args"
+    # Tag keyword args
+    FUNC_KWARGS = "kwargs"
+
+    # Some intermediate result that's interesting to the function. The label for
+    # the result should be appended to this prefix.
+    RESULT_PREFIX = "RESULT."
+
 
 class SynapseBaggage:
     FORCE_TRACING = "synapse-force-tracing"
@@ -967,9 +980,9 @@ def tag_args(func: Callable[P, R]) -> Callable[P, R]:
         #   first argument only if it's named `self` or `cls`. This isn't fool-proof
         #   but handles the idiomatic cases.
         for i, arg in enumerate(args[1:], start=1):  # type: ignore[index]
-            set_tag("ARG_" + argspec.args[i], str(arg))
-        set_tag("args", str(args[len(argspec.args) :]))  # type: ignore[index]
-        set_tag("kwargs", str(kwargs))
+            set_tag(SynapseTags.FUNC_ARG_PREFIX + argspec.args[i], str(arg))
+        set_tag(SynapseTags.FUNC_ARGS, str(args[len(argspec.args) :]))  # type: ignore[index]
+        set_tag(SynapseTags.FUNC_KWARGS, str(kwargs))
         yield
 
     return _custom_sync_async_decorator(func, _wrapping_logic)
diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py
index cf98b0ab48..dad3731b9b 100644
--- a/synapse/storage/controllers/persist_events.py
+++ b/synapse/storage/controllers/persist_events.py
@@ -45,8 +45,14 @@ from twisted.internet import defer
 from synapse.api.constants import EventTypes, Membership
 from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
-from synapse.logging import opentracing
 from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
+from synapse.logging.opentracing import (
+    SynapseTags,
+    active_span,
+    set_tag,
+    start_active_span_follows_from,
+    trace,
+)
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.controllers.state import StateStorageController
 from synapse.storage.databases import Databases
@@ -223,7 +229,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
             queue.append(end_item)
 
         # also add our active opentracing span to the item so that we get a link back
-        span = opentracing.active_span()
+        span = active_span()
         if span:
             end_item.parent_opentracing_span_contexts.append(span.context)
 
@@ -234,7 +240,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
         res = await make_deferred_yieldable(end_item.deferred.observe())
 
         # add another opentracing span which links to the persist trace.
-        with opentracing.start_active_span_follows_from(
+        with start_active_span_follows_from(
             f"{task.name}_complete", (end_item.opentracing_span_context,)
         ):
             pass
@@ -266,7 +272,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
                 queue = self._get_drainining_queue(room_id)
                 for item in queue:
                     try:
-                        with opentracing.start_active_span_follows_from(
+                        with start_active_span_follows_from(
                             item.task.name,
                             item.parent_opentracing_span_contexts,
                             inherit_force_tracing=True,
@@ -355,7 +361,7 @@ class EventsPersistenceStorageController:
                 f"Found an unexpected task type in event persistence queue: {task}"
             )
 
-    @opentracing.trace
+    @trace
     async def persist_events(
         self,
         events_and_contexts: Iterable[Tuple[EventBase, EventContext]],
@@ -380,9 +386,21 @@ class EventsPersistenceStorageController:
             PartialStateConflictError: if attempting to persist a partial state event in
                 a room that has been un-partial stated.
         """
+        event_ids: List[str] = []
         partitioned: Dict[str, List[Tuple[EventBase, EventContext]]] = {}
         for event, ctx in events_and_contexts:
             partitioned.setdefault(event.room_id, []).append((event, ctx))
+            event_ids.append(event.event_id)
+
+        set_tag(
+            SynapseTags.FUNC_ARG_PREFIX + "event_ids",
+            str(event_ids),
+        )
+        set_tag(
+            SynapseTags.FUNC_ARG_PREFIX + "event_ids.length",
+            str(len(event_ids)),
+        )
+        set_tag(SynapseTags.FUNC_ARG_PREFIX + "backfilled", str(backfilled))
 
         async def enqueue(
             item: Tuple[str, List[Tuple[EventBase, EventContext]]]
@@ -418,7 +436,7 @@ class EventsPersistenceStorageController:
             self.main_store.get_room_max_token(),
         )
 
-    @opentracing.trace
+    @trace
     async def persist_event(
         self, event: EventBase, context: EventContext, backfilled: bool = False
     ) -> Tuple[EventBase, PersistedEventPosition, RoomStreamToken]:
diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
index 0c78eb735e..1ad002f57b 100644
--- a/synapse/storage/controllers/state.py
+++ b/synapse/storage/controllers/state.py
@@ -29,7 +29,7 @@ from typing import (
 
 from synapse.api.constants import EventTypes
 from synapse.events import EventBase
-from synapse.logging.opentracing import trace
+from synapse.logging.opentracing import tag_args, trace
 from synapse.storage.roommember import ProfileInfo
 from synapse.storage.state import StateFilter
 from synapse.storage.util.partial_state_events_tracker import (
@@ -229,6 +229,7 @@ class StateStorageController:
         return {event: event_to_state[event] for event in event_ids}
 
     @trace
+    @tag_args
     async def get_state_ids_for_events(
         self,
         event_ids: Collection[str],
@@ -333,6 +334,7 @@ class StateStorageController:
         )
 
     @trace
+    @tag_args
     async def get_state_group_for_events(
         self,
         event_ids: Collection[str],
@@ -474,6 +476,7 @@ class StateStorageController:
             prev_stream_id, max_stream_id
         )
 
+    @trace
     async def get_current_state(
         self, room_id: str, state_filter: Optional[StateFilter] = None
     ) -> StateMap[EventBase]:
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 0bc8401f2b..c836078da6 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -712,6 +712,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         # Return all events where not all sets can reach them.
         return {eid for eid, n in event_to_missing_sets.items() if n}
 
+    @trace
+    @tag_args
     async def get_oldest_event_ids_with_depth_in_room(
         self, room_id: str
     ) -> List[Tuple[str, int]]:
@@ -770,6 +772,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             room_id,
         )
 
+    @trace
     async def get_insertion_event_backward_extremities_in_room(
         self, room_id: str
     ) -> List[Tuple[str, int]]:
@@ -1342,6 +1345,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         event_results.reverse()
         return event_results
 
+    @trace
+    @tag_args
     async def get_successor_events(self, event_id: str) -> List[str]:
         """Fetch all events that have the given event as a prev event
 
@@ -1378,6 +1383,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             _delete_old_forward_extrem_cache_txn,
         )
 
+    @trace
     async def insert_insertion_extremity(self, event_id: str, room_id: str) -> None:
         await self.db_pool.simple_upsert(
             table="insertion_event_extremities",
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 5560b38a48..a4010ee28d 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -40,6 +40,7 @@ from synapse.api.errors import Codes, SynapseError
 from synapse.api.room_versions import RoomVersions
 from synapse.events import EventBase, relation_from_event
 from synapse.events.snapshot import EventContext
+from synapse.logging.opentracing import trace
 from synapse.storage._base import db_to_json, make_in_list_sql_clause
 from synapse.storage.database import (
     DatabasePool,
@@ -145,6 +146,7 @@ class PersistEventsStore:
         self._backfill_id_gen: AbstractStreamIdGenerator = self.store._backfill_id_gen
         self._stream_id_gen: AbstractStreamIdGenerator = self.store._stream_id_gen
 
+    @trace
     async def _persist_events_and_state_updates(
         self,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index b07d812ae2..8a7cdb024d 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -54,6 +54,7 @@ from synapse.logging.context import (
     current_context,
     make_deferred_yieldable,
 )
+from synapse.logging.opentracing import start_active_span, tag_args, trace
 from synapse.metrics.background_process_metrics import (
     run_as_background_process,
     wrap_as_background_process,
@@ -430,6 +431,8 @@ class EventsWorkerStore(SQLBaseStore):
 
         return {e.event_id: e for e in events}
 
+    @trace
+    @tag_args
     async def get_events_as_list(
         self,
         event_ids: Collection[str],
@@ -1090,23 +1093,42 @@ class EventsWorkerStore(SQLBaseStore):
         """
         fetched_event_ids: Set[str] = set()
         fetched_events: Dict[str, _EventRow] = {}
-        events_to_fetch = event_ids
 
-        while events_to_fetch:
-            row_map = await self._enqueue_events(events_to_fetch)
+        async def _fetch_event_ids_and_get_outstanding_redactions(
+            event_ids_to_fetch: Collection[str],
+        ) -> Collection[str]:
+            """
+            Fetch all of the given event_ids and return any associated redaction event_ids
+            that we still need to fetch in the next iteration.
+            """
+            row_map = await self._enqueue_events(event_ids_to_fetch)
 
             # we need to recursively fetch any redactions of those events
             redaction_ids: Set[str] = set()
-            for event_id in events_to_fetch:
+            for event_id in event_ids_to_fetch:
                 row = row_map.get(event_id)
                 fetched_event_ids.add(event_id)
                 if row:
                     fetched_events[event_id] = row
                     redaction_ids.update(row.redactions)
 
-            events_to_fetch = redaction_ids.difference(fetched_event_ids)
-            if events_to_fetch:
-                logger.debug("Also fetching redaction events %s", events_to_fetch)
+            event_ids_to_fetch = redaction_ids.difference(fetched_event_ids)
+            return event_ids_to_fetch
+
+        # Grab the initial list of events requested
+        event_ids_to_fetch = await _fetch_event_ids_and_get_outstanding_redactions(
+            event_ids
+        )
+        # Then go and recursively find all of the associated redactions
+        with start_active_span("recursively fetching redactions"):
+            while event_ids_to_fetch:
+                logger.debug("Also fetching redaction events %s", event_ids_to_fetch)
+
+                event_ids_to_fetch = (
+                    await _fetch_event_ids_and_get_outstanding_redactions(
+                        event_ids_to_fetch
+                    )
+                )
 
         # build a map from event_id to EventBase
         event_map: Dict[str, EventBase] = {}
@@ -1424,6 +1446,8 @@ class EventsWorkerStore(SQLBaseStore):
 
         return {r["event_id"] for r in rows}
 
+    @trace
+    @tag_args
     async def have_seen_events(
         self, room_id: str, event_ids: Iterable[str]
     ) -> Set[str]:
diff --git a/synapse/storage/util/partial_state_events_tracker.py b/synapse/storage/util/partial_state_events_tracker.py
index 466e5137f2..b4bf49dace 100644
--- a/synapse/storage/util/partial_state_events_tracker.py
+++ b/synapse/storage/util/partial_state_events_tracker.py
@@ -20,6 +20,7 @@ from twisted.internet import defer
 from twisted.internet.defer import Deferred
 
 from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
+from synapse.logging.opentracing import trace_with_opname
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.databases.main.room import RoomWorkerStore
 from synapse.util import unwrapFirstError
@@ -58,6 +59,7 @@ class PartialStateEventsTracker:
             for o in observers:
                 o.callback(None)
 
+    @trace_with_opname("PartialStateEventsTracker.await_full_state")
     async def await_full_state(self, event_ids: Collection[str]) -> None:
         """Wait for all the given events to have full state.
 
@@ -151,6 +153,7 @@ class PartialCurrentStateTracker:
             for o in observers:
                 o.callback(None)
 
+    @trace_with_opname("PartialCurrentStateTracker.await_full_state")
     async def await_full_state(self, room_id: str) -> None:
         # We add the deferred immediately so that the DB call to check for
         # partial state doesn't race when we unpartial the room.