summary refs log tree commit diff
path: root/synapse/handlers/sync.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/sync.py')
-rw-r--r--synapse/handlers/sync.py26
1 files changed, 14 insertions, 12 deletions
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index c5c538e0c3..b5859dcb28 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -238,8 +238,8 @@ class SyncHandler:
         self.clock = hs.get_clock()
         self.state = hs.get_state_handler()
         self.auth = hs.get_auth()
-        self.storage = hs.get_storage()
-        self.state_storage = self.storage.state
+        self._storage_controllers = hs.get_storage_controllers()
+        self._state_storage_controller = self._storage_controllers.state
 
         # TODO: flush cache entries on subsequent sync request.
         #    Once we get the next /sync request (ie, one with the same access token
@@ -512,7 +512,7 @@ class SyncHandler:
                     current_state_ids = frozenset(current_state_ids_map.values())
 
                 recents = await filter_events_for_client(
-                    self.storage,
+                    self._storage_controllers,
                     sync_config.user.to_string(),
                     recents,
                     always_include_ids=current_state_ids,
@@ -580,7 +580,7 @@ class SyncHandler:
                     current_state_ids = frozenset(current_state_ids_map.values())
 
                 loaded_recents = await filter_events_for_client(
-                    self.storage,
+                    self._storage_controllers,
                     sync_config.user.to_string(),
                     loaded_recents,
                     always_include_ids=current_state_ids,
@@ -630,7 +630,7 @@ class SyncHandler:
             event: event of interest
             state_filter: The state filter used to fetch state from the database.
         """
-        state_ids = await self.state_storage.get_state_ids_for_event(
+        state_ids = await self._state_storage_controller.get_state_ids_for_event(
             event.event_id, state_filter=state_filter or StateFilter.all()
         )
         if event.is_state():
@@ -710,7 +710,7 @@ class SyncHandler:
             return None
 
         last_event = last_events[-1]
-        state_ids = await self.state_storage.get_state_ids_for_event(
+        state_ids = await self._state_storage_controller.get_state_ids_for_event(
             last_event.event_id,
             state_filter=StateFilter.from_types(
                 [(EventTypes.Name, ""), (EventTypes.CanonicalAlias, "")]
@@ -889,13 +889,15 @@ class SyncHandler:
             if full_state:
                 if batch:
                     current_state_ids = (
-                        await self.state_storage.get_state_ids_for_event(
+                        await self._state_storage_controller.get_state_ids_for_event(
                             batch.events[-1].event_id, state_filter=state_filter
                         )
                     )
 
-                    state_ids = await self.state_storage.get_state_ids_for_event(
-                        batch.events[0].event_id, state_filter=state_filter
+                    state_ids = (
+                        await self._state_storage_controller.get_state_ids_for_event(
+                            batch.events[0].event_id, state_filter=state_filter
+                        )
                     )
 
                 else:
@@ -915,7 +917,7 @@ class SyncHandler:
             elif batch.limited:
                 if batch:
                     state_at_timeline_start = (
-                        await self.state_storage.get_state_ids_for_event(
+                        await self._state_storage_controller.get_state_ids_for_event(
                             batch.events[0].event_id, state_filter=state_filter
                         )
                     )
@@ -950,7 +952,7 @@ class SyncHandler:
 
                 if batch:
                     current_state_ids = (
-                        await self.state_storage.get_state_ids_for_event(
+                        await self._state_storage_controller.get_state_ids_for_event(
                             batch.events[-1].event_id, state_filter=state_filter
                         )
                     )
@@ -982,7 +984,7 @@ class SyncHandler:
                         # So we fish out all the member events corresponding to the
                         # timeline here, and then dedupe any redundant ones below.
 
-                        state_ids = await self.state_storage.get_state_ids_for_event(
+                        state_ids = await self._state_storage_controller.get_state_ids_for_event(
                             batch.events[0].event_id,
                             # we only want members!
                             state_filter=StateFilter.from_types(