summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/14954.misc1
-rw-r--r--synapse/federation/federation_client.py33
-rw-r--r--synapse/federation/sender/__init__.py2
-rw-r--r--synapse/handlers/device.py1
-rw-r--r--synapse/handlers/federation.py20
-rw-r--r--synapse/storage/controllers/state.py3
-rw-r--r--synapse/storage/databases/main/room.py50
-rw-r--r--tests/handlers/test_federation.py2
-rw-r--r--tests/handlers/test_room_member.py2
9 files changed, 77 insertions, 37 deletions
diff --git a/changelog.d/14954.misc b/changelog.d/14954.misc
new file mode 100644
index 0000000000..b86b6bf01e
--- /dev/null
+++ b/changelog.d/14954.misc
@@ -0,0 +1 @@
+Faster room joins: Refactor internal handling of servers in room to never store an empty list.
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 8493ffc2e5..0ac85a3be7 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -19,6 +19,7 @@ import itertools
 import logging
 from typing import (
     TYPE_CHECKING,
+    AbstractSet,
     Awaitable,
     Callable,
     Collection,
@@ -110,8 +111,9 @@ class SendJoinResult:
     # True if 'state' elides non-critical membership events
     partial_state: bool
 
-    # if 'partial_state' is set, a list of the servers in the room (otherwise empty)
-    servers_in_room: List[str]
+    # If 'partial_state' is set, a set of the servers in the room (otherwise empty).
+    # Always contains the server we joined off.
+    servers_in_room: AbstractSet[str]
 
 
 class FederationClient(FederationBase):
@@ -1152,15 +1154,24 @@ class FederationClient(FederationBase):
                     % (auth_chain_create_events,)
                 )
 
-            if response.members_omitted and not response.servers_in_room:
-                raise InvalidResponseError(
-                    "members_omitted was set, but no servers were listed in the room"
-                )
+            servers_in_room = None
+            if response.servers_in_room is not None:
+                servers_in_room = set(response.servers_in_room)
 
-            if response.members_omitted and not partial_state:
-                raise InvalidResponseError(
-                    "members_omitted was set, but we asked for full state"
-                )
+            if response.members_omitted:
+                if not servers_in_room:
+                    raise InvalidResponseError(
+                        "members_omitted was set, but no servers were listed in the room"
+                    )
+
+                if not partial_state:
+                    raise InvalidResponseError(
+                        "members_omitted was set, but we asked for full state"
+                    )
+
+                # `servers_in_room` is supposed to be a complete list.
+                # Fix things up in case the remote homeserver is badly behaved.
+                servers_in_room.add(destination)
 
             return SendJoinResult(
                 event=event,
@@ -1168,7 +1179,7 @@ class FederationClient(FederationBase):
                 auth_chain=signed_auth,
                 origin=destination,
                 partial_state=response.members_omitted,
-                servers_in_room=response.servers_in_room or [],
+                servers_in_room=servers_in_room or frozenset(),
             )
 
         # MSC3083 defines additional error codes for room joins.
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 30ebd62883..43421a9c72 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -447,7 +447,7 @@ class FederationSender(AbstractFederationSender):
                             )
                         )
 
-                        if len(partial_state_destinations) > 0:
+                        if partial_state_destinations is not None:
                             destinations = partial_state_destinations
 
                     if destinations is None:
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 5c06073901..6f7963df43 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -859,6 +859,7 @@ class DeviceHandler(DeviceWorkerHandler):
         known_hosts_at_join = await self.store.get_partial_state_servers_at_join(
             room_id
         )
+        assert known_hosts_at_join is not None
         potentially_changed_hosts.difference_update(known_hosts_at_join)
 
         potentially_changed_hosts.discard(self.server_name)
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index dc1cbf5c3d..7f64130e0a 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -20,7 +20,17 @@ import itertools
 import logging
 from enum import Enum
 from http import HTTPStatus
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple, Union
+from typing import (
+    TYPE_CHECKING,
+    AbstractSet,
+    Dict,
+    Iterable,
+    List,
+    Optional,
+    Set,
+    Tuple,
+    Union,
+)
 
 import attr
 from prometheus_client import Histogram
@@ -169,7 +179,7 @@ class FederationHandler:
         # A dictionary mapping room IDs to (initial destination, other destinations)
         # tuples.
         self._partial_state_syncs_maybe_needing_restart: Dict[
-            str, Tuple[Optional[str], StrCollection]
+            str, Tuple[Optional[str], AbstractSet[str]]
         ] = {}
         # A lock guarding the partial state flag for rooms.
         # When the lock is held for a given room, no other concurrent code may
@@ -1720,7 +1730,7 @@ class FederationHandler:
     def _start_partial_state_room_sync(
         self,
         initial_destination: Optional[str],
-        other_destinations: StrCollection,
+        other_destinations: AbstractSet[str],
         room_id: str,
     ) -> None:
         """Starts the background process to resync the state of a partial state room,
@@ -1802,7 +1812,7 @@ class FederationHandler:
     async def _sync_partial_state_room(
         self,
         initial_destination: Optional[str],
-        other_destinations: StrCollection,
+        other_destinations: AbstractSet[str],
         room_id: str,
     ) -> None:
         """Background process to resync the state of a partial-state room
@@ -1939,7 +1949,7 @@ class FederationHandler:
 
 def _prioritise_destinations_for_partial_state_resync(
     initial_destination: Optional[str],
-    other_destinations: StrCollection,
+    other_destinations: AbstractSet[str],
     room_id: str,
 ) -> StrCollection:
     """Work out the order in which we should ask servers to resync events.
diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
index 2045169b9a..52efd4a171 100644
--- a/synapse/storage/controllers/state.py
+++ b/synapse/storage/controllers/state.py
@@ -569,10 +569,11 @@ class StateStorageController:
         is arbitrary for rooms with partial state.
         """
         # We have to read this list first to mitigate races with un-partial stating.
