summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/federation.py53
-rw-r--r--synapse/handlers/room.py14
-rw-r--r--synapse/storage/controllers/state.py3
-rw-r--r--synapse/storage/databases/main/roommember.py88
4 files changed, 88 insertions, 70 deletions
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index e151962055..dd4b9f66d1 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -70,7 +70,7 @@ from synapse.replication.http.federation import (
 from synapse.storage.databases.main.events import PartialStateConflictError
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.storage.state import StateFilter
-from synapse.types import JsonDict, StateMap, get_domain_from_id
+from synapse.types import JsonDict, get_domain_from_id
 from synapse.util.async_helpers import Linearizer
 from synapse.util.retryutils import NotRetryingDestination
 from synapse.visibility import filter_events_for_server
@@ -104,37 +104,6 @@ backfill_processing_before_timer = Histogram(
 )
 
 
-def get_domains_from_state(state: StateMap[EventBase]) -> List[Tuple[str, int]]:
-    """Get joined domains from state
-
-    Args:
-        state: State map from type/state key to event.
-
-    Returns:
-        Returns a list of servers with the lowest depth of their joins.
-            Sorted by lowest depth first.
-    """
-    joined_users = [
-        (state_key, int(event.depth))
-        for (e_type, state_key), event in state.items()
-        if e_type == EventTypes.Member and event.membership == Membership.JOIN
-    ]
-
-    joined_domains: Dict[str, int] = {}
-    for u, d in joined_users:
-        try:
-            dom = get_domain_from_id(u)
-            old_d = joined_domains.get(dom)
-            if old_d:
-                joined_domains[dom] = min(d, old_d)
-            else:
-                joined_domains[dom] = d
-        except Exception:
-            pass
-
-    return sorted(joined_domains.items(), key=lambda d: d[1])
-
-
 class _BackfillPointType(Enum):
     # a regular backwards extremity (ie, an event which we don't yet have, but which
     # is referred to by other events in the DAG)
@@ -432,21 +401,19 @@ class FederationHandler:
         )
 
         # Now we need to decide which hosts to hit first.
-
-        # First we try hosts that are already in the room
+        # First we try hosts that are already in the room.
         # TODO: HEURISTIC ALERT.
+        likely_domains = (
+            await self._storage_controllers.state.get_current_hosts_in_room(room_id)
+        )
 
-        curr_state = await self._storage_controllers.state.get_current_state(room_id)
-
-        curr_domains = get_domains_from_state(curr_state)
-
-        likely_domains = [
-            domain for domain, depth in curr_domains if domain != self.server_name
-        ]
-
-        async def try_backfill(domains: List[str]) -> bool:
+        async def try_backfill(domains: Collection[str]) -> bool:
             # TODO: Should we try multiple of these at a time?
             for dom in domains:
+                # We don't want to ask our own server for information we don't have
+                if dom == self.server_name:
+                    continue
+
                 try:
                     await self._federation_event_handler.backfill(
                         dom, room_id, limit=100, extremities=extremities_to_request
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 2fc8264858..f64a8690a5 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -60,7 +60,6 @@ from synapse.event_auth import validate_event_for_room_version
 from synapse.events import EventBase
 from synapse.events.utils import copy_and_fixup_power_levels_contents
 from synapse.federation.federation_client import InvalidResponseError
-from synapse.handlers.federation import get_domains_from_state
 from synapse.handlers.relations import BundledAggregations
 from synapse.module_api import NOT_SPAM
 from synapse.rest.admin._base import assert_user_is_admin
@@ -1462,17 +1461,16 @@ class TimestampLookupHandler:
                 timestamp,
             )
 
