diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index da25f20ae5..3787d35b24 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -14,7 +14,7 @@
# limitations under the License.
import heapq
import logging
-from collections import defaultdict
+from collections import ChainMap, defaultdict
from typing import (
TYPE_CHECKING,
Any,
@@ -24,14 +24,12 @@ from typing import (
DefaultDict,
Dict,
FrozenSet,
- Iterable,
List,
Mapping,
Optional,
Sequence,
Set,
Tuple,
- Union,
)
import attr
@@ -43,9 +41,10 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersio
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.logging.context import ContextResourceUsage
+from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet
from synapse.state import v1, v2
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
-from synapse.storage.roommember import ProfileInfo
+from synapse.storage.state import StateFilter
from synapse.types import StateMap
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
@@ -53,6 +52,7 @@ from synapse.util.metrics import Measure, measure_func
if TYPE_CHECKING:
from synapse.server import HomeServer
+ from synapse.storage.controllers import StateStorageController
from synapse.storage.databases.main import DataStore
logger = logging.getLogger(__name__)
@@ -82,17 +82,26 @@ def _gen_state_id() -> str:
class _StateCacheEntry:
- __slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"]
+ __slots__ = ["_state", "state_group", "prev_group", "delta_ids"]
def __init__(
self,
- state: StateMap[str],
+ state: Optional[StateMap[str]],
state_group: Optional[int],
prev_group: Optional[int] = None,
delta_ids: Optional[StateMap[str]] = None,
):
+ if state is None and state_group is None and prev_group is None:
+ raise Exception("One of state, state_group or prev_group must be not None")
+
+ if prev_group is not None and delta_ids is None:
+ raise Exception("If prev_group is set so must delta_ids")
+
# A map from (type, state_key) to event_id.
- self.state = frozendict(state)
+ #
+ # This can be None if we have a `state_group` (as then we can fetch the
+ # state from the DB.)
+ self._state = frozendict(state) if state is not None else None
# the ID of a state group if one and only one is involved.
# otherwise, None otherwise?
@@ -101,20 +110,60 @@ class _StateCacheEntry:
self.prev_group = prev_group
self.delta_ids = frozendict(delta_ids) if delta_ids is not None else None
- # The `state_id` is a unique ID we generate that can be used as ID for
- # this collection of state. Usually this would be the same as the
- # state group, but on worker instances we can't generate a new state
- # group each time we resolve state, so we generate a separate one that
- # isn't persisted and is used solely for caches.
- # `state_id` is either a state_group (and so an int) or a string. This
- # ensures we don't accidentally persist a state_id as a stateg_group
- if state_group:
- self.state_id: Union[str, int] = state_group
- else:
- self.state_id = _gen_state_id()
+ async def get_state(
+ self,
+ state_storage: "StateStorageController",
+ state_filter: Optional["StateFilter"] = None,
+ ) -> StateMap[str]:
+ """Get the state map for this entry, either from the in-memory state or
+ looking up the state group in the DB.
+ """
+
+ if self._state is not None:
+ return self._state
+
+ if self.state_group is not None:
+ return await state_storage.get_state_ids_for_group(
+ self.state_group, state_filter
+ )
+
+ assert self.prev_group is not None and self.delta_ids is not None
+
+ prev_state = await state_storage.get_state_ids_for_group(
+ self.prev_group, state_filter
+ )
+
+ # ChainMap expects MutableMapping, but since we're using it immutably
+ # its safe to give it immutable maps.
+ return ChainMap(self.delta_ids, prev_state) # type: ignore[arg-type]
+
+ def set_state_group(self, state_group: int) -> None:
+ """Update the state group assigned to this state (e.g. after we've
+ persisted it).
+
+ Note: this will cause the cache entry to drop any stored state.
+ """
+
+ self.state_group = state_group
+
+ # We clear out the state as we know longer need to explicitly keep it in
+ # the `state_cache` (as the store state group cache will do that).
+ self._state = None
def __len__(self) -> int:
- return len(self.state)
+ # The len should be used to estimate how large this cache entry is, for
+ # cache eviction purposes. This is why it's fine to return 1 if we're
+ # not storing any state.
+
+ length = 0
+
+ if self._state:
+ length += len(self._state)
+
+ if self.delta_ids:
+ length += len(self.delta_ids)
+
+ return length or 1 # Make sure its not 0.
class StateHandler:
@@ -129,30 +178,42 @@ class StateHandler:
self.hs = hs
self._state_resolution_handler = hs.get_state_resolution_handler()
self._storage_controllers = hs.get_storage_controllers()
+ self._events_shard_config = hs.config.worker.events_shard_config
+ self._instance_name = hs.get_instance_name()
- async def get_current_state_ids(
+ self._update_current_state_client = (
+ ReplicationUpdateCurrentStateRestServlet.make_client(hs)
+ )
+
+ async def compute_state_after_events(
self,
room_id: str,
- latest_event_ids: Collection[str],
+ event_ids: Collection[str],
+ state_filter: Optional[StateFilter] = None,
) -> StateMap[str]:
- """Get the current state, or the state at a set of events, for a room
+ """Fetch the state after each of the given event IDs. Resolve them and return.
+
+ This is typically used where `event_ids` is a collection of forward extremities
+ in a room, intended to become the `prev_events` of a new event E. If so, the
+ return value of this function represents the state before E.
Args:
- room_id:
- latest_event_ids: The forward extremities to resolve.
+ room_id: the room_id containing the given events.
+ event_ids: the events whose state should be fetched and resolved.
Returns:
- the state dict, mapping from (event_type, state_key) -> event_id
+ the state dict (a mapping from (event_type, state_key) -> event_id) which
+ holds the resolution of the states after the given event IDs.
"""
- logger.debug("calling resolve_state_groups from get_current_state_ids")
- ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
- return ret.state
+ logger.debug("calling resolve_state_groups from compute_state_after_events")
+ ret = await self.resolve_state_groups_for_events(room_id, event_ids)
+ return await ret.get_state(self._state_storage_controller, state_filter)
- async def get_current_users_in_room(
+ async def get_current_user_ids_in_room(
self, room_id: str, latest_event_ids: List[str]
- ) -> Dict[str, ProfileInfo]:
+ ) -> Set[str]:
"""
- Get the users who are currently in a room.
+ Get the users IDs who are currently in a room.
Note: This is much slower than using the equivalent method
`DataStore.get_users_in_room` or `DataStore.get_users_in_room_with_profiles`,
@@ -163,14 +224,15 @@ class StateHandler:
room_id: The ID of the room.
latest_event_ids: Precomputed list of latest event IDs. Will be computed if None.
Returns:
- Dictionary of user IDs to their profileinfo.
+ Set of user IDs in the room.
"""
assert latest_event_ids is not None
- logger.debug("calling resolve_state_groups from get_current_users_in_room")
+ logger.debug("calling resolve_state_groups from get_current_user_ids_in_room")
entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
- return await self.store.get_joined_users_from_state(room_id, entry)
+ state = await entry.get_state(self._state_storage_controller, StateFilter.all())
+ return await self.store.get_joined_user_ids_from_state(room_id, state)
async def get_hosts_in_room_at_events(
self, room_id: str, event_ids: Collection[str]
@@ -185,13 +247,14 @@ class StateHandler:
The hosts in the room at the given events
"""
entry = await self.resolve_state_groups_for_events(room_id, event_ids)
- return await self.store.get_joined_hosts(room_id, entry)
+ state = await entry.get_state(self._state_storage_controller, StateFilter.all())
+ return await self.store.get_joined_hosts(room_id, state, entry)
async def compute_event_context(
self,
event: EventBase,
state_ids_before_event: Optional[StateMap[str]] = None,
- partial_state: bool = False,
+ partial_state: Optional[bool] = None,
) -> EventContext:
"""Build an EventContext structure for a non-outlier event.
@@ -206,10 +269,18 @@ class StateHandler:
it can't be calculated from existing events. This is normally
only specified when receiving an event from federation where we
don't have the prev events, e.g. when backfilling.
- partial_state: True if `state_ids_before_event` is partial and omits
- non-critical membership events
+ partial_state:
+ `True` if `state_ids_before_event` is partial and omits non-critical
+ membership events.
+ `False` if `state_ids_before_event` is the full state.
+ `None` when `state_ids_before_event` is not provided. In this case, the
+ flag will be calculated based on `event`'s prev events.
Returns:
The event context.
+
+ Raises:
+ RuntimeError if `state_ids_before_event` is not provided and one or more
+ prev events are missing or outliers.
"""
assert not event.internal_metadata.is_outlier()
@@ -220,17 +291,28 @@ class StateHandler:
#
if state_ids_before_event:
# if we're given the state before the event, then we use that
- state_group_before_event = None
state_group_before_event_prev_group = None
deltas_to_state_group_before_event = None
- entry = None
+ # .. though we need to get a state group for it.
+ state_group_before_event = (
+ await self._state_storage_controller.store_state_group(
+ event.event_id,
+ event.room_id,
+ prev_group=None,
+ delta_ids=None,
+ current_state_ids=state_ids_before_event,
+ )
+ )
+
+ # the partial_state flag must be provided
+ assert partial_state is not None
else:
# otherwise, we'll need to resolve the state across the prev_events.
# partial_state should not be set explicitly in this case:
# we work it out dynamically
- assert not partial_state
+ assert partial_state is None
# if any of the prev-events have partial state, so do we.
# (This is slightly racy - the prev-events might get fixed up before we use
@@ -240,49 +322,49 @@ class StateHandler:
incomplete_prev_events = await self.store.get_partial_state_events(
prev_event_ids
)
- if any(incomplete_prev_events.values()):
+ partial_state = any(incomplete_prev_events.values())
+ if partial_state:
logger.debug(
"New/incoming event %s refers to prev_events %s with partial state",
event.event_id,
[k for (k, v) in incomplete_prev_events.items() if v],
)
- partial_state = True
logger.debug("calling resolve_state_groups from compute_event_context")
+ # we've already taken into account partial state, so no need to wait for
+ # complete state here.
entry = await self.resolve_state_groups_for_events(
- event.room_id, event.prev_event_ids()
+ event.room_id,
+ event.prev_event_ids(),
+ await_full_state=False,
)
- state_ids_before_event = entry.state
- state_group_before_event = entry.state_group
state_group_before_event_prev_group = entry.prev_group
deltas_to_state_group_before_event = entry.delta_ids
+ state_ids_before_event = None
+
+ # We make sure that we have a state group assigned to the state.
+ if entry.state_group is None:
+ # store_state_group requires us to have either a previous state group
+ # (with deltas) or the complete state map. So, if we don't have a
+ # previous state group, load the complete state map now.
+ if state_group_before_event_prev_group is None:
+ state_ids_before_event = await entry.get_state(
+ self._state_storage_controller, StateFilter.all()
+ )
- #
- # make sure that we have a state group at that point. If it's not a state event,
- # that will be the state group for the new event. If it *is* a state event,
- # it might get rejected (in which case we'll need to persist it with the
- # previous state group)
- #
-
- if not state_group_before_event:
- state_group_before_event = (
- await self._state_storage_controller.store_state_group(
- event.event_id,
- event.room_id,
- prev_group=state_group_before_event_prev_group,
- delta_ids=deltas_to_state_group_before_event,
- current_state_ids=state_ids_before_event,
+ state_group_before_event = (
+ await self._state_storage_controller.store_state_group(
+ event.event_id,
+ event.room_id,
+ prev_group=state_group_before_event_prev_group,
+ delta_ids=deltas_to_state_group_before_event,
+ current_state_ids=state_ids_before_event,
+ )
)
- )
-
- # Assign the new state group to the cached state entry.
- #
- # Note that this can race in that we could generate multiple state
- # groups for the same state entry, but that is just inefficient
- # rather than dangerous.
- if entry and entry.state_group is None:
- entry.state_group = state_group_before_event
+ entry.set_state_group(state_group_before_event)
+ else:
+ state_group_before_event = entry.state_group
#
# now if it's not a state event, we're done
@@ -304,13 +386,18 @@ class StateHandler:
#
key = (event.type, event.state_key)
- if key in state_ids_before_event:
- replaces = state_ids_before_event[key]
- if replaces != event.event_id:
- event.unsigned["replaces_state"] = replaces
- state_ids_after_event = dict(state_ids_before_event)
- state_ids_after_event[key] = event.event_id
+ if state_ids_before_event is not None:
+ replaces = state_ids_before_event.get(key)
+ else:
+ replaces_state_map = await entry.get_state(
+ self._state_storage_controller, StateFilter.from_types([key])
+ )
+ replaces = replaces_state_map.get(key)
+
+ if replaces and replaces != event.event_id:
+ event.unsigned["replaces_state"] = replaces
+
delta_ids = {key: event.event_id}
state_group_after_event = (
@@ -319,7 +406,7 @@ class StateHandler:
event.room_id,
prev_group=state_group_before_event,
delta_ids=delta_ids,
- current_state_ids=state_ids_after_event,
+ current_state_ids=None,
)
)
@@ -335,7 +422,7 @@ class StateHandler:
@measure_func()
async def resolve_state_groups_for_events(
- self, room_id: str, event_ids: Collection[str]
+ self, room_id: str, event_ids: Collection[str], await_full_state: bool = True
) -> _StateCacheEntry:
"""Given a list of event_ids this method fetches the state at each
event, resolves conflicts between them and returns them.
@@ -343,14 +430,20 @@ class StateHandler:
Args:
room_id
event_ids
+ await_full_state: if true, will block if we do not yet have complete
+ state at these events.
Returns:
The resolved state
+
+ Raises:
+ RuntimeError if we don't have a state group for one or more of the events
+ (ie. they are outliers or unknown)
"""
logger.debug("resolve_state_groups event_ids %s", event_ids)
state_groups = await self._state_storage_controller.get_state_group_for_events(
- event_ids
+ event_ids, await_full_state=await_full_state
)
state_group_ids = state_groups.values()
@@ -359,9 +452,6 @@ class StateHandler:
state_group_ids_set = set(state_group_ids)
if len(state_group_ids_set) == 1:
(state_group_id,) = state_group_ids_set
- state = await self._state_storage_controller.get_state_for_groups(
- state_group_ids_set
- )
(
prev_group,
delta_ids,
@@ -369,7 +459,7 @@ class StateHandler:
state_group_id
)
return _StateCacheEntry(
- state=state[state_group_id],
+ state=None,
state_group=state_group_id,
prev_group=prev_group,
delta_ids=delta_ids,
@@ -392,30 +482,23 @@ class StateHandler:
)
return result
- async def resolve_events(
- self,
- room_version: str,
- state_sets: Collection[Iterable[EventBase]],
- event: EventBase,
- ) -> StateMap[EventBase]:
- logger.info(
- "Resolving state for %s with %d groups", event.room_id, len(state_sets)
- )
- state_set_ids = [
- {(ev.type, ev.state_key): ev.event_id for ev in st} for st in state_sets
- ]
-
- state_map = {ev.event_id: ev for st in state_sets for ev in st}
+ async def update_current_state(self, room_id: str) -> None:
+ """Recalculates the current state for a room, and persists it.
- new_state = await self._state_resolution_handler.resolve_events_with_store(
- event.room_id,
- room_version,
- state_set_ids,
- event_map=state_map,
- state_res_store=StateResolutionStore(self.store),
- )
+ Raises:
+ SynapseError(502): if all attempts to connect to the event persister worker
+ fail
+ """
+ writer_instance = self._events_shard_config.get_instance(room_id)
+ if writer_instance != self._instance_name:
+ await self._update_current_state_client(
+ instance_name=writer_instance,
+ room_id=room_id,
+ )
+ return
- return {key: state_map[ev_id] for key, ev_id in new_state.items()}
+ assert self._storage_controllers.persistence is not None
+ await self._storage_controllers.persistence.update_current_state(room_id)
@attr.s(slots=True, auto_attribs=True)
@@ -444,6 +527,15 @@ _biggest_room_by_db_counter = Counter(
"expensive room for state resolution",
)
+_cpu_times = Histogram(
+ "synapse_state_res_cpu_for_all_rooms_seconds",
+ "CPU time (utime+stime) spent computing a single state resolution",
+)
+_db_times = Histogram(
+ "synapse_state_res_db_for_all_rooms_seconds",
+ "Database time spent computing a single state resolution",
+)
+
class StateResolutionHandler:
"""Responsible for doing state conflict resolution.
@@ -609,6 +701,9 @@ class StateResolutionHandler:
room_metrics.db_time += rusage.db_txn_duration_sec
room_metrics.db_events += rusage.evt_db_fetch_count
+ _cpu_times.observe(rusage.ru_utime + rusage.ru_stime)
+ _db_times.observe(rusage.db_txn_duration_sec)
+
def _report_metrics(self) -> None:
if not self._state_res_metrics:
# no state res has happened since the last iteration: don't bother logging.
@@ -698,7 +793,7 @@ def _make_state_cache_entry(
old_state_event_ids = set(state.values())
if new_state_event_ids == old_state_event_ids:
# got an exact match.
- return _StateCacheEntry(state=new_state, state_group=sg)
+ return _StateCacheEntry(state=None, state_group=sg)
# TODO: We want to create a state group for this set of events, to
# increase cache hits, but we need to make sure that it doesn't
@@ -709,14 +804,25 @@ def _make_state_cache_entry(
delta_ids: Optional[StateMap[str]] = None
for old_group, old_state in state_groups_ids.items():
+ if old_state.keys() - new_state.keys():
+ # Currently we don't support deltas that remove keys from the state
+ # map, so we have to ignore this group as a candidate to base the
+ # new group on.
+ continue
+
n_delta_ids = {k: v for k, v in new_state.items() if old_state.get(k) != v}
if not delta_ids or len(n_delta_ids) < len(delta_ids):
prev_group = old_group
delta_ids = n_delta_ids
- return _StateCacheEntry(
- state=new_state, state_group=None, prev_group=prev_group, delta_ids=delta_ids
- )
+ if prev_group is not None:
+ # If we have a prev group and deltas then we can drop the new state from
+ # the cache (to reduce memory usage).
+ return _StateCacheEntry(
+ state=None, state_group=None, prev_group=prev_group, delta_ids=delta_ids
+ )
+ else:
+ return _StateCacheEntry(state=new_state, state_group=None)
@attr.s(slots=True, auto_attribs=True)
|