diff options
Diffstat (limited to 'synapse/state')
-rw-r--r-- | synapse/state/__init__.py | 117 |
1 files changed, 67 insertions, 50 deletions
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 781d9f06da..9f0a36652c 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -31,7 +31,6 @@ from typing import ( Sequence, Set, Tuple, - Union, ) import attr @@ -47,6 +46,7 @@ from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServ from synapse.state import v1, v2 from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.roommember import ProfileInfo +from synapse.storage.state import StateFilter from synapse.types import StateMap from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache @@ -54,6 +54,7 @@ from synapse.util.metrics import Measure, measure_func if TYPE_CHECKING: from synapse.server import HomeServer + from synapse.storage.controllers import StateStorageController from synapse.storage.databases.main import DataStore logger = logging.getLogger(__name__) @@ -83,17 +84,20 @@ def _gen_state_id() -> str: class _StateCacheEntry: - __slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"] + __slots__ = ["state", "state_group", "prev_group", "delta_ids"] def __init__( self, - state: StateMap[str], + state: Optional[StateMap[str]], state_group: Optional[int], prev_group: Optional[int] = None, delta_ids: Optional[StateMap[str]] = None, ): + if state is None and state_group is None: + raise Exception("Either state or state group must be not None") + # A map from (type, state_key) to event_id. - self.state = frozendict(state) + self.state = frozendict(state) if state is not None else None # the ID of a state group if one and only one is involved. # otherwise, None otherwise? @@ -102,20 +106,30 @@ class _StateCacheEntry: self.prev_group = prev_group self.delta_ids = frozendict(delta_ids) if delta_ids is not None else None - # The `state_id` is a unique ID we generate that can be used as ID for - # this collection of state. Usually this would be the same as the - # state group, but on worker instances we can't generate a new state - # group each time we resolve state, so we generate a separate one that - # isn't persisted and is used solely for caches. - # `state_id` is either a state_group (and so an int) or a string. This - # ensures we don't accidentally persist a state_id as a stateg_group - if state_group: - self.state_id: Union[str, int] = state_group - else: - self.state_id = _gen_state_id() + async def get_state( + self, + state_storage: "StateStorageController", + state_filter: Optional["StateFilter"] = None, + ) -> StateMap[str]: + """Get the state map for this entry, either from the in-memory state or + looking up the state group in the DB. + """ + + if self.state is not None: + return self.state + + assert self.state_group is not None + + return await state_storage.get_state_ids_for_group( + self.state_group, state_filter + ) def __len__(self) -> int: - return len(self.state) + # The len should is used to estimate how large this cache entry is, for + # cache eviction purposes. This is why if `self.state` is None it's fine + # to return 1. + + return len(self.state) if self.state else 1 class StateHandler: @@ -153,7 +167,7 @@ class StateHandler: """ logger.debug("calling resolve_state_groups from get_current_state_ids") ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids) - return ret.state + return await ret.get_state(self._state_storage_controller, StateFilter.all()) async def get_current_users_in_room( self, room_id: str, latest_event_ids: List[str] @@ -177,7 +191,8 @@ class StateHandler: logger.debug("calling resolve_state_groups from get_current_users_in_room") entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids) - return await self.store.get_joined_users_from_state(room_id, entry) + state = await entry.get_state(self._state_storage_controller, StateFilter.all()) + return await self.store.get_joined_users_from_state(room_id, state, entry) async def get_hosts_in_room_at_events( self, room_id: str, event_ids: Collection[str] @@ -192,7 +207,8 @@ class StateHandler: The hosts in the room at the given events """ entry = await self.resolve_state_groups_for_events(room_id, event_ids) - return await self.store.get_joined_hosts(room_id, entry) + state = await entry.get_state(self._state_storage_controller, StateFilter.all()) + return await self.store.get_joined_hosts(room_id, state, entry) async def compute_event_context( self, @@ -227,10 +243,19 @@ class StateHandler: # if state_ids_before_event: # if we're given the state before the event, then we use that - state_group_before_event = None state_group_before_event_prev_group = None deltas_to_state_group_before_event = None - entry = None + + # .. though we need to get a state group for it. + state_group_before_event = ( + await self._state_storage_controller.store_state_group( + event.event_id, + event.room_id, + prev_group=None, + delta_ids=None, + current_state_ids=state_ids_before_event, + ) + ) else: # otherwise, we'll need to resolve the state across the prev_events. @@ -264,36 +289,27 @@ class StateHandler: await_full_state=False, ) - state_ids_before_event = entry.state - state_group_before_event = entry.state_group state_group_before_event_prev_group = entry.prev_group deltas_to_state_group_before_event = entry.delta_ids - # - # make sure that we have a state group at that point. If it's not a state event, - # that will be the state group for the new event. If it *is* a state event, - # it might get rejected (in which case we'll need to persist it with the - # previous state group) - # - - if not state_group_before_event: - state_group_before_event = ( - await self._state_storage_controller.store_state_group( - event.event_id, - event.room_id, - prev_group=state_group_before_event_prev_group, - delta_ids=deltas_to_state_group_before_event, - current_state_ids=state_ids_before_event, + # We make sure that we have a state group assigned to the state. + if entry.state_group is None: + state_ids_before_event = await entry.get_state( + self._state_storage_controller, StateFilter.all() + ) + state_group_before_event = ( + await self._state_storage_controller.store_state_group( + event.event_id, + event.room_id, + prev_group=state_group_before_event_prev_group, + delta_ids=deltas_to_state_group_before_event, + current_state_ids=state_ids_before_event, + ) ) - ) - - # Assign the new state group to the cached state entry. - # - # Note that this can race in that we could generate multiple state - # groups for the same state entry, but that is just inefficient - # rather than dangerous. - if entry and entry.state_group is None: entry.state_group = state_group_before_event + else: + state_group_before_event = entry.state_group + state_ids_before_event = None # # now if it's not a state event, we're done @@ -313,6 +329,10 @@ class StateHandler: # # otherwise, we'll need to create a new state group for after the event # + if state_ids_before_event is None: + state_ids_before_event = await entry.get_state( + self._state_storage_controller, StateFilter.all() + ) key = (event.type, event.state_key) if key in state_ids_before_event: @@ -372,9 +392,6 @@ class StateHandler: state_group_ids_set = set(state_group_ids) if len(state_group_ids_set) == 1: (state_group_id,) = state_group_ids_set - state = await self._state_storage_controller.get_state_for_groups( - state_group_ids_set - ) ( prev_group, delta_ids, @@ -382,7 +399,7 @@ class StateHandler: state_group_id ) return _StateCacheEntry( - state=state[state_group_id], + state=None, state_group=state_group_id, prev_group=prev_group, delta_ids=delta_ids, |