diff --git a/changelog.d/14954.misc b/changelog.d/14954.misc
new file mode 100644
index 0000000000..b86b6bf01e
--- /dev/null
+++ b/changelog.d/14954.misc
@@ -0,0 +1 @@
+Faster room joins: Refactor internal handling of servers in room to never store an empty list.
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 8493ffc2e5..0ac85a3be7 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -19,6 +19,7 @@ import itertools
import logging
from typing import (
TYPE_CHECKING,
+ AbstractSet,
Awaitable,
Callable,
Collection,
@@ -110,8 +111,9 @@ class SendJoinResult:
# True if 'state' elides non-critical membership events
partial_state: bool
- # if 'partial_state' is set, a list of the servers in the room (otherwise empty)
- servers_in_room: List[str]
+ # If 'partial_state' is set, a set of the servers in the room (otherwise empty).
+ # Always contains the server we joined off.
+ servers_in_room: AbstractSet[str]
class FederationClient(FederationBase):
@@ -1152,15 +1154,24 @@ class FederationClient(FederationBase):
% (auth_chain_create_events,)
)
- if response.members_omitted and not response.servers_in_room:
- raise InvalidResponseError(
- "members_omitted was set, but no servers were listed in the room"
- )
+ servers_in_room = None
+ if response.servers_in_room is not None:
+ servers_in_room = set(response.servers_in_room)
- if response.members_omitted and not partial_state:
- raise InvalidResponseError(
- "members_omitted was set, but we asked for full state"
- )
+ if response.members_omitted:
+ if not servers_in_room:
+ raise InvalidResponseError(
+ "members_omitted was set, but no servers were listed in the room"
+ )
+
+ if not partial_state:
+ raise InvalidResponseError(
+ "members_omitted was set, but we asked for full state"
+ )
+
+ # `servers_in_room` is supposed to be a complete list.
+ # Fix things up in case the remote homeserver is badly behaved.
+ servers_in_room.add(destination)
return SendJoinResult(
event=event,
@@ -1168,7 +1179,7 @@ class FederationClient(FederationBase):
auth_chain=signed_auth,
origin=destination,
partial_state=response.members_omitted,
- servers_in_room=response.servers_in_room or [],
+ servers_in_room=servers_in_room or frozenset(),
)
# MSC3083 defines additional error codes for room joins.
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 30ebd62883..43421a9c72 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -447,7 +447,7 @@ class FederationSender(AbstractFederationSender):
)
)
- if len(partial_state_destinations) > 0:
+ if partial_state_destinations is not None:
destinations = partial_state_destinations
if destinations is None:
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 5c06073901..6f7963df43 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -859,6 +859,7 @@ class DeviceHandler(DeviceWorkerHandler):
known_hosts_at_join = await self.store.get_partial_state_servers_at_join(
room_id
)
+ assert known_hosts_at_join is not None
potentially_changed_hosts.difference_update(known_hosts_at_join)
potentially_changed_hosts.discard(self.server_name)
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index dc1cbf5c3d..7f64130e0a 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -20,7 +20,17 @@ import itertools
import logging
from enum import Enum
from http import HTTPStatus
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple, Union
+from typing import (
+ TYPE_CHECKING,
+ AbstractSet,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Set,
+ Tuple,
+ Union,
+)
import attr
from prometheus_client import Histogram
@@ -169,7 +179,7 @@ class FederationHandler:
# A dictionary mapping room IDs to (initial destination, other destinations)
# tuples.
self._partial_state_syncs_maybe_needing_restart: Dict[
- str, Tuple[Optional[str], StrCollection]
+ str, Tuple[Optional[str], AbstractSet[str]]
] = {}
# A lock guarding the partial state flag for rooms.
# When the lock is held for a given room, no other concurrent code may
@@ -1720,7 +1730,7 @@ class FederationHandler:
def _start_partial_state_room_sync(
self,
initial_destination: Optional[str],
- other_destinations: StrCollection,
+ other_destinations: AbstractSet[str],
room_id: str,
) -> None:
"""Starts the background process to resync the state of a partial state room,
@@ -1802,7 +1812,7 @@ class FederationHandler:
async def _sync_partial_state_room(
self,
initial_destination: Optional[str],
- other_destinations: StrCollection,
+ other_destinations: AbstractSet[str],
room_id: str,
) -> None:
"""Background process to resync the state of a partial-state room
@@ -1939,7 +1949,7 @@ class FederationHandler:
def _prioritise_destinations_for_partial_state_resync(
initial_destination: Optional[str],
- other_destinations: StrCollection,
+ other_destinations: AbstractSet[str],
room_id: str,
) -> StrCollection:
"""Work out the order in which we should ask servers to resync events.
diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
index 2045169b9a..52efd4a171 100644
--- a/synapse/storage/controllers/state.py
+++ b/synapse/storage/controllers/state.py
@@ -569,10 +569,11 @@ class StateStorageController:
is arbitrary for rooms with partial state.
"""
# We have to read this list first to mitigate races with un-partial stating.
- # This will be empty for rooms with full state.
hosts_at_join = await self.stores.main.get_partial_state_servers_at_join(
room_id
)
+ if hosts_at_join is None:
+ hosts_at_join = frozenset()
hosts_from_state = await self.stores.main.get_current_hosts_in_room(room_id)
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 4ddb27f686..644bbb8878 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -18,6 +18,7 @@ from abc import abstractmethod
from enum import Enum
from typing import (
TYPE_CHECKING,
+ AbstractSet,
Any,
Awaitable,
Collection,
@@ -25,7 +26,6 @@ from typing import (
List,
Mapping,
Optional,
- Sequence,
Set,
Tuple,
Union,
@@ -109,7 +109,7 @@ class RoomSortOrder(Enum):
@attr.s(slots=True, frozen=True, auto_attribs=True)
class PartialStateResyncInfo:
joined_via: Optional[str]
- servers_in_room: List[str] = attr.ib(factory=list)
+ servers_in_room: Set[str] = attr.ib(factory=set)
class RoomWorkerStore(CacheInvalidationWorkerStore):
@@ -1193,21 +1193,35 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
get_rooms_for_retention_period_in_range_txn,
)
- @cached(iterable=True)
- async def get_partial_state_servers_at_join(self, room_id: str) -> Sequence[str]:
- """Gets the list of servers in a partial state room at the time we joined it.
+ async def get_partial_state_servers_at_join(
+ self, room_id: str
+ ) -> Optional[AbstractSet[str]]:
+ """Gets the set of servers in a partial state room at the time we joined it.
Returns:
The `servers_in_room` list from the `/send_join` response for partial state
rooms. May not be accurate or complete, as it comes from a remote
homeserver.
- An empty list for full state rooms.
+ `None` for full state rooms.
"""
- return await self.db_pool.simple_select_onecol(
- "partial_state_rooms_servers",
- keyvalues={"room_id": room_id},
- retcol="server_name",
- desc="get_partial_state_servers_at_join",
+ servers_in_room = await self._get_partial_state_servers_at_join(room_id)
+
+ if len(servers_in_room) == 0:
+ return None
+
+ return servers_in_room
+
+ @cached(iterable=True)
+ async def _get_partial_state_servers_at_join(
+ self, room_id: str
+ ) -> AbstractSet[str]:
+ return frozenset(
+ await self.db_pool.simple_select_onecol(
+ "partial_state_rooms_servers",
+ keyvalues={"room_id": room_id},
+ retcol="server_name",
+ desc="get_partial_state_servers_at_join",
+ )
)
async def get_partial_state_room_resync_info(
@@ -1252,7 +1266,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
# partial-joined between the two SELECTs, but this is unlikely to happen
# in practice.)
continue
- entry.servers_in_room.append(server_name)
+ entry.servers_in_room.add(server_name)
return room_servers
@@ -1942,7 +1956,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
async def store_partial_state_room(
self,
room_id: str,
- servers: Collection[str],
+ servers: AbstractSet[str],
device_lists_stream_id: int,
joined_via: str,
) -> None:
@@ -1957,11 +1971,13 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
Args:
room_id: the ID of the room
- servers: other servers known to be in the room
+ servers: other servers known to be in the room. must include `joined_via`.
device_lists_stream_id: the device_lists stream ID at the time when we first
joined the room.
joined_via: the server name we requested a partial join from.
"""
+ assert joined_via in servers
+
await self.db_pool.runInteraction(
"store_partial_state_room",
self._store_partial_state_room_txn,
@@ -1975,7 +1991,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
self,
txn: LoggingTransaction,
room_id: str,
- servers: Collection[str],
+ servers: AbstractSet[str],
device_lists_stream_id: int,
joined_via: str,
) -> None:
@@ -1998,7 +2014,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
)
self._invalidate_cache_and_stream(txn, self.is_partial_state_room, (room_id,))
self._invalidate_cache_and_stream(
- txn, self.get_partial_state_servers_at_join, (room_id,)
+ txn, self._get_partial_state_servers_at_join, (room_id,)
)
async def write_partial_state_rooms_join_event_id(
@@ -2409,7 +2425,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
)
self._invalidate_cache_and_stream(txn, self.is_partial_state_room, (room_id,))
self._invalidate_cache_and_stream(
- txn, self.get_partial_state_servers_at_join, (room_id,)
+ txn, self._get_partial_state_servers_at_join, (room_id,)
)
DatabasePool.simple_insert_txn(
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index c1558c40c3..57675fa407 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -656,7 +656,7 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
EVENT_INVITATION_MEMBERSHIP,
],
partial_state=True,
- servers_in_room=["example.com"],
+ servers_in_room={"example.com"},
)
)
)
diff --git a/tests/handlers/test_room_member.py b/tests/handlers/test_room_member.py
index 6bbfd5dc84..6a38893b68 100644
--- a/tests/handlers/test_room_member.py
+++ b/tests/handlers/test_room_member.py
@@ -171,7 +171,7 @@ class TestJoinsLimitedByPerRoomRateLimiter(FederatingHomeserverTestCase):
state=[create_event],
auth_chain=[create_event],
partial_state=False,
- servers_in_room=[],
+ servers_in_room=frozenset(),
)
)
)
|