summary refs log tree commit diff
path: root/synapse/state
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/state')
-rw-r--r--synapse/state/__init__.py68
1 files changed, 53 insertions, 15 deletions
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index e3faa52cd6..87ccd52f0a 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 import heapq
 import logging
-from collections import defaultdict
+from collections import ChainMap, defaultdict
 from typing import (
     TYPE_CHECKING,
     Any,
@@ -92,8 +92,11 @@ class _StateCacheEntry:
         prev_group: Optional[int] = None,
         delta_ids: Optional[StateMap[str]] = None,
     ):
-        if state is None and state_group is None:
-            raise Exception("Either state or state group must be not None")
+        if state is None and state_group is None and prev_group is None:
+            raise Exception("One of state, state_group or prev_group must be not None")
+
+        if prev_group is not None and delta_ids is None:
+            raise Exception("If prev_group is set so must delta_ids")
 
         # A map from (type, state_key) to event_id.
         #
@@ -120,18 +123,48 @@ class _StateCacheEntry:
         if self._state is not None:
             return self._state
 
-        assert self.state_group is not None
+        if self.state_group is not None:
+            return await state_storage.get_state_ids_for_group(
+                self.state_group, state_filter
+            )
+
+        assert self.prev_group is not None and self.delta_ids is not None
 
-        return await state_storage.get_state_ids_for_group(
-            self.state_group, state_filter
+        prev_state = await state_storage.get_state_ids_for_group(
+            self.prev_group, state_filter
         )
 
+        # ChainMap expects MutableMapping, but since we're using it immutably
+        # its safe to give it immutable maps.
+        return ChainMap(self.delta_ids, prev_state)  # type: ignore[arg-type]
+
+    def set_state_group(self, state_group: int) -> None:
+        """Update the state group assigned to this state (e.g. after we've
+        persisted it).
+
+        Note: this will cause the cache entry to drop any stored state.
+        """
+
+        self.state_group = state_group
+
+        # We clear out the state as we know longer need to explicitly keep it in
+        # the `state_cache` (as the store state group cache will do that).
+        self._state = None
+
     def __len__(self) -> int:
-        # The len should is used to estimate how large this cache entry is, for
-        # cache eviction purposes. This is why if `self.state` is None it's fine
-        # to return 1.
+        # The len should be used to estimate how large this cache entry is, for
+        # cache eviction purposes. This is why it's fine to return 1 if we're
+        # not storing any state.
+
+        length = 0
 
-        return len(self._state) if self._state else 1
+        if self._state:
+            length += len(self._state)
+
+        if self.delta_ids:
+            length += len(self.delta_ids)
+
+        return length or 1  # Make sure its not 0.
 
 
 class StateHandler:
@@ -320,7 +353,7 @@ class StateHandler:
                         current_state_ids=state_ids_before_event,
                     )
                 )
-                entry.state_group = state_group_before_event
+                entry.set_state_group(state_group_before_event)
             else:
                 state_group_before_event = entry.state_group
 
@@ -747,7 +780,7 @@ def _make_state_cache_entry(
         old_state_event_ids = set(state.values())
         if new_state_event_ids == old_state_event_ids:
             # got an exact match.
-            return _StateCacheEntry(state=new_state, state_group=sg)
+            return _StateCacheEntry(state=None, state_group=sg)
 
     # TODO: We want to create a state group for this set of events, to
     # increase cache hits, but we need to make sure that it doesn't
@@ -769,9 +802,14 @@ def _make_state_cache_entry(
             prev_group = old_group
             delta_ids = n_delta_ids
 
-    return _StateCacheEntry(
-        state=new_state, state_group=None, prev_group=prev_group, delta_ids=delta_ids
-    )
+    if prev_group is not None:
+        # If we have a prev group and deltas then we can drop the new state from
+        # the cache (to reduce memory usage).
+        return _StateCacheEntry(
+            state=None, state_group=None, prev_group=prev_group, delta_ids=delta_ids
+        )
+    else:
+        return _StateCacheEntry(state=new_state, state_group=None)
 
 
 @attr.s(slots=True, auto_attribs=True)