summary refs log tree commit diff
path: root/synapse/state
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2021-07-26 12:49:53 -0400
committerGitHub <noreply@github.com>2021-07-26 12:49:53 -0400
commitb7186c6e8ddacc328ae2c155162b36291a3c2b79 (patch)
tree2ade38cead9dc9a8cff31405961a55378aae78e5 /synapse/state
parentUpdate the MSC3083 support to verify if joins are from an authorized server. ... (diff)
downloadsynapse-b7186c6e8ddacc328ae2c155162b36291a3c2b79.tar.xz
Add type hints to state handler. (#10482)
Diffstat (limited to 'synapse/state')
-rw-r--r--synapse/state/__init__.py26
1 files changed, 15 insertions, 11 deletions
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 2e15471435..463ce58dae 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -16,6 +16,7 @@ import heapq
 import logging
 from collections import defaultdict, namedtuple
 from typing import (
+    TYPE_CHECKING,
     Any,
     Awaitable,
     Callable,
@@ -52,6 +53,10 @@ from synapse.util.async_helpers import Linearizer
 from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.metrics import Measure, measure_func
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+    from synapse.storage.databases.main import DataStore
+
 logger = logging.getLogger(__name__)
 metrics_logger = logging.getLogger("synapse.state.metrics")
 
@@ -74,7 +79,7 @@ _NEXT_STATE_ID = 1
 POWER_KEY = (EventTypes.PowerLevels, "")
 
 
-def _gen_state_id():
+def _gen_state_id() -> str:
     global _NEXT_STATE_ID
     s = "X%d" % (_NEXT_STATE_ID,)
     _NEXT_STATE_ID += 1
@@ -109,7 +114,7 @@ class _StateCacheEntry:
         # `state_id` is either a state_group (and so an int) or a string. This
         # ensures we don't accidentally persist a state_id as a stateg_group
         if state_group:
-            self.state_id = state_group
+            self.state_id: Union[str, int] = state_group
         else:
             self.state_id = _gen_state_id()
 
@@ -122,7 +127,7 @@ class StateHandler:
     where necessary
     """
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.clock = hs.get_clock()
         self.store = hs.get_datastore()
         self.state_store = hs.get_storage().state
@@ -507,7 +512,7 @@ class StateResolutionHandler:
     be storage-independent.
     """
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.clock = hs.get_clock()
 
         self.resolve_linearizer = Linearizer(name="state_resolve_lock")
@@ -657,13 +662,15 @@ class StateResolutionHandler:
         finally:
             self._record_state_res_metrics(room_id, m.get_resource_usage())
 
-    def _record_state_res_metrics(self, room_id: str, rusage: ContextResourceUsage):
+    def _record_state_res_metrics(
+        self, room_id: str, rusage: ContextResourceUsage
+    ) -> None:
         room_metrics = self._state_res_metrics[room_id]
         room_metrics.cpu_time += rusage.ru_utime + rusage.ru_stime
         room_metrics.db_time += rusage.db_txn_duration_sec
         room_metrics.db_events += rusage.evt_db_fetch_count
 
-    def _report_metrics(self):
+    def _report_metrics(self) -> None:
         if not self._state_res_metrics:
             # no state res has happened since the last iteration: don't bother logging.
             return
@@ -773,16 +780,13 @@ def _make_state_cache_entry(
     )
 
 
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
 class StateResolutionStore:
     """Interface that allows state resolution algorithms to access the database
     in well defined way.
-
-    Args:
-        store (DataStore)
     """
 
-    store = attr.ib()
+    store: "DataStore"
 
     def get_events(
         self, event_ids: Iterable[str], allow_rejected: bool = False