diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 4dccbb732a..0843f10340 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -83,6 +83,7 @@ logger = logging.getLogger(__name__)
class DataStore(
EventsBackgroundUpdatesStore,
+ DeviceStore,
RoomMemberStore,
RoomStore,
RoomBatchStore,
@@ -114,7 +115,6 @@ class DataStore(
StreamWorkerStore,
OpenIdStore,
ClientIpWorkerStore,
- DeviceStore,
DeviceInboxStore,
UserDirectoryStore,
UserErasureStore,
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 5d700ca6c3..1151fb0cc3 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -47,6 +47,7 @@ from synapse.storage.database import (
make_tuple_comparison_clause,
)
from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
+from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.types import Cursor
from synapse.types import JsonDict, get_verify_key_from_cross_signing_key
from synapse.util import json_decoder, json_encoder
@@ -70,7 +71,7 @@ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
-class DeviceWorkerStore(EndToEndKeyWorkerStore):
+class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
def __init__(
self,
database: DatabasePool,
@@ -985,24 +986,59 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
desc="mark_remote_user_device_cache_as_valid",
)
+ async def handle_potentially_left_users(self, user_ids: Set[str]) -> None:
+ """Given a set of remote users check if the server still shares a room with
+ them. If not then mark those users' device cache as stale.
+ """
+
+ if not user_ids:
+ return
+
+ await self.db_pool.runInteraction(
+ "_handle_potentially_left_users",
+ self.handle_potentially_left_users_txn,
+ user_ids,
+ )
+
+ def handle_potentially_left_users_txn(
+ self,
+ txn: LoggingTransaction,
+ user_ids: Set[str],
+ ) -> None:
+ """Given a set of remote users check if the server still shares a room with
+ them. If not then mark those users' device cache as stale.
+ """
+
+ if not user_ids:
+ return
+
+ joined_users = self.get_users_server_still_shares_room_with_txn(txn, user_ids)
+ left_users = user_ids - joined_users
+
+ for user_id in left_users:
+ self.mark_remote_user_device_list_as_unsubscribed_txn(txn, user_id)
+
async def mark_remote_user_device_list_as_unsubscribed(self, user_id: str) -> None:
"""Mark that we no longer track device lists for remote user."""
- def _mark_remote_user_device_list_as_unsubscribed_txn(
- txn: LoggingTransaction,
- ) -> None:
- self.db_pool.simple_delete_txn(
- txn,
- table="device_lists_remote_extremeties",
- keyvalues={"user_id": user_id},
- )
- self._invalidate_cache_and_stream(
- txn, self.get_device_list_last_stream_id_for_remote, (user_id,)
- )
-
await self.db_pool.runInteraction(
"mark_remote_user_device_list_as_unsubscribed",
- _mark_remote_user_device_list_as_unsubscribed_txn,
+ self.mark_remote_user_device_list_as_unsubscribed_txn,
+ user_id,
+ )
+
+ def mark_remote_user_device_list_as_unsubscribed_txn(
+ self,
+ txn: LoggingTransaction,
+ user_id: str,
+ ) -> None:
+ self.db_pool.simple_delete_txn(
+ txn,
+ table="device_lists_remote_extremeties",
+ keyvalues={"user_id": user_id},
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_device_list_last_stream_id_for_remote, (user_id,)
)
async def get_dehydrated_device(
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 2e156a4a11..b59eb7478b 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1202,6 +1202,12 @@ class PersistEventsStore:
txn, room_id, members_changed
)
+ # Check if any of the remote membership changes requires us to
+ # unsubscribe from their device lists.
+ self.store.handle_potentially_left_users_txn(
+ txn, {m for m in members_changed if not self.hs.is_mine_id(m)}
+ )
+
def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str) -> None:
"""Update the room version in the database based off current state
events.
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index a8d224602a..8ada3cdac3 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -662,31 +662,37 @@ class RoomMemberWorkerStore(EventsWorkerStore):
if not user_ids:
return set()
- def _get_users_server_still_shares_room_with_txn(
- txn: LoggingTransaction,
- ) -> Set[str]:
- sql = """
- SELECT state_key FROM current_state_events
- WHERE
- type = 'm.room.member'
- AND membership = 'join'
- AND %s
- GROUP BY state_key
- """
-
- clause, args = make_in_list_sql_clause(
- self.database_engine, "state_key", user_ids
- )
+ return await self.db_pool.runInteraction(
+ "get_users_server_still_shares_room_with",
+ self.get_users_server_still_shares_room_with_txn,
+ user_ids,
+ )
- txn.execute(sql % (clause,), args)
+ def get_users_server_still_shares_room_with_txn(
+ self,
+ txn: LoggingTransaction,
+ user_ids: Collection[str],
+ ) -> Set[str]:
+ if not user_ids:
+ return set()
- return {row[0] for row in txn}
+ sql = """
+ SELECT state_key FROM current_state_events
+ WHERE
+ type = 'm.room.member'
+ AND membership = 'join'
+ AND %s
+ GROUP BY state_key
+ """
- return await self.db_pool.runInteraction(
- "get_users_server_still_shares_room_with",
- _get_users_server_still_shares_room_with_txn,
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "state_key", user_ids
)
+ txn.execute(sql % (clause,), args)
+
+ return {row[0] for row in txn}
+
@cancellable
async def get_rooms_for_user(
self, user_id: str, on_invalidate: Optional[Callable[[], None]] = None
|