summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/10482.misc1
-rw-r--r--synapse/state/__init__.py26
-rw-r--r--synapse/storage/databases/state/store.py17
-rw-r--r--synapse/storage/state.py4
4 files changed, 29 insertions, 19 deletions
diff --git a/changelog.d/10482.misc b/changelog.d/10482.misc
new file mode 100644
index 0000000000..4e9e2126e1
--- /dev/null
+++ b/changelog.d/10482.misc
@@ -0,0 +1 @@
+Additional type hints in the state handler.
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 2e15471435..463ce58dae 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -16,6 +16,7 @@ import heapq
 import logging
 from collections import defaultdict, namedtuple
 from typing import (
+    TYPE_CHECKING,
     Any,
     Awaitable,
     Callable,
@@ -52,6 +53,10 @@ from synapse.util.async_helpers import Linearizer
 from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.metrics import Measure, measure_func
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+    from synapse.storage.databases.main import DataStore
+
 logger = logging.getLogger(__name__)
 metrics_logger = logging.getLogger("synapse.state.metrics")
 
@@ -74,7 +79,7 @@ _NEXT_STATE_ID = 1
 POWER_KEY = (EventTypes.PowerLevels, "")
 
 
-def _gen_state_id():
+def _gen_state_id() -> str:
     global _NEXT_STATE_ID
     s = "X%d" % (_NEXT_STATE_ID,)
     _NEXT_STATE_ID += 1
@@ -109,7 +114,7 @@ class _StateCacheEntry:
         # `state_id` is either a state_group (and so an int) or a string. This
         # ensures we don't accidentally persist a state_id as a stateg_group
         if state_group:
-            self.state_id = state_group
+            self.state_id: Union[str, int] = state_group
         else:
             self.state_id = _gen_state_id()
 
@@ -122,7 +127,7 @@ class StateHandler:
     where necessary
     """
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.clock = hs.get_clock()
         self.store = hs.get_datastore()
         self.state_store = hs.get_storage().state
@@ -507,7 +512,7 @@ class StateResolutionHandler:
     be storage-independent.
     """
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.clock = hs.get_clock()
 
         self.resolve_linearizer = Linearizer(name="state_resolve_lock")
@@ -657,13 +662,15 @@ class StateResolutionHandler:
         finally:
             self._record_state_res_metrics(room_id, m.get_resource_usage())
 
-    def _record_state_res_metrics(self, room_id: str, rusage: ContextResourceUsage):
+    def _record_state_res_metrics(
+        self, room_id: str, rusage: ContextResourceUsage
+    ) -> None:
         room_metrics = self._state_res_metrics[room_id]
         room_metrics.cpu_time += rusage.ru_utime + rusage.ru_stime
         room_metrics.db_time += rusage.db_txn_duration_sec
         room_metrics.db_events += rusage.evt_db_fetch_count
 
-    def _report_metrics(self):
+    def _report_metrics(self) -> None:
         if not self._state_res_metrics:
             # no state res has happened since the last iteration: don't bother logging.
             return
@@ -773,16 +780,13 @@ def _make_state_cache_entry(
     )
 
 
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
 class StateResolutionStore:
     """Interface that allows state resolution algorithms to access the database
     in well defined way.
-
-    Args:
-        store (DataStore)
     """
 
-    store = attr.ib()
+    store: "DataStore"
 
     def get_events(
         self, event_ids: Iterable[str], allow_rejected: bool = False
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index e38461adbc..f839c0c24f 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -372,18 +372,23 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
             )
 
     async def store_state_group(
-        self, event_id, room_id, prev_group, delta_ids, current_state_ids
+        self,
+        event_id: str,
+        room_id: str,
+        prev_group: Optional[int],
+        delta_ids: Optional[StateMap[str]],
+        current_state_ids: StateMap[str],
     ) -> int:
         """Store a new set of state, returning a newly assigned state group.
 
         Args:
-            event_id (str): The event ID for which the state was calculated
-            room_id (str)
-            prev_group (int|None): A previous state group for the room, optional.
-            delta_ids (dict|None): The delta between state at `prev_group` and
+            event_id: The event ID for which the state was calculated
+            room_id
+            prev_group: A previous state group for the room, optional.
+            delta_ids: The delta between state at `prev_group` and
                 `current_state_ids`, if `prev_group` was given. Same format as
                 `current_state_ids`.
-            current_state_ids (dict): The state to store. Map of (type, state_key)
+            current_state_ids: The state to store. Map of (type, state_key)
                 to event_id.
 
         Returns:
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index f8fbba9d38..e5400d681a 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -570,8 +570,8 @@ class StateGroupStorage:
         event_id: str,
         room_id: str,
         prev_group: Optional[int],
-        delta_ids: Optional[dict],
-        current_state_ids: dict,
+        delta_ids: Optional[StateMap[str]],
+        current_state_ids: StateMap[str],
     ) -> int:
         """Store a new set of state, returning a newly assigned state group.