summary refs log tree commit diff
path: root/synapse/state
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/state')
-rw-r--r--synapse/state/__init__.py117
1 files changed, 67 insertions, 50 deletions
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 781d9f06da..9f0a36652c 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -31,7 +31,6 @@ from typing import (
     Sequence,
     Set,
     Tuple,
-    Union,
 )
 
 import attr
@@ -47,6 +46,7 @@ from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServ
 from synapse.state import v1, v2
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.storage.roommember import ProfileInfo
+from synapse.storage.state import StateFilter
 from synapse.types import StateMap
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches.expiringcache import ExpiringCache
@@ -54,6 +54,7 @@ from synapse.util.metrics import Measure, measure_func
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
+    from synapse.storage.controllers import StateStorageController
     from synapse.storage.databases.main import DataStore
 
 logger = logging.getLogger(__name__)
@@ -83,17 +84,20 @@ def _gen_state_id() -> str:
 
 
 class _StateCacheEntry:
-    __slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"]
+    __slots__ = ["state", "state_group", "prev_group", "delta_ids"]
 
     def __init__(
         self,
-        state: StateMap[str],
+        state: Optional[StateMap[str]],
         state_group: Optional[int],
         prev_group: Optional[int] = None,
         delta_ids: Optional[StateMap[str]] = None,
     ):
+        if state is None and state_group is None:
+            raise Exception("Either state or state group must be not None")
+
         # A map from (type, state_key) to event_id.
-        self.state = frozendict(state)
+        self.state = frozendict(state) if state is not None else None
 
         # the ID of a state group if one and only one is involved.
         # otherwise, None otherwise?
@@ -102,20 +106,30 @@ class _StateCacheEntry:
         self.prev_group = prev_group
         self.delta_ids = frozendict(delta_ids) if delta_ids is not None else None
 
-        # The `state_id` is a unique ID we generate that can be used as ID for
-        # this collection of state. Usually this would be the same as the
-        # state group, but on worker instances we can't generate a new state
-        # group each time we resolve state, so we generate a separate one that
-        # isn't persisted and is used solely for caches.
-        # `state_id` is either a state_group (and so an int) or a string. This
-        # ensures we don't accidentally persist a state_id as a stateg_group
-        if state_group:
-            self.state_id: Union[str, int] = state_group
-        else:
-            self.state_id = _gen_state_id()
+    async def get_state(
+        self,
+        state_storage: "StateStorageController",
+        state_filter: Optional["StateFilter"] = None,
+    ) -> StateMap[str]:
+        """Get the state map for this entry, either from the in-memory state or
+        looking up the state group in the DB.
+        """
+
+        if self.state is not None:
+            return self.state
+
+        assert self.state_group is not None
+
+        return await state_storage.get_state_ids_for_group(
+            self.state_group, state_filter
+        )
 
     def __len__(self) -> int:
-        return len(self.state)
+        # The len should is used to estimate how large this cache entry is, for
+        # cache eviction purposes. This is why if `self.state` is None it's fine
+        # to return 1.
+
+        return len(self.state) if self.state else 1
 
 
 class StateHandler:
@@ -153,7 +167,7 @@ class StateHandler:
         """
         logger.debug("calling resolve_state_groups from get_current_state_ids")
         ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
-        return ret.state
+        return await ret.get_state(self._state_storage_controller, StateFilter.all())
 
     async def get_current_users_in_room(
         self, room_id: str, latest_event_ids: List[str]
@@ -177,7 +191,8 @@ class StateHandler:
 
         logger.debug("calling resolve_state_groups from get_current_users_in_room")
         entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
-        return await self.store.get_joined_users_from_state(room_id, entry)
+        state = await entry.get_state(self._state_storage_controller, StateFilter.all())
+        return await self.store.get_joined_users_from_state(room_id, state, entry)
 
     async def get_hosts_in_room_at_events(
         self, room_id: str, event_ids: Collection[str]
@@ -192,7 +207,8 @@ class StateHandler:
             The hosts in the room at the given events
         """
         entry = await self.resolve_state_groups_for_events(room_id, event_ids)
