summary refs log tree commit diff
path: root/synapse/storage/databases/main/stream.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/stream.py')
-rw-r--r--synapse/storage/databases/main/stream.py32
1 files changed, 20 insertions, 12 deletions
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 3a1df7776c..2590b52f73 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -1022,8 +1022,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         }
 
     async def get_all_new_events_stream(
-        self, from_id: int, current_id: int, limit: int
-    ) -> Tuple[int, List[EventBase]]:
+        self, from_id: int, current_id: int, limit: int, get_prev_content: bool = False
+    ) -> Tuple[int, List[EventBase], Dict[str, Optional[int]]]:
         """Get all new events
 
         Returns all events with from_id < stream_ordering <= current_id.
@@ -1032,19 +1032,21 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             from_id:  the stream_ordering of the last event we processed
             current_id:  the stream_ordering of the most recently processed event
             limit: the maximum number of events to return
+            get_prev_content: whether to fetch previous event content
 
         Returns:
-            A tuple of (next_id, events), where `next_id` is the next value to
-            pass as `from_id` (it will either be the stream_ordering of the
-            last returned event, or, if fewer than `limit` events were found,
-            the `current_id`).
+            A tuple of (next_id, events, event_to_received_ts), where `next_id`
+            is the next value to pass as `from_id` (it will either be the
+            stream_ordering of the last returned event, or, if fewer than `limit`
+            events were found, the `current_id`). The `event_to_received_ts` is
+            a dictionary mapping event ID to the event `received_ts`.
         """
 
         def get_all_new_events_stream_txn(
             txn: LoggingTransaction,
-        ) -> Tuple[int, List[str]]:
+        ) -> Tuple[int, Dict[str, Optional[int]]]:
             sql = (
-                "SELECT e.stream_ordering, e.event_id"
+                "SELECT e.stream_ordering, e.event_id, e.received_ts"
                 " FROM events AS e"
                 " WHERE"
                 " ? < e.stream_ordering AND e.stream_ordering <= ?"
@@ -1059,15 +1061,21 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             if len(rows) == limit:
                 upper_bound = rows[-1][0]
 
-            return upper_bound, [row[1] for row in rows]
+            event_to_received_ts: Dict[str, Optional[int]] = {
+                row[1]: row[2] for row in rows
+            }
+            return upper_bound, event_to_received_ts
 
-        upper_bound, event_ids = await self.db_pool.runInteraction(
+        upper_bound, event_to_received_ts = await self.db_pool.runInteraction(
             "get_all_new_events_stream", get_all_new_events_stream_txn
         )
 
-        events = await self.get_events_as_list(event_ids)
+        events = await self.get_events_as_list(
+            event_to_received_ts.keys(),
+            get_prev_content=get_prev_content,
+        )
 
-        return upper_bound, events
+        return upper_bound, events, event_to_received_ts
 
     async def get_federation_out_pos(self, typ: str) -> int:
         if self._need_to_reset_federation_stream_positions: