diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 0f1f0d11ea..52ad947c6c 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -25,6 +25,7 @@ from typing import (
List,
Mapping,
Optional,
+ Sequence,
Tuple,
Union,
cast,
@@ -96,6 +97,12 @@ class RoomSortOrder(Enum):
STATE_EVENTS = "state_events"
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class PartialStateResyncInfo:
+ joined_via: Optional[str]
+ servers_in_room: List[str] = attr.ib(factory=list)
+
+
class RoomWorkerStore(CacheInvalidationWorkerStore):
def __init__(
self,
@@ -206,21 +213,30 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
def _construct_room_type_where_clause(
self, room_types: Union[List[Union[str, None]], None]
- ) -> Tuple[Union[str, None], List[str]]:
+ ) -> Tuple[Union[str, None], list]:
if not room_types:
return None, []
- else:
- # We use None when we want get rooms without a type
- is_null_clause = ""
- if None in room_types:
- is_null_clause = "OR room_type IS NULL"
- room_types = [value for value in room_types if value is not None]
+ # Since None is used to represent a room without a type, care needs to
+ # be taken into account when constructing the where clause.
+ clauses = []
+ args: list = []
+
+ room_types_set = set(room_types)
+
+ # We use None to represent a room without a type.
+ if None in room_types_set:
+ clauses.append("room_type IS NULL")
+ room_types_set.remove(None)
+
+ # If there are other room types, generate the proper clause.
+ if room_types:
list_clause, args = make_in_list_sql_clause(
- self.database_engine, "room_type", room_types
+ self.database_engine, "room_type", room_types_set
)
+ clauses.append(list_clause)
- return f"({list_clause} {is_null_clause})", args
+ return f"({' OR '.join(clauses)})", args
async def count_public_rooms(
self,
@@ -240,14 +256,6 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
def _count_public_rooms_txn(txn: LoggingTransaction) -> int:
query_args = []
- room_type_clause, args = self._construct_room_type_where_clause(
- search_filter.get(PublicRoomsFilterFields.ROOM_TYPES, None)
- if search_filter
- else None
- )
- room_type_clause = f" AND {room_type_clause}" if room_type_clause else ""
- query_args += args
-
if network_tuple:
if network_tuple.appservice_id:
published_sql = """
@@ -267,6 +275,14 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
UNION SELECT room_id from appservice_room_list
"""
+ room_type_clause, args = self._construct_room_type_where_clause(
+ search_filter.get(PublicRoomsFilterFields.ROOM_TYPES, None)
+ if search_filter
+ else None
+ )
+ room_type_clause = f" AND {room_type_clause}" if room_type_clause else ""
+ query_args += args
+
sql = f"""
SELECT
COUNT(*)
@@ -641,8 +657,10 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
"version": room[5],
"creator": room[6],
"encryption": room[7],
- "federatable": room[8],
- "public": room[9],
+ # room_stats_state.federatable is an integer on sqlite.
+ "federatable": bool(room[8]),
+ # rooms.is_public is an integer on sqlite.
+ "public": bool(room[9]),
"join_rules": room[10],
"guest_access": room[11],
"history_visibility": room[12],
@@ -894,7 +912,11 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
event_json = db_to_json(content_json)
content = event_json["content"]
content_url = content.get("url")
- thumbnail_url = content.get("info", {}).get("thumbnail_url")
+ info = content.get("info")
+ if isinstance(info, dict):
+ thumbnail_url = info.get("thumbnail_url")
+ else:
+ thumbnail_url = None
for url in (content_url, thumbnail_url):
if not url:
@@ -1131,17 +1153,46 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
get_rooms_for_retention_period_in_range_txn,
)
- async def get_partial_state_rooms_and_servers(
+ @cached(iterable=True)
+ async def get_partial_state_servers_at_join(self, room_id: str) -> Sequence[str]:
+ """Gets the list of servers in a partial state room at the time we joined it.
+
+ Returns:
+ The `servers_in_room` list from the `/send_join` response for partial state
+ rooms. May not be accurate or complete, as it comes from a remote
+ homeserver.
+ An empty list for full state rooms.
+ """
+ return await self.db_pool.simple_select_onecol(
+ "partial_state_rooms_servers",
+ keyvalues={"room_id": room_id},
+ retcol="server_name",
+ desc="get_partial_state_servers_at_join",
+ )
+
+ async def get_partial_state_room_resync_info(
self,
- ) -> Mapping[str, Collection[str]]:
- """Get all rooms containing events with partial state, and the servers known
- to be in the room.
+ ) -> Mapping[str, PartialStateResyncInfo]:
+ """Get all rooms containing events with partial state, and the information
+ needed to restart a "resync" of those rooms.
Returns:
A dictionary of rooms with partial state, with room IDs as keys and
lists of servers in rooms as values.
"""
- room_servers: Dict[str, List[str]] = {}
+ room_servers: Dict[str, PartialStateResyncInfo] = {}
+
+ rows = await self.db_pool.simple_select_list(
+ table="partial_state_rooms",
+ keyvalues={},
+ retcols=("room_id", "joined_via"),
+ desc="get_server_which_served_partial_join",
+ )
+
+ for row in rows:
+ room_id = row["room_id"]
+ joined_via = row["joined_via"]
+ room_servers[room_id] = PartialStateResyncInfo(joined_via=joined_via)
rows = await self.db_pool.simple_select_list(
"partial_state_rooms_servers",
@@ -1153,7 +1204,15 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
for row in rows:
room_id = row["room_id"]
server_name = row["server_name"]
- room_servers.setdefault(room_id, []).append(server_name)
+ entry = room_servers.get(room_id)
+ if entry is None:
+ # There is a foreign key constraint which enforces that every room_id in
+ # partial_state_rooms_servers appears in partial_state_rooms. So we
+ # expect `entry` to be non-null. (This reasoning fails if we've
+ # partial-joined between the two SELECTs, but this is unlikely to happen
+ # in practice.)
+ continue
+ entry.servers_in_room.append(server_name)
return room_servers
@@ -1183,8 +1242,9 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
)
return False
- @staticmethod
- def _clear_partial_state_room_txn(txn: LoggingTransaction, room_id: str) -> None:
+ def _clear_partial_state_room_txn(
+ self, txn: LoggingTransaction, room_id: str
+ ) -> None:
DatabasePool.simple_delete_txn(
txn,
table="partial_state_rooms_servers",
@@ -1195,7 +1255,32 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
table="partial_state_rooms",
keyvalues={"room_id": room_id},
)
+ self._invalidate_cache_and_stream(txn, self.is_partial_state_room, (room_id,))
+ self._invalidate_cache_and_stream(
+ txn, self.get_partial_state_servers_at_join, (room_id,)
+ )
+ # We now delete anything from `device_lists_remote_pending` with a
+ # stream ID less than the minimum
+ # `partial_state_rooms.device_lists_stream_id`, as we no longer need them.
+ device_lists_stream_id = DatabasePool.simple_select_one_onecol_txn(
+ txn,
+ table="partial_state_rooms",
+ keyvalues={},
+ retcol="MIN(device_lists_stream_id)",
+ allow_none=True,
+ )
+ if device_lists_stream_id is None:
+ # There are no rooms being currently partially joined, so we delete everything.
+ txn.execute("DELETE FROM device_lists_remote_pending")
+ else:
+ sql = """
+ DELETE FROM device_lists_remote_pending
+ WHERE stream_id <= ?
+ """
+ txn.execute(sql, (device_lists_stream_id,))
+
+ @cached()
async def is_partial_state_room(self, room_id: str) -> bool:
"""Checks if this room has partial state.
@@ -1214,6 +1299,22 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
return entry is not None
+ async def get_join_event_id_and_device_lists_stream_id_for_partial_state(
+ self, room_id: str
+ ) -> Tuple[str, int]:
+ """Get the event ID of the initial join that started the partial
+ join, and the device list stream ID at the point we started the partial
+ join.
+ """
+
+ result = await self.db_pool.simple_select_one(
+ table="partial_state_rooms",
+ keyvalues={"room_id": room_id},
+ retcols=("join_event_id", "device_lists_stream_id"),
+ desc="get_join_event_id_for_partial_state",
+ )
+ return result["join_event_id"], result["device_lists_stream_id"]
+
class _BackgroundUpdates:
REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
@@ -1755,29 +1856,51 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
self,
room_id: str,
servers: Collection[str],
+ device_lists_stream_id: int,
+ joined_via: str,
) -> None:
- """Mark the given room as containing events with partial state
+ """Mark the given room as containing events with partial state.
+
+ We also store additional data that describes _when_ we first partial-joined this
+ room, which helps us to keep other homeservers in sync when we finally fully
+ join this room.
+
+ We do not include a `join_event_id` here---we need to wait for the join event
+ to be persisted first.
Args:
room_id: the ID of the room
servers: other servers known to be in the room
+ device_lists_stream_id: the device_lists stream ID at the time when we first
+ joined the room.
+ joined_via: the server name we requested a partial join from.
"""
await self.db_pool.runInteraction(
"store_partial_state_room",
self._store_partial_state_room_txn,
room_id,
servers,
+ device_lists_stream_id,
+ joined_via,
)
- @staticmethod
def _store_partial_state_room_txn(
- txn: LoggingTransaction, room_id: str, servers: Collection[str]
+ self,
+ txn: LoggingTransaction,
+ room_id: str,
+ servers: Collection[str],
+ device_lists_stream_id: int,
+ joined_via: str,
) -> None:
DatabasePool.simple_insert_txn(
txn,
table="partial_state_rooms",
values={
"room_id": room_id,
+ "device_lists_stream_id": device_lists_stream_id,
+ # To be updated later once the join event is persisted.
+ "join_event_id": None,
+ "joined_via": joined_via,
},
)
DatabasePool.simple_insert_many_txn(
@@ -1786,6 +1909,40 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
keys=("room_id", "server_name"),
values=((room_id, s) for s in servers),
)
+ self._invalidate_cache_and_stream(txn, self.is_partial_state_room, (room_id,))
+ self._invalidate_cache_and_stream(
+ txn, self.get_partial_state_servers_at_join, (room_id,)
+ )
+
+ async def write_partial_state_rooms_join_event_id(
+ self,
+ room_id: str,
+ join_event_id: str,
+ ) -> None:
+ """Record the join event which resulted from a partial join.
+
+ We do this separately to `store_partial_state_room` because we need to wait for
+ the join event to be persisted. Otherwise we violate a foreign key constraint.
+ """
+ await self.db_pool.runInteraction(
+ "write_partial_state_rooms_join_event_id",
+ self._write_partial_state_rooms_join_event_id,
+ room_id,
+ join_event_id,
+ )
+
+ def _write_partial_state_rooms_join_event_id(
+ self,
+ txn: LoggingTransaction,
+ room_id: str,
+ join_event_id: str,
+ ) -> None:
+ DatabasePool.simple_update_txn(
+ txn,
+ table="partial_state_rooms",
+ keyvalues={"room_id": room_id},
+ updatevalues={"join_event_id": join_event_id},
+ )
async def maybe_store_room_on_outlier_membership(
self, room_id: str, room_version: RoomVersion
@@ -1904,7 +2061,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
Args:
report_id: ID of reported event in database
Returns:
- event_report: json list of information from event report
+ JSON dict of information from an event report or None if the
+ report does not exist.
"""
def _get_event_report_txn(
@@ -1977,8 +2135,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
user_id: search for user_id. Ignored if user_id is None
room_id: search for room_id. Ignored if room_id is None
Returns:
- event_reports: json list of event reports
- count: total number of event reports matching the filter criteria
+ Tuple of:
+ json list of event reports
+ total number of event reports matching the filter criteria
"""
def _get_event_reports_paginate_txn(
@@ -2001,9 +2160,15 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
+ # We join on room_stats_state despite not using any columns from it
+ # because the join can influence the number of rows returned;
+ # e.g. a room that doesn't have state, maybe because it was deleted.
+ # The query returning the total count should be consistent with
+ # the query returning the results.
sql = """
SELECT COUNT(*) as total_event_reports
FROM event_reports AS er
+ JOIN room_stats_state ON room_stats_state.room_id = er.room_id
{}
""".format(
where_clause
|