diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 9c9d946f38..bf09f5128a 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -127,10 +127,10 @@ class StateHandler:
def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()
self.store = hs.get_datastores().main
- self.state_storage = hs.get_storage().state
+ self._state_storage_controller = hs.get_storage_controllers().state
self.hs = hs
self._state_resolution_handler = hs.get_state_resolution_handler()
- self._storage = hs.get_storage()
+ self._storage_controllers = hs.get_storage_controllers()
@overload
async def get_current_state(
@@ -337,12 +337,14 @@ class StateHandler:
#
if not state_group_before_event:
- state_group_before_event = await self.state_storage.store_state_group(
- event.event_id,
- event.room_id,
- prev_group=state_group_before_event_prev_group,
- delta_ids=deltas_to_state_group_before_event,
- current_state_ids=state_ids_before_event,
+ state_group_before_event = (
+ await self._state_storage_controller.store_state_group(
+ event.event_id,
+ event.room_id,
+ prev_group=state_group_before_event_prev_group,
+ delta_ids=deltas_to_state_group_before_event,
+ current_state_ids=state_ids_before_event,
+ )
)
# Assign the new state group to the cached state entry.
@@ -359,7 +361,7 @@ class StateHandler:
if not event.is_state():
return EventContext.with_state(
- storage=self._storage,
+ storage=self._storage_controllers,
state_group_before_event=state_group_before_event,
state_group=state_group_before_event,
state_delta_due_to_event={},
@@ -382,16 +384,18 @@ class StateHandler:
state_ids_after_event[key] = event.event_id
delta_ids = {key: event.event_id}
- state_group_after_event = await self.state_storage.store_state_group(
- event.event_id,
- event.room_id,
- prev_group=state_group_before_event,
- delta_ids=delta_ids,
- current_state_ids=state_ids_after_event,
+ state_group_after_event = (
+ await self._state_storage_controller.store_state_group(
+ event.event_id,
+ event.room_id,
+ prev_group=state_group_before_event,
+ delta_ids=delta_ids,
+ current_state_ids=state_ids_after_event,
+ )
)
return EventContext.with_state(
- storage=self._storage,
+ storage=self._storage_controllers,
state_group=state_group_after_event,
state_group_before_event=state_group_before_event,
state_delta_due_to_event=delta_ids,
@@ -416,7 +420,9 @@ class StateHandler:
"""
logger.debug("resolve_state_groups event_ids %s", event_ids)
- state_groups = await self.state_storage.get_state_group_for_events(event_ids)
+ state_groups = await self._state_storage_controller.get_state_group_for_events(
+ event_ids
+ )
state_group_ids = state_groups.values()
@@ -424,8 +430,13 @@ class StateHandler:
state_group_ids_set = set(state_group_ids)
if len(state_group_ids_set) == 1:
(state_group_id,) = state_group_ids_set
- state = await self.state_storage.get_state_for_groups(state_group_ids_set)
- prev_group, delta_ids = await self.state_storage.get_state_group_delta(
+ state = await self._state_storage_controller.get_state_for_groups(
+ state_group_ids_set
+ )
+ (
+ prev_group,
+ delta_ids,
+ ) = await self._state_storage_controller.get_state_group_delta(
state_group_id
)
return _StateCacheEntry(
@@ -439,7 +450,7 @@ class StateHandler:
room_version = await self.store.get_room_version_id(room_id)
- state_to_resolve = await self.state_storage.get_state_for_groups(
+ state_to_resolve = await self._state_storage_controller.get_state_for_groups(
state_group_ids_set
)
|