summary refs log tree commit diff
path: root/synapse/state
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/state/__init__.py51
1 files changed, 31 insertions, 20 deletions
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 9c9d946f38..bf09f5128a 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -127,10 +127,10 @@ class StateHandler:
     def __init__(self, hs: "HomeServer"):
         self.clock = hs.get_clock()
         self.store = hs.get_datastores().main
-        self.state_storage = hs.get_storage().state
+        self._state_storage_controller = hs.get_storage_controllers().state
         self.hs = hs
         self._state_resolution_handler = hs.get_state_resolution_handler()
-        self._storage = hs.get_storage()
+        self._storage_controllers = hs.get_storage_controllers()
 
     @overload
     async def get_current_state(
@@ -337,12 +337,14 @@ class StateHandler:
         #
 
         if not state_group_before_event:
-            state_group_before_event = await self.state_storage.store_state_group(
-                event.event_id,
-                event.room_id,
-                prev_group=state_group_before_event_prev_group,
-                delta_ids=deltas_to_state_group_before_event,
-                current_state_ids=state_ids_before_event,
+            state_group_before_event = (
+                await self._state_storage_controller.store_state_group(
+                    event.event_id,
+                    event.room_id,
+                    prev_group=state_group_before_event_prev_group,
+                    delta_ids=deltas_to_state_group_before_event,
+                    current_state_ids=state_ids_before_event,
+                )
             )
 
             # Assign the new state group to the cached state entry.
@@ -359,7 +361,7 @@ class StateHandler:
 
         if not event.is_state():
             return EventContext.with_state(
-                storage=self._storage,
+                storage=self._storage_controllers,
                 state_group_before_event=state_group_before_event,
                 state_group=state_group_before_event,
                 state_delta_due_to_event={},
@@ -382,16 +384,18 @@ class StateHandler:
         state_ids_after_event[key] = event.event_id
         delta_ids = {key: event.event_id}
 
-        state_group_after_event = await self.state_storage.store_state_group(
-            event.event_id,
-            event.room_id,
-            prev_group=state_group_before_event,
-            delta_ids=delta_ids,
-            current_state_ids=state_ids_after_event,
+        state_group_after_event = (
+            await self._state_storage_controller.store_state_group(
+                event.event_id,
+                event.room_id,
+                prev_group=state_group_before_event,
+                delta_ids=delta_ids,
+                current_state_ids=state_ids_after_event,
+            )
         )
 
         return EventContext.with_state(
-            storage=self._storage,
+            storage=self._storage_controllers,
             state_group=state_group_after_event,
             state_group_before_event=state_group_before_event,
             state_delta_due_to_event=delta_ids,
@@ -416,7 +420,9 @@ class StateHandler:
         """
         logger.debug("resolve_state_groups event_ids %s", event_ids)
 
-        state_groups = await self.state_storage.get_state_group_for_events(event_ids)
+        state_groups = await self._state_storage_controller.get_state_group_for_events(
+            event_ids
+        )
 
         state_group_ids = state_groups.values()
 
@@ -424,8 +430,13 @@ class StateHandler:
         state_group_ids_set = set(state_group_ids)
         if len(state_group_ids_set) == 1:
             (state_group_id,) = state_group_ids_set
-            state = await self.state_storage.get_state_for_groups(state_group_ids_set)
-            prev_group, delta_ids = await self.state_storage.get_state_group_delta(
+            state = await self._state_storage_controller.get_state_for_groups(
+                state_group_ids_set
+            )
+            (
+                prev_group,
+                delta_ids,
+            ) = await self._state_storage_controller.get_state_group_delta(
                 state_group_id
             )
             return _StateCacheEntry(
@@ -439,7 +450,7 @@ class StateHandler:
 
         room_version = await self.store.get_room_version_id(room_id)
 
-        state_to_resolve = await self.state_storage.get_state_for_groups(
+        state_to_resolve = await self._state_storage_controller.get_state_for_groups(
             state_group_ids_set
         )