diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py
index 445213e12a..3151186e0c 100644
--- a/synapse/storage/databases/main/state_deltas.py
+++ b/synapse/storage/databases/main/state_deltas.py
@@ -13,7 +13,9 @@
# limitations under the License.
import logging
-from typing import Any, Dict, List, Tuple
+from typing import List, Optional, Tuple
+
+import attr
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import LoggingTransaction
@@ -22,6 +24,20 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
logger = logging.getLogger(__name__)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class StateDelta:
+ stream_id: int
+ room_id: str
+ event_type: str
+ state_key: str
+
+ event_id: Optional[str]
+ """new event_id for this state key. None if the state has been deleted."""
+
+ prev_event_id: Optional[str]
+ """previous event_id for this state key. None if it's new state."""
+
+
class StateDeltasStore(SQLBaseStore):
# This class must be mixed in with a child class which provides the following
# attribute. TODO: can we get static analysis to enforce this?
@@ -29,31 +45,21 @@ class StateDeltasStore(SQLBaseStore):
async def get_partial_current_state_deltas(
self, prev_stream_id: int, max_stream_id: int
- ) -> Tuple[int, List[Dict[str, Any]]]:
+ ) -> Tuple[int, List[StateDelta]]:
"""Fetch a list of room state changes since the given stream id
- Each entry in the result contains the following fields:
- - stream_id (int)
- - room_id (str)
- - type (str): event type
- - state_key (str):
- - event_id (str|None): new event_id for this state key. None if the
- state has been deleted.
- - prev_event_id (str|None): previous event_id for this state key. None
- if it's new state.
-
This may be the partial state if we're lazy joining the room.
Args:
prev_stream_id: point to get changes since (exclusive)
max_stream_id: the point that we know has been correctly persisted
- - ie, an upper limit to return changes from.
+ - ie, an upper limit to return changes from.
Returns:
A tuple consisting of:
- - the stream id which these results go up to
- - list of current_state_delta_stream rows. If it is empty, we are
- up to date.
+ - the stream id which these results go up to
+ - list of current_state_delta_stream rows. If it is empty, we are
+ up to date.
"""
prev_stream_id = int(prev_stream_id)
@@ -72,7 +78,7 @@ class StateDeltasStore(SQLBaseStore):
def get_current_state_deltas_txn(
txn: LoggingTransaction,
- ) -> Tuple[int, List[Dict[str, Any]]]:
+ ) -> Tuple[int, List[StateDelta]]:
# First we calculate the max stream id that will give us less than
# N results.
# We arbitrarily limit to 100 stream_id entries to ensure we don't
@@ -112,7 +118,17 @@ class StateDeltasStore(SQLBaseStore):
ORDER BY stream_id ASC
"""
txn.execute(sql, (prev_stream_id, clipped_stream_id))
- return clipped_stream_id, self.db_pool.cursor_to_dict(txn)
+ return clipped_stream_id, [
+ StateDelta(
+ stream_id=row[0],
+ room_id=row[1],
+ event_type=row[2],
+ state_key=row[3],
+ event_id=row[4],
+ prev_event_id=row[5],
+ )
+ for row in txn.fetchall()
+ ]
return await self.db_pool.runInteraction(
"get_current_state_deltas", get_current_state_deltas_txn
|