-        return await self.store.get_joined_hosts(room_id, entry)
+        state = await entry.get_state(self._state_storage_controller, StateFilter.all())
+        return await self.store.get_joined_hosts(room_id, state, entry)
 
     async def compute_event_context(
         self,
@@ -227,10 +243,19 @@ class StateHandler:
         #
         if state_ids_before_event:
             # if we're given the state before the event, then we use that
-            state_group_before_event = None
             state_group_before_event_prev_group = None
             deltas_to_state_group_before_event = None
-            entry = None
+
+            # .. though we need to get a state group for it.
+            state_group_before_event = (
+                await self._state_storage_controller.store_state_group(
+                    event.event_id,
+                    event.room_id,
+                    prev_group=None,
+                    delta_ids=None,
+                    current_state_ids=state_ids_before_event,
+                )
+            )
 
         else:
             # otherwise, we'll need to resolve the state across the prev_events.
@@ -264,36 +289,27 @@ class StateHandler:
                 await_full_state=False,
             )
 
-            state_ids_before_event = entry.state
-            state_group_before_event = entry.state_group
             state_group_before_event_prev_group = entry.prev_group
             deltas_to_state_group_before_event = entry.delta_ids
 
-        #
-        # make sure that we have a state group at that point. If it's not a state event,
-        # that will be the state group for the new event. If it *is* a state event,
-        # it might get rejected (in which case we'll need to persist it with the
-        # previous state group)
-        #
-
-        if not state_group_before_event:
-            state_group_before_event = (
-                await self._state_storage_controller.store_state_group(
-                    event.event_id,
-                    event.room_id,
-                    prev_group=state_group_before_event_prev_group,
-                    delta_ids=deltas_to_state_group_before_event,
-                    current_state_ids=state_ids_before_event,
+            # We make sure that we have a state group assigned to the state.
+            if entry.state_group is None:
+                state_ids_before_event = await entry.get_state(
+                    self._state_storage_controller, StateFilter.all()
+                )
+                state_group_before_event = (
+                    await self._state_storage_controller.store_state_group(
+                        event.event_id,
+                        event.room_id,
+                        prev_group=state_group_before_event_prev_group,
+                        delta_ids=deltas_to_state_group_before_event,
+                        current_state_ids=state_ids_before_event,
+                    )
                 )
-            )
-
-            # Assign the new state group to the cached state entry.
-            #
-            # Note that this can race in that we could generate multiple state
-            # groups for the same state entry, but that is just inefficient
-            # rather than dangerous.
-            if entry and entry.state_group is None:
                 entry.state_group = state_group_before_event
+            else:
+                state_group_before_event = entry.state_group
+                state_ids_before_event = None
 
         #
         # now if it's not a state event, we're done
@@ -313,6 +329,10 @@ class StateHandler:
         #
         # otherwise, we'll need to create a new state group for after the event
         #
+        if state_ids_before_event is None:
+            state_ids_before_event = await entry.get_state(
+                self._state_storage_controller, StateFilter.all()
+            )
 
         key = (event.type, event.state_key)
         if key in state_ids_before_event:
@@ -372,9 +392,6 @@ class StateHandler:
         state_group_ids_set = set(state_group_ids)
         if len(state_group_ids_set) == 1:
             (state_group_id,) = state_group_ids_set
-            state = await self._state_storage_controller.get_state_for_groups(
-                state_group_ids_set
-            )
             (
                 prev_group,
                 delta_ids,
@@ -382,7 +399,7 @@ class StateHandler:
                 state_group_id
             )
             return _StateCacheEntry(
-                state=state[state_group_id],
+                state=None,
                 state_group=state_group_id,
                 prev_group=prev_group,
                 delta_ids=delta_ids,