diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py
index d07440e3ed..33bebd1c48 100644
--- a/synapse/storage/data_stores/main/state.py
+++ b/synapse/storage/data_stores/main/state.py
@@ -165,19 +165,20 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
)
# FIXME: how should this be cached?
- def get_filtered_current_state_ids(self, room_id, state_filter=StateFilter.all()):
+ def get_filtered_current_state_ids(
+ self, room_id: str, state_filter: StateFilter = StateFilter.all()
+ ):
"""Get the current state event of a given type for a room based on the
current_state_events table. This may not be as up-to-date as the result
of doing a fresh state resolution as per state_handler.get_current_state
Args:
- room_id (str)
- state_filter (StateFilter): The state filter used to fetch state
+ room_id
+ state_filter: The state filter used to fetch state
from the database.
Returns:
- Deferred[dict[tuple[str, str], str]]: Map from type/state_key to
- event ID.
+ defer.Deferred[StateMap[str]]: Map from type/state_key to event ID.
"""
where_clause, where_args = state_filter.make_sql_filter_clause()
diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/data_stores/state/store.py
index d53695f238..c4ee9b7ccb 100644
--- a/synapse/storage/data_stores/state/store.py
+++ b/synapse/storage/data_stores/state/store.py
@@ -15,6 +15,7 @@
import logging
from collections import namedtuple
+from typing import Dict, Iterable, List, Set, Tuple
from six import iteritems
from six.moves import range
@@ -26,6 +27,7 @@ from synapse.storage._base import SQLBaseStore
from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateStore
from synapse.storage.database import Database
from synapse.storage.state import StateFilter
+from synapse.types import StateMap
from synapse.util.caches import get_cache_factor_for
from synapse.util.caches.descriptors import cached
from synapse.util.caches.dictionary_cache import DictionaryCache
@@ -133,17 +135,18 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
)
@defer.inlineCallbacks
- def _get_state_groups_from_groups(self, groups, state_filter):
- """Returns the state groups for a given set of groups, filtering on
- types of state events.
+ def _get_state_groups_from_groups(
+ self, groups: List[int], state_filter: StateFilter
+ ):
+ """Returns the state groups for a given set of groups from the
+ database, filtering on types of state events.
Args:
- groups(list[int]): list of state group IDs to query
- state_filter (StateFilter): The state filter used to fetch state
+ groups: list of state group IDs to query
+ state_filter: The state filter used to fetch state
from the database.
Returns:
- Deferred[dict[int, dict[tuple[str, str], str]]]:
- dict of state_group_id -> (dict of (type, state_key) -> event id)
+ Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
"""
results = {}
@@ -199,18 +202,19 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return state_filter.filter_state(state_dict_ids), not missing_types
@defer.inlineCallbacks
- def _get_state_for_groups(self, groups, state_filter=StateFilter.all()):
+ def _get_state_for_groups(
+ self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
+ ):
"""Gets the state at each of a list of state groups, optionally
filtering by type/state_key
Args:
- groups (iterable[int]): list of state groups for which we want
+ groups: list of state groups for which we want
to get the state.
- state_filter (StateFilter): The state filter used to fetch state
+ state_filter: The state filter used to fetch state
from the database.
Returns:
- Deferred[dict[int, dict[tuple[str, str], str]]]:
- dict of state_group_id -> (dict of (type, state_key) -> event id)
+ Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
"""
member_filter, non_member_filter = state_filter.get_member_split()
@@ -268,24 +272,24 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return state
- def _get_state_for_groups_using_cache(self, groups, cache, state_filter):
+ def _get_state_for_groups_using_cache(
+ self, groups: Iterable[int], cache: DictionaryCache, state_filter: StateFilter
+ ) -> Tuple[Dict[int, StateMap[str]], Set[int]]:
"""Gets the state at each of a list of state groups, optionally
filtering by type/state_key, querying from a specific cache.
Args:
- groups (iterable[int]): list of state groups for which we want
- to get the state.
- cache (DictionaryCache): the cache of group ids to state dicts which
- we will pass through - either the normal state cache or the specific
- members state cache.
- state_filter (StateFilter): The state filter used to fetch state
- from the database.
+ groups: list of state groups for which we want to get the state.
+ cache: the cache of group ids to state dicts which
+ we will pass through - either the normal state cache or the
+ specific members state cache.
+ state_filter: The state filter used to fetch state from the
+ database.
Returns:
- tuple[dict[int, dict[tuple[str, str], str]], set[int]]: Tuple of
- dict of state_group_id -> (dict of (type, state_key) -> event id)
- of entries in the cache, and the state group ids either missing
- from the cache or incomplete.
+ Tuple of dict of state_group_id to state map of entries in the
+ cache, and the state group ids either missing from the cache or
+ incomplete.
"""
results = {}
incomplete_groups = set()
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index cbeb586014..c522c80922 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -14,6 +14,7 @@
# limitations under the License.
import logging
+from typing import Iterable, List, TypeVar
from six import iteritems, itervalues
@@ -22,9 +23,13 @@ import attr
from twisted.internet import defer
from synapse.api.constants import EventTypes
+from synapse.types import StateMap
logger = logging.getLogger(__name__)
+# Used for generic functions below
+T = TypeVar("T")
+
@attr.s(slots=True)
class StateFilter(object):
@@ -233,14 +238,14 @@ class StateFilter(object):
return len(self.concrete_types())
- def filter_state(self, state_dict):
+ def filter_state(self, state_dict: StateMap[T]) -> StateMap[T]:
"""Returns the state filtered with by this StateFilter
Args:
- state (dict[tuple[str, str], Any]): The state map to filter
+ state: The state map to filter
Returns:
- dict[tuple[str, str], Any]: The filtered state map
+ The filtered state map
"""
if self.is_full():
return dict(state_dict)
@@ -333,12 +338,12 @@ class StateGroupStorage(object):
def __init__(self, hs, stores):
self.stores = stores
- def get_state_group_delta(self, state_group):
+ def get_state_group_delta(self, state_group: int):
"""Given a state group try to return a previous group and a delta between
the old and the new.
Returns:
- Deferred[Tuple[Optional[int], Optional[list[dict[tuple[str, str], str]]]]]):
+ Deferred[Tuple[Optional[int], Optional[StateMap[str]]]]:
(prev_group, delta_ids)
"""
@@ -353,7 +358,7 @@ class StateGroupStorage(object):
event_ids (iterable[str]): ids of the events
Returns:
- Deferred[dict[int, dict[tuple[str, str], str]]]:
+ Deferred[dict[int, StateMap[str]]]:
dict of state_group_id -> (dict of (type, state_key) -> event id)
"""
if not event_ids:
@@ -410,17 +415,18 @@ class StateGroupStorage(object):
for group, event_id_map in iteritems(group_to_ids)
}
- def _get_state_groups_from_groups(self, groups, state_filter):
+ def _get_state_groups_from_groups(
+ self, groups: List[int], state_filter: StateFilter
+ ):
"""Returns the state groups for a given set of groups, filtering on
types of state events.
Args:
- groups(list[int]): list of state group IDs to query
- state_filter (StateFilter): The state filter used to fetch state
+ groups: list of state group IDs to query
+ state_filter: The state filter used to fetch state
from the database.
Returns:
- Deferred[dict[int, dict[tuple[str, str], str]]]:
- dict of state_group_id -> (dict of (type, state_key) -> event id)
+ Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
"""
return self.stores.state._get_state_groups_from_groups(groups, state_filter)
@@ -519,7 +525,9 @@ class StateGroupStorage(object):
state_map = yield self.get_state_ids_for_events([event_id], state_filter)
return state_map[event_id]
- def _get_state_for_groups(self, groups, state_filter=StateFilter.all()):
+ def _get_state_for_groups(
+ self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
+ ):
"""Gets the state at each of a list of state groups, optionally
filtering by type/state_key
@@ -529,8 +537,7 @@ class StateGroupStorage(object):
state_filter (StateFilter): The state filter used to fetch state
from the database.
Returns:
- Deferred[dict[int, dict[tuple[str, str], str]]]:
- dict of state_group_id -> (dict of (type, state_key) -> event id)
+ Deferred[dict[int, StateMap[str]]]: Dict of state group to state map.
"""
return self.stores.state._get_state_for_groups(groups, state_filter)
|