diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 8e88784d3c..a347430aa7 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -46,16 +46,19 @@ from typing import (
Set,
Tuple,
cast,
+ overload,
)
import attr
from frozendict import frozendict
+from typing_extensions import Literal
from twisted.internet import defer
from synapse.api.filtering import Filter
from synapse.events import EventBase
from synapse.logging.context import make_deferred_yieldable, run_in_background
+from synapse.logging.opentracing import trace
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
@@ -795,6 +798,24 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
)
return RoomStreamToken(topo, stream_ordering)
+ @overload
+ def get_stream_id_for_event_txn(
+ self,
+ txn: LoggingTransaction,
+ event_id: str,
+ allow_none: Literal[False] = False,
+ ) -> int:
+ ...
+
+ @overload
+ def get_stream_id_for_event_txn(
+ self,
+ txn: LoggingTransaction,
+ event_id: str,
+ allow_none: bool = False,
+ ) -> Optional[int]:
+ ...
+
def get_stream_id_for_event_txn(
self,
txn: LoggingTransaction,
@@ -1002,8 +1023,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
}
async def get_all_new_events_stream(
- self, from_id: int, current_id: int, limit: int
- ) -> Tuple[int, List[EventBase]]:
+ self, from_id: int, current_id: int, limit: int, get_prev_content: bool = False
+ ) -> Tuple[int, List[EventBase], Dict[str, Optional[int]]]:
"""Get all new events
Returns all events with from_id < stream_ordering <= current_id.
@@ -1012,19 +1033,21 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
from_id: the stream_ordering of the last event we processed
current_id: the stream_ordering of the most recently processed event
limit: the maximum number of events to return
+ get_prev_content: whether to fetch previous event content
Returns:
- A tuple of (next_id, events), where `next_id` is the next value to
- pass as `from_id` (it will either be the stream_ordering of the
- last returned event, or, if fewer than `limit` events were found,
- the `current_id`).
+ A tuple of (next_id, events, event_to_received_ts), where `next_id`
+ is the next value to pass as `from_id` (it will either be the
+ stream_ordering of the last returned event, or, if fewer than `limit`
+ events were found, the `current_id`). The `event_to_received_ts` is
+ a dictionary mapping event ID to the event `received_ts`.
"""
def get_all_new_events_stream_txn(
txn: LoggingTransaction,
- ) -> Tuple[int, List[str]]:
+ ) -> Tuple[int, Dict[str, Optional[int]]]:
sql = (
- "SELECT e.stream_ordering, e.event_id"
+ "SELECT e.stream_ordering, e.event_id, e.received_ts"
" FROM events AS e"
" WHERE"
" ? < e.stream_ordering AND e.stream_ordering <= ?"
@@ -1039,15 +1062,21 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
if len(rows) == limit:
upper_bound = rows[-1][0]
- return upper_bound, [row[1] for row in rows]
+ event_to_received_ts: Dict[str, Optional[int]] = {
+ row[1]: row[2] for row in rows
+ }
+ return upper_bound, event_to_received_ts
- upper_bound, event_ids = await self.db_pool.runInteraction(
+ upper_bound, event_to_received_ts = await self.db_pool.runInteraction(
"get_all_new_events_stream", get_all_new_events_stream_txn
)
- events = await self.get_events_as_list(event_ids)
+ events = await self.get_events_as_list(
+ event_to_received_ts.keys(),
+ get_prev_content=get_prev_content,
+ )
- return upper_bound, events
+ return upper_bound, events, event_to_received_ts
async def get_federation_out_pos(self, typ: str) -> int:
if self._need_to_reset_federation_stream_positions:
@@ -1318,6 +1347,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return rows, next_token
+ @trace
async def paginate_room_events(
self,
room_id: str,
|