summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/admin.py12
-rw-r--r--synapse/handlers/device.py4
-rw-r--r--synapse/handlers/events.py4
-rw-r--r--synapse/handlers/federation.py30
-rw-r--r--synapse/handlers/federation_event.py27
-rw-r--r--synapse/handlers/initial_sync.py17
-rw-r--r--synapse/handlers/message.py30
-rw-r--r--synapse/handlers/pagination.py17
-rw-r--r--synapse/handlers/relations.py7
-rw-r--r--synapse/handlers/room.py11
-rw-r--r--synapse/handlers/room_batch.py4
-rw-r--r--synapse/handlers/search.py14
-rw-r--r--synapse/handlers/sync.py26
13 files changed, 119 insertions, 84 deletions
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index 50e34743b7..d4fe7df533 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -30,8 +30,8 @@ logger = logging.getLogger(__name__)
 class AdminHandler:
     def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastores().main
-        self.storage = hs.get_storage()
-        self.state_storage = self.storage.state
+        self._storage_controllers = hs.get_storage_controllers()
+        self._state_storage_controller = self._storage_controllers.state
 
     async def get_whois(self, user: UserID) -> JsonDict:
         connections = []
@@ -197,7 +197,9 @@ class AdminHandler:
 
                 from_key = events[-1].internal_metadata.after
 
-                events = await filter_events_for_client(self.storage, user_id, events)
+                events = await filter_events_for_client(
+                    self._storage_controllers, user_id, events
+                )
 
                 writer.write_events(room_id, events)
 
@@ -233,7 +235,9 @@ class AdminHandler:
             for event_id in extremities:
                 if not event_to_unseen_prevs[event_id]:
                     continue
-                state = await self.state_storage.get_state_for_event(event_id)
+                state = await self._state_storage_controller.get_state_for_event(
+                    event_id
+                )
                 writer.write_state(room_id, event_id, state)
 
         return writer.finished()
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 2a56473dc6..72faf2ee38 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -71,7 +71,7 @@ class DeviceWorkerHandler:
         self.store = hs.get_datastores().main
         self.notifier = hs.get_notifier()
         self.state = hs.get_state_handler()
-        self.state_storage = hs.get_storage().state
+        self._state_storage = hs.get_storage_controllers().state
         self._auth_handler = hs.get_auth_handler()
         self.server_name = hs.hostname
 
@@ -204,7 +204,7 @@ class DeviceWorkerHandler:
                 continue
 
             # mapping from event_id -> state_dict
