summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2023-10-16 07:35:22 -0400
committerGitHub <noreply@github.com>2023-10-16 07:35:22 -0400
commite3e0ae4ab1f48974ca66a4c4e6be8019aaa38fd1 (patch)
treefdc93c73b1d80f27454c29541dd90a8b704596dd /synapse/storage
parentBump pillow from 10.0.1 to 10.1.0 (#16498) (diff)
downloadsynapse-e3e0ae4ab1f48974ca66a4c4e6be8019aaa38fd1.tar.xz
Convert state delta processing from a dict to attrs. (#16469)
For improved type checking & memory usage.
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/controllers/state.py14
-rw-r--r--synapse/storage/databases/main/state_deltas.py52
2 files changed, 36 insertions, 30 deletions
diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
index 46957723a1..9f7959c45d 100644
--- a/synapse/storage/controllers/state.py
+++ b/synapse/storage/controllers/state.py
@@ -16,7 +16,6 @@ from itertools import chain
 from typing import (
     TYPE_CHECKING,
     AbstractSet,
-    Any,
     Callable,
     Collection,
     Dict,
@@ -32,6 +31,7 @@ from typing import (
 from synapse.api.constants import EventTypes, Membership
 from synapse.events import EventBase
 from synapse.logging.opentracing import tag_args, trace
+from synapse.storage.databases.main.state_deltas import StateDelta
 from synapse.storage.roommember import ProfileInfo
 from synapse.storage.util.partial_state_events_tracker import (
     PartialCurrentStateTracker,
@@ -531,19 +531,9 @@ class StateStorageController:
     @tag_args
     async def get_current_state_deltas(
         self, prev_stream_id: int, max_stream_id: int
-    ) -> Tuple[int, List[Dict[str, Any]]]:
+    ) -> Tuple[int, List[StateDelta]]:
         """Fetch a list of room state changes since the given stream id
 
-        Each entry in the result contains the following fields:
-            - stream_id (int)
-            - room_id (str)
-            - type (str): event type
-            - state_key (str):
-            - event_id (str|None): new event_id for this state key. None if the
-                state has been deleted.
-            - prev_event_id (str|None): previous event_id for this state key. None
-                if it's new state.
-
         Args:
             prev_stream_id: point to get changes since (exclusive)
             max_stream_id: the point that we know has been correctly persisted
diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py
index 445213e12a..3151186e0c 100644
--- a/synapse/storage/databases/main/state_deltas.py
+++ b/synapse/storage/databases/main/state_deltas.py
@@ -13,7 +13,9 @@
 # limitations under the License.
 
 import logging
-from typing import Any, Dict, List, Tuple
+from typing import List, Optional, Tuple
+
+import attr
 
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import LoggingTransaction
@@ -22,6 +24,20 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
 logger = logging.getLogger(__name__)
 
 
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class StateDelta:
+    stream_id: int
+    room_id: str
+    event_type: str
+    state_key: str
+
+    event_id: Optional[str]
+    """new event_id for this state key. None if the state has been deleted."""
+
+    prev_event_id: Optional[str]
+    """previous event_id for this state key. None if it's new state."""
+
+
 class StateDeltasStore(SQLBaseStore):
     # This class must be mixed in with a child class which provides the following
     # attribute. TODO: can we get static analysis to enforce this?
@@ -29,31 +45,21 @@ class StateDeltasStore(SQLBaseStore):
 
     async def get_partial_current_state_deltas(
         self, prev_stream_id: int, max_stream_id: int
-    ) -> Tuple[int, List[Dict[str, Any]]]:
+    ) -> Tuple[int, List[StateDelta]]:
         """Fetch a list of room state changes since the given stream id
 
-        Each entry in the result contains the following fields:
-            - stream_id (int)
-            - room_id (str)
-            - type (str): event type
-            - state_key (str):
-            - event_id (str|None): new event_id for this state key. None if the
-                state has been deleted.
-            - prev_event_id (str|None): previous event_id for this state key. None
-                if it's new state.
-
         This may be the partial state if we're lazy joining the room.
 
         Args:
             prev_stream_id: point to get changes since (exclusive)
             max_stream_id: the point that we know has been correctly persisted
-               - ie, an upper limit to return changes from.
+                - ie, an upper limit to return changes from.
 
         Returns:
             A tuple consisting of:
-               - the stream id which these results go up to
-               - list of current_state_delta_stream rows. If it is empty, we are
-                 up to date.
+                - the stream id which these results go up to
+                - list of current_state_delta_stream rows. If it is empty, we are
+                  up to date.
         """
         prev_stream_id = int(prev_stream_id)
 
@@ -72,7 +78,7 @@ class StateDeltasStore(SQLBaseStore):
 
         def get_current_state_deltas_txn(
             txn: LoggingTransaction,
-        ) -> Tuple[int, List[Dict[str, Any]]]:
+        ) -> Tuple[int, List[StateDelta]]:
             # First we calculate the max stream id that will give us less than
             # N results.
             # We arbitrarily limit to 100 stream_id entries to ensure we don't
@@ -112,7 +118,17 @@ class StateDeltasStore(SQLBaseStore):
                 ORDER BY stream_id ASC
             """
             txn.execute(sql, (prev_stream_id, clipped_stream_id))
-            return clipped_stream_id, self.db_pool.cursor_to_dict(txn)
+            return clipped_stream_id, [
+                StateDelta(
+                    stream_id=row[0],
+                    room_id=row[1],
+                    event_type=row[2],
+                    state_key=row[3],
+                    event_id=row[4],
+                    prev_event_id=row[5],
+                )
+                for row in txn.fetchall()
+            ]
 
         return await self.db_pool.runInteraction(
             "get_current_state_deltas", get_current_state_deltas_txn