diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index e3faa52cd6..87ccd52f0a 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -14,7 +14,7 @@
# limitations under the License.
import heapq
import logging
-from collections import defaultdict
+from collections import ChainMap, defaultdict
from typing import (
TYPE_CHECKING,
Any,
@@ -92,8 +92,11 @@ class _StateCacheEntry:
prev_group: Optional[int] = None,
delta_ids: Optional[StateMap[str]] = None,
):
- if state is None and state_group is None:
- raise Exception("Either state or state group must be not None")
+ if state is None and state_group is None and prev_group is None:
+ raise Exception("One of state, state_group or prev_group must be not None")
+
+ if prev_group is not None and delta_ids is None:
+ raise Exception("If prev_group is set so must delta_ids")
# A map from (type, state_key) to event_id.
#
@@ -120,18 +123,48 @@ class _StateCacheEntry:
if self._state is not None:
return self._state
- assert self.state_group is not None
+ if self.state_group is not None:
+ return await state_storage.get_state_ids_for_group(
+ self.state_group, state_filter
+ )
+
+ assert self.prev_group is not None and self.delta_ids is not None
- return await state_storage.get_state_ids_for_group(
- self.state_group, state_filter
+ prev_state = await state_storage.get_state_ids_for_group(
+ self.prev_group, state_filter
)
+ # ChainMap expects MutableMapping, but since we're using it immutably
+ # its safe to give it immutable maps.
+ return ChainMap(self.delta_ids, prev_state) # type: ignore[arg-type]
+
+ def set_state_group(self, state_group: int) -> None:
+ """Update the state group assigned to this state (e.g. after we've
+ persisted it).
+
+ Note: this will cause the cache entry to drop any stored state.
+ """
+
+ self.state_group = state_group
+
+ # We clear out the state as we know longer need to explicitly keep it in
+ # the `state_cache` (as the store state group cache will do that).
+ self._state = None
+
def __len__(self) -> int:
- # The len should is used to estimate how large this cache entry is, for
- # cache eviction purposes. This is why if `self.state` is None it's fine
- # to return 1.
+ # The len should be used to estimate how large this cache entry is, for
+ # cache eviction purposes. This is why it's fine to return 1 if we're
+ # not storing any state.
+
+ length = 0
- return len(self._state) if self._state else 1
+ if self._state:
+ length += len(self._state)
+
+ if self.delta_ids:
+ length += len(self.delta_ids)
+
+ return length or 1 # Make sure its not 0.
class StateHandler:
@@ -320,7 +353,7 @@ class StateHandler:
current_state_ids=state_ids_before_event,
)
)
- entry.state_group = state_group_before_event
+ entry.set_state_group(state_group_before_event)
else:
state_group_before_event = entry.state_group
@@ -747,7 +780,7 @@ def _make_state_cache_entry(
old_state_event_ids = set(state.values())
if new_state_event_ids == old_state_event_ids:
# got an exact match.
- return _StateCacheEntry(state=new_state, state_group=sg)
+ return _StateCacheEntry(state=None, state_group=sg)
# TODO: We want to create a state group for this set of events, to
# increase cache hits, but we need to make sure that it doesn't
@@ -769,9 +802,14 @@ def _make_state_cache_entry(
prev_group = old_group
delta_ids = n_delta_ids
- return _StateCacheEntry(
- state=new_state, state_group=None, prev_group=prev_group, delta_ids=delta_ids
- )
+ if prev_group is not None:
+ # If we have a prev group and deltas then we can drop the new state from
+ # the cache (to reduce memory usage).
+ return _StateCacheEntry(
+ state=None, state_group=None, prev_group=prev_group, delta_ids=delta_ids
+ )
+ else:
+ return _StateCacheEntry(state=new_state, state_group=None)
@attr.s(slots=True, auto_attribs=True)
|