-            prev_state_ids = await self.state_storage.get_state_ids_for_events(
+            prev_state_ids = await self._state_storage.get_state_ids_for_events(
                 event_ids
             )
 
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index cb7e0ca7a8..ac13340d3a 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -139,7 +139,7 @@ class EventStreamHandler:
 class EventHandler:
     def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastores().main
-        self.storage = hs.get_storage()
+        self._storage_controllers = hs.get_storage_controllers()
 
     async def get_event(
         self,
@@ -177,7 +177,7 @@ class EventHandler:
         is_peeking = user.to_string() not in users
 
         filtered = await filter_events_for_client(
-            self.storage, user.to_string(), [event], is_peeking=is_peeking
+            self._storage_controllers, user.to_string(), [event], is_peeking=is_peeking
         )
 
         if not filtered:
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index c8233270d7..80ee7e7b4e 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -125,8 +125,8 @@ class FederationHandler:
         self.hs = hs
 
         self.store = hs.get_datastores().main
-        self.storage = hs.get_storage()
-        self.state_storage = self.storage.state
+        self._storage_controllers = hs.get_storage_controllers()
+        self._state_storage_controller = self._storage_controllers.state
         self.federation_client = hs.get_federation_client()
         self.state_handler = hs.get_state_handler()
         self.server_name = hs.hostname
@@ -324,7 +324,7 @@ class FederationHandler:
             # We set `check_history_visibility_only` as we might otherwise get false
             # positives from users having been erased.
             filtered_extremities = await filter_events_for_server(
-                self.storage,
+                self._storage_controllers,
                 self.server_name,
                 events_to_check,
                 redact=False,
@@ -660,7 +660,7 @@ class FederationHandler:
         # in the invitee's sync stream. It is stripped out for all other local users.
         event.unsigned["knock_room_state"] = stripped_room_state["knock_state_events"]
 
-        context = EventContext.for_outlier(self.storage)
+        context = EventContext.for_outlier(self._storage_controllers)
         stream_id = await self._federation_event_handler.persist_events_and_notify(
             event.room_id, [(event, context)]
         )
@@ -849,7 +849,7 @@ class FederationHandler:
             )
         )
 
-        context = EventContext.for_outlier(self.storage)
+        context = EventContext.for_outlier(self._storage_controllers)
         await self._federation_event_handler.persist_events_and_notify(
             event.room_id, [(event, context)]
         )
@@ -878,7 +878,7 @@ class FederationHandler:
 
         await self.federation_client.send_leave(host_list, event)
 
-        context = EventContext.for_outlier(self.storage)
+        context = EventContext.for_outlier(self._storage_controllers)
         stream_id = await self._federation_event_handler.persist_events_and_notify(
             event.room_id, [(event, context)]
         )
@@ -1027,7 +1027,7 @@ class FederationHandler:
         if event.internal_metadata.outlier:
             raise NotFoundError("State not known at event %s" % (event_id,))
 
-        state_groups = await self.state_storage.get_state_groups_ids(
+        state_groups = await self._state_storage_controller.get_state_groups_ids(
             room_id, [event_id]
         )
 
@@ -1078,7 +1078,9 @@ class FederationHandler:
             ],
         )
 
-        events = await filter_events_for_server(self.storage, origin, events)
+        events = await filter_events_for_server(
+            self._storage_controllers, origin, events
+        )
 
         return events
 
@@ -1109,7 +1111,9 @@ class FederationHandler:
             if not in_room:
                 raise AuthError(403, "Host not in room.")
 
-            events = await filter_events_for_server(self.storage, origin, [event])
+            events = await filter_events_for_server(
+                self._storage_controllers, origin, [event]
+            )
             event = events[0]
             return event
         else:
@@ -1138,7 +1142,7 @@ class FederationHandler:
         )
 
         missing_events = await filter_events_for_server(
-            self.storage, origin, missing_events
+            self._storage_controllers, origin, missing_events
         )
 
         return missing_events
@@ -1480,9 +1484,11 @@ class FederationHandler:
                 # clear the lazy-loading flag.
                 logger.info("Updating current state for %s", room_id)
                 assert (
-                    self.storage.persistence is not None
+                    self._storage_controllers.persistence is not None
                 ), "TODO(faster_joins): support for workers"
-                await self.storage.persistence.update_current_state(room_id)
+                await self._storage_controllers.persistence.update_current_state(
+                    room_id
+                )
 
                 logger.info("Clearing partial-state flag for %s", room_id)
                 success = await self.store.clear_partial_state_room(room_id)
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index a1361af272..b908674529 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -98,8 +98,8 @@ class FederationEventHandler:
 
     def __init__(self, hs: "HomeServer"):
         self._store = hs.get_datastores().main
-        self._storage = hs.get_storage()
-        self._state_storage = self._storage.state
+        self._storage_controllers = hs.get_storage_controllers()
+        self._state_storage_controller = self._storage_controllers.state
 
         self._state_handler = hs.get_state_handler()
         self._event_creation_handler = hs.get_event_creation_handler()
@@ -535,7 +535,9 @@ class FederationEventHandler:
                 )
                 return
             await self._store.update_state_for_partial_state_event(event, context)
