diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index e3faa52cd6..3047e1b1ad 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -14,7 +14,7 @@
# limitations under the License.
import heapq
import logging
-from collections import defaultdict
+from collections import ChainMap, defaultdict
from typing import (
TYPE_CHECKING,
Any,
@@ -44,7 +44,6 @@ from synapse.logging.context import ContextResourceUsage
from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet
from synapse.state import v1, v2
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
-from synapse.storage.roommember import ProfileInfo
from synapse.storage.state import StateFilter
from synapse.types import StateMap
from synapse.util.async_helpers import Linearizer
@@ -92,8 +91,11 @@ class _StateCacheEntry:
prev_group: Optional[int] = None,
delta_ids: Optional[StateMap[str]] = None,
):
- if state is None and state_group is None:
- raise Exception("Either state or state group must be not None")
+ if state is None and state_group is None and prev_group is None:
+ raise Exception("One of state, state_group or prev_group must be not None")
+
+ if prev_group is not None and delta_ids is None:
+ raise Exception("If prev_group is set so must delta_ids")
# A map from (type, state_key) to event_id.
#
@@ -120,18 +122,48 @@ class _StateCacheEntry:
if self._state is not None:
return self._state
- assert self.state_group is not None
+ if self.state_group is not None:
+ return await state_storage.get_state_ids_for_group(
+ self.state_group, state_filter
+ )
+
+ assert self.prev_group is not None and self.delta_ids is not None
- return await state_storage.get_state_ids_for_group(
- self.state_group, state_filter
+ prev_state = await state_storage.get_state_ids_for_group(
+ self.prev_group, state_filter
)
+ # ChainMap expects MutableMapping, but since we're using it immutably
+ # its safe to give it immutable maps.
+ return ChainMap(self.delta_ids, prev_state) # type: ignore[arg-type]
+
+ def set_state_group(self, state_group: int) -> None:
+ """Update the state group assigned to this state (e.g. after we've
+ persisted it).
+
+ Note: this will cause the cache entry to drop any stored state.
+ """
+
+ self.state_group = state_group
+
+ # We clear out the state as we know longer need to explicitly keep it in
+ # the `state_cache` (as the store state group cache will do that).
+ self._state = None
+
def __len__(self) -> int:
- # The len should is used to estimate how large this cache entry is, for
- # cache eviction purposes. This is why if `self.state` is None it's fine
- # to return 1.
+ # The len should be used to estimate how large this cache entry is, for
+ # cache eviction purposes. This is why it's fine to return 1 if we're
+ # not storing any state.
+
+ length = 0
+
+ if self._state:
+ length += len(self._state)
- return len(self._state) if self._state else 1
+ if self.delta_ids:
+ length += len(self.delta_ids)
+
+ return length or 1 # Make sure its not 0.
class StateHandler:
@@ -177,11 +209,11 @@ class StateHandler:
ret = await self.resolve_state_groups_for_events(room_id, event_ids)
return await ret.get_state(self._state_storage_controller, state_filter)
- async def get_current_users_in_room(
+ async def get_current_user_ids_in_room(
self, room_id: str, latest_event_ids: List[str]
- ) -> Dict[str, ProfileInfo]:
+ ) -> Set[str]:
"""
- Get the users who are currently in a room.
+ Get the users IDs who are currently in a room.
Note: This is much slower than using the equivalent method
`DataStore.get_users_in_room` or `DataStore.get_users_in_room_with_profiles`,
@@ -192,15 +224,15 @@ class StateHandler:
room_id: The ID of the room.
latest_event_ids: Precomputed list of latest event IDs. Will be computed if None.
Returns:
- Dictionary of user IDs to their profileinfo.
+ Set of user IDs in the room.
"""
assert latest_event_ids is not None
- logger.debug("calling resolve_state_groups from get_current_users_in_room")
+ logger.debug("calling resolve_state_groups from get_current_user_ids_in_room")
entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
state = await entry.get_state(self._state_storage_controller, StateFilter.all())
- return await self.store.get_joined_users_from_state(room_id, state, entry)
+ return await self.store.get_joined_user_ids_from_state(room_id, state, entry)
async def get_hosts_in_room_at_events(
self, room_id: str, event_ids: Collection[str]
@@ -222,7 +254,7 @@ class StateHandler:
self,
event: EventBase,
state_ids_before_event: Optional[StateMap[str]] = None,
- partial_state: bool = False,
+ partial_state: Optional[bool] = None,
) -> EventContext:
"""Build an EventContext structure for a non-outlier event.
@@ -237,10 +269,18 @@ class StateHandler:
it can't be calculated from existing events. This is normally
only specified when receiving an event from federation where we
don't have the prev events, e.g. when backfilling.
- partial_state: True if `state_ids_before_event` is partial and omits
- non-critical membership events
+ partial_state:
+ `True` if `state_ids_before_event` is partial and omits non-critical
+ membership events.
+ `False` if `state_ids_before_event` is the full state.
+ `None` when `state_ids_before_event` is not provided. In this case, the
+ flag will be calculated based on `event`'s prev events.
Returns:
The event context.
+
+ Raises:
+ RuntimeError if `state_ids_before_event` is not provided and one or more
+ prev events are missing or outliers.
"""
assert not event.internal_metadata.is_outlier()
@@ -265,12 +305,14 @@ class StateHandler:
)
)
+ # the partial_state flag must be provided
+ assert partial_state is not None
else:
# otherwise, we'll need to resolve the state across the prev_events.
# partial_state should not be set explicitly in this case:
# we work it out dynamically
- assert not partial_state
+ assert partial_state is None
# if any of the prev-events have partial state, so do we.
# (This is slightly racy - the prev-events might get fixed up before we use
@@ -280,13 +322,13 @@ class StateHandler:
incomplete_prev_events = await self.store.get_partial_state_events(
prev_event_ids
)
- if any(incomplete_prev_events.values()):
+ partial_state = any(incomplete_prev_events.values())
+ if partial_state:
logger.debug(
"New/incoming event %s refers to prev_events %s with partial state",
event.event_id,
[k for (k, v) in incomplete_prev_events.items() if v],
)
- partial_state = True
logger.debug("calling resolve_state_groups from compute_event_context")
# we've already taken into account partial state, so no need to wait for
@@ -320,7 +362,7 @@ class StateHandler:
current_state_ids=state_ids_before_event,
)
)
- entry.state_group = state_group_before_event
+ entry.set_state_group(state_group_before_event)
else:
state_group_before_event = entry.state_group
@@ -393,6 +435,10 @@ class StateHandler:
Returns:
The resolved state
+
+ Raises:
+ RuntimeError if we don't have a state group for one or more of the events
+ (ie. they are outliers or unknown)
"""
logger.debug("resolve_state_groups event_ids %s", event_ids)
@@ -747,7 +793,7 @@ def _make_state_cache_entry(
old_state_event_ids = set(state.values())
if new_state_event_ids == old_state_event_ids:
# got an exact match.
- return _StateCacheEntry(state=new_state, state_group=sg)
+ return _StateCacheEntry(state=None, state_group=sg)
# TODO: We want to create a state group for this set of events, to
# increase cache hits, but we need to make sure that it doesn't
@@ -769,9 +815,14 @@ def _make_state_cache_entry(
prev_group = old_group
delta_ids = n_delta_ids
- return _StateCacheEntry(
- state=new_state, state_group=None, prev_group=prev_group, delta_ids=delta_ids
- )
+ if prev_group is not None:
+ # If we have a prev group and deltas then we can drop the new state from
+ # the cache (to reduce memory usage).
+ return _StateCacheEntry(
+ state=None, state_group=None, prev_group=prev_group, delta_ids=delta_ids
+ )
+ else:
+ return _StateCacheEntry(state=new_state, state_group=None)
@attr.s(slots=True, auto_attribs=True)
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index 7db032203b..cf3045f82e 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -434,7 +434,7 @@ async def _add_event_and_auth_chain_to_graph(
event_id: str,
event_map: Dict[str, EventBase],
state_res_store: StateResolutionStore,
- auth_diff: Set[str],
+ full_conflicted_set: Set[str],
) -> None:
"""Helper function for _reverse_topological_power_sort that add the event
and its auth chain (that is in the auth diff) to the graph
@@ -445,7 +445,7 @@ async def _add_event_and_auth_chain_to_graph(
event_id: Event to add to the graph
event_map
state_res_store
- auth_diff: Set of event IDs that are in the auth difference.
+ full_conflicted_set: Set of event IDs that are in the full conflicted set.
"""
state = [event_id]
@@ -455,7 +455,7 @@ async def _add_event_and_auth_chain_to_graph(
event = await _get_event(room_id, eid, event_map, state_res_store)
for aid in event.auth_event_ids():
- if aid in auth_diff:
+ if aid in full_conflicted_set:
if aid not in graph:
state.append(aid)
@@ -468,7 +468,7 @@ async def _reverse_topological_power_sort(
event_ids: Iterable[str],
event_map: Dict[str, EventBase],
state_res_store: StateResolutionStore,
- auth_diff: Set[str],
+ full_conflicted_set: Set[str],
) -> List[str]:
"""Returns a list of the event_ids sorted by reverse topological ordering,
and then by power level and origin_server_ts
@@ -479,7 +479,7 @@ async def _reverse_topological_power_sort(
event_ids: The events to sort
event_map
state_res_store
- auth_diff: Set of event IDs that are in the auth difference.
+ full_conflicted_set: Set of event IDs that are in the full conflicted set.
Returns:
The sorted list
@@ -488,7 +488,7 @@ async def _reverse_topological_power_sort(
graph: Dict[str, Set[str]] = {}
for idx, event_id in enumerate(event_ids, start=1):
await _add_event_and_auth_chain_to_graph(
- graph, room_id, event_id, event_map, state_res_store, auth_diff
+ graph, room_id, event_id, event_map, state_res_store, full_conflicted_set
)
# We await occasionally when we're working with large data sets to
|