summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/controllers/persist_events.py30
-rw-r--r--synapse/storage/controllers/state.py5
-rw-r--r--synapse/storage/databases/main/event_federation.py6
-rw-r--r--synapse/storage/databases/main/events.py2
-rw-r--r--synapse/storage/databases/main/events_worker.py38
-rw-r--r--synapse/storage/util/partial_state_events_tracker.py3
6 files changed, 70 insertions, 14 deletions
diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py
index cf98b0ab48..dad3731b9b 100644
--- a/synapse/storage/controllers/persist_events.py
+++ b/synapse/storage/controllers/persist_events.py
@@ -45,8 +45,14 @@ from twisted.internet import defer
 from synapse.api.constants import EventTypes, Membership
 from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
-from synapse.logging import opentracing
 from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
+from synapse.logging.opentracing import (
+    SynapseTags,
+    active_span,
+    set_tag,
+    start_active_span_follows_from,
+    trace,
+)
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.controllers.state import StateStorageController
 from synapse.storage.databases import Databases
@@ -223,7 +229,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
             queue.append(end_item)
 
         # also add our active opentracing span to the item so that we get a link back
-        span = opentracing.active_span()
+        span = active_span()
         if span:
             end_item.parent_opentracing_span_contexts.append(span.context)
 
@@ -234,7 +240,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
         res = await make_deferred_yieldable(end_item.deferred.observe())
 
         # add another opentracing span which links to the persist trace.
