summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/device.py6
-rw-r--r--synapse/storage/controllers/state.py44
-rw-r--r--synapse/storage/databases/main/room.py17
3 files changed, 65 insertions, 2 deletions
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 901e2310b7..6566b3bf3d 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -688,11 +688,15 @@ class DeviceHandler(DeviceWorkerHandler):
                     # Ignore any users that aren't ours
                     if self.hs.is_mine_id(user_id):
                         hosts = set(
-                            await self._storage_controllers.state.get_current_hosts_in_room(
+                            await self._storage_controllers.state.get_current_hosts_in_room_or_partial_state_approximation(
                                 room_id
                             )
                         )
                         hosts.discard(self.server_name)
+                        # For rooms with partial state, `hosts` is merely an
+                        # approximation. When we transition to a full state room, we
+                        # will have to send out device list updates to any servers we
+                        # missed.
 
                     # Check if we've already sent this update to some hosts
                     if current_stream_id == stream_id:
diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
index bbe568bf05..b1aa17047c 100644
--- a/synapse/storage/controllers/state.py
+++ b/synapse/storage/controllers/state.py
@@ -23,6 +23,7 @@ from typing import (
     List,
     Mapping,
     Optional,
+    Sequence,
     Tuple,
 )
 
@@ -524,12 +525,53 @@ class StateStorageController:
         return state_map.get(key)
 
     async def get_current_hosts_in_room(self, room_id: str) -> List[str]:
-        """Get current hosts in room based on current state."""
+        """Get current hosts in room based on current state.
+
+        Blocks until we have full state for the given room. This only happens for rooms
+        with partial state.
+
+        Returns:
+            A list of hosts in the room, sorted by longest in the room first. (aka.
+            sorted by join with the lowest depth first).
+        """
 
         await self._partial_state_room_tracker.await_full_state(room_id)
 
         return await self.stores.main.get_current_hosts_in_room(room_id)
 
+    async def get_current_hosts_in_room_or_partial_state_approximation(
+        self, room_id: str
+    ) -> Sequence[str]:
+        """Get approximation of current hosts in room based on current state.
+
+        For rooms with full state, this is equivalent to `get_current_hosts_in_room`,
+        with the same order of results.
+
+        For rooms with partial state, no blocking occurs. Instead, the list of hosts
+        in the room at the time of joining is combined with the list of hosts which
+        joined the room afterwards. The returned list may include hosts that are not
+        actually in the room and exclude hosts that are in the room, since we may
+        calculate state incorrectly during the partial state phase. The order of results
+        is arbitrary for rooms with partial state.
+        """
+        # We have to read this list first to mitigate races with un-partial stating.
+        # This will be empty for rooms with full state.
+        hosts_at_join = await self.stores.main.get_partial_state_servers_at_join(
+            room_id
+        )
+
+        hosts_from_state = await self.stores.main.get_current_hosts_in_room(room_id)
+        hosts_from_state_set = set(hosts_from_state)
+
+        # First take the list of hosts based on the current state.
+        # For rooms with partial state, this will be missing most hosts.
+        hosts = list(hosts_from_state)
+        # Then add in the list of hosts in the room at the time we joined.
+        # This will be an empty list for rooms with full state.
+        hosts.extend(host for host in hosts_at_join if host not in hosts_from_state_set)
+
+        return hosts
+
     async def get_users_in_room_with_profiles(
         self, room_id: str
     ) -> Dict[str, ProfileInfo]:
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index bef66f1992..5dd116d766 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -25,6 +25,7 @@ from typing import (
     List,
     Mapping,
     Optional,
+    Sequence,
     Tuple,
     Union,
     cast,
@@ -1133,6 +1134,22 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
             get_rooms_for_retention_period_in_range_txn,
         )
 
+    async def get_partial_state_servers_at_join(self, room_id: str) -> Sequence[str]:
+        """Gets the list of servers in a partial state room at the time we joined it.
+
+        Returns:
+            The `servers_in_room` list from the `/send_join` response for partial state
+            rooms. May not be accurate or complete, as it comes from a remote
+            homeserver.
+            An empty list for full state rooms.
+        """
+        return await self.db_pool.simple_select_onecol(
+            "partial_state_rooms_servers",
+            keyvalues={"room_id": room_id},
+            retcol="server_name",
+            desc="get_partial_state_servers_at_join",
+        )
+
     async def get_partial_state_rooms_and_servers(
         self,
     ) -> Mapping[str, Collection[str]]: