summary refs log tree commit diff
path: root/synapse/storage/controllers/persist_events.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/controllers/persist_events.py')
-rw-r--r--synapse/storage/controllers/persist_events.py15
1 files changed, 12 insertions, 3 deletions
diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py
index ea499ce0f8..cf98b0ab48 100644
--- a/synapse/storage/controllers/persist_events.py
+++ b/synapse/storage/controllers/persist_events.py
@@ -48,9 +48,11 @@ from synapse.events.snapshot import EventContext
 from synapse.logging import opentracing
 from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.controllers.state import StateStorageController
 from synapse.storage.databases import Databases
 from synapse.storage.databases.main.events import DeltaState
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
+from synapse.storage.state import StateFilter
 from synapse.types import (
     PersistedEventPosition,
     RoomStreamToken,
@@ -308,7 +310,12 @@ class EventsPersistenceStorageController:
     current state and forward extremity changes.
     """
 
-    def __init__(self, hs: "HomeServer", stores: Databases):
+    def __init__(
+        self,
+        hs: "HomeServer",
+        stores: Databases,
+        state_controller: StateStorageController,
+    ):
         # We ultimately want to split out the state store from the main store,
         # so we use separate variables here even though they point to the same
         # store for now.
@@ -325,6 +332,7 @@ class EventsPersistenceStorageController:
             self._process_event_persist_queue_task
         )
         self._state_resolution_handler = hs.get_state_resolution_handler()
+        self._state_controller = state_controller
 
     async def _process_event_persist_queue_task(
         self,
@@ -504,7 +512,7 @@ class EventsPersistenceStorageController:
             state_res_store=StateResolutionStore(self.main_store),
         )
 
-        return res.state
+        return await res.get_state(self._state_controller, StateFilter.all())
 
     async def _persist_event_batch(
         self, _room_id: str, task: _PersistEventsTask
@@ -940,7 +948,8 @@ class EventsPersistenceStorageController:
                 events_context,
             )
 
-        return res.state, None, new_latest_event_ids
+        full_state = await res.get_state(self._state_controller)
+        return full_state, None, new_latest_event_ids
 
     async def _prune_extremities(
         self,