diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index c5c538e0c3..b5859dcb28 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -238,8 +238,8 @@ class SyncHandler:
self.clock = hs.get_clock()
self.state = hs.get_state_handler()
self.auth = hs.get_auth()
- self.storage = hs.get_storage()
- self.state_storage = self.storage.state
+ self._storage_controllers = hs.get_storage_controllers()
+ self._state_storage_controller = self._storage_controllers.state
# TODO: flush cache entries on subsequent sync request.
# Once we get the next /sync request (ie, one with the same access token
@@ -512,7 +512,7 @@ class SyncHandler:
current_state_ids = frozenset(current_state_ids_map.values())
recents = await filter_events_for_client(
- self.storage,
+ self._storage_controllers,
sync_config.user.to_string(),
recents,
always_include_ids=current_state_ids,
@@ -580,7 +580,7 @@ class SyncHandler:
current_state_ids = frozenset(current_state_ids_map.values())
loaded_recents = await filter_events_for_client(
- self.storage,
+ self._storage_controllers,
sync_config.user.to_string(),
loaded_recents,
always_include_ids=current_state_ids,
@@ -630,7 +630,7 @@ class SyncHandler:
event: event of interest
state_filter: The state filter used to fetch state from the database.
"""
- state_ids = await self.state_storage.get_state_ids_for_event(
+ state_ids = await self._state_storage_controller.get_state_ids_for_event(
event.event_id, state_filter=state_filter or StateFilter.all()
)
if event.is_state():
@@ -710,7 +710,7 @@ class SyncHandler:
return None
last_event = last_events[-1]
- state_ids = await self.state_storage.get_state_ids_for_event(
+ state_ids = await self._state_storage_controller.get_state_ids_for_event(
last_event.event_id,
state_filter=StateFilter.from_types(
[(EventTypes.Name, ""), (EventTypes.CanonicalAlias, "")]
@@ -889,13 +889,15 @@ class SyncHandler:
if full_state:
if batch:
current_state_ids = (
- await self.state_storage.get_state_ids_for_event(
+ await self._state_storage_controller.get_state_ids_for_event(
batch.events[-1].event_id, state_filter=state_filter
)
)
- state_ids = await self.state_storage.get_state_ids_for_event(
- batch.events[0].event_id, state_filter=state_filter
+ state_ids = (
+ await self._state_storage_controller.get_state_ids_for_event(
+ batch.events[0].event_id, state_filter=state_filter
+ )
)
else:
@@ -915,7 +917,7 @@ class SyncHandler:
elif batch.limited:
if batch:
state_at_timeline_start = (
- await self.state_storage.get_state_ids_for_event(
+ await self._state_storage_controller.get_state_ids_for_event(
batch.events[0].event_id, state_filter=state_filter
)
)
@@ -950,7 +952,7 @@ class SyncHandler:
if batch:
current_state_ids = (
- await self.state_storage.get_state_ids_for_event(
+ await self._state_storage_controller.get_state_ids_for_event(
batch.events[-1].event_id, state_filter=state_filter
)
)
@@ -982,7 +984,7 @@ class SyncHandler:
# So we fish out all the member events corresponding to the
# timeline here, and then dedupe any redundant ones below.
- state_ids = await self.state_storage.get_state_ids_for_event(
+ state_ids = await self._state_storage_controller.get_state_ids_for_event(
batch.events[0].event_id,
# we only want members!
state_filter=StateFilter.from_types(
|