-            self._state_storage.notify_event_un_partial_stated(event.event_id)
+            self._state_storage_controller.notify_event_un_partial_stated(
+                event.event_id
+            )
 
     async def backfill(
         self, dest: str, room_id: str, limit: int, extremities: Collection[str]
@@ -835,7 +837,9 @@ class FederationEventHandler:
 
         try:
             # Get the state of the events we know about
-            ours = await self._state_storage.get_state_groups_ids(room_id, seen)
+            ours = await self._state_storage_controller.get_state_groups_ids(
+                room_id, seen
+            )
 
             # state_maps is a list of mappings from (type, state_key) to event_id
             state_maps: List[StateMap[str]] = list(ours.values())
@@ -1436,7 +1440,7 @@ class FederationEventHandler:
                 # we're not bothering about room state, so flag the event as an outlier.
                 event.internal_metadata.outlier = True
 
-                context = EventContext.for_outlier(self._storage)
+                context = EventContext.for_outlier(self._storage_controllers)
                 try:
                     validate_event_for_room_version(room_version_obj, event)
                     check_auth_rules_for_event(room_version_obj, event, auth)
@@ -1613,7 +1617,7 @@ class FederationEventHandler:
             # given state at the event. This should correctly handle cases
             # like bans, especially with state res v2.
 
-            state_sets_d = await self._state_storage.get_state_groups_ids(
+            state_sets_d = await self._state_storage_controller.get_state_groups_ids(
                 event.room_id, extrem_ids
             )
             state_sets: List[StateMap[str]] = list(state_sets_d.values())
@@ -1885,7 +1889,7 @@ class FederationEventHandler:
 
         # create a new state group as a delta from the existing one.
         prev_group = context.state_group
-        state_group = await self._state_storage.store_state_group(
+        state_group = await self._state_storage_controller.store_state_group(
             event.event_id,
             event.room_id,
             prev_group=prev_group,
@@ -1894,7 +1898,7 @@ class FederationEventHandler:
         )
 
         return EventContext.with_state(
-            storage=self._storage,
+            storage=self._storage_controllers,
             state_group=state_group,
             state_group_before_event=context.state_group_before_event,
             state_delta_due_to_event=state_updates,
@@ -1984,11 +1988,14 @@ class FederationEventHandler:
                 )
             return result["max_stream_id"]
         else:
-            assert self._storage.persistence
+            assert self._storage_controllers.persistence
 
             # Note that this returns the events that were persisted, which may not be
             # the same as were passed in if some were deduplicated due to transaction IDs.
-            events, max_stream_token = await self._storage.persistence.persist_events(
+            (
+                events,
+                max_stream_token,
+            ) = await self._storage_controllers.persistence.persist_events(
                 event_and_contexts, backfilled=backfilled
             )
 
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index fbdbeeedfd..d2b489e816 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -67,8 +67,8 @@ class InitialSyncHandler:
             ]
         ] = ResponseCache(hs.get_clock(), "initial_sync_cache")
         self._event_serializer = hs.get_event_client_serializer()
-        self.storage = hs.get_storage()
-        self.state_storage = self.storage.state
+        self._storage_controllers = hs.get_storage_controllers()
+        self._state_storage_controller = self._storage_controllers.state
 
     async def snapshot_all_rooms(
         self,
@@ -198,7 +198,8 @@ class InitialSyncHandler:
                         event.stream_ordering,
                     )
                     deferred_room_state = run_in_background(
-                        self.state_storage.get_state_for_events, [event.event_id]
+                        self._state_storage_controller.get_state_for_events,
+                        [event.event_id],
                     ).addCallback(
                         lambda states: cast(StateMap[EventBase], states[event.event_id])
                     )
@@ -218,7 +219,7 @@ class InitialSyncHandler:
                 ).addErrback(unwrapFirstError)
 
                 messages = await filter_events_for_client(
-                    self.storage, user_id, messages
+                    self._storage_controllers, user_id, messages
                 )
 
                 start_token = now_token.copy_and_replace(StreamKeyType.ROOM, token)
@@ -355,7 +356,9 @@ class InitialSyncHandler:
         member_event_id: str,
         is_peeking: bool,
     ) -> JsonDict:
