summary refs log tree commit diff
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2023-10-16 07:35:22 -0400
committerGitHub <noreply@github.com>2023-10-16 07:35:22 -0400
commite3e0ae4ab1f48974ca66a4c4e6be8019aaa38fd1 (patch)
treefdc93c73b1d80f27454c29541dd90a8b704596dd
parentBump pillow from 10.0.1 to 10.1.0 (#16498) (diff)
downloadsynapse-e3e0ae4ab1f48974ca66a4c4e6be8019aaa38fd1.tar.xz
Convert state delta processing from a dict to attrs. (#16469)
For improved type checking & memory usage.
-rw-r--r--changelog.d/16469.misc1
-rw-r--r--synapse/handlers/presence.py32
-rw-r--r--synapse/handlers/room_member.py21
-rw-r--r--synapse/handlers/stats.py64
-rw-r--r--synapse/handlers/user_directory.py34
-rw-r--r--synapse/storage/controllers/state.py14
-rw-r--r--synapse/storage/databases/main/state_deltas.py52
-rw-r--r--tests/handlers/test_typing.py2
8 files changed, 111 insertions, 109 deletions
diff --git a/changelog.d/16469.misc b/changelog.d/16469.misc
new file mode 100644
index 0000000000..93ceaeafc9
--- /dev/null
+++ b/changelog.d/16469.misc
@@ -0,0 +1 @@
+Improve type hints.
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 7c7cda3e95..dfc0b9db07 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -110,6 +110,7 @@ from synapse.replication.http.streams import ReplicationGetStreamUpdates
 from synapse.replication.tcp.commands import ClearUserSyncsCommand
 from synapse.replication.tcp.streams import PresenceFederationStream, PresenceStream
 from synapse.storage.databases.main import DataStore
+from synapse.storage.databases.main.state_deltas import StateDelta
 from synapse.streams import EventSource
 from synapse.types import (
     JsonDict,
@@ -1499,9 +1500,9 @@ class PresenceHandler(BasePresenceHandler):
                 # We may get multiple deltas for different rooms, but we want to
                 # handle them on a room by room basis, so we batch them up by
                 # room.
-                deltas_by_room: Dict[str, List[JsonDict]] = {}
+                deltas_by_room: Dict[str, List[StateDelta]] = {}
                 for delta in deltas:
-                    deltas_by_room.setdefault(delta["room_id"], []).append(delta)
+                    deltas_by_room.setdefault(delta.room_id, []).append(delta)
 
                 for room_id, deltas_for_room in deltas_by_room.items():
                     await self._handle_state_delta(room_id, deltas_for_room)
@@ -1513,7 +1514,7 @@ class PresenceHandler(BasePresenceHandler):
                     max_pos
                 )
 
-    async def _handle_state_delta(self, room_id: str, deltas: List[JsonDict]) -> None:
+    async def _handle_state_delta(self, room_id: str, deltas: List[StateDelta]) -> None:
         """Process current state deltas for the room to find new joins that need
         to be handled.
         """
@@ -1524,31 +1525,30 @@ class PresenceHandler(BasePresenceHandler):
         newly_joined_users = set()
 
         for delta in deltas:
-            assert room_id == delta["room_id"]
+            assert room_id == delta.room_id
 
-            typ = delta["type"]
-            state_key = delta["state_key"]
-            event_id = delta["event_id"]
-            prev_event_id = delta["prev_event_id"]
-
-            logger.debug("Handling: %r %r, %s", typ, state_key, event_id)
+            logger.debug(
+                "Handling: %r %r, %s", delta.event_type, delta.state_key, delta.event_id
+            )
 
             # Drop any event that isn't a membership join
-            if typ != EventTypes.Member:
+            if delta.event_type != EventTypes.Member:
                 continue
 
-            if event_id is None:
+            if delta.event_id is None:
                 # state has been deleted, so this is not a join. We only care about
                 # joins.
                 continue
 
-            event = await self.store.get_event(event_id, allow_none=True)
+            event = await self.store.get_event(delta.event_id, allow_none=True)
             if not event or event.content.get("membership") != Membership.JOIN:
                 # We only care about joins
                 continue
 
-            if prev_event_id:
-                prev_event = await self.store.get_event(prev_event_id, allow_none=True)
+            if delta.prev_event_id:
+                prev_event = await self.store.get_event(
+                    delta.prev_event_id, allow_none=True
+                )
                 if (
                     prev_event
                     and prev_event.content.get("membership") == Membership.JOIN
@@ -1556,7 +1556,7 @@ class PresenceHandler(BasePresenceHandler):
                     # Ignore changes to join events.
                     continue
 
-            newly_joined_users.add(state_key)
+            newly_joined_users.add(delta.state_key)
 
         if not newly_joined_users:
             # If nobody has joined then there's nothing to do.
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 130eee7e1d..918eb203e2 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -16,7 +16,7 @@ import abc
 import logging
 import random
 from http import HTTPStatus
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Iterable, List, Optional, Set, Tuple
 
 from synapse import types
 from synapse.api.constants import (
@@ -44,6 +44,7 @@ from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME
 from synapse.logging import opentracing
 from synapse.metrics import event_processing_positions
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.databases.main.state_deltas import StateDelta
 from synapse.types import (
     JsonDict,
     Requester,
@@ -2146,24 +2147,18 @@ class RoomForgetterHandler(StateDeltasHandler):
 
             await self._store.update_room_forgetter_stream_pos(max_pos)
 
-    async def _handle_deltas(self, deltas: List[Dict[str, Any]]) -> None:
+    async def _handle_deltas(self, deltas: List[StateDelta]) -> None:
         """Called with the state deltas to process"""
         for delta in deltas:
-            typ = delta["type"]
-            state_key = delta["state_key"]
-            room_id = delta["room_id"]
-            event_id = delta["event_id"]
-            prev_event_id = delta["prev_event_id"]
-
-            if typ != EventTypes.Member:
+            if delta.event_type != EventTypes.Member:
                 continue
 
-            if not self._hs.is_mine_id(state_key):
+            if not self._hs.is_mine_id(delta.state_key):
                 continue
 
             change = await self._get_key_change(
-                prev_event_id,
-                event_id,
+                delta.prev_event_id,
+                delta.event_id,
                 key_name="membership",
                 public_value=Membership.JOIN,
             )
@@ -2172,7 +2167,7 @@ class RoomForgetterHandler(StateDeltasHandler):
             if is_leave:
                 try:
                     await self._room_member_handler.forget(
-                        UserID.from_string(state_key), room_id
+                        UserID.from_string(delta.state_key), delta.room_id
                     )
                 except SynapseError as e:
                     if e.code == 400:
diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
index 3dde19fc81..817b41aa37 100644
--- a/synapse/handlers/stats.py
+++ b/synapse/handlers/stats.py
@@ -27,6 +27,7 @@ from typing import (
 from synapse.api.constants import EventContentFields, EventTypes, Membership
 from synapse.metrics import event_processing_positions
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.databases.main.state_deltas import StateDelta
 from synapse.types import JsonDict
 
 if TYPE_CHECKING:
@@ -142,7 +143,7 @@ class StatsHandler:
             self.pos = max_pos
 
     async def _handle_deltas(
-        self, deltas: Iterable[JsonDict]
+        self, deltas: Iterable[StateDelta]
     ) -> Tuple[Dict[str, CounterType[str]], Dict[str, CounterType[str]]]:
         """Called with the state deltas to process
 
@@ -157,51 +158,50 @@ class StatsHandler:
         room_to_state_updates: Dict[str, Dict[str, Any]] = {}
 
         for delta in deltas:
-            typ = delta["type"]
-            state_key = delta["state_key"]
-            room_id = delta["room_id"]
-            event_id = delta["event_id"]
-            stream_id = delta["stream_id"]
-            prev_event_id = delta["prev_event_id"]
-
-            logger.debug("Handling: %r, %r %r, %s", room_id, typ, state_key, event_id)
+            logger.debug(
+                "Handling: %r, %r %r, %s",
+                delta.room_id,
+                delta.event_type,
+                delta.state_key,
+                delta.event_id,
+            )
 
-            token = await self.store.get_earliest_token_for_stats("room", room_id)
+            token = await self.store.get_earliest_token_for_stats("room", delta.room_id)
 
             # If the earliest token to begin from is larger than our current
             # stream ID, skip processing this delta.
-            if token is not None and token >= stream_id:
+            if token is not None and token >= delta.stream_id:
                 logger.debug(
                     "Ignoring: %s as earlier than this room's initial ingestion event",
-                    event_id,
+                    delta.event_id,
                 )
                 continue
 
-            if event_id is None and prev_event_id is None:
+            if delta.event_id is None and delta.prev_event_id is None:
                 logger.error(
                     "event ID is None and so is the previous event ID. stream_id: %s",
-                    stream_id,
+                    delta.stream_id,
                 )
                 continue
 
             event_content: JsonDict = {}
 
-            if event_id is not None:
-                event = await self.store.get_event(event_id, allow_none=True)
+            if delta.event_id is not None:
+                event = await self.store.get_event(delta.event_id, allow_none=True)
                 if event:
                     event_content = event.content or {}
 
             # All the values in this dict are deltas (RELATIVE changes)
-            room_stats_delta = room_to_stats_deltas.setdefault(room_id, Counter())
+            room_stats_delta = room_to_stats_deltas.setdefault(delta.room_id, Counter())
 
-            room_state = room_to_state_updates.setdefault(room_id, {})
+            room_state = room_to_state_updates.setdefault(delta.room_id, {})
 
-            if prev_event_id is None:
+            if delta.prev_event_id is None:
                 # this state event doesn't overwrite another,
                 # so it is a new effective/current state event
                 room_stats_delta["current_state_events"] += 1
 
-            if typ == EventTypes.Member:
+            if delta.event_type == EventTypes.Member:
                 # we could use StateDeltasHandler._get_key_change here but it's
                 # a bit inefficient given we're not testing for a specific
                 # result; might as well just grab the prev_membership and
@@ -210,9 +210,9 @@ class StatsHandler:
                 # in the absence of a previous event because we do not want to
                 # reduce the leave count when a new-to-the-room user joins.
                 prev_membership = None
-                if prev_event_id is not None:
+                if delta.prev_event_id is not None:
                     prev_event = await self.store.get_event(
-                        prev_event_id, allow_none=True
+                        delta.prev_event_id, allow_none=True
                     )
                     if prev_event:
                         prev_event_content = prev_event.content
@@ -256,7 +256,7 @@ class StatsHandler:
                 else:
                     raise ValueError("%r is not a valid membership" % (membership,))
 
-                user_id = state_key
+                user_id = delta.state_key
                 if self.is_mine_id(user_id):
                     # this accounts for transitions like leave → ban and so on.
                     has_changed_joinedness = (prev_membership == Membership.JOIN) != (
@@ -272,30 +272,30 @@ class StatsHandler:
 
                         room_stats_delta["local_users_in_room"] += membership_delta
 
-            elif typ == EventTypes.Create:
+            elif delta.event_type == EventTypes.Create:
                 room_state["is_federatable"] = (
                     event_content.get(EventContentFields.FEDERATE, True) is True
                 )
                 room_type = event_content.get(EventContentFields.ROOM_TYPE)
                 if isinstance(room_type, str):
                     room_state["room_type"] = room_type
-            elif typ == EventTypes.JoinRules:
+            elif delta.event_type == EventTypes.JoinRules:
                 room_state["join_rules"] = event_content.get("join_rule")
-            elif typ == EventTypes.RoomHistoryVisibility:
+            elif delta.event_type == EventTypes.RoomHistoryVisibility:
                 room_state["history_visibility"] = event_content.get(
                     "history_visibility"
                 )
-            elif typ == EventTypes.RoomEncryption:
+            elif delta.event_type == EventTypes.RoomEncryption:
                 room_state["encryption"] = event_content.get("algorithm")
-            elif typ == EventTypes.Name:
+            elif delta.event_type == EventTypes.Name:
                 room_state["name"] = event_content.get("name")
-            elif typ == EventTypes.Topic:
+            elif delta.event_type == EventTypes.Topic:
                 room_state["topic"] = event_content.get("topic")
-            elif typ == EventTypes.RoomAvatar:
+            elif delta.event_type == EventTypes.RoomAvatar:
                 room_state["avatar"] = event_content.get("url")
-            elif typ == EventTypes.CanonicalAlias:
+            elif delta.event_type == EventTypes.CanonicalAlias:
                 room_state["canonical_alias"] = event_content.get("alias")
-            elif typ == EventTypes.GuestAccess:
+            elif delta.event_type == EventTypes.GuestAccess:
                 room_state["guest_access"] = event_content.get(
                     EventContentFields.GUEST_ACCESS
                 )
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index a0f5568000..75717ba4f9 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -14,7 +14,7 @@
 
 import logging
 from http import HTTPStatus
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, List, Optional, Set, Tuple
 
 from twisted.internet.interfaces import IDelayedCall
 
@@ -23,6 +23,7 @@ from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules, Memb
 from synapse.api.errors import Codes, SynapseError
 from synapse.handlers.state_deltas import MatchChange, StateDeltasHandler
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.databases.main.state_deltas import StateDelta
 from synapse.storage.databases.main.user_directory import SearchResult
 from synapse.storage.roommember import ProfileInfo
 from synapse.types import UserID
@@ -247,32 +248,31 @@ class UserDirectoryHandler(StateDeltasHandler):
 
                 await self.store.update_user_directory_stream_pos(max_pos)
 
-    async def _handle_deltas(self, deltas: List[Dict[str, Any]]) -> None:
+    async def _handle_deltas(self, deltas: List[StateDelta]) -> None:
         """Called with the state deltas to process"""
         for delta in deltas:
-            typ = delta["type"]
-            state_key = delta["state_key"]
-            room_id = delta["room_id"]
-            event_id: Optional[str] = delta["event_id"]
-            prev_event_id: Optional[str] = delta["prev_event_id"]
-
-            logger.debug("Handling: %r %r, %s", typ, state_key, event_id)
+            logger.debug(
+                "Handling: %r %r, %s", delta.event_type, delta.state_key, delta.event_id
+            )
 
             # For join rule and visibility changes we need to check if the room
             # may have become public or not and add/remove the users in said room
-            if typ in (EventTypes.RoomHistoryVisibility, EventTypes.JoinRules):
+            if delta.event_type in (
+                EventTypes.RoomHistoryVisibility,
+                EventTypes.JoinRules,
+            ):
                 await self._handle_room_publicity_change(
-                    room_id, prev_event_id, event_id, typ
+                    delta.room_id, delta.prev_event_id, delta.event_id, delta.event_type
                 )
-            elif typ == EventTypes.Member:
+            elif delta.event_type == EventTypes.Member:
                 await self._handle_room_membership_event(
-                    room_id,
-                    prev_event_id,
-                    event_id,
-                    state_key,
+                    delta.room_id,
+                    delta.prev_event_id,
+                    delta.event_id,
+                    delta.state_key,
                 )
             else:
-                logger.debug("Ignoring irrelevant type: %r", typ)
+                logger.debug("Ignoring irrelevant type: %r", delta.event_type)
 
     async def _handle_room_publicity_change(
         self,
diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
index 46957723a1..9f7959c45d 100644
--- a/synapse/storage/controllers/state.py
+++ b/synapse/storage/controllers/state.py
@@ -16,7 +16,6 @@ from itertools import chain
 from typing import (
     TYPE_CHECKING,
     AbstractSet,
-    Any,
     Callable,
     Collection,
     Dict,
@@ -32,6 +31,7 @@ from typing import (
 from synapse.api.constants import EventTypes, Membership
 from synapse.events import EventBase
 from synapse.logging.opentracing import tag_args, trace
+from synapse.storage.databases.main.state_deltas import StateDelta
 from synapse.storage.roommember import ProfileInfo
 from synapse.storage.util.partial_state_events_tracker import (
     PartialCurrentStateTracker,
@@ -531,19 +531,9 @@ class StateStorageController:
     @tag_args
     async def get_current_state_deltas(
         self, prev_stream_id: int, max_stream_id: int
-    ) -> Tuple[int, List[Dict[str, Any]]]:
+    ) -> Tuple[int, List[StateDelta]]:
         """Fetch a list of room state changes since the given stream id
 
-        Each entry in the result contains the following fields:
-            - stream_id (int)
-            - room_id (str)
-            - type (str): event type
-            - state_key (str):
-            - event_id (str|None): new event_id for this state key. None if the
-                state has been deleted.
-            - prev_event_id (str|None): previous event_id for this state key. None
-                if it's new state.
-
         Args:
             prev_stream_id: point to get changes since (exclusive)
             max_stream_id: the point that we know has been correctly persisted
diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py
index 445213e12a..3151186e0c 100644
--- a/synapse/storage/databases/main/state_deltas.py
+++ b/synapse/storage/databases/main/state_deltas.py
@@ -13,7 +13,9 @@
 # limitations under the License.
 
 import logging
-from typing import Any, Dict, List, Tuple
+from typing import List, Optional, Tuple
+
+import attr
 
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import LoggingTransaction
@@ -22,6 +24,20 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
 logger = logging.getLogger(__name__)
 
 
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class StateDelta:
+    stream_id: int
+    room_id: str
+    event_type: str
+    state_key: str
+
+    event_id: Optional[str]
+    """new event_id for this state key. None if the state has been deleted."""
+
+    prev_event_id: Optional[str]
+    """previous event_id for this state key. None if it's new state."""
+
+
 class StateDeltasStore(SQLBaseStore):
     # This class must be mixed in with a child class which provides the following
     # attribute. TODO: can we get static analysis to enforce this?
@@ -29,31 +45,21 @@ class StateDeltasStore(SQLBaseStore):
 
     async def get_partial_current_state_deltas(
         self, prev_stream_id: int, max_stream_id: int
-    ) -> Tuple[int, List[Dict[str, Any]]]:
+    ) -> Tuple[int, List[StateDelta]]:
         """Fetch a list of room state changes since the given stream id
 
-        Each entry in the result contains the following fields:
-            - stream_id (int)
-            - room_id (str)
-            - type (str): event type
-            - state_key (str):
-            - event_id (str|None): new event_id for this state key. None if the
-                state has been deleted.
-            - prev_event_id (str|None): previous event_id for this state key. None
-                if it's new state.
-
         This may be the partial state if we're lazy joining the room.
 
         Args:
             prev_stream_id: point to get changes since (exclusive)
             max_stream_id: the point that we know has been correctly persisted
-               - ie, an upper limit to return changes from.
+                - ie, an upper limit to return changes from.
 
         Returns:
             A tuple consisting of:
-               - the stream id which these results go up to
-               - list of current_state_delta_stream rows. If it is empty, we are
-                 up to date.
+                - the stream id which these results go up to
+                - list of current_state_delta_stream rows. If it is empty, we are
+                  up to date.
         """
         prev_stream_id = int(prev_stream_id)
 
@@ -72,7 +78,7 @@ class StateDeltasStore(SQLBaseStore):
 
         def get_current_state_deltas_txn(
             txn: LoggingTransaction,
-        ) -> Tuple[int, List[Dict[str, Any]]]:
+        ) -> Tuple[int, List[StateDelta]]:
             # First we calculate the max stream id that will give us less than
             # N results.
             # We arbitrarily limit to 100 stream_id entries to ensure we don't
@@ -112,7 +118,17 @@ class StateDeltasStore(SQLBaseStore):
                 ORDER BY stream_id ASC
             """
             txn.execute(sql, (prev_stream_id, clipped_stream_id))
-            return clipped_stream_id, self.db_pool.cursor_to_dict(txn)
+            return clipped_stream_id, [
+                StateDelta(
+                    stream_id=row[0],
+                    room_id=row[1],
+                    event_type=row[2],
+                    state_key=row[3],
+                    event_id=row[4],
+                    prev_event_id=row[5],
+                )
+                for row in txn.fetchall()
+            ]
 
         return await self.db_pool.runInteraction(
             "get_current_state_deltas", get_current_state_deltas_txn
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 3060bc9744..d7025c6f2c 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -174,7 +174,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
             return_value=1
         )
 
-        self.datastore.get_partial_current_state_deltas = Mock(return_value=(0, None))  # type: ignore[method-assign]
+        self.datastore.get_partial_current_state_deltas = Mock(return_value=(0, []))  # type: ignore[method-assign]
 
         self.datastore.get_to_device_stream_token = Mock(  # type: ignore[method-assign]
             return_value=0