summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/controllers/state.py3
-rw-r--r--synapse/storage/databases/main/room.py50
2 files changed, 35 insertions, 18 deletions
diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
index 2045169b9a..52efd4a171 100644
--- a/synapse/storage/controllers/state.py
+++ b/synapse/storage/controllers/state.py
@@ -569,10 +569,11 @@ class StateStorageController:
         is arbitrary for rooms with partial state.
         """
         # We have to read this list first to mitigate races with un-partial stating.
-        # This will be empty for rooms with full state.
         hosts_at_join = await self.stores.main.get_partial_state_servers_at_join(
             room_id
         )
+        if hosts_at_join is None:
+            hosts_at_join = frozenset()
 
         hosts_from_state = await self.stores.main.get_current_hosts_in_room(room_id)
 
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 4ddb27f686..644bbb8878 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -18,6 +18,7 @@ from abc import abstractmethod
 from enum import Enum
 from typing import (
     TYPE_CHECKING,
+    AbstractSet,
     Any,
     Awaitable,
     Collection,
@@ -25,7 +26,6 @@ from typing import (
     List,
     Mapping,
     Optional,
-    Sequence,
     Set,
     Tuple,
     Union,
@@ -109,7 +109,7 @@ class RoomSortOrder(Enum):
 @attr.s(slots=True, frozen=True, auto_attribs=True)
 class PartialStateResyncInfo:
     joined_via: Optional[str]
-    servers_in_room: List[str] = attr.ib(factory=list)
+    servers_in_room: Set[str] = attr.ib(factory=set)
 
 
 class RoomWorkerStore(CacheInvalidationWorkerStore):
@@ -1193,21 +1193,35 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
             get_rooms_for_retention_period_in_range_txn,
         )
 
-    @cached(iterable=True)
-    async def get_partial_state_servers_at_join(self, room_id: str) -> Sequence[str]:
-        """Gets the list of servers in a partial state room at the time we joined it.
+    async def get_partial_state_servers_at_join(
+        self, room_id: str
+    ) -> Optional[AbstractSet[str]]:
+        """Gets the set of servers in a partial state room at the time we joined it.
 
         Returns:
             The `servers_in_room` list from the `/send_join` response for partial state
             rooms. May not be accurate or complete, as it comes from a remote
             homeserver.
-            An empty list for full state rooms.
+            `None` for full state rooms.
         """
-        return await self.db_pool.simple_select_onecol(
-            "partial_state_rooms_servers",
-            keyvalues={"room_id": room_id},
-            retcol="server_name",
-            desc="get_partial_state_servers_at_join",
+        servers_in_room = await self._get_partial_state_servers_at_join(room_id)
+
+        if len(servers_in_room) == 0:
+            return None
+
+        return servers_in_room
+
+    @cached(iterable=True)
+    async def _get_partial_state_servers_at_join(
+        self, room_id: str
+    ) -> AbstractSet[str]:
+        return frozenset(
+            await self.db_pool.simple_select_onecol(
+                "partial_state_rooms_servers",
+                keyvalues={"room_id": room_id},
+                retcol="server_name",
+                desc="get_partial_state_servers_at_join",
+            )
         )
 
     async def get_partial_state_room_resync_info(
@@ -1252,7 +1266,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
                 # partial-joined between the two SELECTs, but this is unlikely to happen
                 # in practice.)
                 continue
-            entry.servers_in_room.append(server_name)
+            entry.servers_in_room.add(server_name)
 
         return room_servers
 
@@ -1942,7 +1956,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
     async def store_partial_state_room(
         self,
         room_id: str,
-        servers: Collection[str],
+        servers: AbstractSet[str],
         device_lists_stream_id: int,
         joined_via: str,
     ) -> None:
@@ -1957,11 +1971,13 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
 
         Args:
             room_id: the ID of the room
-            servers: other servers known to be in the room
+            servers: other servers known to be in the room. must include `joined_via`.
             device_lists_stream_id: the device_lists stream ID at the time when we first
                 joined the room.
             joined_via: the server name we requested a partial join from.
         """
+        assert joined_via in servers
+
         await self.db_pool.runInteraction(
             "store_partial_state_room",
             self._store_partial_state_room_txn,
@@ -1975,7 +1991,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
         self,
         txn: LoggingTransaction,
         room_id: str,
-        servers: Collection[str],
+        servers: AbstractSet[str],
         device_lists_stream_id: int,
         joined_via: str,
     ) -> None:
@@ -1998,7 +2014,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
         )
         self._invalidate_cache_and_stream(txn, self.is_partial_state_room, (room_id,))
         self._invalidate_cache_and_stream(
-            txn, self.get_partial_state_servers_at_join, (room_id,)
+            txn, self._get_partial_state_servers_at_join, (room_id,)
         )
 
     async def write_partial_state_rooms_join_event_id(
@@ -2409,7 +2425,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
         )
         self._invalidate_cache_and_stream(txn, self.is_partial_state_room, (room_id,))
         self._invalidate_cache_and_stream(
-            txn, self.get_partial_state_servers_at_join, (room_id,)
+            txn, self._get_partial_state_servers_at_join, (room_id,)
         )
 
         DatabasePool.simple_insert_txn(