summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/databases/main/roommember.py122
1 files changed, 38 insertions, 84 deletions
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 06500457bd..4f0adb136a 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -31,7 +31,6 @@ from typing import (
 import attr
 
 from synapse.api.constants import EventTypes, Membership
-from synapse.events import EventBase
 from synapse.metrics import LaterGauge
 from synapse.metrics.background_process_metrics import (
     run_as_background_process,
@@ -883,96 +882,51 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         return shared_room_ids or frozenset()
 
     async def get_joined_user_ids_from_state(
-        self, room_id: str, state: StateMap[str], state_entry: "_StateCacheEntry"
+        self, room_id: str, state: StateMap[str]
     ) -> Set[str]:
-        state_group: Union[object, int] = state_entry.state_group
-        if not state_group:
-            # If state_group is None it means it has yet to be assigned a
-            # state group, i.e. we need to make sure that calls with a state_group
-            # of None don't hit previous cached calls with a None state_group.
-            # To do this we set the state_group to a new object as object() != object()
-            state_group = object()
-
-        assert state_group is not None
-        with Measure(self._clock, "get_joined_users_from_state"):
-            return await self._get_joined_user_ids_from_context(
-                room_id, state_group, state, context=state_entry
-            )
+        """
+        For a given set of state IDs, get a set of user IDs in the room.
 
-    @cached(num_args=2, iterable=True, max_entries=100000)
-    async def _get_joined_user_ids_from_context(
-        self,
-        room_id: str,
-        state_group: Union[object, int],
-        current_state_ids: StateMap[str],
-        event: Optional[EventBase] = None,
-        context: Optional["_StateCacheEntry"] = None,
-    ) -> Set[str]:
-        # We don't use `state_group`, it's there so that we can cache based
-        # on it. However, it's important that it's never None, since two current_states
-        # with a state_group of None are likely to be different.
-        assert state_group is not None
+        This method checks the local event cache, before calling
+        `_get_user_ids_from_membership_event_ids` for any uncached events.
+        """
 
-        users_in_room = set()
-        member_event_ids = [
-            e_id
-            for key, e_id in current_state_ids.items()
-            if key[0] == EventTypes.Member
-        ]
-
-        if context is not None:
-            # If we have a context with a delta from a previous state group,
-            # check if we also have the result from the previous group in cache.
-            # If we do then we can reuse that result and simply update it with
-            # any membership changes in `delta_ids`
-            if context.prev_group and context.delta_ids:
-                prev_res = self._get_joined_user_ids_from_context.cache.get_immediate(
-                    (room_id, context.prev_group), None
-                )
-                if prev_res and isinstance(prev_res, set):
-                    users_in_room = prev_res
-                    member_event_ids = [
-                        e_id
-                        for key, e_id in context.delta_ids.items()
-                        if key[0] == EventTypes.Member
-                    ]
-                    for etype, state_key in context.delta_ids:
-                        if etype == EventTypes.Member:
-                            users_in_room.discard(state_key)
-
-        # We check if we have any of the member event ids in the event cache
-        # before we ask the DB
-
-        # We don't update the event cache hit ratio as it completely throws off
-        # the hit ratio counts. After all, we don't populate the cache if we
-        # miss it here
-        event_map = self._get_events_from_local_cache(
-            member_event_ids, update_metrics=False
-        )
+        with Measure(self._clock, "get_joined_user_ids_from_state"):
+            users_in_room = set()
+            member_event_ids = [
+                e_id for key, e_id in state.items() if key[0] == EventTypes.Member
+            ]
 
-        missing_member_event_ids = []
-        for event_id in member_event_ids:
-            ev_entry = event_map.get(event_id)
-            if ev_entry and not ev_entry.event.rejected_reason:
-                if ev_entry.event.membership == Membership.JOIN:
-                    users_in_room.add(ev_entry.event.state_key)
-            else:
-                missing_member_event_ids.append(event_id)
+            # We check if we have any of the member event ids in the event cache
+            # before we ask the DB
 
-        if missing_member_event_ids:
-            event_to_memberships = await self._get_user_ids_from_membership_event_ids(
-                missing_member_event_ids
-            )
-            users_in_room.update(
-                user_id for user_id in event_to_memberships.values() if user_id
+            # We don't update the event cache hit ratio as it completely throws off
+            # the hit ratio counts. After all, we don't populate the cache if we
+            # miss it here
+            event_map = self._get_events_from_local_cache(
+                member_event_ids, update_metrics=False
             )
 
-        if event is not None and event.type == EventTypes.Member:
-            if event.membership == Membership.JOIN:
-                if event.event_id in member_event_ids:
-                    users_in_room.add(event.state_key)
+            missing_member_event_ids = []
+            for event_id in member_event_ids:
+                ev_entry = event_map.get(event_id)
+                if ev_entry and not ev_entry.event.rejected_reason:
+                    if ev_entry.event.membership == Membership.JOIN:
+                        users_in_room.add(ev_entry.event.state_key)
+                else:
+                    missing_member_event_ids.append(event_id)
+
+            if missing_member_event_ids:
+                event_to_memberships = (
+                    await self._get_user_ids_from_membership_event_ids(
+                        missing_member_event_ids
+                    )
+                )
+                users_in_room.update(
+                    user_id for user_id in event_to_memberships.values() if user_id
+                )
 
-        return users_in_room
+            return users_in_room
 
     @cached(
         max_entries=10000,
@@ -1205,7 +1159,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
                 # The cache doesn't match the state group or prev state group,
                 # so we calculate the result from first principles.
                 joined_user_ids = await self.get_joined_user_ids_from_state(
-                    room_id, state, state_entry
+                    room_id, state
                 )
 
                 cache.hosts_to_joined_users = {}