summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/controllers/state.py3
-rw-r--r--synapse/storage/databases/main/roommember.py88
2 files changed, 72 insertions, 19 deletions
diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
index f9ffd0e29e..ba5380ce3e 100644
--- a/synapse/storage/controllers/state.py
+++ b/synapse/storage/controllers/state.py
@@ -23,7 +23,6 @@ from typing import (
     List,
     Mapping,
     Optional,
-    Set,
     Tuple,
 )
 
@@ -520,7 +519,7 @@ class StateStorageController:
         )
         return state_map.get(key)
 
-    async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
+    async def get_current_hosts_in_room(self, room_id: str) -> List[str]:
         """Get current hosts in room based on current state."""
 
         await self._partial_state_room_tracker.await_full_state(room_id)
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 9e5034b401..06500457bd 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -187,27 +187,48 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
     @cached(max_entries=100000, iterable=True)
     async def get_users_in_room(self, room_id: str) -> List[str]:
+        """
+        Returns a list of users in the room sorted by longest in the room first
+        (aka. with the lowest depth). This is done to match the sort in
+        `get_current_hosts_in_room()` and so we can re-use the cache but it's
+        not horrible to have here either.
+        """
+
         return await self.db_pool.runInteraction(
             "get_users_in_room", self.get_users_in_room_txn, room_id
         )
 
     def get_users_in_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[str]:
+        """
+        Returns a list of users in the room sorted by longest in the room first
+        (aka. with the lowest depth). This is done to match the sort in
+        `get_current_hosts_in_room()` and so we can re-use the cache but it's
+        not horrible to have here either.
+        """
         # If we can assume current_state_events.membership is up to date
         # then we can avoid a join, which is a Very Good Thing given how
         # frequently this function gets called.
         if self._current_state_events_membership_up_to_date:
             sql = """
-                SELECT state_key FROM current_state_events
-                WHERE type = 'm.room.member' AND room_id = ? AND membership = ?
+                SELECT c.state_key FROM current_state_events as c
+                /* Get the depth of the event from the events table */
+                INNER JOIN events AS e USING (event_id)
+                WHERE c.type = 'm.room.member' AND c.room_id = ? AND membership = ?
+                /* Sorted by lowest depth first */
+                ORDER BY e.depth ASC;
             """
         else:
             sql = """
-                SELECT state_key FROM room_memberships as m
+                SELECT c.state_key FROM room_memberships as m
+                /* Get the depth of the event from the events table */
+                INNER JOIN events AS e USING (event_id)
                 INNER JOIN current_state_events as c
                 ON m.event_id = c.event_id
                 AND m.room_id = c.room_id
                 AND m.user_id = c.state_key
                 WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?
+                /* Sorted by lowest depth first */
+                ORDER BY e.depth ASC;
             """
 
         txn.execute(sql, (room_id, Membership.JOIN))
@@ -1037,37 +1058,70 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         return True
 
     @cached(iterable=True, max_entries=10000)
-    async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
-        """Get current hosts in room based on current state."""
+    async def get_current_hosts_in_room(self, room_id: str) -> List[str]:
+        """
+        Get current hosts in room based on current state.
+
+        The heuristic of sorting by servers who have been in the room the
+        longest is good because they're most likely to have anything we ask
+        about.
+
+        Returns:
+            Returns a list of servers sorted by longest in the room first. (aka.
+            sorted by join with the lowest depth first).
+        """
 
         # First we check if we already have `get_users_in_room` in the cache, as
         # we can just calculate result from that
         users = self.get_users_in_room.cache.get_immediate(
             (room_id,), None, update_metrics=False
         )
-        if users is not None:
-            return {get_domain_from_id(u) for u in users}
-
-        if isinstance(self.database_engine, Sqlite3Engine):
+        if users is None and isinstance(self.database_engine, Sqlite3Engine):
             # If we're using SQLite then let's just always use
             # `get_users_in_room` rather than funky SQL.
             users = await self.get_users_in_room(room_id)
-            return {get_domain_from_id(u) for u in users}
+
+        if users is not None:
+            # Because `users` is sorted from lowest -> highest depth, the list
+            # of domains will also be sorted that way.
+            domains: List[str] = []
+            # We use a `Set` just for fast lookups
+            domain_set: Set[str] = set()
+            for u in users:
+                domain = get_domain_from_id(u)
+                if domain not in domain_set:
+                    domain_set.add(domain)
+                    domains.append(domain)
+            return domains
 
         # For PostgreSQL we can use a regex to pull out the domains from the
         # joined users in `current_state_events` via regex.
 
-        def get_current_hosts_in_room_txn(txn: LoggingTransaction) -> Set[str]:
+        def get_current_hosts_in_room_txn(txn: LoggingTransaction) -> List[str]:
+            # Returns a list of servers currently joined in the room sorted by
+            # longest in the room first (aka. with the lowest depth). The
+            # heuristic of sorting by servers who have been in the room the
+            # longest is good because they're most likely to have anything we
+            # ask about.
             sql = """
-                SELECT DISTINCT substring(state_key FROM '@[^:]*:(.*)$')
-                FROM current_state_events
+                SELECT
+                    /* Match the domain part of the MXID */
+                    substring(c.state_key FROM '@[^:]*:(.*)$') as server_domain
+                FROM current_state_events c
+                /* Get the depth of the event from the events table */
+                INNER JOIN events AS e USING (event_id)
                 WHERE
-                    type = 'm.room.member'
-                    AND membership = 'join'
-                    AND room_id = ?
+                    /* Find any join state events in the room */
+                    c.type = 'm.room.member'
+                    AND c.membership = 'join'
+                    AND c.room_id = ?
+                /* Group all state events from the same domain into their own buckets (groups) */
+                GROUP BY server_domain
+                /* Sorted by lowest depth first */
+                ORDER BY min(e.depth) ASC;
             """
             txn.execute(sql, (room_id,))
-            return {d for d, in txn}
+            return [d for d, in txn]
 
         return await self.db_pool.runInteraction(
             "get_current_hosts_in_room", get_current_hosts_in_room_txn