diff options
Diffstat (limited to 'synapse/storage/state.py')
-rw-r--r-- | synapse/storage/state.py | 82 |
1 files changed, 71 insertions, 11 deletions
diff --git a/synapse/storage/state.py b/synapse/storage/state.py index d1d5859214..ab630953ac 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -1,4 +1,5 @@ # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2022 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,6 +16,7 @@ import logging from typing import ( TYPE_CHECKING, Awaitable, + Callable, Collection, Dict, Iterable, @@ -62,7 +64,7 @@ class StateFilter: types: "frozendict[str, Optional[FrozenSet[str]]]" include_others: bool = False - def __attrs_post_init__(self): + def __attrs_post_init__(self) -> None: # If `include_others` is set we canonicalise the filter by removing # wildcards from the types dictionary if self.include_others: @@ -138,7 +140,9 @@ class StateFilter: ) @staticmethod - def freeze(types: Mapping[str, Optional[Collection[str]]], include_others: bool): + def freeze( + types: Mapping[str, Optional[Collection[str]]], include_others: bool + ) -> "StateFilter": """ Returns a (frozen) StateFilter with the same contents as the parameters specified here, which can be made of mutable types. @@ -530,6 +534,44 @@ class StateFilter: new_all, new_excludes, new_wildcards, new_concrete_keys ) + def must_await_full_state(self, is_mine_id: Callable[[str], bool]) -> bool: + """Check if we need to wait for full state to complete to calculate this state + + If we have a state filter which is completely satisfied even with partial + state, then we don't need to await_full_state before we can return it. + + Args: + is_mine_id: a callable which confirms if a given state_key matches a mxid + of a local user + """ + + # TODO(faster_joins): it's not entirely clear that this is safe. In particular, + # there may be circumstances in which we return a piece of state that, once we + # resync the state, we discover is invalid. For example: if it turns out that + # the sender of a piece of state wasn't actually in the room, then clearly that + # state shouldn't have been returned. + # We should at least add some tests around this to see what happens. + + # if we haven't requested membership events, then it depends on the value of + # 'include_others' + if EventTypes.Member not in self.types: + return self.include_others + + # if we're looking for *all* membership events, then we have to wait + member_state_keys = self.types[EventTypes.Member] + if member_state_keys is None: + return True + + # otherwise, consider whose membership we are looking for. If it's entirely + # local users, then we don't need to wait. + for state_key in member_state_keys: + if not is_mine_id(state_key): + # remote user + return True + + # local users only + return False + _ALL_STATE_FILTER = StateFilter(types=frozendict(), include_others=True) _ALL_NON_MEMBER_STATE_FILTER = StateFilter( @@ -542,6 +584,7 @@ class StateGroupStorage: """High level interface to fetching state for event.""" def __init__(self, hs: "HomeServer", stores: "Databases"): + self._is_mine_id = hs.is_mine_id self.stores = stores self._partial_state_events_tracker = PartialStateEventsTracker(stores.main) @@ -584,23 +627,26 @@ class StateGroupStorage: if not event_ids: return {} - event_to_groups = await self._get_state_group_for_events(event_ids) + event_to_groups = await self.get_state_group_for_events(event_ids) groups = set(event_to_groups.values()) group_to_state = await self.stores.state._get_state_for_groups(groups) return group_to_state - async def get_state_ids_for_group(self, state_group: int) -> StateMap[str]: + async def get_state_ids_for_group( + self, state_group: int, state_filter: Optional[StateFilter] = None + ) -> StateMap[str]: """Get the event IDs of all the state in the given state group Args: state_group: A state group for which we want to get the state IDs. + state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules Returns: Resolves to a map of (type, state_key) -> event_id """ - group_to_state = await self._get_state_for_groups((state_group,)) + group_to_state = await self.get_state_for_groups((state_group,), state_filter) return group_to_state[state_group] @@ -673,7 +719,13 @@ class StateGroupStorage: RuntimeError if we don't have a state group for one or more of the events (ie they are outliers or unknown) """ - event_to_groups = await self._get_state_group_for_events(event_ids) + await_full_state = True + if state_filter and not state_filter.must_await_full_state(self._is_mine_id): + await_full_state = False + + event_to_groups = await self.get_state_group_for_events( + event_ids, await_full_state=await_full_state + ) groups = set(event_to_groups.values()) group_to_state = await self.stores.state._get_state_for_groups( @@ -697,7 +749,9 @@ class StateGroupStorage: return {event: event_to_state[event] for event in event_ids} async def get_state_ids_for_events( - self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None + self, + event_ids: Collection[str], + state_filter: Optional[StateFilter] = None, ) -> Dict[str, StateMap[str]]: """ Get the state dicts corresponding to a list of events, containing the event_ids @@ -714,7 +768,13 @@ class StateGroupStorage: RuntimeError if we don't have a state group for one or more of the events (ie they are outliers or unknown) """ - event_to_groups = await self._get_state_group_for_events(event_ids) + await_full_state = True + if state_filter and not state_filter.must_await_full_state(self._is_mine_id): + await_full_state = False + + event_to_groups = await self.get_state_group_for_events( + event_ids, await_full_state=await_full_state + ) groups = set(event_to_groups.values()) group_to_state = await self.stores.state._get_state_for_groups( @@ -772,7 +832,7 @@ class StateGroupStorage: ) return state_map[event_id] - def _get_state_for_groups( + def get_state_for_groups( self, groups: Iterable[int], state_filter: Optional[StateFilter] = None ) -> Awaitable[Dict[int, MutableStateMap[str]]]: """Gets the state at each of a list of state groups, optionally @@ -790,7 +850,7 @@ class StateGroupStorage: groups, state_filter or StateFilter.all() ) - async def _get_state_group_for_events( + async def get_state_group_for_events( self, event_ids: Collection[str], await_full_state: bool = True, @@ -800,7 +860,7 @@ class StateGroupStorage: Args: event_ids: events to get state groups for await_full_state: if true, will block if we do not yet have complete - state at this event. + state at these events. """ if await_full_state: await self._partial_state_events_tracker.await_full_state(event_ids) |