summary refs log tree commit diff
path: root/synapse/state/__init__.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/state/__init__.py')
-rw-r--r--synapse/state/__init__.py36
1 files changed, 17 insertions, 19 deletions
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 4b4ed42cff..9c9d946f38 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -127,7 +127,7 @@ class StateHandler:
     def __init__(self, hs: "HomeServer"):
         self.clock = hs.get_clock()
         self.store = hs.get_datastores().main
-        self.state_store = hs.get_storage().state
+        self.state_storage = hs.get_storage().state
         self.hs = hs
         self._state_resolution_handler = hs.get_state_resolution_handler()
         self._storage = hs.get_storage()
@@ -261,7 +261,7 @@ class StateHandler:
     async def compute_event_context(
         self,
         event: EventBase,
-        old_state: Optional[Iterable[EventBase]] = None,
+        state_ids_before_event: Optional[StateMap[str]] = None,
         partial_state: bool = False,
     ) -> EventContext:
         """Build an EventContext structure for a non-outlier event.
@@ -273,12 +273,12 @@ class StateHandler:
 
         Args:
             event:
-            old_state: The state at the event if it can't be
-                calculated from existing events. This is normally only specified
-                when receiving an event from federation where we don't have the
-                prev events for, e.g. when backfilling.
-            partial_state: True if `old_state` is partial and omits non-critical
-                membership events
+            state_ids_before_event: The event ids of the state before the event if
+                it can't be calculated from existing events. This is normally
+                only specified when receiving an event from federation where we
+                don't have the prev events, e.g. when backfilling.
+            partial_state: True if `state_ids_before_event` is partial and omits
+                non-critical membership events
         Returns:
             The event context.
         """
@@ -286,13 +286,11 @@ class StateHandler:
         assert not event.internal_metadata.is_outlier()
 
         #
-        # first of all, figure out the state before the event
+        # first of all, figure out the state before the event, unless we
+        # already have it.
         #
-        if old_state:
+        if state_ids_before_event:
             # if we're given the state before the event, then we use that
-            state_ids_before_event: StateMap[str] = {
-                (s.type, s.state_key): s.event_id for s in old_state
-            }
             state_group_before_event = None
             state_group_before_event_prev_group = None
             deltas_to_state_group_before_event = None
@@ -339,7 +337,7 @@ class StateHandler:
         #
 
         if not state_group_before_event:
-            state_group_before_event = await self.state_store.store_state_group(
+            state_group_before_event = await self.state_storage.store_state_group(
                 event.event_id,
                 event.room_id,
                 prev_group=state_group_before_event_prev_group,
@@ -384,7 +382,7 @@ class StateHandler:
         state_ids_after_event[key] = event.event_id
         delta_ids = {key: event.event_id}
 
-        state_group_after_event = await self.state_store.store_state_group(
+        state_group_after_event = await self.state_storage.store_state_group(
             event.event_id,
             event.room_id,
             prev_group=state_group_before_event,
@@ -418,7 +416,7 @@ class StateHandler:
         """
         logger.debug("resolve_state_groups event_ids %s", event_ids)
 
-        state_groups = await self.state_store.get_state_group_for_events(event_ids)
+        state_groups = await self.state_storage.get_state_group_for_events(event_ids)
 
         state_group_ids = state_groups.values()
 
@@ -426,8 +424,8 @@ class StateHandler:
         state_group_ids_set = set(state_group_ids)
         if len(state_group_ids_set) == 1:
             (state_group_id,) = state_group_ids_set
-            state = await self.state_store.get_state_for_groups(state_group_ids_set)
-            prev_group, delta_ids = await self.state_store.get_state_group_delta(
+            state = await self.state_storage.get_state_for_groups(state_group_ids_set)
+            prev_group, delta_ids = await self.state_storage.get_state_group_delta(
                 state_group_id
             )
             return _StateCacheEntry(
@@ -441,7 +439,7 @@ class StateHandler:
 
         room_version = await self.store.get_room_version_id(room_id)
 
-        state_to_resolve = await self.state_store.get_state_for_groups(
+        state_to_resolve = await self.state_storage.get_state_for_groups(
             state_group_ids_set
         )