summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/events.py59
1 files changed, 35 insertions, 24 deletions
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 4aa4e7ab15..656e57b5c6 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -19,7 +19,6 @@ from twisted.internet import defer, reactor
 
 from synapse.events import FrozenEvent
 from synapse.events.utils import prune_event
-from synapse.util import unwrap_deferred
 
 from synapse.util.logcontext import preserve_context_over_deferred
 from synapse.util.logutils import log_function
@@ -401,11 +400,7 @@ class EventsStore(SQLBaseStore):
 
     @defer.inlineCallbacks
     def _get_events(self, event_ids, check_redacted=True,
-                    get_prev_content=False, allow_rejected=False, txn=None):
-        """Gets a collection of events. If `txn` is not None the we use the
-        current transaction to fetch events and we return a deferred that is
-        guarenteed to have resolved.
-        """
+                    get_prev_content=False, allow_rejected=False):
         if not event_ids:
             defer.returnValue([])
 
@@ -424,21 +419,12 @@ class EventsStore(SQLBaseStore):
                 if e_id in event_map and event_map[e_id]
             ])
 
-        if not txn:
-            missing_events = yield self._enqueue_events(
-                missing_events_ids,
-                check_redacted=check_redacted,
-                get_prev_content=get_prev_content,
-                allow_rejected=allow_rejected,
-            )
-        else:
-            missing_events = self._fetch_events_txn(
-                txn,
-                missing_events_ids,
-                check_redacted=check_redacted,
-                get_prev_content=get_prev_content,
-                allow_rejected=allow_rejected,
-            )
+        missing_events = yield self._enqueue_events(
+            missing_events_ids,
+            check_redacted=check_redacted,
+            get_prev_content=get_prev_content,
+            allow_rejected=allow_rejected,
+        )
 
         event_map.update(missing_events)
 
@@ -449,13 +435,38 @@ class EventsStore(SQLBaseStore):
 
     def _get_events_txn(self, txn, event_ids, check_redacted=True,
                         get_prev_content=False, allow_rejected=False):
-        return unwrap_deferred(self._get_events(
+        if not event_ids:
+            return []
+
+        event_map = self._get_events_from_cache(
             event_ids,
             check_redacted=check_redacted,
             get_prev_content=get_prev_content,
             allow_rejected=allow_rejected,
-            txn=txn,
-        ))
+        )
+
+        missing_events_ids = [e for e in event_ids if e not in event_map]
+
+        if not missing_events_ids:
+            return [
+                event_map[e_id] for e_id in event_ids
+                if e_id in event_map and event_map[e_id]
+            ]
+
+        missing_events = self._fetch_events_txn(
+            txn,
+            missing_events_ids,
+            check_redacted=check_redacted,
+            get_prev_content=get_prev_content,
+            allow_rejected=allow_rejected,
+        )
+
+        event_map.update(missing_events)
+
+        return [
+            event_map[e_id] for e_id in event_ids
+            if e_id in event_map and event_map[e_id]
+        ]
 
     def _invalidate_get_event_cache(self, event_id):
         for check_redacted in (False, True):