-        room_state = await self.state_storage.get_state_for_event(member_event_id)
+        room_state = await self._state_storage_controller.get_state_for_event(
+            member_event_id
+        )
 
         limit = pagin_config.limit if pagin_config else None
         if limit is None:
@@ -369,7 +372,7 @@ class InitialSyncHandler:
         )
 
         messages = await filter_events_for_client(
-            self.storage, user_id, messages, is_peeking=is_peeking
+            self._storage_controllers, user_id, messages, is_peeking=is_peeking
         )
 
         start_token = StreamToken.START.copy_and_replace(StreamKeyType.ROOM, token)
@@ -474,7 +477,7 @@ class InitialSyncHandler:
         )
 
         messages = await filter_events_for_client(
-            self.storage, user_id, messages, is_peeking=is_peeking
+            self._storage_controllers, user_id, messages, is_peeking=is_peeking
         )
 
         start_token = now_token.copy_and_replace(StreamKeyType.ROOM, token)
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 38b71a2c96..f377769071 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -84,8 +84,8 @@ class MessageHandler:
         self.clock = hs.get_clock()
         self.state = hs.get_state_handler()
         self.store = hs.get_datastores().main
-        self.storage = hs.get_storage()
-        self.state_storage = self.storage.state
+        self._storage_controllers = hs.get_storage_controllers()
+        self._state_storage_controller = self._storage_controllers.state
         self._event_serializer = hs.get_event_client_serializer()
         self._ephemeral_events_enabled = hs.config.server.enable_ephemeral_messages
 
@@ -132,7 +132,7 @@ class MessageHandler:
             assert (
                 membership_event_id is not None
             ), "check_user_in_room_or_world_readable returned invalid data"
