diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
index ce8be72bfe..73df6b33ba 100644
--- a/synapse/storage/data_stores/main/events_worker.py
+++ b/synapse/storage/data_stores/main/events_worker.py
@@ -19,7 +19,7 @@ import itertools
import logging
import threading
from collections import namedtuple
-from typing import List, Optional
+from typing import List, Optional, Tuple
from canonicaljson import json
from constantly import NamedConstant, Names
@@ -1084,7 +1084,28 @@ class EventsWorkerStore(SQLBaseStore):
"get_all_new_backfill_event_rows", get_all_new_backfill_event_rows
)
- def get_all_updated_current_state_deltas(self, from_token, to_token, limit):
+ async def get_all_updated_current_state_deltas(
+ self, from_token: int, to_token: int, target_row_count: int
+ ) -> Tuple[List[Tuple], int, bool]:
+ """Fetch updates from current_state_delta_stream
+
+ Args:
+ from_token: The previous stream token. Updates from this stream id will
+ be excluded.
+
+ to_token: The current stream token (ie the upper limit). Updates up to this
+ stream id will be included (modulo the 'limit' param)
+
+ target_row_count: The number of rows to try to return. If more rows are
+ available, we will set 'limited' in the result. In the event of a large
+ batch, we may return more rows than this.
+ Returns:
+ A triplet `(updates, new_last_token, limited)`, where:
+ * `updates` is a list of database tuples.
+ * `new_last_token` is the new position in stream.
+ * `limited` is whether there are more updates to fetch.
+ """
+
def get_all_updated_current_state_deltas_txn(txn):
sql = """
SELECT stream_id, room_id, type, state_key, event_id
@@ -1092,10 +1113,45 @@ class EventsWorkerStore(SQLBaseStore):
WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC LIMIT ?
"""
- txn.execute(sql, (from_token, to_token, limit))
+ txn.execute(sql, (from_token, to_token, target_row_count))
return txn.fetchall()
- return self.db.runInteraction(
+ def get_deltas_for_stream_id_txn(txn, stream_id):
+ sql = """
+ SELECT stream_id, room_id, type, state_key, event_id
+ FROM current_state_delta_stream
+ WHERE stream_id = ?
+ """
+ txn.execute(sql, [stream_id])
+ return txn.fetchall()
+
+ # we need to make sure that, for every stream id in the results, we get *all*
+ # the rows with that stream id.
+
+ rows = await self.db.runInteraction(
"get_all_updated_current_state_deltas",
get_all_updated_current_state_deltas_txn,
+ ) # type: List[Tuple]
+
+ # if we've got fewer rows than the limit, we're good
+ if len(rows) < target_row_count:
+ return rows, to_token, False
+
+ # we hit the limit, so reduce the upper limit so that we exclude the stream id
+ # of the last row in the result.
+ assert rows[-1][0] <= to_token
+ to_token = rows[-1][0] - 1
+
+ # search backwards through the list for the point to truncate
+ for idx in range(len(rows) - 1, 0, -1):
+ if rows[idx - 1][0] <= to_token:
+ return rows[:idx], to_token, True
+
+ # bother. We didn't get a full set of changes for even a single
+ # stream id. let's run the query again, without a row limit, but for
+ # just one stream id.
+ to_token += 1
+ rows = await self.db.runInteraction(
+ "get_deltas_for_stream_id", get_deltas_for_stream_id_txn, to_token
)
+ return rows, to_token, True
|