summary refs log tree commit diff
path: root/synapse/state/__init__.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/state/__init__.py')
-rw-r--r--synapse/state/__init__.py324
1 files changed, 215 insertions, 109 deletions
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py

index da25f20ae5..3787d35b24 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py
@@ -14,7 +14,7 @@ # limitations under the License. import heapq import logging -from collections import defaultdict +from collections import ChainMap, defaultdict from typing import ( TYPE_CHECKING, Any, @@ -24,14 +24,12 @@ from typing import ( DefaultDict, Dict, FrozenSet, - Iterable, List, Mapping, Optional, Sequence, Set, Tuple, - Union, ) import attr @@ -43,9 +41,10 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersio from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.logging.context import ContextResourceUsage +from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet from synapse.state import v1, v2 from synapse.storage.databases.main.events_worker import EventRedactBehaviour -from synapse.storage.roommember import ProfileInfo +from synapse.storage.state import StateFilter from synapse.types import StateMap from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache @@ -53,6 +52,7 @@ from synapse.util.metrics import Measure, measure_func if TYPE_CHECKING: from synapse.server import HomeServer + from synapse.storage.controllers import StateStorageController from synapse.storage.databases.main import DataStore logger = logging.getLogger(__name__) @@ -82,17 +82,26 @@ def _gen_state_id() -> str: class _StateCacheEntry: - __slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"] + __slots__ = ["_state", "state_group", "prev_group", "delta_ids"] def __init__( self, - state: StateMap[str], + state: Optional[StateMap[str]], state_group: Optional[int], prev_group: Optional[int] = None, delta_ids: Optional[StateMap[str]] = None, ): + if state is None and state_group is None and prev_group is None: + raise Exception("One of state, state_group or prev_group must be not None") + + if prev_group is not None and delta_ids is None: + raise Exception("If prev_group is set so must delta_ids") + # A map from (type, state_key) to event_id. - self.state = frozendict(state) + # + # This can be None if we have a `state_group` (as then we can fetch the + # state from the DB.) + self._state = frozendict(state) if state is not None else None # the ID of a state group if one and only one is involved. # otherwise, None otherwise? @@ -101,20 +110,60 @@ class _StateCacheEntry: self.prev_group = prev_group self.delta_ids = frozendict(delta_ids) if delta_ids is not None else None - # The `state_id` is a unique ID we generate that can be used as ID for - # this collection of state. Usually this would be the same as the - # state group, but on worker instances we can't generate a new state - # group each time we resolve state, so we generate a separate one that - # isn't persisted and is used solely for caches. - # `state_id` is either a state_group (and so an int) or a string. This - # ensures we don't accidentally persist a state_id as a stateg_group - if state_group: - self.state_id: Union[str, int] = state_group - else: - self.state_id = _gen_state_id() + async def get_state( + self, + state_storage: "StateStorageController", + state_filter: Optional["StateFilter"] = None, + ) -> StateMap[str]: + """Get the state map for this entry, either from the in-memory state or + looking up the state group in the DB. + """ + + if self._state is not None: + return self._state + + if self.state_group is not None: + return await state_storage.get_state_ids_for_group( + self.state_group, state_filter + ) + + assert self.prev_group is not None and self.delta_ids is not None + + prev_state = await state_storage.get_state_ids_for_group( + self.prev_group, state_filter + ) + + # ChainMap expects MutableMapping, but since we're using it immutably + # its safe to give it immutable maps. + return ChainMap(self.delta_ids, prev_state) # type: ignore[arg-type] + + def set_state_group(self, state_group: int) -> None: + """Update the state group assigned to this state (e.g. after we've + persisted it). + + Note: this will cause the cache entry to drop any stored state. + """ + + self.state_group = state_group + + # We clear out the state as we know longer need to explicitly keep it in + # the `state_cache` (as the store state group cache will do that). + self._state = None def __len__(self) -> int: - return len(self.state) + # The len should be used to estimate how large this cache entry is, for + # cache eviction purposes. This is why it's fine to return 1 if we're + # not storing any state. + + length = 0 + + if self._state: + length += len(self._state) + + if self.delta_ids: + length += len(self.delta_ids) + + return length or 1 # Make sure its not 0. class StateHandler: @@ -129,30 +178,42 @@ class StateHandler: self.hs = hs self._state_resolution_handler = hs.get_state_resolution_handler() self._storage_controllers = hs.get_storage_controllers() + self._events_shard_config = hs.config.worker.events_shard_config + self._instance_name = hs.get_instance_name() - async def get_current_state_ids( + self._update_current_state_client = ( + ReplicationUpdateCurrentStateRestServlet.make_client(hs) + ) + + async def compute_state_after_events( self, room_id: str, - latest_event_ids: Collection[str], + event_ids: Collection[str], + state_filter: Optional[StateFilter] = None, ) -> StateMap[str]: - """Get the current state, or the state at a set of events, for a room + """Fetch the state after each of the given event IDs. Resolve them and return. + + This is typically used where `event_ids` is a collection of forward extremities + in a room, intended to become the `prev_events` of a new event E. If so, the + return value of this function represents the state before E. Args: - room_id: - latest_event_ids: The forward extremities to resolve. + room_id: the room_id containing the given events. + event_ids: the events whose state should be fetched and resolved. Returns: - the state dict, mapping from (event_type, state_key) -> event_id + the state dict (a mapping from (event_type, state_key) -> event_id) which + holds the resolution of the states after the given event IDs. """ - logger.debug("calling resolve_state_groups from get_current_state_ids") - ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids) - return ret.state + logger.debug("calling resolve_state_groups from compute_state_after_events") + ret = await self.resolve_state_groups_for_events(room_id, event_ids) + return await ret.get_state(self._state_storage_controller, state_filter) - async def get_current_users_in_room( + async def get_current_user_ids_in_room( self, room_id: str, latest_event_ids: List[str] - ) -> Dict[str, ProfileInfo]: + ) -> Set[str]: """ - Get the users who are currently in a room. + Get the users IDs who are currently in a room. Note: This is much slower than using the equivalent method `DataStore.get_users_in_room` or `DataStore.get_users_in_room_with_profiles`, @@ -163,14 +224,15 @@ class StateHandler: room_id: The ID of the room. latest_event_ids: Precomputed list of latest event IDs. Will be computed if None. Returns: - Dictionary of user IDs to their profileinfo. + Set of user IDs in the room. """ assert latest_event_ids is not None - logger.debug("calling resolve_state_groups from get_current_users_in_room") + logger.debug("calling resolve_state_groups from get_current_user_ids_in_room") entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids) - return await self.store.get_joined_users_from_state(room_id, entry) + state = await entry.get_state(self._state_storage_controller, StateFilter.all()) + return await self.store.get_joined_user_ids_from_state(room_id, state) async def get_hosts_in_room_at_events( self, room_id: str, event_ids: Collection[str] @@ -185,13 +247,14 @@ class StateHandler: The hosts in the room at the given events """ entry = await self.resolve_state_groups_for_events(room_id, event_ids) - return await self.store.get_joined_hosts(room_id, entry) + state = await entry.get_state(self._state_storage_controller, StateFilter.all()) + return await self.store.get_joined_hosts(room_id, state, entry) async def compute_event_context( self, event: EventBase, state_ids_before_event: Optional[StateMap[str]] = None, - partial_state: bool = False, + partial_state: Optional[bool] = None, ) -> EventContext: """Build an EventContext structure for a non-outlier event. @@ -206,10 +269,18 @@ class StateHandler: it can't be calculated from existing events. This is normally only specified when receiving an event from federation where we don't have the prev events, e.g. when backfilling. - partial_state: True if `state_ids_before_event` is partial and omits - non-critical membership events + partial_state: + `True` if `state_ids_before_event` is partial and omits non-critical + membership events. + `False` if `state_ids_before_event` is the full state. + `None` when `state_ids_before_event` is not provided. In this case, the + flag will be calculated based on `event`'s prev events. Returns: The event context. + + Raises: + RuntimeError if `state_ids_before_event` is not provided and one or more + prev events are missing or outliers. """ assert not event.internal_metadata.is_outlier() @@ -220,17 +291,28 @@ class StateHandler: # if state_ids_before_event: # if we're given the state before the event, then we use that - state_group_before_event = None state_group_before_event_prev_group = None deltas_to_state_group_before_event = None - entry = None + # .. though we need to get a state group for it. + state_group_before_event = ( + await self._state_storage_controller.store_state_group( + event.event_id, + event.room_id, + prev_group=None, + delta_ids=None, + current_state_ids=state_ids_before_event, + ) + ) + + # the partial_state flag must be provided + assert partial_state is not None else: # otherwise, we'll need to resolve the state across the prev_events. # partial_state should not be set explicitly in this case: # we work it out dynamically - assert not partial_state + assert partial_state is None # if any of the prev-events have partial state, so do we. # (This is slightly racy - the prev-events might get fixed up before we use @@ -240,49 +322,49 @@ class StateHandler: incomplete_prev_events = await self.store.get_partial_state_events( prev_event_ids ) - if any(incomplete_prev_events.values()): + partial_state = any(incomplete_prev_events.values()) + if partial_state: logger.debug( "New/incoming event %s refers to prev_events %s with partial state", event.event_id, [k for (k, v) in incomplete_prev_events.items() if v], ) - partial_state = True logger.debug("calling resolve_state_groups from compute_event_context") + # we've already taken into account partial state, so no need to wait for + # complete state here. entry = await self.resolve_state_groups_for_events( - event.room_id, event.prev_event_ids() + event.room_id, + event.prev_event_ids(), + await_full_state=False, ) - state_ids_before_event = entry.state - state_group_before_event = entry.state_group state_group_before_event_prev_group = entry.prev_group deltas_to_state_group_before_event = entry.delta_ids + state_ids_before_event = None + + # We make sure that we have a state group assigned to the state. + if entry.state_group is None: + # store_state_group requires us to have either a previous state group + # (with deltas) or the complete state map. So, if we don't have a + # previous state group, load the complete state map now. + if state_group_before_event_prev_group is None: + state_ids_before_event = await entry.get_state( + self._state_storage_controller, StateFilter.all() + ) - # - # make sure that we have a state group at that point. If it's not a state event, - # that will be the state group for the new event. If it *is* a state event, - # it might get rejected (in which case we'll need to persist it with the - # previous state group) - # - - if not state_group_before_event: - state_group_before_event = ( - await self._state_storage_controller.store_state_group( - event.event_id, - event.room_id, - prev_group=state_group_before_event_prev_group, - delta_ids=deltas_to_state_group_before_event, - current_state_ids=state_ids_before_event, + state_group_before_event = ( + await self._state_storage_controller.store_state_group( + event.event_id, + event.room_id, + prev_group=state_group_before_event_prev_group, + delta_ids=deltas_to_state_group_before_event, + current_state_ids=state_ids_before_event, + ) ) - ) - - # Assign the new state group to the cached state entry. - # - # Note that this can race in that we could generate multiple state - # groups for the same state entry, but that is just inefficient - # rather than dangerous. - if entry and entry.state_group is None: - entry.state_group = state_group_before_event + entry.set_state_group(state_group_before_event) + else: + state_group_before_event = entry.state_group # # now if it's not a state event, we're done @@ -304,13 +386,18 @@ class StateHandler: # key = (event.type, event.state_key) - if key in state_ids_before_event: - replaces = state_ids_before_event[key] - if replaces != event.event_id: - event.unsigned["replaces_state"] = replaces - state_ids_after_event = dict(state_ids_before_event) - state_ids_after_event[key] = event.event_id + if state_ids_before_event is not None: + replaces = state_ids_before_event.get(key) + else: + replaces_state_map = await entry.get_state( + self._state_storage_controller, StateFilter.from_types([key]) + ) + replaces = replaces_state_map.get(key) + + if replaces and replaces != event.event_id: + event.unsigned["replaces_state"] = replaces + delta_ids = {key: event.event_id} state_group_after_event = ( @@ -319,7 +406,7 @@ class StateHandler: event.room_id, prev_group=state_group_before_event, delta_ids=delta_ids, - current_state_ids=state_ids_after_event, + current_state_ids=None, ) ) @@ -335,7 +422,7 @@ class StateHandler: @measure_func() async def resolve_state_groups_for_events( - self, room_id: str, event_ids: Collection[str] + self, room_id: str, event_ids: Collection[str], await_full_state: bool = True ) -> _StateCacheEntry: """Given a list of event_ids this method fetches the state at each event, resolves conflicts between them and returns them. @@ -343,14 +430,20 @@ class StateHandler: Args: room_id event_ids + await_full_state: if true, will block if we do not yet have complete + state at these events. Returns: The resolved state + + Raises: + RuntimeError if we don't have a state group for one or more of the events + (ie. they are outliers or unknown) """ logger.debug("resolve_state_groups event_ids %s", event_ids) state_groups = await self._state_storage_controller.get_state_group_for_events( - event_ids + event_ids, await_full_state=await_full_state ) state_group_ids = state_groups.values() @@ -359,9 +452,6 @@ class StateHandler: state_group_ids_set = set(state_group_ids) if len(state_group_ids_set) == 1: (state_group_id,) = state_group_ids_set - state = await self._state_storage_controller.get_state_for_groups( - state_group_ids_set - ) ( prev_group, delta_ids, @@ -369,7 +459,7 @@ class StateHandler: state_group_id ) return _StateCacheEntry( - state=state[state_group_id], + state=None, state_group=state_group_id, prev_group=prev_group, delta_ids=delta_ids, @@ -392,30 +482,23 @@ class StateHandler: ) return result - async def resolve_events( - self, - room_version: str, - state_sets: Collection[Iterable[EventBase]], - event: EventBase, - ) -> StateMap[EventBase]: - logger.info( - "Resolving state for %s with %d groups", event.room_id, len(state_sets) - ) - state_set_ids = [ - {(ev.type, ev.state_key): ev.event_id for ev in st} for st in state_sets - ] - - state_map = {ev.event_id: ev for st in state_sets for ev in st} + async def update_current_state(self, room_id: str) -> None: + """Recalculates the current state for a room, and persists it. - new_state = await self._state_resolution_handler.resolve_events_with_store( - event.room_id, - room_version, - state_set_ids, - event_map=state_map, - state_res_store=StateResolutionStore(self.store), - ) + Raises: + SynapseError(502): if all attempts to connect to the event persister worker + fail + """ + writer_instance = self._events_shard_config.get_instance(room_id) + if writer_instance != self._instance_name: + await self._update_current_state_client( + instance_name=writer_instance, + room_id=room_id, + ) + return - return {key: state_map[ev_id] for key, ev_id in new_state.items()} + assert self._storage_controllers.persistence is not None + await self._storage_controllers.persistence.update_current_state(room_id) @attr.s(slots=True, auto_attribs=True) @@ -444,6 +527,15 @@ _biggest_room_by_db_counter = Counter( "expensive room for state resolution", ) +_cpu_times = Histogram( + "synapse_state_res_cpu_for_all_rooms_seconds", + "CPU time (utime+stime) spent computing a single state resolution", +) +_db_times = Histogram( + "synapse_state_res_db_for_all_rooms_seconds", + "Database time spent computing a single state resolution", +) + class StateResolutionHandler: """Responsible for doing state conflict resolution. @@ -609,6 +701,9 @@ class StateResolutionHandler: room_metrics.db_time += rusage.db_txn_duration_sec room_metrics.db_events += rusage.evt_db_fetch_count + _cpu_times.observe(rusage.ru_utime + rusage.ru_stime) + _db_times.observe(rusage.db_txn_duration_sec) + def _report_metrics(self) -> None: if not self._state_res_metrics: # no state res has happened since the last iteration: don't bother logging. @@ -698,7 +793,7 @@ def _make_state_cache_entry( old_state_event_ids = set(state.values()) if new_state_event_ids == old_state_event_ids: # got an exact match. - return _StateCacheEntry(state=new_state, state_group=sg) + return _StateCacheEntry(state=None, state_group=sg) # TODO: We want to create a state group for this set of events, to # increase cache hits, but we need to make sure that it doesn't @@ -709,14 +804,25 @@ def _make_state_cache_entry( delta_ids: Optional[StateMap[str]] = None for old_group, old_state in state_groups_ids.items(): + if old_state.keys() - new_state.keys(): + # Currently we don't support deltas that remove keys from the state + # map, so we have to ignore this group as a candidate to base the + # new group on. + continue + n_delta_ids = {k: v for k, v in new_state.items() if old_state.get(k) != v} if not delta_ids or len(n_delta_ids) < len(delta_ids): prev_group = old_group delta_ids = n_delta_ids - return _StateCacheEntry( - state=new_state, state_group=None, prev_group=prev_group, delta_ids=delta_ids - ) + if prev_group is not None: + # If we have a prev group and deltas then we can drop the new state from + # the cache (to reduce memory usage). + return _StateCacheEntry( + state=None, state_group=None, prev_group=prev_group, delta_ids=delta_ids + ) + else: + return _StateCacheEntry(state=new_state, state_group=None) @attr.s(slots=True, auto_attribs=True)