diff options
Diffstat (limited to 'synapse/state')
-rw-r--r-- | synapse/state/__init__.py | 51 |
1 files changed, 31 insertions, 20 deletions
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 9c9d946f38..bf09f5128a 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -127,10 +127,10 @@ class StateHandler: def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() self.store = hs.get_datastores().main - self.state_storage = hs.get_storage().state + self._state_storage_controller = hs.get_storage_controllers().state self.hs = hs self._state_resolution_handler = hs.get_state_resolution_handler() - self._storage = hs.get_storage() + self._storage_controllers = hs.get_storage_controllers() @overload async def get_current_state( @@ -337,12 +337,14 @@ class StateHandler: # if not state_group_before_event: - state_group_before_event = await self.state_storage.store_state_group( - event.event_id, - event.room_id, - prev_group=state_group_before_event_prev_group, - delta_ids=deltas_to_state_group_before_event, - current_state_ids=state_ids_before_event, + state_group_before_event = ( + await self._state_storage_controller.store_state_group( + event.event_id, + event.room_id, + prev_group=state_group_before_event_prev_group, + delta_ids=deltas_to_state_group_before_event, + current_state_ids=state_ids_before_event, + ) ) # Assign the new state group to the cached state entry. @@ -359,7 +361,7 @@ class StateHandler: if not event.is_state(): return EventContext.with_state( - storage=self._storage, + storage=self._storage_controllers, state_group_before_event=state_group_before_event, state_group=state_group_before_event, state_delta_due_to_event={}, @@ -382,16 +384,18 @@ class StateHandler: state_ids_after_event[key] = event.event_id delta_ids = {key: event.event_id} - state_group_after_event = await self.state_storage.store_state_group( - event.event_id, - event.room_id, - prev_group=state_group_before_event, - delta_ids=delta_ids, - current_state_ids=state_ids_after_event, + state_group_after_event = ( + await self._state_storage_controller.store_state_group( + event.event_id, + event.room_id, + prev_group=state_group_before_event, + delta_ids=delta_ids, + current_state_ids=state_ids_after_event, + ) ) return EventContext.with_state( - storage=self._storage, + storage=self._storage_controllers, state_group=state_group_after_event, state_group_before_event=state_group_before_event, state_delta_due_to_event=delta_ids, @@ -416,7 +420,9 @@ class StateHandler: """ logger.debug("resolve_state_groups event_ids %s", event_ids) - state_groups = await self.state_storage.get_state_group_for_events(event_ids) + state_groups = await self._state_storage_controller.get_state_group_for_events( + event_ids + ) state_group_ids = state_groups.values() @@ -424,8 +430,13 @@ class StateHandler: state_group_ids_set = set(state_group_ids) if len(state_group_ids_set) == 1: (state_group_id,) = state_group_ids_set - state = await self.state_storage.get_state_for_groups(state_group_ids_set) - prev_group, delta_ids = await self.state_storage.get_state_group_delta( + state = await self._state_storage_controller.get_state_for_groups( + state_group_ids_set + ) + ( + prev_group, + delta_ids, + ) = await self._state_storage_controller.get_state_group_delta( state_group_id ) return _StateCacheEntry( @@ -439,7 +450,7 @@ class StateHandler: room_version = await self.store.get_room_version_id(room_id) - state_to_resolve = await self.state_storage.get_state_for_groups( + state_to_resolve = await self._state_storage_controller.get_state_for_groups( state_group_ids_set ) |