diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py
index 79680ee856..c6fa7f82fd 100644
--- a/synapse/storage/events_worker.py
+++ b/synapse/storage/events_worker.py
@@ -29,12 +29,7 @@ from synapse.api.room_versions import EventFormatVersions
from synapse.events import FrozenEvent, event_type_from_format_version # noqa: F401
from synapse.events.snapshot import EventContext # noqa: F401
from synapse.events.utils import prune_event
-from synapse.logging.context import (
- LoggingContext,
- PreserveLoggingContext,
- make_deferred_yieldable,
- run_in_background,
-)
+from synapse.logging.context import LoggingContext, PreserveLoggingContext
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import get_domain_from_id
from synapse.util import batch_iter
@@ -342,13 +337,12 @@ class EventsWorkerStore(SQLBaseStore):
log_ctx = LoggingContext.current_context()
log_ctx.record_event_fetch(len(missing_events_ids))
- # Note that _enqueue_events is also responsible for turning db rows
+ # Note that _get_events_from_db is also responsible for turning db rows
# into FrozenEvents (via _get_event_from_row), which involves seeing if
# the events have been redacted, and if so pulling the redaction event out
# of the database to check it.
#
- # _enqueue_events is a bit of a rubbish name but naming is hard.
- missing_events = yield self._enqueue_events(
+ missing_events = yield self._get_events_from_db(
missing_events_ids, allow_rejected=allow_rejected
)
@@ -421,28 +415,28 @@ class EventsWorkerStore(SQLBaseStore):
The fetch requests. Each entry consists of a list of event
ids to be fetched, and a deferred to be completed once the
events have been fetched.
+
+ The deferreds are callbacked with a dictionary mapping from event id
+ to event row. Note that it may well contain additional events that
+ were not part of this request.
"""
with Measure(self._clock, "_fetch_event_list"):
try:
- event_id_lists = list(zip(*event_list))[0]
- event_ids = [item for sublist in event_id_lists for item in sublist]
+ events_to_fetch = set(
+ event_id for events, _ in event_list for event_id in events
+ )
row_dict = self._new_transaction(
- conn, "do_fetch", [], [], self._fetch_event_rows, event_ids
+ conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch
)
# We only want to resolve deferreds from the main thread
- def fire(lst, res):
- for ids, d in lst:
- if not d.called:
- try:
- with PreserveLoggingContext():
- d.callback([res[i] for i in ids if i in res])
- except Exception:
- logger.exception("Failed to callback")
+ def fire():
+ for _, d in event_list:
+ d.callback(row_dict)
with PreserveLoggingContext():
- self.hs.get_reactor().callFromThread(fire, event_list, row_dict)
+ self.hs.get_reactor().callFromThread(fire)
except Exception as e:
logger.exception("do_fetch")
@@ -457,13 +451,98 @@ class EventsWorkerStore(SQLBaseStore):
self.hs.get_reactor().callFromThread(fire, event_list, e)
@defer.inlineCallbacks
- def _enqueue_events(self, events, allow_rejected=False):
+ def _get_events_from_db(self, event_ids, allow_rejected=False):
+ """Fetch a bunch of events from the database.
+
+ Returned events will be added to the cache for future lookups.
+
+ Args:
+ event_ids (Iterable[str]): The event_ids of the events to fetch
+ allow_rejected (bool): Whether to include rejected events
+
+ Returns:
+ Deferred[Dict[str, _EventCacheEntry]]:
+ map from event id to result. May return extra events which
+ weren't asked for.
+ """
+ fetched_events = {}
+ events_to_fetch = event_ids
+
+ while events_to_fetch:
+ row_map = yield self._enqueue_events(events_to_fetch)
+
+ # we need to recursively fetch any redactions of those events
+ redaction_ids = set()
+ for event_id in events_to_fetch:
+ row = row_map.get(event_id)
+ fetched_events[event_id] = row
+ if row:
+ redaction_ids.update(row["redactions"])
+
+ events_to_fetch = redaction_ids.difference(fetched_events.keys())
+ if events_to_fetch:
+ logger.debug("Also fetching redaction events %s", events_to_fetch)
+
+ # build a map from event_id to EventBase
+ event_map = {}
+ for event_id, row in fetched_events.items():
+ if not row:
+ continue
+ assert row["event_id"] == event_id
+
+ rejected_reason = row["rejected_reason"]
+
+ if not allow_rejected and rejected_reason:
+ continue
+
+ d = json.loads(row["json"])
+ internal_metadata = json.loads(row["internal_metadata"])
+
+ format_version = row["format_version"]
+ if format_version is None:
+ # This means that we stored the event before we had the concept
+ # of a event format version, so it must be a V1 event.
+ format_version = EventFormatVersions.V1
+
+ original_ev = event_type_from_format_version(format_version)(
+ event_dict=d,
+ internal_metadata_dict=internal_metadata,
+ rejected_reason=rejected_reason,
+ )
+
+ event_map[event_id] = original_ev
+
+ # finally, we can decide whether each one nededs redacting, and build
+ # the cache entries.
+ result_map = {}
+ for event_id, original_ev in event_map.items():
+ redactions = fetched_events[event_id]["redactions"]
+ redacted_event = self._maybe_redact_event_row(
+ original_ev, redactions, event_map
+ )
+
+ cache_entry = _EventCacheEntry(
+ event=original_ev, redacted_event=redacted_event
+ )
+
+ self._get_event_cache.prefill((event_id,), cache_entry)
+ result_map[event_id] = cache_entry
+
+ return result_map
+
+ @defer.inlineCallbacks
+ def _enqueue_events(self, events):
"""Fetches events from the database using the _event_fetch_list. This
allows batch and bulk fetching of events - it allows us to fetch events
without having to create a new transaction for each request for events.
+
+ Args:
+ events (Iterable[str]): events to be fetched.
+
+ Returns:
+ Deferred[Dict[str, Dict]]: map from event id to row data from the database.
+ May contain events that weren't requested.
"""
- if not events:
- return {}
events_d = defer.Deferred()
with self._event_fetch_lock:
@@ -482,32 +561,12 @@ class EventsWorkerStore(SQLBaseStore):
"fetch_events", self.runWithConnection, self._do_fetch
)
- logger.debug("Loading %d events", len(events))
+ logger.debug("Loading %d events: %s", len(events), events)
with PreserveLoggingContext():
- rows = yield events_d
- logger.debug("Loaded %d events (%d rows)", len(events), len(rows))
-
- if not allow_rejected:
- rows[:] = [r for r in rows if r["rejected_reason"] is None]
-
- res = yield make_deferred_yieldable(
- defer.gatherResults(
- [
- run_in_background(
- self._get_event_from_row,
- row["internal_metadata"],
- row["json"],
- row["redactions"],
- rejected_reason=row["rejected_reason"],
- format_version=row["format_version"],
- )
- for row in rows
- ],
- consumeErrors=True,
- )
- )
+ row_map = yield events_d
+ logger.debug("Loaded %d events (%d rows)", len(events), len(row_map))
- return {e.event.event_id: e for e in res if e}
+ return row_map
def _fetch_event_rows(self, txn, event_ids):
"""Fetch event rows from the database
@@ -580,50 +639,7 @@ class EventsWorkerStore(SQLBaseStore):
return event_dict
- @defer.inlineCallbacks
- def _get_event_from_row(
- self, internal_metadata, js, redactions, format_version, rejected_reason=None
- ):
- """Parse an event row which has been read from the database
-
- Args:
- internal_metadata (str): json-encoded internal_metadata column
- js (str): json-encoded event body from event_json
- redactions (list[str]): a list of the events which claim to have redacted
- this event, from the redactions table
- format_version: (str): the 'format_version' column
- rejected_reason (str|None): the reason this event was rejected, if any
-
- Returns:
- _EventCacheEntry
- """
- with Measure(self._clock, "_get_event_from_row"):
- d = json.loads(js)
- internal_metadata = json.loads(internal_metadata)
-
- if format_version is None:
- # This means that we stored the event before we had the concept
- # of a event format version, so it must be a V1 event.
- format_version = EventFormatVersions.V1
-
- original_ev = event_type_from_format_version(format_version)(
- event_dict=d,
- internal_metadata_dict=internal_metadata,
- rejected_reason=rejected_reason,
- )
-
- redacted_event = yield self._maybe_redact_event_row(original_ev, redactions)
-
- cache_entry = _EventCacheEntry(
- event=original_ev, redacted_event=redacted_event
- )
-
- self._get_event_cache.prefill((original_ev.event_id,), cache_entry)
-
- return cache_entry
-
- @defer.inlineCallbacks
- def _maybe_redact_event_row(self, original_ev, redactions):
+ def _maybe_redact_event_row(self, original_ev, redactions, event_map):
"""Given an event object and a list of possible redacting event ids,
determine whether to honour any of those redactions and if so return a redacted
event.
@@ -631,6 +647,8 @@ class EventsWorkerStore(SQLBaseStore):
Args:
original_ev (EventBase):
redactions (iterable[str]): list of event ids of potential redaction events
+ event_map (dict[str, EventBase]): other events which have been fetched, in
+ which we can look up the redaaction events. Map from event id to event.
Returns:
Deferred[EventBase|None]: if the event should be redacted, a pruned
@@ -640,15 +658,9 @@ class EventsWorkerStore(SQLBaseStore):
# we choose to ignore redactions of m.room.create events.
return None
- if original_ev.type == "m.room.redaction":
- # ... and redaction events
- return None
-
- redaction_map = yield self._get_events_from_cache_or_db(redactions)
-
for redaction_id in redactions:
- redaction_entry = redaction_map.get(redaction_id)
- if not redaction_entry:
+ redaction_event = event_map.get(redaction_id)
+ if not redaction_event or redaction_event.rejected_reason:
# we don't have the redaction event, or the redaction event was not
# authorized.
logger.debug(
@@ -658,7 +670,6 @@ class EventsWorkerStore(SQLBaseStore):
)
continue
- redaction_event = redaction_entry.event
if redaction_event.room_id != original_ev.room_id:
logger.debug(
"%s was redacted by %s but redaction was in a different room!",
|