diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 4aa4e7ab15..656e57b5c6 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -19,7 +19,6 @@ from twisted.internet import defer, reactor
from synapse.events import FrozenEvent
from synapse.events.utils import prune_event
-from synapse.util import unwrap_deferred
from synapse.util.logcontext import preserve_context_over_deferred
from synapse.util.logutils import log_function
@@ -401,11 +400,7 @@ class EventsStore(SQLBaseStore):
@defer.inlineCallbacks
def _get_events(self, event_ids, check_redacted=True,
- get_prev_content=False, allow_rejected=False, txn=None):
- """Gets a collection of events. If `txn` is not None the we use the
- current transaction to fetch events and we return a deferred that is
- guarenteed to have resolved.
- """
+ get_prev_content=False, allow_rejected=False):
if not event_ids:
defer.returnValue([])
@@ -424,21 +419,12 @@ class EventsStore(SQLBaseStore):
if e_id in event_map and event_map[e_id]
])
- if not txn:
- missing_events = yield self._enqueue_events(
- missing_events_ids,
- check_redacted=check_redacted,
- get_prev_content=get_prev_content,
- allow_rejected=allow_rejected,
- )
- else:
- missing_events = self._fetch_events_txn(
- txn,
- missing_events_ids,
- check_redacted=check_redacted,
- get_prev_content=get_prev_content,
- allow_rejected=allow_rejected,
- )
+ missing_events = yield self._enqueue_events(
+ missing_events_ids,
+ check_redacted=check_redacted,
+ get_prev_content=get_prev_content,
+ allow_rejected=allow_rejected,
+ )
event_map.update(missing_events)
@@ -449,13 +435,38 @@ class EventsStore(SQLBaseStore):
def _get_events_txn(self, txn, event_ids, check_redacted=True,
get_prev_content=False, allow_rejected=False):
- return unwrap_deferred(self._get_events(
+ if not event_ids:
+ return []
+
+ event_map = self._get_events_from_cache(
event_ids,
check_redacted=check_redacted,
get_prev_content=get_prev_content,
allow_rejected=allow_rejected,
- txn=txn,
- ))
+ )
+
+ missing_events_ids = [e for e in event_ids if e not in event_map]
+
+ if not missing_events_ids:
+ return [
+ event_map[e_id] for e_id in event_ids
+ if e_id in event_map and event_map[e_id]
+ ]
+
+ missing_events = self._fetch_events_txn(
+ txn,
+ missing_events_ids,
+ check_redacted=check_redacted,
+ get_prev_content=get_prev_content,
+ allow_rejected=allow_rejected,
+ )
+
+ event_map.update(missing_events)
+
+ return [
+ event_map[e_id] for e_id in event_ids
+ if e_id in event_map and event_map[e_id]
+ ]
def _invalidate_get_event_cache(self, event_id):
for check_redacted in (False, True):
|