summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/storage/appservice.py4
-rw-r--r--synapse/storage/event_federation.py6
-rw-r--r--synapse/storage/events_worker.py50
-rw-r--r--synapse/storage/search.py4
-rw-r--r--synapse/storage/stream.py18
-rw-r--r--tests/storage/test_appservice.py2
6 files changed, 54 insertions, 30 deletions
diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py
index 6092f600ba..eb329ebd8b 100644
--- a/synapse/storage/appservice.py
+++ b/synapse/storage/appservice.py
@@ -302,7 +302,7 @@ class ApplicationServiceTransactionWorkerStore(
 
         event_ids = json.loads(entry["event_ids"])
 
-        events = yield self._get_events(event_ids)
+        events = yield self.get_events_as_list(event_ids)
 
         defer.returnValue(
             AppServiceTransaction(service=service, id=entry["txn_id"], events=events)
@@ -358,7 +358,7 @@ class ApplicationServiceTransactionWorkerStore(
             "get_new_events_for_appservice", get_new_events_for_appservice_txn
         )
 
-        events = yield self._get_events(event_ids)
+        events = yield self.get_events_as_list(event_ids)
 
         defer.returnValue((upper_bound, events))
 
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 956f876572..09e39c2c28 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -45,7 +45,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         """
         return self.get_auth_chain_ids(
             event_ids, include_given=include_given
-        ).addCallback(self._get_events)
+        ).addCallback(self.get_events_as_list)
 
     def get_auth_chain_ids(self, event_ids, include_given=False):
         """Get auth events for given event_ids. The events *must* be state events.
@@ -316,7 +316,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
                 event_list,
                 limit,
             )
-            .addCallback(self._get_events)
+            .addCallback(self.get_events_as_list)
             .addCallback(lambda l: sorted(l, key=lambda e: -e.depth))
         )
 
@@ -382,7 +382,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             latest_events,
             limit,
         )
-        events = yield self._get_events(ids)
+        events = yield self.get_events_as_list(ids)
         defer.returnValue(events)
 
     def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit):
diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py
index 663991a9b6..5c40056d11 100644
--- a/synapse/storage/events_worker.py
+++ b/synapse/storage/events_worker.py
@@ -103,7 +103,7 @@ class EventsWorkerStore(SQLBaseStore):
         Returns:
             Deferred : A FrozenEvent.
         """
-        events = yield self._get_events(
+        events = yield self.get_events_as_list(
             [event_id],
             check_redacted=check_redacted,
             get_prev_content=get_prev_content,
@@ -142,7 +142,7 @@ class EventsWorkerStore(SQLBaseStore):
         Returns:
             Deferred : Dict from event_id to event.
         """
