diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index dfca34550d..d7f015c783 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -47,6 +47,7 @@ from synapse.storage.database import (
make_tuple_comparison_clause,
)
from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
+from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.types import Cursor
from synapse.types import JsonDict, get_verify_key_from_cross_signing_key
from synapse.util import json_decoder, json_encoder
@@ -70,7 +71,7 @@ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
-class DeviceWorkerStore(EndToEndKeyWorkerStore):
+class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
def __init__(
self,
database: DatabasePool,
@@ -985,24 +986,59 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
desc="mark_remote_user_device_cache_as_valid",
)
+ async def handle_potentially_left_users(self, user_ids: Set[str]) -> None:
+ """Given a set of remote users check if the server still shares a room with
+ them. If not then mark those users' device cache as stale.
+ """
+
+ if not user_ids:
+ return
+
+ await self.db_pool.runInteraction(
+ "_handle_potentially_left_users",
+ self.handle_potentially_left_users_txn,
+ user_ids,
+ )
+
+ def handle_potentially_left_users_txn(
+ self,
+ txn: LoggingTransaction,
+ user_ids: Set[str],
+ ) -> None:
+ """Given a set of remote users check if the server still shares a room with
+ them. If not then mark those users' device cache as stale.
+ """
+
+ if not user_ids:
+ return
+
+ joined_users = self.get_users_server_still_shares_room_with_txn(txn, user_ids)
+ left_users = user_ids - joined_users
+
+ for user_id in left_users:
+ self.mark_remote_user_device_list_as_unsubscribed_txn(txn, user_id)
+
async def mark_remote_user_device_list_as_unsubscribed(self, user_id: str) -> None:
"""Mark that we no longer track device lists for remote user."""
- def _mark_remote_user_device_list_as_unsubscribed_txn(
- txn: LoggingTransaction,
- ) -> None:
- self.db_pool.simple_delete_txn(
- txn,
- table="device_lists_remote_extremeties",
- keyvalues={"user_id": user_id},
- )
- self._invalidate_cache_and_stream(
- txn, self.get_device_list_last_stream_id_for_remote, (user_id,)
- )
-
await self.db_pool.runInteraction(
"mark_remote_user_device_list_as_unsubscribed",
- _mark_remote_user_device_list_as_unsubscribed_txn,
+ self.mark_remote_user_device_list_as_unsubscribed_txn,
+ user_id,
+ )
+
+ def mark_remote_user_device_list_as_unsubscribed_txn(
+ self,
+ txn: LoggingTransaction,
+ user_id: str,
+ ) -> None:
+ self.db_pool.simple_delete_txn(
+ txn,
+ table="device_lists_remote_extremeties",
+ keyvalues={"user_id": user_id},
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_device_list_last_stream_id_for_remote, (user_id,)
)
async def get_dehydrated_device(
@@ -1271,6 +1307,33 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
return changes
+ async def get_device_list_changes_in_room(
+ self, room_id: str, min_stream_id: int
+ ) -> Collection[Tuple[str, str]]:
+ """Get all device list changes that happened in the room since the given
+ stream ID.
+
+ Returns:
+ Collection of user ID/device ID tuples of all devices that have
+ changed
+ """
+
+ sql = """
+ SELECT DISTINCT user_id, device_id FROM device_lists_changes_in_room
+ WHERE room_id = ? AND stream_id > ?
+ """
+
+ def get_device_list_changes_in_room_txn(
+ txn: LoggingTransaction,
+ ) -> Collection[Tuple[str, str]]:
+ txn.execute(sql, (room_id, min_stream_id))
+ return cast(Collection[Tuple[str, str]], txn.fetchall())
+
+ return await self.db_pool.runInteraction(
+ "get_device_list_changes_in_room",
+ get_device_list_changes_in_room_txn,
+ )
+
class DeviceBackgroundUpdateStore(SQLBaseStore):
def __init__(
@@ -1910,14 +1973,15 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
user_id: str,
device_id: str,
room_id: str,
- stream_id: int,
+ stream_id: Optional[int],
hosts: Collection[str],
context: Optional[Dict[str, str]],
) -> None:
"""Queue the device update to be sent to the given set of hosts,
calculated from the room ID.
- Marks the associated row in `device_lists_changes_in_room` as handled.
+ Marks the associated row in `device_lists_changes_in_room` as handled,
+ if `stream_id` is provided.
"""
def add_device_list_outbound_pokes_txn(
@@ -1933,17 +1997,18 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
context=context,
)
- self.db_pool.simple_update_txn(
- txn,
- table="device_lists_changes_in_room",
- keyvalues={
- "user_id": user_id,
- "device_id": device_id,
- "stream_id": stream_id,
- "room_id": room_id,
- },
- updatevalues={"converted_to_destinations": True},
- )
+ if stream_id:
+ self.db_pool.simple_update_txn(
+ txn,
+ table="device_lists_changes_in_room",
+ keyvalues={
+ "user_id": user_id,
+ "device_id": device_id,
+ "stream_id": stream_id,
+ "room_id": room_id,
+ },
+ updatevalues={"converted_to_destinations": True},
+ )
if not hosts:
# If there are no hosts then we don't try and generate stream IDs.
@@ -1959,3 +2024,58 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
add_device_list_outbound_pokes_txn,
stream_ids,
)
+
+ async def add_remote_device_list_to_pending(
+ self, user_id: str, device_id: str
+ ) -> None:
+ """Add a device list update to the table tracking remote device list
+ updates during partial joins.
+ """
+
+ async with self._device_list_id_gen.get_next() as stream_id: # type: ignore[attr-defined]
+ await self.db_pool.simple_upsert(
+ table="device_lists_remote_pending",
+ keyvalues={
+ "user_id": user_id,
+ "device_id": device_id,
+ },
+ values={"stream_id": stream_id},
+ desc="add_remote_device_list_to_pending",
+ )
+
+ async def get_pending_remote_device_list_updates_for_room(
+ self, room_id: str
+ ) -> Collection[Tuple[str, str]]:
+ """Get the set of remote device list updates from the pending table for
+ the room.
+ """
+
+ min_device_stream_id = await self.db_pool.simple_select_one_onecol(
+ table="partial_state_rooms",
+ keyvalues={
+ "room_id": room_id,
+ },
+ retcol="device_lists_stream_id",
+ desc="get_pending_remote_device_list_updates_for_room_device",
+ )
+
+ sql = """
+ SELECT user_id, device_id FROM device_lists_remote_pending AS d
+ INNER JOIN current_state_events AS c ON
+ type = 'm.room.member'
+ AND state_key = user_id
+ AND membership = 'join'
+ WHERE
+ room_id = ? AND stream_id > ?
+ """
+
+ def get_pending_remote_device_list_updates_for_room_txn(
+ txn: LoggingTransaction,
+ ) -> Collection[Tuple[str, str]]:
+ txn.execute(sql, (room_id, min_device_stream_id))
+ return cast(Collection[Tuple[str, str]], txn.fetchall())
+
+ return await self.db_pool.runInteraction(
+ "get_pending_remote_device_list_updates_for_room",
+ get_pending_remote_device_list_updates_for_room_txn,
+ )
|