-        # This will be empty for rooms with full state.
         hosts_at_join = await self.stores.main.get_partial_state_servers_at_join(
             room_id
         )
+        if hosts_at_join is None:
+            hosts_at_join = frozenset()
 
         hosts_from_state = await self.stores.main.get_current_hosts_in_room(room_id)
 
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 4ddb27f686..644bbb8878 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -18,6 +18,7 @@ from abc import abstractmethod
 from enum import Enum
 from typing import (
     TYPE_CHECKING,
+    AbstractSet,
     Any,
     Awaitable,
     Collection,
@@ -25,7 +26,6 @@ from typing import (
     List,
     Mapping,
     Optional,
-    Sequence,
     Set,
     Tuple,
     Union,
@@ -109,7 +109,7 @@ class RoomSortOrder(Enum):
 @attr.s(slots=True, frozen=True, auto_attribs=True)
 class PartialStateResyncInfo:
     joined_via: Optional[str]
-    servers_in_room: List[str] = attr.ib(factory=list)
+    servers_in_room: Set[str] = attr.ib(factory=set)
 
 
 class RoomWorkerStore(CacheInvalidationWorkerStore):
@@ -1193,21 +1193,35 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
             get_rooms_for_retention_period_in_range_txn,
         )
 
-    @cached(iterable=True)
-    async def get_partial_state_servers_at_join(self, room_id: str) -> Sequence[str]:
-        """Gets the list of servers in a partial state room at the time we joined it.
+    async def get_partial_state_servers_at_join(
+        self, room_id: str
+    ) -> Optional[AbstractSet[str]]:
+        """Gets the set of servers in a partial state room at the time we joined it.
 
         Returns:
             The `servers_in_room` list from the `/send_join` response for partial state
             rooms. May not be accurate or complete, as it comes from a remote
             homeserver.
-            An empty list for full state rooms.
+            `None` for full state rooms.
         """
-        return await self.db_pool.simple_select_onecol(
-            "partial_state_rooms_servers",
-            keyvalues={"room_id": room_id},
-            retcol="server_name",
-            desc="get_partial_state_servers_at_join",
+        servers_in_room = await self._get_partial_state_servers_at_join(room_id)
+
+        if len(servers_in_room) == 0:
+            return None
+
+        return servers_in_room
+
+    @cached(iterable=True)
+    async def _get_partial_state_servers_at_join(
+        self, room_id: str
+    ) -> AbstractSet[str]:
+        return frozenset(
+            await self.db_pool.simple_select_onecol(
+                "partial_state_rooms_servers",
+                keyvalues={"room_id": room_id},
+                retcol="server_name",
+                desc="get_partial_state_servers_at_join",
+            )
         )
 
     async def get_partial_state_room_resync_info(
@@ -1252,7 +1266,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
                 # partial-joined between the two SELECTs, but this is unlikely to happen
                 # in practice.)
                 continue
-            entry.servers_in_room.append(server_name)
+            entry.servers_in_room.add(server_name)
 
         return room_servers
 
@@ -1942,7 +1956,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
     async def store_partial_state_room(
         self,
         room_id: str,
-        servers: Collection[str],
+        servers: AbstractSet[str],
         device_lists_stream_id: int,
         joined_via: str,
     ) -> None:
@@ -1957,11 +1971,13 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
 
         Args:
             room_id: the ID of the room
-            servers: other servers known to be in the room
+            servers: other servers known to be in the room. must include `joined_via`.
             device_lists_stream_id: the device_lists stream ID at the time when we first
                 joined the room.
             joined_via: the server name we requested a partial join from.
         """
+        assert joined_via in servers
+
         await self.db_pool.runInteraction(
             "store_partial_state_room",
             self._store_partial_state_room_txn,
@@ -1975,7 +1991,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
         self,
         txn: LoggingTransaction,
         room_id: str,
-        servers: Collection[str],
+        servers: AbstractSet[str],
         device_lists_stream_id: int,
         joined_via: str,
     ) -> None:
@@ -1998,7 +2014,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
         )
         self._invalidate_cache_and_stream(txn, self.is_partial_state_room, (room_id,))
         self._invalidate_cache_and_stream(
-            txn, self.get_partial_state_servers_at_join, (room_id,)
+            txn, self._get_partial_state_servers_at_join, (room_id,)
         )
 
     async def write_partial_state_rooms_join_event_id(
@@ -2409,7 +2425,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
         )
         self._invalidate_cache_and_stream(txn, self.is_partial_state_room, (room_id,))
         self._invalidate_cache_and_stream(
-            txn, self.get_partial_state_servers_at_join, (room_id,)
+            txn, self._get_partial_state_servers_at_join, (room_id,)
         )
 
         DatabasePool.simple_insert_txn(
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index c1558c40c3..57675fa407 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -656,7 +656,7 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
                         EVENT_INVITATION_MEMBERSHIP,
                     ],
                     partial_state=True,
-                    servers_in_room=["example.com"],
+                    servers_in_room={"example.com"},
                 )
             )
         )
diff --git a/tests/handlers/test_room_member.py b/tests/handlers/test_room_member.py
index 6bbfd5dc84..6a38893b68 100644
--- a/tests/handlers/test_room_member.py
+++ b/tests/handlers/test_room_member.py
@@ -171,7 +171,7 @@ class TestJoinsLimitedByPerRoomRateLimiter(FederatingHomeserverTestCase):
                     state=[create_event],
                     auth_chain=[create_event],
                     partial_state=False,
-                    servers_in_room=[],
+                    servers_in_room=frozenset(),
                 )
             )
         )