diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py
index eaa13da368..ba52fff652 100644
--- a/synapse/storage/databases/main/state_deltas.py
+++ b/synapse/storage/databases/main/state_deltas.py
@@ -161,45 +161,80 @@ class StateDeltasStore(SQLBaseStore):
self._get_max_stream_id_in_current_state_deltas_txn,
)
- @trace
- async def get_current_state_deltas_for_room(
- self, room_id: str, from_token: RoomStreamToken, to_token: RoomStreamToken
+ def get_current_state_deltas_for_room_txn(
+ self,
+ txn: LoggingTransaction,
+ room_id: str,
+ *,
+ from_token: Optional[RoomStreamToken],
+ to_token: Optional[RoomStreamToken],
) -> List[StateDelta]:
- """Get the state deltas between two tokens."""
-
- if not self._curr_state_delta_stream_cache.has_entity_changed(
- room_id, from_token.stream
- ):
- return []
+ """
+ Get the state deltas between two tokens.
- def get_current_state_deltas_for_room_txn(
- txn: LoggingTransaction,
- ) -> List[StateDelta]:
- sql = """
+ (> `from_token` and <= `to_token`)
+ """
+ from_clause = ""
+ from_args = []
+ if from_token is not None:
+ from_clause = "AND ? < stream_id"
+ from_args = [from_token.stream]
+
+ to_clause = ""
+ to_args = []
+ if to_token is not None:
+ to_clause = "AND stream_id <= ?"
+ to_args = [to_token.get_max_stream_pos()]
+
+ sql = f"""
SELECT instance_name, stream_id, type, state_key, event_id, prev_event_id
FROM current_state_delta_stream
- WHERE room_id = ? AND ? < stream_id AND stream_id <= ?
+ WHERE room_id = ? {from_clause} {to_clause}
ORDER BY stream_id ASC
"""
- txn.execute(
- sql, (room_id, from_token.stream, to_token.get_max_stream_pos())
+ txn.execute(sql, [room_id] + from_args + to_args)
+
+ return [
+ StateDelta(
+ stream_id=row[1],
+ room_id=room_id,
+ event_type=row[2],
+ state_key=row[3],
+ event_id=row[4],
+ prev_event_id=row[5],
)
+ for row in txn
+ if _filter_results_by_stream(from_token, to_token, row[0], row[1])
+ ]
- return [
- StateDelta(
- stream_id=row[1],
- room_id=room_id,
- event_type=row[2],
- state_key=row[3],
- event_id=row[4],
- prev_event_id=row[5],
- )
- for row in txn
- if _filter_results_by_stream(from_token, to_token, row[0], row[1])
- ]
+ @trace
+ async def get_current_state_deltas_for_room(
+ self,
+ room_id: str,
+ *,
+ from_token: Optional[RoomStreamToken],
+ to_token: Optional[RoomStreamToken],
+ ) -> List[StateDelta]:
+ """
+ Get the state deltas between two tokens.
+
+ (> `from_token` and <= `to_token`)
+ """
+
+ if (
+ from_token is not None
+ and not self._curr_state_delta_stream_cache.has_entity_changed(
+ room_id, from_token.stream
+ )
+ ):
+ return []
return await self.db_pool.runInteraction(
- "get_current_state_deltas_for_room", get_current_state_deltas_for_room_txn
+ "get_current_state_deltas_for_room",
+ self.get_current_state_deltas_for_room_txn,
+ room_id,
+ from_token=from_token,
+ to_token=to_token,
)
@trace
|