summary refs log tree commit diff
path: root/synapse/storage/databases/main/roommember.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/roommember.py')
-rw-r--r--synapse/storage/databases/main/roommember.py35
1 files changed, 8 insertions, 27 deletions
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 0b5e4e4254..71a65d565a 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -31,7 +31,6 @@ import attr
 
 from synapse.api.constants import EventTypes, Membership
 from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
 from synapse.metrics import LaterGauge
 from synapse.metrics.background_process_metrics import (
     run_as_background_process,
@@ -780,26 +779,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         return shared_room_ids or frozenset()
 
-    async def get_joined_users_from_context(
-        self, event: EventBase, context: EventContext
-    ) -> Dict[str, ProfileInfo]:
-        state_group: Union[object, int] = context.state_group
-        if not state_group:
-            # If state_group is None it means it has yet to be assigned a
-            # state group, i.e. we need to make sure that calls with a state_group
-            # of None don't hit previous cached calls with a None state_group.
-            # To do this we set the state_group to a new object as object() != object()
-            state_group = object()
-
-        current_state_ids = await context.get_current_state_ids()
-        assert current_state_ids is not None
-        assert state_group is not None
-        return await self._get_joined_users_from_context(
-            event.room_id, state_group, current_state_ids, event=event, context=context
-        )
-
     async def get_joined_users_from_state(
-        self, room_id: str, state_entry: "_StateCacheEntry"
+        self, room_id: str, state: StateMap[str], state_entry: "_StateCacheEntry"
     ) -> Dict[str, ProfileInfo]:
         state_group: Union[object, int] = state_entry.state_group
         if not state_group:
@@ -812,18 +793,17 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         assert state_group is not None
         with Measure(self._clock, "get_joined_users_from_state"):
             return await self._get_joined_users_from_context(
-                room_id, state_group, state_entry.state, context=state_entry
+                room_id, state_group, state, context=state_entry
             )
 
-    @cached(num_args=2, cache_context=True, iterable=True, max_entries=100000)
+    @cached(num_args=2, iterable=True, max_entries=100000)
     async def _get_joined_users_from_context(
         self,
         room_id: str,
         state_group: Union[object, int],
         current_state_ids: StateMap[str],
-        cache_context: _CacheContext,
         event: Optional[EventBase] = None,
-        context: Optional[Union[EventContext, "_StateCacheEntry"]] = None,
+        context: Optional["_StateCacheEntry"] = None,
     ) -> Dict[str, ProfileInfo]:
         # We don't use `state_group`, it's there so that we can cache based
         # on it. However, it's important that it's never None, since two current_states
@@ -1017,7 +997,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         )
 
     async def get_joined_hosts(
-        self, room_id: str, state_entry: "_StateCacheEntry"
+        self, room_id: str, state: StateMap[str], state_entry: "_StateCacheEntry"
     ) -> FrozenSet[str]:
         state_group: Union[object, int] = state_entry.state_group
         if not state_group:
@@ -1030,7 +1010,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         assert state_group is not None
         with Measure(self._clock, "get_joined_hosts"):
             return await self._get_joined_hosts(
-                room_id, state_group, state_entry=state_entry
+                room_id, state_group, state, state_entry=state_entry
             )
 
     @cached(num_args=2, max_entries=10000, iterable=True)
@@ -1038,6 +1018,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         self,
         room_id: str,
         state_group: Union[object, int],
+        state: StateMap[str],
         state_entry: "_StateCacheEntry",
     ) -> FrozenSet[str]:
         # We don't use `state_group`, it's there so that we can cache based on
@@ -1093,7 +1074,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
                 # The cache doesn't match the state group or prev state group,
                 # so we calculate the result from first principles.
                 joined_users = await self.get_joined_users_from_state(
-                    room_id, state_entry
+                    room_id, state, state_entry
                 )
 
                 cache.hosts_to_joined_users = {}