-        with opentracing.start_active_span_follows_from(
+        with start_active_span_follows_from(
             f"{task.name}_complete", (end_item.opentracing_span_context,)
         ):
             pass
@@ -266,7 +272,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
                 queue = self._get_drainining_queue(room_id)
                 for item in queue:
                     try:
-                        with opentracing.start_active_span_follows_from(
+                        with start_active_span_follows_from(
                             item.task.name,
                             item.parent_opentracing_span_contexts,
                             inherit_force_tracing=True,
@@ -355,7 +361,7 @@ class EventsPersistenceStorageController:
                 f"Found an unexpected task type in event persistence queue: {task}"
             )
 
-    @opentracing.trace
+    @trace
     async def persist_events(
         self,
         events_and_contexts: Iterable[Tuple[EventBase, EventContext]],
@@ -380,9 +386,21 @@ class EventsPersistenceStorageController:
             PartialStateConflictError: if attempting to persist a partial state event in
                 a room that has been un-partial stated.
         """
+        event_ids: List[str] = []
         partitioned: Dict[str, List[Tuple[EventBase, EventContext]]] = {}
         for event, ctx in events_and_contexts:
             partitioned.setdefault(event.room_id, []).append((event, ctx))
+            event_ids.append(event.event_id)
+
+        set_tag(
+            SynapseTags.FUNC_ARG_PREFIX + "event_ids",
+            str(event_ids),
+        )
+        set_tag(
+            SynapseTags.FUNC_ARG_PREFIX + "event_ids.length",
+            str(len(event_ids)),
+        )
+        set_tag(SynapseTags.FUNC_ARG_PREFIX + "backfilled", str(backfilled))
 
         async def enqueue(
             item: Tuple[str, List[Tuple[EventBase, EventContext]]]
@@ -418,7 +436,7 @@ class EventsPersistenceStorageController:
             self.main_store.get_room_max_token(),
         )
 
-    @opentracing.trace
+    @trace
     async def persist_event(
         self, event: EventBase, context: EventContext, backfilled: bool = False
     ) -> Tuple[EventBase, PersistedEventPosition, RoomStreamToken]:
diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
index 0c78eb735e..1ad002f57b 100644
--- a/synapse/storage/controllers/state.py
+++ b/synapse/storage/controllers/state.py
@@ -29,7 +29,7 @@ from typing import (
 
 from synapse.api.constants import EventTypes
 from synapse.events import EventBase
-from synapse.logging.opentracing import trace
+from synapse.logging.opentracing import tag_args, trace
 from synapse.storage.roommember import ProfileInfo
 from synapse.storage.state import StateFilter
 from synapse.storage.util.partial_state_events_tracker import (
@@ -229,6 +229,7 @@ class StateStorageController:
         return {event: event_to_state[event] for event in event_ids}
 
     @trace
+    @tag_args
     async def get_state_ids_for_events(
         self,
         event_ids: Collection[str],
@@ -333,6 +334,7 @@ class StateStorageController:
         )
 
     @trace
+    @tag_args
     async def get_state_group_for_events(
         self,
         event_ids: Collection[str],
@@ -474,6 +476,7 @@ class StateStorageController:
             prev_stream_id, max_stream_id
         )
 
+    @trace
     async def get_current_state(
         self, room_id: str, state_filter: Optional[StateFilter] = None
     ) -> StateMap[EventBase]:
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 0bc8401f2b..c836078da6 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -712,6 +712,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         # Return all events where not all sets can reach them.
         return {eid for eid, n in event_to_missing_sets.items() if n}
 
+    @trace
+    @tag_args
     async def get_oldest_event_ids_with_depth_in_room(
         self, room_id: str
     ) -> List[Tuple[str, int]]:
@@ -770,6 +772,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             room_id,
         )
 
+    @trace
     async def get_insertion_event_backward_extremities_in_room(
         self, room_id: str
     ) -> List[Tuple[str, int]]:
@@ -1342,6 +1345,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         event_results.reverse()
         return event_results
 
+    @trace
+    @tag_args
     async def get_successor_events(self, event_id: str) -> List[str]:
         """Fetch all events that have the given event as a prev event
 
@@ -1378,6 +1383,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             _delete_old_forward_extrem_cache_txn,
         )
 
+    @trace
     async def insert_insertion_extremity(self, event_id: str, room_id: str) -> None:
         await self.db_pool.simple_upsert(
             table="insertion_event_extremities",
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 5560b38a48..a4010ee28d 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -40,6 +40,7 @@ from synapse.api.errors import Codes, SynapseError
 from synapse.api.room_versions import RoomVersions
 from synapse.events import EventBase, relation_from_event
 from synapse.events.snapshot import EventContext
+from synapse.logging.opentracing import trace
 from synapse.storage._base import db_to_json, make_in_list_sql_clause
 from synapse.storage.database import (
     DatabasePool,
@@ -145,6 +146,7 @@ class PersistEventsStore:
         self._backfill_id_gen: AbstractStreamIdGenerator = self.store._backfill_id_gen
         self._stream_id_gen: AbstractStreamIdGenerator = self.store._stream_id_gen
 
+    @trace
     async def _persist_events_and_state_updates(
         self,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index b07d812ae2..8a7cdb024d 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -54,6 +54,7 @@ from synapse.logging.context import (
     current_context,
     make_deferred_yieldable,
 )
+from synapse.logging.opentracing import start_active_span, tag_args, trace
 from synapse.metrics.background_process_metrics import (
     run_as_background_process,
     wrap_as_background_process,
@@ -430,6 +431,8 @@ class EventsWorkerStore(SQLBaseStore):
 
         return {e.event_id: e for e in events}
 
+    @trace
+    @tag_args
     async def get_events_as_list(
         self,
         event_ids: Collection[str],
@@ -1090,23 +1093,42 @@ class EventsWorkerStore(SQLBaseStore):
         """
         fetched_event_ids: Set[str] = set()
         fetched_events: Dict[str, _EventRow] = {}
-        events_to_fetch = event_ids
 
-        while events_to_fetch:
-            row_map = await self._enqueue_events(events_to_fetch)
+        async def _fetch_event_ids_and_get_outstanding_redactions(
+            event_ids_to_fetch: Collection[str],
+        ) -> Collection[str]:
+            """
+            Fetch all of the given event_ids and return any associated redaction event_ids
+            that we still need to fetch in the next iteration.
+            """
+            row_map = await self._enqueue_events(event_ids_to_fetch)
 
             # we need to recursively fetch any redactions of those events
             redaction_ids: Set[str] = set()
-            for event_id in events_to_fetch:
+            for event_id in event_ids_to_fetch:
                 row = row_map.get(event_id)
                 fetched_event_ids.add(event_id)
                 if row:
                     fetched_events[event_id] = row
                     redaction_ids.update(row.redactions)
 
-            events_to_fetch = redaction_ids.difference(fetched_event_ids)
-            if events_to_fetch:
-                logger.debug("Also fetching redaction events %s", events_to_fetch)
+            event_ids_to_fetch = redaction_ids.difference(fetched_event_ids)
+            return event_ids_to_fetch
+
+        # Grab the initial list of events requested
+        event_ids_to_fetch = await _fetch_event_ids_and_get_outstanding_redactions(
+            event_ids
+        )
+        # Then go and recursively find all of the associated redactions
+        with start_active_span("recursively fetching redactions"):
+            while event_ids_to_fetch:
+                logger.debug("Also fetching redaction events %s", event_ids_to_fetch)
+
+                event_ids_to_fetch = (
+                    await _fetch_event_ids_and_get_outstanding_redactions(
+                        event_ids_to_fetch
+                    )
+                )
 
         # build a map from event_id to EventBase
         event_map: Dict[str, EventBase] = {}
@@ -1424,6 +1446,8 @@ class EventsWorkerStore(SQLBaseStore):
 
         return {r["event_id"] for r in rows}
 
+    @trace
+    @tag_args
     async def have_seen_events(
         self, room_id: str, event_ids: Iterable[str]
     ) -> Set[str]:
diff --git a/synapse/storage/util/partial_state_events_tracker.py b/synapse/storage/util/partial_state_events_tracker.py
index 466e5137f2..b4bf49dace 100644
--- a/synapse/storage/util/partial_state_events_tracker.py
+++ b/synapse/storage/util/partial_state_events_tracker.py
@@ -20,6 +20,7 @@ from twisted.internet import defer
 from twisted.internet.defer import Deferred
 
 from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
+from synapse.logging.opentracing import trace_with_opname
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.databases.main.room import RoomWorkerStore
 from synapse.util import unwrapFirstError
@@ -58,6 +59,7 @@ class PartialStateEventsTracker:
             for o in observers:
                 o.callback(None)
 
+    @trace_with_opname("PartialStateEventsTracker.await_full_state")
     async def await_full_state(self, event_ids: Collection[str]) -> None:
         """Wait for all the given events to have full state.
 
@@ -151,6 +153,7 @@ class PartialCurrentStateTracker:
             for o in observers:
                 o.callback(None)
 
+    @trace_with_opname("PartialCurrentStateTracker.await_full_state")
     async def await_full_state(self, room_id: str) -> None:
         # We add the deferred immediately so that the DB call to check for
         # partial state doesn't race when we unpartial the room.