-        events = yield self._get_events(
+        events = yield self.get_events_as_list(
             event_ids,
             check_redacted=check_redacted,
             get_prev_content=get_prev_content,
@@ -152,13 +152,32 @@ class EventsWorkerStore(SQLBaseStore):
         defer.returnValue({e.event_id: e for e in events})
 
     @defer.inlineCallbacks
-    def _get_events(
+    def get_events_as_list(
         self,
         event_ids,
         check_redacted=True,
         get_prev_content=False,
         allow_rejected=False,
     ):
+        """Get events from the database and return in a list in the same order
+        as given by `event_ids` arg.
+
+        Args:
+            event_ids (list): The event_ids of the events to fetch
+            check_redacted (bool): If True, check if event has been redacted
+                and redact it.
+            get_prev_content (bool): If True and event is a state event,
+                include the previous states content in the unsigned field.
+            allow_rejected (bool): If True return rejected events.
+
+        Returns:
+            Deferred[list]: List of events fetched from the database. The
+            events are in the same order as `event_ids` arg.
+
+            Note that the returned list may be smaller than the list of event
+            IDs if not all events could be fetched.
+        """
+
         if not event_ids:
             defer.returnValue([])
 
@@ -202,21 +221,22 @@ class EventsWorkerStore(SQLBaseStore):
                     #
                     # The problem is that we end up at this point when an event
                     # which has been redacted is pulled out of the database by
-                    # _enqueue_events, because _enqueue_events needs to check the
-                    # redaction before it can cache the redacted event. So obviously,
-                    # calling get_event to get the redacted event out of the database
-                    # gives us an infinite loop.
+                    # _enqueue_events, because _enqueue_events needs to check
+                    # the redaction before it can cache the redacted event. So
+                    # obviously, calling get_event to get the redacted event out
+                    # of the database gives us an infinite loop.
                     #
-                    # For now (quick hack to fix during 0.99 release cycle), we just
-                    # go and fetch the relevant row from the db, but it would be nice
-                    # to think about how we can cache this rather than hit the db
-                    # every time we access a redaction event.
+                    # For now (quick hack to fix during 0.99 release cycle), we
+                    # just go and fetch the relevant row from the db, but it
+                    # would be nice to think about how we can cache this rather
+                    # than hit the db every time we access a redaction event.
                     #
                     # One thought on how to do this:
-                    #  1. split _get_events up so that it is divided into (a) get the
-                    #     rawish event from the db/cache, (b) do the redaction/rejection
-                    #     filtering
-                    #  2. have _get_event_from_row just call the first half of that
+                    #  1. split get_events_as_list up so that it is divided into
+                    #     (a) get the rawish event from the db/cache, (b) do the
+                    #     redaction/rejection filtering
+                    #  2. have _get_event_from_row just call the first half of
+                    #     that
 
                     orig_sender = yield self._simple_select_one_onecol(
                         table="events",
diff --git a/synapse/storage/search.py b/synapse/storage/search.py
index 226f8f1b7e..ff49eaae02 100644
--- a/synapse/storage/search.py
+++ b/synapse/storage/search.py
@@ -460,7 +460,7 @@ class SearchStore(BackgroundUpdateStore):
 
         results = list(filter(lambda row: row["room_id"] in room_ids, results))
 
-        events = yield self._get_events([r["event_id"] for r in results])
+        events = yield self.get_events_as_list([r["event_id"] for r in results])
 
         event_map = {ev.event_id: ev for ev in events}
 
@@ -605,7 +605,7 @@ class SearchStore(BackgroundUpdateStore):
 
         results = list(filter(lambda row: row["room_id"] in room_ids, results))
 
-        events = yield self._get_events([r["event_id"] for r in results])
+        events = yield self.get_events_as_list([r["event_id"] for r in results])
 
         event_map = {ev.event_id: ev for ev in events}
 
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index 9cd1e0f9fe..d105b6b17d 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -319,7 +319,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         rows = yield self.runInteraction("get_room_events_stream_for_room", f)
 
-        ret = yield self._get_events([r.event_id for r in rows], get_prev_content=True)
+        ret = yield self.get_events_as_list([
+            r.event_id for r in rows], get_prev_content=True,
+        )
 
         self._set_before_and_after(ret, rows, topo_order=from_id is None)
 
@@ -367,7 +369,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         rows = yield self.runInteraction("get_membership_changes_for_user", f)
 
-        ret = yield self._get_events([r.event_id for r in rows], get_prev_content=True)
+        ret = yield self.get_events_as_list(
+            [r.event_id for r in rows], get_prev_content=True,
+        )
 
         self._set_before_and_after(ret, rows, topo_order=False)
 
@@ -394,7 +398,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         )
 
         logger.debug("stream before")
-        events = yield self._get_events(
+        events = yield self.get_events_as_list(
             [r.event_id for r in rows], get_prev_content=True
         )
         logger.debug("stream after")
@@ -580,11 +584,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             event_filter,
         )
 
-        events_before = yield self._get_events(
+        events_before = yield self.get_events_as_list(
             [e for e in results["before"]["event_ids"]], get_prev_content=True
         )
 
-        events_after = yield self._get_events(
+        events_after = yield self.get_events_as_list(
             [e for e in results["after"]["event_ids"]], get_prev_content=True
         )
 
@@ -697,7 +701,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             "get_all_new_events_stream", get_all_new_events_stream_txn
         )
 
-        events = yield self._get_events(event_ids)
+        events = yield self.get_events_as_list(event_ids)
 
         defer.returnValue((upper_bound, events))
 
@@ -849,7 +853,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             event_filter,
         )
 
-        events = yield self._get_events(
+        events = yield self.get_events_as_list(
             [r.event_id for r in rows], get_prev_content=True
         )
 
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 3f0083831b..25a6c89ef5 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -340,7 +340,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         other_events = [Mock(event_id="e5"), Mock(event_id="e6")]
 
         # we aren't testing store._base stuff here, so mock this out
-        self.store._get_events = Mock(return_value=events)
+        self.store.get_events_as_list = Mock(return_value=events)
 
         yield self._insert_txn(self.as_list[1]["id"], 9, other_events)
         yield self._insert_txn(service.id, 10, events)