-            # Find other homeservers from the given state in the room
-            curr_state = await self._storage_controllers.state.get_current_state(
-                room_id
+            likely_domains = (
+                await self._storage_controllers.state.get_current_hosts_in_room(room_id)
             )
-            curr_domains = get_domains_from_state(curr_state)
-            likely_domains = [
-                domain for domain, depth in curr_domains if domain != self.server_name
-            ]
 
             # Loop through each homeserver candidate until we get a succesful response
             for domain in likely_domains:
+                # We don't want to ask our own server for information we don't have
+                if domain == self.server_name:
+                    continue
+
                 try:
                     remote_response = await self.federation_client.timestamp_to_event(
                         domain, room_id, timestamp, direction
diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
index f9ffd0e29e..ba5380ce3e 100644
--- a/synapse/storage/controllers/state.py
+++ b/synapse/storage/controllers/state.py
@@ -23,7 +23,6 @@ from typing import (
     List,
     Mapping,
     Optional,
-    Set,
     Tuple,
 )
 
@@ -520,7 +519,7 @@ class StateStorageController:
         )
         return state_map.get(key)
 
-    async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
+    async def get_current_hosts_in_room(self, room_id: str) -> List[str]:
         """Get current hosts in room based on current state."""
 
         await self._partial_state_room_tracker.await_full_state(room_id)
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 9e5034b401..06500457bd 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -187,27 +187,48 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
     @cached(max_entries=100000, iterable=True)
     async def get_users_in_room(self, room_id: str) -> List[str]:
+        """
+        Returns a list of users in the room sorted by longest in the room first
+        (aka. with the lowest depth). This is done to match the sort in
+        `get_current_hosts_in_room()` and so we can re-use the cache but it's
+        not horrible to have here either.
+        """
+
         return await self.db_pool.runInteraction(
             "get_users_in_room", self.get_users_in_room_txn, room_id
         )
 
     def get_users_in_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[str]:
+        """
+        Returns a list of users in the room sorted by longest in the room first
+        (aka. with the lowest depth). This is done to match the sort in
+        `get_current_hosts_in_room()` and so we can re-use the cache but it's
+        not horrible to have here either.
+        """
         # If we can assume current_state_events.membership is up to date
         # then we can avoid a join, which is a Very Good Thing given how
         # frequently this function gets called.
         if self._current_state_events_membership_up_to_date:
             sql = """
-                SELECT state_key FROM current_state_events
-                WHERE type = 'm.room.member' AND room_id = ? AND membership = ?
+                SELECT c.state_key FROM current_state_events as c
+                /* Get the depth of the event from the events table */
+                INNER JOIN events AS e USING (event_id)
+                WHERE c.type = 'm.room.member' AND c.room_id = ? AND membership = ?
+                /* Sorted by lowest depth first */
+                ORDER BY e.depth ASC;
             """
         else:
             sql = """
-                SELECT state_key FROM room_memberships as m
+                SELECT c.state_key FROM room_memberships as m
+                /* Get the depth of the event from the events table */
+                INNER JOIN events AS e USING (event_id)
                 INNER JOIN current_state_events as c
                 ON m.event_id = c.event_id
                 AND m.room_id = c.room_id
                 AND m.user_id = c.state_key
                 WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?
+                /* Sorted by lowest depth first */
+                ORDER BY e.depth ASC;
             """
 
         txn.execute(sql, (room_id, Membership.JOIN))
@@ -1037,37 +1058,70 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         return True
 
     @cached(iterable=True, max_entries=10000)
-    async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
-        """Get current hosts in room based on current state."""
+    async def get_current_hosts_in_room(self, room_id: str) -> List[str]:
+        """
+        Get current hosts in room based on current state.
+
+        The heuristic of sorting by servers who have been in the room the
+        longest is good because they're most likely to have anything we ask
+        about.
+
+        Returns:
+            Returns a list of servers sorted by longest in the room first. (aka.
+            sorted by join with the lowest depth first).
+        """
 
         # First we check if we already have `get_users_in_room` in the cache, as
         # we can just calculate result from that
         users = self.get_users_in_room.cache.get_immediate(
             (room_id,), None, update_metrics=False
         )
-        if users is not None:
-            return {get_domain_from_id(u) for u in users}
-
-        if isinstance(self.database_engine, Sqlite3Engine):
+        if users is None and isinstance(self.database_engine, Sqlite3Engine):
             # If we're using SQLite then let's just always use
             # `get_users_in_room` rather than funky SQL.
             users = await self.get_users_in_room(room_id)
-            return {get_domain_from_id(u) for u in users}
+
+        if users is not None:
+            # Because `users` is sorted from lowest -> highest depth, the list
+            # of domains will also be sorted that way.
+            domains: List[str] = []
+            # We use a `Set` just for fast lookups
+            domain_set: Set[str] = set()
+            for u in users:
+                domain = get_domain_from_id(u)
+                if domain not in domain_set:
+                    domain_set.add(domain)
+                    domains.append(domain)
+            return domains
 
         # For PostgreSQL we can use a regex to pull out the domains from the
         # joined users in `current_state_events` via regex.
 
-        def get_current_hosts_in_room_txn(txn: LoggingTransaction) -> Set[str]:
+        def get_current_hosts_in_room_txn(txn: LoggingTransaction) -> List[str]:
+            # Returns a list of servers currently joined in the room sorted by
+            # longest in the room first (aka. with the lowest depth). The
+            # heuristic of sorting by servers who have been in the room the
+            # longest is good because they're most likely to have anything we
+            # ask about.
             sql = """
-                SELECT DISTINCT substring(state_key FROM '@[^:]*:(.*)$')
-                FROM current_state_events
+                SELECT
+                    /* Match the domain part of the MXID */
+                    substring(c.state_key FROM '@[^:]*:(.*)$') as server_domain
+                FROM current_state_events c
+                /* Get the depth of the event from the events table */
+                INNER JOIN events AS e USING (event_id)
                 WHERE
-                    type = 'm.room.member'
-                    AND membership = 'join'
-                    AND room_id = ?
+                    /* Find any join state events in the room */
+                    c.type = 'm.room.member'
+                    AND c.membership = 'join'
+                    AND c.room_id = ?
+                /* Group all state events from the same domain into their own buckets (groups) */
+                GROUP BY server_domain
+                /* Sorted by lowest depth first */
+                ORDER BY min(e.depth) ASC;
             """
             txn.execute(sql, (room_id,))
-            return {d for d, in txn}
+            return [d for d, in txn]
 
         return await self.db_pool.runInteraction(
             "get_current_hosts_in_room", get_current_hosts_in_room_txn