diff --git a/synapse/storage/controllers/__init__.py b/synapse/storage/controllers/__init__.py
index 55649719f6..45101cda7a 100644
--- a/synapse/storage/controllers/__init__.py
+++ b/synapse/storage/controllers/__init__.py
@@ -43,4 +43,6 @@ class StorageControllers:
self.persistence = None
if stores.persist_events:
- self.persistence = EventsPersistenceStorageController(hs, stores)
+ self.persistence = EventsPersistenceStorageController(
+ hs, stores, self.state
+ )
diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py
index ea499ce0f8..af65e5913b 100644
--- a/synapse/storage/controllers/persist_events.py
+++ b/synapse/storage/controllers/persist_events.py
@@ -48,9 +48,11 @@ from synapse.events.snapshot import EventContext
from synapse.logging import opentracing
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.controllers.state import StateStorageController
from synapse.storage.databases import Databases
from synapse.storage.databases.main.events import DeltaState
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
+from synapse.storage.state import StateFilter
from synapse.types import (
PersistedEventPosition,
RoomStreamToken,
@@ -308,7 +310,12 @@ class EventsPersistenceStorageController:
current state and forward extremity changes.
"""
- def __init__(self, hs: "HomeServer", stores: Databases):
+ def __init__(
+ self,
+ hs: "HomeServer",
+ stores: Databases,
+ state_controller: StateStorageController,
+ ):
# We ultimately want to split out the state store from the main store,
# so we use separate variables here even though they point to the same
# store for now.
@@ -325,6 +332,7 @@ class EventsPersistenceStorageController:
self._process_event_persist_queue_task
)
self._state_resolution_handler = hs.get_state_resolution_handler()
+ self._state_controller = state_controller
async def _process_event_persist_queue_task(
self,
@@ -504,7 +512,7 @@ class EventsPersistenceStorageController:
state_res_store=StateResolutionStore(self.main_store),
)
- return res.state
+ return await res.get_state(self._state_controller, StateFilter.all())
async def _persist_event_batch(
self, _room_id: str, task: _PersistEventsTask
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 0b5e4e4254..71a65d565a 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -31,7 +31,6 @@ import attr
from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import (
run_as_background_process,
@@ -780,26 +779,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return shared_room_ids or frozenset()
- async def get_joined_users_from_context(
- self, event: EventBase, context: EventContext
- ) -> Dict[str, ProfileInfo]:
- state_group: Union[object, int] = context.state_group
- if not state_group:
- # If state_group is None it means it has yet to be assigned a
- # state group, i.e. we need to make sure that calls with a state_group
- # of None don't hit previous cached calls with a None state_group.
- # To do this we set the state_group to a new object as object() != object()
- state_group = object()
-
- current_state_ids = await context.get_current_state_ids()
- assert current_state_ids is not None
- assert state_group is not None
- return await self._get_joined_users_from_context(
- event.room_id, state_group, current_state_ids, event=event, context=context
- )
-
async def get_joined_users_from_state(
- self, room_id: str, state_entry: "_StateCacheEntry"
+ self, room_id: str, state: StateMap[str], state_entry: "_StateCacheEntry"
) -> Dict[str, ProfileInfo]:
state_group: Union[object, int] = state_entry.state_group
if not state_group:
@@ -812,18 +793,17 @@ class RoomMemberWorkerStore(EventsWorkerStore):
assert state_group is not None
with Measure(self._clock, "get_joined_users_from_state"):
return await self._get_joined_users_from_context(
- room_id, state_group, state_entry.state, context=state_entry
+ room_id, state_group, state, context=state_entry
)
- @cached(num_args=2, cache_context=True, iterable=True, max_entries=100000)
+ @cached(num_args=2, iterable=True, max_entries=100000)
async def _get_joined_users_from_context(
self,
room_id: str,
state_group: Union[object, int],
current_state_ids: StateMap[str],
- cache_context: _CacheContext,
event: Optional[EventBase] = None,
- context: Optional[Union[EventContext, "_StateCacheEntry"]] = None,
+ context: Optional["_StateCacheEntry"] = None,
) -> Dict[str, ProfileInfo]:
# We don't use `state_group`, it's there so that we can cache based
# on it. However, it's important that it's never None, since two current_states
@@ -1017,7 +997,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
async def get_joined_hosts(
- self, room_id: str, state_entry: "_StateCacheEntry"
+ self, room_id: str, state: StateMap[str], state_entry: "_StateCacheEntry"
) -> FrozenSet[str]:
state_group: Union[object, int] = state_entry.state_group
if not state_group:
@@ -1030,7 +1010,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
assert state_group is not None
with Measure(self._clock, "get_joined_hosts"):
return await self._get_joined_hosts(
- room_id, state_group, state_entry=state_entry
+ room_id, state_group, state, state_entry=state_entry
)
@cached(num_args=2, max_entries=10000, iterable=True)
@@ -1038,6 +1018,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
self,
room_id: str,
state_group: Union[object, int],
+ state: StateMap[str],
state_entry: "_StateCacheEntry",
) -> FrozenSet[str]:
# We don't use `state_group`, it's there so that we can cache based on
@@ -1093,7 +1074,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# The cache doesn't match the state group or prev state group,
# so we calculate the result from first principles.
joined_users = await self.get_joined_users_from_state(
- room_id, state_entry
+ room_id, state, state_entry
)
cache.hosts_to_joined_users = {}
|