-            room_state = await self.state_storage.get_state_for_events(
+            room_state = await self._state_storage_controller.get_state_for_events(
                 [membership_event_id], StateFilter.from_types([key])
             )
             data = room_state[membership_event_id].get(key)
@@ -193,7 +193,7 @@ class MessageHandler:
 
             # check whether the user is in the room at that time to determine
             # whether they should be treated as peeking.
-            state_map = await self.state_storage.get_state_for_event(
+            state_map = await self._state_storage_controller.get_state_for_event(
                 last_event.event_id,
                 StateFilter.from_types([(EventTypes.Member, user_id)]),
             )
@@ -206,7 +206,7 @@ class MessageHandler:
             is_peeking = not joined
 
             visible_events = await filter_events_for_client(
-                self.storage,
+                self._storage_controllers,
                 user_id,
                 [last_event],
                 filter_send_to_client=False,
@@ -214,8 +214,10 @@ class MessageHandler:
             )
 
             if visible_events:
-                room_state_events = await self.state_storage.get_state_for_events(
-                    [last_event.event_id], state_filter=state_filter
+                room_state_events = (
+                    await self._state_storage_controller.get_state_for_events(
+                        [last_event.event_id], state_filter=state_filter
+                    )
                 )
                 room_state: Mapping[Any, EventBase] = room_state_events[
                     last_event.event_id
@@ -244,8 +246,10 @@ class MessageHandler:
                 assert (
                     membership_event_id is not None
                 ), "check_user_in_room_or_world_readable returned invalid data"
-                room_state_events = await self.state_storage.get_state_for_events(
-                    [membership_event_id], state_filter=state_filter
+                room_state_events = (
+                    await self._state_storage_controller.get_state_for_events(
+                        [membership_event_id], state_filter=state_filter
+                    )
                 )
                 room_state = room_state_events[membership_event_id]
 
@@ -402,7 +406,7 @@ class EventCreationHandler:
         self.auth = hs.get_auth()
         self._event_auth_handler = hs.get_event_auth_handler()
         self.store = hs.get_datastores().main
-        self.storage = hs.get_storage()
+        self._storage_controllers = hs.get_storage_controllers()
         self.state = hs.get_state_handler()
         self.clock = hs.get_clock()
         self.validator = EventValidator()
@@ -1032,7 +1036,7 @@ class EventCreationHandler:
         # after it is created
         if builder.internal_metadata.outlier:
             event.internal_metadata.outlier = True
-            context = EventContext.for_outlier(self.storage)
+            context = EventContext.for_outlier(self._storage_controllers)
         elif (
             event.type == EventTypes.MSC2716_INSERTION
             and state_event_ids
@@ -1445,7 +1449,7 @@ class EventCreationHandler:
         """
         extra_users = extra_users or []
 
-        assert self.storage.persistence is not None
+        assert self._storage_controllers.persistence is not None
         assert self._events_shard_config.should_handle(
             self._instance_name, event.room_id
         )
@@ -1679,7 +1683,7 @@ class EventCreationHandler:
             event,
             event_pos,
             max_stream_token,
-        ) = await self.storage.persistence.persist_event(
+        ) = await self._storage_controllers.persistence.persist_event(
             event, context=context, backfilled=backfilled
         )
 
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 35afe6b855..6262a35822 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -129,8 +129,8 @@ class PaginationHandler:
         self.hs = hs
         self.auth = hs.get_auth()
         self.store = hs.get_datastores().main
-        self.storage = hs.get_storage()
-        self.state_storage = self.storage.state
+        self._storage_controllers = hs.get_storage_controllers()
+        self._state_storage_controller = self._storage_controllers.state
         self.clock = hs.get_clock()
         self._server_name = hs.hostname
         self._room_shutdown_handler = hs.get_room_shutdown_handler()
@@ -352,7 +352,7 @@ class PaginationHandler:
         self._purges_in_progress_by_room.add(room_id)
         try:
             async with self.pagination_lock.write(room_id):
-                await self.storage.purge_events.purge_history(
+                await self._storage_controllers.purge_events.purge_history(
                     room_id, token, delete_local_events
                 )
             logger.info("[purge] complete")
@@ -414,7 +414,7 @@ class PaginationHandler:
                 if joined:
                     raise SynapseError(400, "Users are still joined to this room")
 
-            await self.storage.purge_events.purge_room(room_id)
+            await self._storage_controllers.purge_events.purge_room(room_id)
 
     async def get_messages(
         self,
@@ -529,7 +529,10 @@ class PaginationHandler:
             events = await event_filter.filter(events)
 
         events = await filter_events_for_client(
-            self.storage, user_id, events, is_peeking=(member_event_id is None)
+            self._storage_controllers,
+            user_id,
+            events,
+            is_peeking=(member_event_id is None),
         )
 
         # if after the filter applied there are no more events
@@ -550,7 +553,7 @@ class PaginationHandler:
                 (EventTypes.Member, event.sender) for event in events
             )
 
-            state_ids = await self.state_storage.get_state_ids_for_event(
+            state_ids = await self._state_storage_controller.get_state_ids_for_event(
                 events[0].event_id, state_filter=state_filter
             )
 
@@ -664,7 +667,7 @@ class PaginationHandler:
                                 400, "Users are still joined to this room"
                             )
 
-                    await self.storage.purge_events.purge_room(room_id)
+                    await self._storage_controllers.purge_events.purge_room(room_id)
 
             logger.info("complete")
             self._delete_by_id[delete_id].status = DeleteStatus.STATUS_COMPLETE
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index ab7e54857d..9a1cc11bb3 100644
--- a/synapse/handlers/relations.py
+++ b/synapse/handlers/relations.py
@@ -69,7 +69,7 @@ class BundledAggregations:
 class RelationsHandler:
     def __init__(self, hs: "HomeServer"):
         self._main_store = hs.get_datastores().main
-        self._storage = hs.get_storage()
+        self._storage_controllers = hs.get_storage_controllers()
         self._auth = hs.get_auth()
         self._clock = hs.get_clock()
         self._event_handler = hs.get_event_handler()
@@ -143,7 +143,10 @@ class RelationsHandler:
         )
 
         events = await filter_events_for_client(
-            self._storage, user_id, events, is_peeking=(member_event_id is None)
+            self._storage_controllers,
+            user_id,
+            events,
+            is_peeking=(member_event_id is None),
         )
 
         now = self._clock.time_msec()
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index e2775b34f1..5c91d33f58 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -1192,8 +1192,8 @@ class RoomContextHandler:
         self.hs = hs
         self.auth = hs.get_auth()
         self.store = hs.get_datastores().main
-        self.storage = hs.get_storage()
-        self.state_storage = self.storage.state
+        self._storage_controllers = hs.get_storage_controllers()
+        self._state_storage_controller = self._storage_controllers.state
         self._relations_handler = hs.get_relations_handler()
 
     async def get_event_context(
@@ -1236,7 +1236,10 @@ class RoomContextHandler:
             if use_admin_priviledge:
                 return events
             return await filter_events_for_client(
-                self.storage, user.to_string(), events, is_peeking=is_peeking
+                self._storage_controllers,
+                user.to_string(),
+                events,
+                is_peeking=is_peeking,
             )
 
         event = await self.store.get_event(
@@ -1293,7 +1296,7 @@ class RoomContextHandler:
         # first? Shouldn't we be consistent with /sync?
         # https://github.com/matrix-org/matrix-doc/issues/687
 
-        state = await self.state_storage.get_state_for_events(
+        state = await self._state_storage_controller.get_state_for_events(
             [last_event_id], state_filter=state_filter
         )
 
diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py
index 7ce32f2e9c..1414e575d6 100644
--- a/synapse/handlers/room_batch.py
+++ b/synapse/handlers/room_batch.py
@@ -17,7 +17,7 @@ class RoomBatchHandler:
     def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.store = hs.get_datastores().main
-        self.state_storage = hs.get_storage().state
+        self._state_storage_controller = hs.get_storage_controllers().state
         self.event_creation_handler = hs.get_event_creation_handler()
         self.room_member_handler = hs.get_room_member_handler()
         self.auth = hs.get_auth()
@@ -141,7 +141,7 @@ class RoomBatchHandler:
         ) = await self.store.get_max_depth_of(event_ids)
         # mapping from (type, state_key) -> state_event_id
         assert most_recent_event_id is not None
-        prev_state_map = await self.state_storage.get_state_ids_for_event(
+        prev_state_map = await self._state_storage_controller.get_state_ids_for_event(
             most_recent_event_id
         )
         # List of state event ID's
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index e02c915248..659f99f7e2 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -55,8 +55,8 @@ class SearchHandler:
         self.hs = hs
         self._event_serializer = hs.get_event_client_serializer()
         self._relations_handler = hs.get_relations_handler()
-        self.storage = hs.get_storage()
-        self.state_storage = self.storage.state
+        self._storage_controllers = hs.get_storage_controllers()
+        self._state_storage_controller = self._storage_controllers.state
         self.auth = hs.get_auth()
 
     async def get_old_rooms_from_upgraded_room(self, room_id: str) -> Iterable[str]:
@@ -460,7 +460,7 @@ class SearchHandler:
         filtered_events = await search_filter.filter([r["event"] for r in results])
 
         events = await filter_events_for_client(
-            self.storage, user.to_string(), filtered_events
+            self._storage_controllers, user.to_string(), filtered_events
         )
 
         events.sort(key=lambda e: -rank_map[e.event_id])
@@ -559,7 +559,7 @@ class SearchHandler:
             filtered_events = await search_filter.filter([r["event"] for r in results])
 
             events = await filter_events_for_client(
-                self.storage, user.to_string(), filtered_events
+                self._storage_controllers, user.to_string(), filtered_events
             )
 
             room_events.extend(events)
@@ -644,11 +644,11 @@ class SearchHandler:
             )
 
             events_before = await filter_events_for_client(
-                self.storage, user.to_string(), res.events_before
+                self._storage_controllers, user.to_string(), res.events_before
             )
 
             events_after = await filter_events_for_client(
-                self.storage, user.to_string(), res.events_after
+                self._storage_controllers, user.to_string(), res.events_after
             )
 
             context: JsonDict = {
@@ -677,7 +677,7 @@ class SearchHandler:
                     [(EventTypes.Member, sender) for sender in senders]
                 )
 
-                state = await self.state_storage.get_state_for_event(
+                state = await self._state_storage_controller.get_state_for_event(
                     last_event_id, state_filter
                 )
 
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index c5c538e0c3..b5859dcb28 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -238,8 +238,8 @@ class SyncHandler:
         self.clock = hs.get_clock()
         self.state = hs.get_state_handler()
         self.auth = hs.get_auth()
-        self.storage = hs.get_storage()
-        self.state_storage = self.storage.state
+        self._storage_controllers = hs.get_storage_controllers()
+        self._state_storage_controller = self._storage_controllers.state
 
         # TODO: flush cache entries on subsequent sync request.
         #    Once we get the next /sync request (ie, one with the same access token
@@ -512,7 +512,7 @@ class SyncHandler:
                     current_state_ids = frozenset(current_state_ids_map.values())
 
                 recents = await filter_events_for_client(
-                    self.storage,
+                    self._storage_controllers,
                     sync_config.user.to_string(),
                     recents,
                     always_include_ids=current_state_ids,
@@ -580,7 +580,7 @@ class SyncHandler:
                     current_state_ids = frozenset(current_state_ids_map.values())
 
                 loaded_recents = await filter_events_for_client(
-                    self.storage,
+                    self._storage_controllers,
                     sync_config.user.to_string(),
                     loaded_recents,
                     always_include_ids=current_state_ids,
@@ -630,7 +630,7 @@ class SyncHandler:
             event: event of interest
             state_filter: The state filter used to fetch state from the database.
         """
-        state_ids = await self.state_storage.get_state_ids_for_event(
+        state_ids = await self._state_storage_controller.get_state_ids_for_event(
             event.event_id, state_filter=state_filter or StateFilter.all()
         )
         if event.is_state():
@@ -710,7 +710,7 @@ class SyncHandler:
             return None
 
         last_event = last_events[-1]
-        state_ids = await self.state_storage.get_state_ids_for_event(
+        state_ids = await self._state_storage_controller.get_state_ids_for_event(
             last_event.event_id,
             state_filter=StateFilter.from_types(
                 [(EventTypes.Name, ""), (EventTypes.CanonicalAlias, "")]
@@ -889,13 +889,15 @@ class SyncHandler:
             if full_state:
                 if batch:
                     current_state_ids = (
-                        await self.state_storage.get_state_ids_for_event(
+                        await self._state_storage_controller.get_state_ids_for_event(
                             batch.events[-1].event_id, state_filter=state_filter
                         )
                     )
 
-                    state_ids = await self.state_storage.get_state_ids_for_event(
-                        batch.events[0].event_id, state_filter=state_filter
+                    state_ids = (
+                        await self._state_storage_controller.get_state_ids_for_event(
+                            batch.events[0].event_id, state_filter=state_filter
+                        )
                     )
 
                 else:
@@ -915,7 +917,7 @@ class SyncHandler:
             elif batch.limited:
                 if batch:
                     state_at_timeline_start = (
-                        await self.state_storage.get_state_ids_for_event(
+                        await self._state_storage_controller.get_state_ids_for_event(
                             batch.events[0].event_id, state_filter=state_filter
                         )
                     )
@@ -950,7 +952,7 @@ class SyncHandler:
 
                 if batch:
                     current_state_ids = (
-                        await self.state_storage.get_state_ids_for_event(
+                        await self._state_storage_controller.get_state_ids_for_event(
                             batch.events[-1].event_id, state_filter=state_filter
                         )
                     )
@@ -982,7 +984,7 @@ class SyncHandler:
                         # So we fish out all the member events corresponding to the
                         # timeline here, and then dedupe any redundant ones below.
 
-                        state_ids = await self.state_storage.get_state_ids_for_event(
+                        state_ids = await self._state_storage_controller.get_state_ids_for_event(
                             batch.events[0].event_id,
                             # we only want members!
                             state_filter=StateFilter.from_types(