diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py
index 79680ee856..5b606bec7c 100644
--- a/synapse/storage/events_worker.py
+++ b/synapse/storage/events_worker.py
@@ -17,6 +17,7 @@ from __future__ import division
import itertools
import logging
+import operator
from collections import namedtuple
from canonicaljson import json
@@ -421,28 +422,28 @@ class EventsWorkerStore(SQLBaseStore):
The fetch requests. Each entry consists of a list of event
ids to be fetched, and a deferred to be completed once the
events have been fetched.
+
+ The deferreds are callbacked with a dictionary mapping from event id
+ to event row. Note that it may well contain additional events that
+ were not part of this request.
"""
with Measure(self._clock, "_fetch_event_list"):
try:
- event_id_lists = list(zip(*event_list))[0]
- event_ids = [item for sublist in event_id_lists for item in sublist]
+ events_to_fetch = set(
+ event_id for events, _ in event_list for event_id in events
+ )
row_dict = self._new_transaction(
- conn, "do_fetch", [], [], self._fetch_event_rows, event_ids
+ conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch
)
# We only want to resolve deferreds from the main thread
- def fire(lst, res):
- for ids, d in lst:
- if not d.called:
- try:
- with PreserveLoggingContext():
- d.callback([res[i] for i in ids if i in res])
- except Exception:
- logger.exception("Failed to callback")
+ def fire():
+ for _, d in event_list:
+ d.callback(row_dict)
with PreserveLoggingContext():
- self.hs.get_reactor().callFromThread(fire, event_list, row_dict)
+ self.hs.get_reactor().callFromThread(fire)
except Exception as e:
logger.exception("do_fetch")
@@ -461,6 +462,12 @@ class EventsWorkerStore(SQLBaseStore):
"""Fetches events from the database using the _event_fetch_list. This
allows batch and bulk fetching of events - it allows us to fetch events
without having to create a new transaction for each request for events.
+
+ Args:
+ events (Iterable[str]): events to be fetched.
+
+ Returns:
+ Deferred[Dict[str, _EventCacheEntry]]: map from event id to result.
"""
if not events:
return {}
@@ -484,11 +491,16 @@ class EventsWorkerStore(SQLBaseStore):
logger.debug("Loading %d events", len(events))
with PreserveLoggingContext():
- rows = yield events_d
- logger.debug("Loaded %d events (%d rows)", len(events), len(rows))
+ row_map = yield events_d
+ logger.debug("Loaded %d events (%d rows)", len(events), len(row_map))
+
+ rows = (row_map.get(event_id) for event_id in events)
+
+ # filter out absent rows
+ rows = filter(operator.truth, rows)
if not allow_rejected:
- rows[:] = [r for r in rows if r["rejected_reason"] is None]
+ rows = (r for r in rows if r["rejected_reason"] is None)
res = yield make_deferred_yieldable(
defer.gatherResults(
|