summary refs log tree commit diff
path: root/synapse/storage/state.py
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2020-01-16 13:31:22 +0000
committerGitHub <noreply@github.com>2020-01-16 13:31:22 +0000
commitd386f2f339c839ff6ec8d656492dd635dc26f811 (patch)
treef127b8130fb9d778d863300e97c3ac55945e2e61 /synapse/storage/state.py
parentAdd tips for the changelog to the pull request template (#6663) (diff)
downloadsynapse-d386f2f339c839ff6ec8d656492dd635dc26f811.tar.xz
Add StateMap type alias (#6715)
Diffstat (limited to 'synapse/storage/state.py')
-rw-r--r--synapse/storage/state.py35
1 files changed, 21 insertions, 14 deletions
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index cbeb586014..c522c80922 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 import logging
+from typing import Iterable, List, TypeVar
 
 from six import iteritems, itervalues
 
@@ -22,9 +23,13 @@ import attr
 from twisted.internet import defer
 
 from synapse.api.constants import EventTypes
+from synapse.types import StateMap
 
 logger = logging.getLogger(__name__)
 
+# Used for generic functions below
+T = TypeVar("T")
+
 
 @attr.s(slots=True)
 class StateFilter(object):
@@ -233,14 +238,14 @@ class StateFilter(object):
 
         return len(self.concrete_types())
 
-    def filter_state(self, state_dict):
+    def filter_state(self, state_dict: StateMap[T]) -> StateMap[T]:
         """Returns the state filtered with by this StateFilter
 
         Args:
-            state (dict[tuple[str, str], Any]): The state map to filter
+            state: The state map to filter
 
         Returns:
-            dict[tuple[str, str], Any]: The filtered state map
+            The filtered state map
         """
         if self.is_full():
             return dict(state_dict)
@@ -333,12 +338,12 @@ class StateGroupStorage(object):
     def __init__(self, hs, stores):
         self.stores = stores
 
-    def get_state_group_delta(self, state_group):
+    def get_state_group_delta(self, state_group: int):
         """Given a state group try to return a previous group and a delta between
         the old and the new.
 
         Returns:
-            Deferred[Tuple[Optional[int], Optional[list[dict[tuple[str, str], str]]]]]):
+            Deferred[Tuple[Optional[int], Optional[StateMap[str]]]]:
                 (prev_group, delta_ids)
         """
 
@@ -353,7 +358,7 @@ class StateGroupStorage(object):
             event_ids (iterable[str]): ids of the events
 
         Returns:
-            Deferred[dict[int, dict[tuple[str, str], str]]]:
+            Deferred[dict[int, StateMap[str]]]:
                 dict of state_group_id -> (dict of (type, state_key) -> event id)
         """
         if not event_ids:
@@ -410,17 +415,18 @@ class StateGroupStorage(object):
             for group, event_id_map in iteritems(group_to_ids)
         }
 
-    def _get_state_groups_from_groups(self, groups, state_filter):
+    def _get_state_groups_from_groups(
+        self, groups: List[int], state_filter: StateFilter
+    ):
         """Returns the state groups for a given set of groups, filtering on
         types of state events.
 
         Args:
-            groups(list[int]): list of state group IDs to query
-            state_filter (StateFilter): The state filter used to fetch state
+            groups: list of state group IDs to query
+            state_filter: The state filter used to fetch state
                 from the database.
         Returns:
-            Deferred[dict[int, dict[tuple[str, str], str]]]:
-                dict of state_group_id -> (dict of (type, state_key) -> event id)
+            Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
         """
 
         return self.stores.state._get_state_groups_from_groups(groups, state_filter)
@@ -519,7 +525,9 @@ class StateGroupStorage(object):
         state_map = yield self.get_state_ids_for_events([event_id], state_filter)
         return state_map[event_id]
 
-    def _get_state_for_groups(self, groups, state_filter=StateFilter.all()):
+    def _get_state_for_groups(
+        self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
+    ):
         """Gets the state at each of a list of state groups, optionally
         filtering by type/state_key
 
@@ -529,8 +537,7 @@ class StateGroupStorage(object):
             state_filter (StateFilter): The state filter used to fetch state
                 from the database.
         Returns:
-            Deferred[dict[int, dict[tuple[str, str], str]]]:
-                dict of state_group_id -> (dict of (type, state_key) -> event id)
+            Deferred[dict[int, StateMap[str]]]: Dict of state group to state map.
         """
         return self.stores.state._get_state_for_groups(groups, state_filter)