summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/storage/databases/main/roommember.py46
1 files changed, 44 insertions, 2 deletions
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 8e2ba7b7b4..ea6a5e2f34 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -13,6 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+from itertools import chain
 from typing import (
     TYPE_CHECKING,
     AbstractSet,
@@ -1131,12 +1132,33 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             else:
                 # The cache doesn't match the state group or prev state group,
                 # so we calculate the result from first principles.
+                #
+                # We need to fetch all hosts joined to the room according to `state` by
+                # inspecting all join memberships in `state`. However, if the `state` is
+                # relatively recent then many of its events are likely to be held in
+                # the current state of the room, which is easily available and likely
+                # cached.
+                #
+                # We therefore compute the set of `state` events not in the
+                # current state and only fetch those.
+                current_memberships = (
+                    await self._get_approximate_current_memberships_in_room(room_id)
+                )
+                unknown_state_events = {}
+                joined_users_in_current_state = []
+
+                for (type, state_key), event_id in state.items():
+                    if event_id not in current_memberships:
+                        unknown_state_events[type, state_key] = event_id
+                    elif current_memberships[event_id] == Membership.JOIN:
+                        joined_users_in_current_state.append(state_key)
+
                 joined_user_ids = await self.get_joined_user_ids_from_state(
-                    room_id, state
+                    room_id, unknown_state_events
                 )
 
                 cache.hosts_to_joined_users = {}
-                for user_id in joined_user_ids:
+                for user_id in chain(joined_user_ids, joined_users_in_current_state):
                     host = intern_string(get_domain_from_id(user_id))
                     cache.hosts_to_joined_users.setdefault(host, set()).add(user_id)
 
@@ -1147,6 +1169,26 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         return frozenset(cache.hosts_to_joined_users)
 
+    async def _get_approximate_current_memberships_in_room(
+        self, room_id: str
+    ) -> Mapping[str, Optional[str]]:
+        """Build a map from event id to membership, for all events in the current state.
+
+        The event ids of non-memberships events (e.g. `m.room.power_levels`) are present
+        in the result, mapped to values of `None`.
+
+        The result is approximate for partially-joined rooms. It is fully accurate
+        for fully-joined rooms.
+        """
+
+        rows = await self.db_pool.simple_select_list(
+            "current_state_events",
+            keyvalues={"room_id": room_id},
+            retcols=("event_id", "membership"),
+            desc="has_completed_background_updates",
+        )
+        return {row["event_id"]: row["membership"] for row in rows}
+
     @cached(max_entries=10000)
     def _get_joined_hosts_cache(self, room_id: str) -> "_JoinedHostsCache":
         return _JoinedHostsCache()