diff options
Diffstat (limited to 'synapse/storage/state.py')
-rw-r--r-- | synapse/storage/state.py | 216 |
1 files changed, 114 insertions, 102 deletions
diff --git a/synapse/storage/state.py b/synapse/storage/state.py index dc568476f4..534883361f 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -14,13 +14,12 @@ # limitations under the License. import logging -from typing import Iterable, List, TypeVar +from typing import Awaitable, Dict, Iterable, List, Optional, Set, Tuple, TypeVar import attr -from twisted.internet import defer - from synapse.api.constants import EventTypes +from synapse.events import EventBase from synapse.types import StateMap logger = logging.getLogger(__name__) @@ -34,16 +33,16 @@ class StateFilter(object): """A filter used when querying for state. Attributes: - types (dict[str, set[str]|None]): Map from type to set of state keys (or - None). This specifies which state_keys for the given type to fetch - from the DB. If None then all events with that type are fetched. If - the set is empty then no events with that type are fetched. - include_others (bool): Whether to fetch events with types that do not + types: Map from type to set of state keys (or None). This specifies + which state_keys for the given type to fetch from the DB. If None + then all events with that type are fetched. If the set is empty + then no events with that type are fetched. + include_others: Whether to fetch events with types that do not appear in `types`. """ - types = attr.ib() - include_others = attr.ib(default=False) + types = attr.ib(type=Dict[str, Optional[Set[str]]]) + include_others = attr.ib(default=False, type=bool) def __attrs_post_init__(self): # If `include_others` is set we canonicalise the filter by removing @@ -52,36 +51,35 @@ class StateFilter(object): self.types = {k: v for k, v in self.types.items() if v is not None} @staticmethod - def all(): + def all() -> "StateFilter": """Creates a filter that fetches everything. Returns: - StateFilter + The new state filter. """ return StateFilter(types={}, include_others=True) @staticmethod - def none(): + def none() -> "StateFilter": """Creates a filter that fetches nothing. Returns: - StateFilter + The new state filter. """ return StateFilter(types={}, include_others=False) @staticmethod - def from_types(types): + def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter": """Creates a filter that only fetches the given types Args: - types (Iterable[tuple[str, str|None]]): A list of type and state - keys to fetch. A state_key of None fetches everything for - that type + types: A list of type and state keys to fetch. A state_key of None + fetches everything for that type Returns: - StateFilter + The new state filter. """ - type_dict = {} + type_dict = {} # type: Dict[str, Optional[Set[str]]] for typ, s in types: if typ in type_dict: if type_dict[typ] is None: @@ -91,24 +89,24 @@ class StateFilter(object): type_dict[typ] = None continue - type_dict.setdefault(typ, set()).add(s) + type_dict.setdefault(typ, set()).add(s) # type: ignore return StateFilter(types=type_dict) @staticmethod - def from_lazy_load_member_list(members): + def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter": """Creates a filter that returns all non-member events, plus the member events for the given users Args: - members (iterable[str]): Set of user IDs + members: Set of user IDs Returns: - StateFilter + The new state filter """ return StateFilter(types={EventTypes.Member: set(members)}, include_others=True) - def return_expanded(self): + def return_expanded(self) -> "StateFilter": """Creates a new StateFilter where type wild cards have been removed (except for memberships). The returned filter is a superset of the current one, i.e. anything that passes the current filter will pass @@ -130,7 +128,7 @@ class StateFilter(object): return all non-member events Returns: - StateFilter + The new state filter. """ if self.is_full(): @@ -167,7 +165,7 @@ class StateFilter(object): include_others=True, ) - def make_sql_filter_clause(self): + def make_sql_filter_clause(self) -> Tuple[str, List[str]]: """Converts the filter to an SQL clause. For example: @@ -179,13 +177,12 @@ class StateFilter(object): Returns: - tuple[str, list]: The SQL string (may be empty) and arguments. An - empty SQL string is returned when the filter matches everything - (i.e. is "full"). + The SQL string (may be empty) and arguments. An empty SQL string is + returned when the filter matches everything (i.e. is "full"). """ where_clause = "" - where_args = [] + where_args = [] # type: List[str] if self.is_full(): return where_clause, where_args @@ -221,7 +218,7 @@ class StateFilter(object): return where_clause, where_args - def max_entries_returned(self): + def max_entries_returned(self) -> Optional[int]: """Returns the maximum number of entries this filter will return if known, otherwise returns None. @@ -260,33 +257,33 @@ class StateFilter(object): return filtered_state - def is_full(self): + def is_full(self) -> bool: """Whether this filter fetches everything or not Returns: - bool + True if the filter fetches everything. """ return self.include_others and not self.types - def has_wildcards(self): + def has_wildcards(self) -> bool: """Whether the filter includes wildcards or is attempting to fetch specific state. Returns: - bool + True if the filter includes wildcards. """ return self.include_others or any( state_keys is None for state_keys in self.types.values() ) - def concrete_types(self): + def concrete_types(self) -> List[Tuple[str, str]]: """Returns a list of concrete type/state_keys (i.e. not None) that will be fetched. This will be a complete list if `has_wildcards` returns False, but otherwise will be a subset (or even empty). Returns: - list[tuple[str,str]] + A list of type/state_keys tuples. """ return [ (t, s) @@ -295,7 +292,7 @@ class StateFilter(object): for s in state_keys ] - def get_member_split(self): + def get_member_split(self) -> Tuple["StateFilter", "StateFilter"]: """Return the filter split into two: one which assumes it's exclusively matching against member state, and one which assumes it's matching against non member state. @@ -307,7 +304,7 @@ class StateFilter(object): state caches). Returns: - tuple[StateFilter, StateFilter]: The member and non member filters + The member and non member filters """ if EventTypes.Member in self.types: @@ -340,6 +337,9 @@ class StateGroupStorage(object): """Given a state group try to return a previous group and a delta between the old and the new. + Args: + state_group: The state group used to retrieve state deltas. + Returns: Deferred[Tuple[Optional[int], Optional[StateMap[str]]]]: (prev_group, delta_ids) @@ -347,55 +347,59 @@ class StateGroupStorage(object): return self.stores.state.get_state_group_delta(state_group) - @defer.inlineCallbacks - def get_state_groups_ids(self, _room_id, event_ids): + async def get_state_groups_ids( + self, _room_id: str, event_ids: Iterable[str] + ) -> Dict[int, StateMap[str]]: """Get the event IDs of all the state for the state groups for the given events Args: - _room_id (str): id of the room for these events - event_ids (iterable[str]): ids of the events + _room_id: id of the room for these events + event_ids: ids of the events Returns: - Deferred[dict[int, StateMap[str]]]: - dict of state_group_id -> (dict of (type, state_key) -> event id) + dict of state_group_id -> (dict of (type, state_key) -> event id) """ if not event_ids: return {} - event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) + event_to_groups = await self.stores.main._get_state_group_for_events(event_ids) groups = set(event_to_groups.values()) - group_to_state = yield self.stores.state._get_state_for_groups(groups) + group_to_state = await self.stores.state._get_state_for_groups(groups) return group_to_state - @defer.inlineCallbacks - def get_state_ids_for_group(self, state_group): + async def get_state_ids_for_group(self, state_group: int) -> StateMap[str]: """Get the event IDs of all the state in the given state group Args: - state_group (int) + state_group: A state group for which we want to get the state IDs. Returns: - Deferred[dict]: Resolves to a map of (type, state_key) -> event_id + Resolves to a map of (type, state_key) -> event_id """ - group_to_state = yield self._get_state_for_groups((state_group,)) + group_to_state = await self._get_state_for_groups((state_group,)) return group_to_state[state_group] - @defer.inlineCallbacks - def get_state_groups(self, room_id, event_ids): + async def get_state_groups( + self, room_id: str, event_ids: Iterable[str] + ) -> Dict[int, List[EventBase]]: """ Get the state groups for the given list of event_ids + + Args: + room_id: ID of the room for these events. + event_ids: The event IDs to retrieve state for. + Returns: - Deferred[dict[int, list[EventBase]]]: - dict of state_group_id -> list of state events. + dict of state_group_id -> list of state events. """ if not event_ids: return {} - group_to_ids = yield self.get_state_groups_ids(room_id, event_ids) + group_to_ids = await self.get_state_groups_ids(room_id, event_ids) - state_event_map = yield self.stores.main.get_events( + state_event_map = await self.stores.main.get_events( [ ev_id for group_ids in group_to_ids.values() @@ -415,7 +419,7 @@ class StateGroupStorage(object): def _get_state_groups_from_groups( self, groups: List[int], state_filter: StateFilter - ): + ) -> Awaitable[Dict[int, StateMap[str]]]: """Returns the state groups for a given set of groups, filtering on types of state events. @@ -423,31 +427,34 @@ class StateGroupStorage(object): groups: list of state group IDs to query state_filter: The state filter used to fetch state from the database. + Returns: - Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map. + Dict of state group to state map. """ return self.stores.state._get_state_groups_from_groups(groups, state_filter) - @defer.inlineCallbacks - def get_state_for_events(self, event_ids, state_filter=StateFilter.all()): + async def get_state_for_events( + self, event_ids: List[str], state_filter: StateFilter = StateFilter.all() + ): """Given a list of event_ids and type tuples, return a list of state dicts for each event. + Args: - event_ids (list[string]) - state_filter (StateFilter): The state filter used to fetch state - from the database. + event_ids: The events to fetch the state of. + state_filter: The state filter used to fetch state. + Returns: - deferred: A dict of (event_id) -> (type, state_key) -> [state_events] + A dict of (event_id) -> (type, state_key) -> [state_events] """ - event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) + event_to_groups = await self.stores.main._get_state_group_for_events(event_ids) groups = set(event_to_groups.values()) - group_to_state = yield self.stores.state._get_state_for_groups( + group_to_state = await self.stores.state._get_state_for_groups( groups, state_filter ) - state_event_map = yield self.stores.main.get_events( + state_event_map = await self.stores.main.get_events( [ev_id for sd in group_to_state.values() for ev_id in sd.values()], get_prev_content=False, ) @@ -463,24 +470,24 @@ class StateGroupStorage(object): return {event: event_to_state[event] for event in event_ids} - @defer.inlineCallbacks - def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()): + async def get_state_ids_for_events( + self, event_ids: List[str], state_filter: StateFilter = StateFilter.all() + ): """ Get the state dicts corresponding to a list of events, containing the event_ids of the state events (as opposed to the events themselves) Args: - event_ids(list(str)): events whose state should be returned - state_filter (StateFilter): The state filter used to fetch state - from the database. + event_ids: events whose state should be returned + state_filter: The state filter used to fetch state from the database. Returns: - A deferred dict from event_id -> (type, state_key) -> event_id + A dict from event_id -> (type, state_key) -> event_id """ - event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) + event_to_groups = await self.stores.main._get_state_group_for_events(event_ids) groups = set(event_to_groups.values()) - group_to_state = yield self.stores.state._get_state_for_groups( + group_to_state = await self.stores.state._get_state_for_groups( groups, state_filter ) @@ -491,67 +498,72 @@ class StateGroupStorage(object): return {event: event_to_state[event] for event in event_ids} - @defer.inlineCallbacks - def get_state_for_event(self, event_id, state_filter=StateFilter.all()): + async def get_state_for_event( + self, event_id: str, state_filter: StateFilter = StateFilter.all() + ): """ Get the state dict corresponding to a particular event Args: - event_id(str): event whose state should be returned - state_filter (StateFilter): The state filter used to fetch state - from the database. + event_id: event whose state should be returned + state_filter: The state filter used to fetch state from the database. Returns: - A deferred dict from (type, state_key) -> state_event + A dict from (type, state_key) -> state_event """ - state_map = yield self.get_state_for_events([event_id], state_filter) + state_map = await self.get_state_for_events([event_id], state_filter) return state_map[event_id] - @defer.inlineCallbacks - def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()): + async def get_state_ids_for_event( + self, event_id: str, state_filter: StateFilter = StateFilter.all() + ): """ Get the state dict corresponding to a particular event Args: - event_id(str): event whose state should be returned - state_filter (StateFilter): The state filter used to fetch state - from the database. + event_id: event whose state should be returned + state_filter: The state filter used to fetch state from the database. Returns: A deferred dict from (type, state_key) -> state_event """ - state_map = yield self.get_state_ids_for_events([event_id], state_filter) + state_map = await self.get_state_ids_for_events([event_id], state_filter) return state_map[event_id] def _get_state_for_groups( self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all() - ): + ) -> Awaitable[Dict[int, StateMap[str]]]: """Gets the state at each of a list of state groups, optionally filtering by type/state_key Args: - groups (iterable[int]): list of state groups for which we want - to get the state. - state_filter (StateFilter): The state filter used to fetch state + groups: list of state groups for which we want to get the state. + state_filter: The state filter used to fetch state. from the database. + Returns: - Deferred[dict[int, StateMap[str]]]: Dict of state group to state map. + Dict of state group to state map. """ return self.stores.state._get_state_for_groups(groups, state_filter) def store_state_group( - self, event_id, room_id, prev_group, delta_ids, current_state_ids + self, + event_id: str, + room_id: str, + prev_group: Optional[int], + delta_ids: Optional[dict], + current_state_ids: dict, ): """Store a new set of state, returning a newly assigned state group. Args: - event_id (str): The event ID for which the state was calculated - room_id (str) - prev_group (int|None): A previous state group for the room, optional. - delta_ids (dict|None): The delta between state at `prev_group` and + event_id: The event ID for which the state was calculated. + room_id: ID of the room for which the state was calculated. + prev_group: A previous state group for the room, optional. + delta_ids: The delta between state at `prev_group` and `current_state_ids`, if `prev_group` was given. Same format as `current_state_ids`. - current_state_ids (dict): The state to store. Map of (type, state_key) + current_state_ids: The state to store. Map of (type, state_key) to event_id. Returns: |