summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/controllers/__init__.py4
-rw-r--r--synapse/storage/controllers/persist_events.py12
-rw-r--r--synapse/storage/databases/main/roommember.py35
3 files changed, 21 insertions, 30 deletions
diff --git a/synapse/storage/controllers/__init__.py b/synapse/storage/controllers/__init__.py
index 55649719f6..45101cda7a 100644
--- a/synapse/storage/controllers/__init__.py
+++ b/synapse/storage/controllers/__init__.py
@@ -43,4 +43,6 @@ class StorageControllers:
 
         self.persistence = None
         if stores.persist_events:
-            self.persistence = EventsPersistenceStorageController(hs, stores)
+            self.persistence = EventsPersistenceStorageController(
+                hs, stores, self.state
+            )
diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py
index ea499ce0f8..af65e5913b 100644
--- a/synapse/storage/controllers/persist_events.py
+++ b/synapse/storage/controllers/persist_events.py
@@ -48,9 +48,11 @@ from synapse.events.snapshot import EventContext
 from synapse.logging import opentracing
 from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.controllers.state import StateStorageController
 from synapse.storage.databases import Databases
 from synapse.storage.databases.main.events import DeltaState
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
+from synapse.storage.state import StateFilter
 from synapse.types import (
     PersistedEventPosition,
     RoomStreamToken,
@@ -308,7 +310,12 @@ class EventsPersistenceStorageController:
     current state and forward extremity changes.
     """
 
-    def __init__(self, hs: "HomeServer", stores: Databases):
+    def __init__(
+        self,
+        hs: "HomeServer",
+        stores: Databases,
+        state_controller: StateStorageController,
+    ):
         # We ultimately want to split out the state store from the main store,
         # so we use separate variables here even though they point to the same
         # store for now.
@@ -325,6 +332,7 @@ class EventsPersistenceStorageController:
             self._process_event_persist_queue_task
         )
         self._state_resolution_handler = hs.get_state_resolution_handler()
+        self._state_controller = state_controller
 
     async def _process_event_persist_queue_task(
         self,
@@ -504,7 +512,7 @@ class EventsPersistenceStorageController:
             state_res_store=StateResolutionStore(self.main_store),
         )
 
-        return res.state
+        return await res.get_state(self._state_controller, StateFilter.all())
 
     async def _persist_event_batch(
         self, _room_id: str, task: _PersistEventsTask
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 0b5e4e4254..71a65d565a 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -31,7 +31,6 @@ import attr
 
 from synapse.api.constants import EventTypes, Membership
 from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
 from synapse.metrics import LaterGauge
 from synapse.metrics.background_process_metrics import (
     run_as_background_process,
@@ -780,26 +779,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         return shared_room_ids or frozenset()
 
-    async def get_joined_users_from_context(
-        self, event: EventBase, context: EventContext
-    ) -> Dict[str, ProfileInfo]:
-        state_group: Union[object, int] = context.state_group
-        if not state_group:
-            # If state_group is None it means it has yet to be assigned a
-            # state group, i.e. we need to make sure that calls with a state_group
-            # of None don't hit previous cached calls with a None state_group.
-            # To do this we set the state_group to a new object as object() != object()
-            state_group = object()
-
-        current_state_ids = await context.get_current_state_ids()
-        assert current_state_ids is not None
-        assert state_group is not None
-        return await self._get_joined_users_from_context(
-            event.room_id, state_group, current_state_ids, event=event, context=context
-        )
-
     async def get_joined_users_from_state(
-        self, room_id: str, state_entry: "_StateCacheEntry"
+        self, room_id: str, state: StateMap[str], state_entry: "_StateCacheEntry"
     ) -> Dict[str, ProfileInfo]:
         state_group: Union[object, int] = state_entry.state_group
         if not state_group:
@@ -812,18 +793,17 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         assert state_group is not None
         with Measure(self._clock, "get_joined_users_from_state"):
             return await self._get_joined_users_from_context(
-                room_id, state_group, state_entry.state, context=state_entry
+                room_id, state_group, state, context=state_entry
             )
 
-    @cached(num_args=2, cache_context=True, iterable=True, max_entries=100000)
+    @cached(num_args=2, iterable=True, max_entries=100000)
     async def _get_joined_users_from_context(
         self,
         room_id: str,
         state_group: Union[object, int],
         current_state_ids: StateMap[str],
-        cache_context: _CacheContext,
         event: Optional[EventBase] = None,
-        context: Optional[Union[EventContext, "_StateCacheEntry"]] = None,
+        context: Optional["_StateCacheEntry"] = None,
     ) -> Dict[str, ProfileInfo]:
         # We don't use `state_group`, it's there so that we can cache based
         # on it. However, it's important that it's never None, since two current_states
@@ -1017,7 +997,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         )
 
     async def get_joined_hosts(
-        self, room_id: str, state_entry: "_StateCacheEntry"
+        self, room_id: str, state: StateMap[str], state_entry: "_StateCacheEntry"
     ) -> FrozenSet[str]:
         state_group: Union[object, int] = state_entry.state_group
         if not state_group:
@@ -1030,7 +1010,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         assert state_group is not None
         with Measure(self._clock, "get_joined_hosts"):
             return await self._get_joined_hosts(
-                room_id, state_group, state_entry=state_entry
+                room_id, state_group, state, state_entry=state_entry
             )
 
     @cached(num_args=2, max_entries=10000, iterable=True)
@@ -1038,6 +1018,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         self,
         room_id: str,
         state_group: Union[object, int],
+        state: StateMap[str],
         state_entry: "_StateCacheEntry",
     ) -> FrozenSet[str]:
         # We don't use `state_group`, it's there so that we can cache based on
@@ -1093,7 +1074,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
                 # The cache doesn't match the state group or prev state group,
                 # so we calculate the result from first principles.
                 joined_users = await self.get_joined_users_from_state(
-                    room_id, state_entry
+                    room_id, state, state_entry
                 )
 
                 cache.hosts_to_joined_users = {}