summary refs log tree commit diff
path: root/synapse/storage/databases/main
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main')
-rw-r--r--synapse/storage/databases/main/appservice.py58
-rw-r--r--synapse/storage/databases/main/events_worker.py19
-rw-r--r--synapse/storage/databases/main/stream.py32
3 files changed, 38 insertions, 71 deletions
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index e284454b66..64b70a7b28 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -371,52 +371,30 @@ class ApplicationServiceTransactionWorkerStore(
             device_list_summary=DeviceListUpdates(),
         )
 
-    async def set_appservice_last_pos(self, pos: int) -> None:
-        def set_appservice_last_pos_txn(txn: LoggingTransaction) -> None:
-            txn.execute(
-                "UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
-            )
+    async def get_appservice_last_pos(self) -> int:
+        """
+        Get the last stream ordering position for the appservice process.
+        """
 
-        await self.db_pool.runInteraction(
-            "set_appservice_last_pos", set_appservice_last_pos_txn
+        return await self.db_pool.simple_select_one_onecol(
+            table="appservice_stream_position",
+            retcol="stream_ordering",
+            keyvalues={},
+            desc="get_appservice_last_pos",
         )
 
-    async def get_new_events_for_appservice(
-        self, current_id: int, limit: int
-    ) -> Tuple[int, List[EventBase]]:
-        """Get all new events for an appservice"""
-
-        def get_new_events_for_appservice_txn(
-            txn: LoggingTransaction,
-        ) -> Tuple[int, List[str]]:
-            sql = (
-                "SELECT e.stream_ordering, e.event_id"
-                " FROM events AS e"
-                " WHERE"
-                " (SELECT stream_ordering FROM appservice_stream_position)"
-                "     < e.stream_ordering"
-                " AND e.stream_ordering <= ?"
-                " ORDER BY e.stream_ordering ASC"
-                " LIMIT ?"
-            )
-
-            txn.execute(sql, (current_id, limit))
-            rows = txn.fetchall()
-
-            upper_bound = current_id
-            if len(rows) == limit:
-                upper_bound = rows[-1][0]
-
-            return upper_bound, [row[1] for row in rows]
+    async def set_appservice_last_pos(self, pos: int) -> None:
+        """
+        Set the last stream ordering position for the appservice process.
+        """
 
-        upper_bound, event_ids = await self.db_pool.runInteraction(
-            "get_new_events_for_appservice", get_new_events_for_appservice_txn
+        await self.db_pool.simple_update_one(
+            table="appservice_stream_position",
+            keyvalues={},
+            updatevalues={"stream_ordering": pos},
+            desc="set_appservice_last_pos",
         )
 
-        events = await self.get_events_as_list(event_ids, get_prev_content=True)
-
-        return upper_bound, events
-
     async def get_type_stream_id_for_appservice(
         self, service: ApplicationService, type: str
     ) -> int:
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index b99b107784..621f92e238 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -292,25 +292,6 @@ class EventsWorkerStore(SQLBaseStore):
 
         super().process_replication_rows(stream_name, instance_name, token, rows)
 
-    async def get_received_ts(self, event_id: str) -> Optional[int]:
-        """Get received_ts (when it was persisted) for the event.
-
-        Raises an exception for unknown events.
-
-        Args:
-            event_id: The event ID to query.
-
-        Returns:
-            Timestamp in milliseconds, or None for events that were persisted
-            before received_ts was implemented.
-        """
-        return await self.db_pool.simple_select_one_onecol(
-            table="events",
-            keyvalues={"event_id": event_id},
-            retcol="received_ts",
-            desc="get_received_ts",
-        )
-
     async def have_censored_event(self, event_id: str) -> bool:
         """Check if an event has been censored, i.e. if the content of the event has been erased
         from the database due to a redaction.
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 3a1df7776c..2590b52f73 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -1022,8 +1022,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         }
 
     async def get_all_new_events_stream(
-        self, from_id: int, current_id: int, limit: int
-    ) -> Tuple[int, List[EventBase]]:
+        self, from_id: int, current_id: int, limit: int, get_prev_content: bool = False
+    ) -> Tuple[int, List[EventBase], Dict[str, Optional[int]]]:
         """Get all new events
 
         Returns all events with from_id < stream_ordering <= current_id.
@@ -1032,19 +1032,21 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             from_id:  the stream_ordering of the last event we processed
             current_id:  the stream_ordering of the most recently processed event
             limit: the maximum number of events to return
+            get_prev_content: whether to fetch previous event content
 
         Returns:
-            A tuple of (next_id, events), where `next_id` is the next value to
-            pass as `from_id` (it will either be the stream_ordering of the
-            last returned event, or, if fewer than `limit` events were found,
-            the `current_id`).
+            A tuple of (next_id, events, event_to_received_ts), where `next_id`
+            is the next value to pass as `from_id` (it will either be the
+            stream_ordering of the last returned event, or, if fewer than `limit`
+            events were found, the `current_id`). The `event_to_received_ts` is
+            a dictionary mapping event ID to the event `received_ts`.
         """
 
         def get_all_new_events_stream_txn(
             txn: LoggingTransaction,
-        ) -> Tuple[int, List[str]]:
+        ) -> Tuple[int, Dict[str, Optional[int]]]:
             sql = (
-                "SELECT e.stream_ordering, e.event_id"
+                "SELECT e.stream_ordering, e.event_id, e.received_ts"
                 " FROM events AS e"
                 " WHERE"
                 " ? < e.stream_ordering AND e.stream_ordering <= ?"
@@ -1059,15 +1061,21 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             if len(rows) == limit:
                 upper_bound = rows[-1][0]
 
-            return upper_bound, [row[1] for row in rows]
+            event_to_received_ts: Dict[str, Optional[int]] = {
+                row[1]: row[2] for row in rows
+            }
+            return upper_bound, event_to_received_ts
 
-        upper_bound, event_ids = await self.db_pool.runInteraction(
+        upper_bound, event_to_received_ts = await self.db_pool.runInteraction(
             "get_all_new_events_stream", get_all_new_events_stream_txn
         )
 
-        events = await self.get_events_as_list(event_ids)
+        events = await self.get_events_as_list(
+            event_to_received_ts.keys(),
+            get_prev_content=get_prev_content,
+        )
 
-        return upper_bound, events
+        return upper_bound, events, event_to_received_ts
 
     async def get_federation_out_pos(self, typ: str) -> int:
         if self._need_to_reset_federation_stream_positions: