summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/__init__.py2
-rw-r--r--synapse/storage/databases/main/devices.py64
-rw-r--r--synapse/storage/databases/main/events.py6
-rw-r--r--synapse/storage/databases/main/roommember.py46
4 files changed, 83 insertions, 35 deletions
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 4dccbb732a..0843f10340 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -83,6 +83,7 @@ logger = logging.getLogger(__name__)
 
 class DataStore(
     EventsBackgroundUpdatesStore,
+    DeviceStore,
     RoomMemberStore,
     RoomStore,
     RoomBatchStore,
@@ -114,7 +115,6 @@ class DataStore(
     StreamWorkerStore,
     OpenIdStore,
     ClientIpWorkerStore,
-    DeviceStore,
     DeviceInboxStore,
     UserDirectoryStore,
     UserErasureStore,
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 5d700ca6c3..1151fb0cc3 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -47,6 +47,7 @@ from synapse.storage.database import (
     make_tuple_comparison_clause,
 )
 from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
+from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
 from synapse.storage.types import Cursor
 from synapse.types import JsonDict, get_verify_key_from_cross_signing_key
 from synapse.util import json_decoder, json_encoder
@@ -70,7 +71,7 @@ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
 BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
 
 
-class DeviceWorkerStore(EndToEndKeyWorkerStore):
+class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
     def __init__(
         self,
         database: DatabasePool,
@@ -985,24 +986,59 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
             desc="mark_remote_user_device_cache_as_valid",
         )
 
+    async def handle_potentially_left_users(self, user_ids: Set[str]) -> None:
+        """Given a set of remote users check if the server still shares a room with
+        them. If not then mark those users' device cache as stale.
+        """
+
+        if not user_ids:
+            return
+
+        await self.db_pool.runInteraction(
+            "_handle_potentially_left_users",
+            self.handle_potentially_left_users_txn,
+            user_ids,
+        )
+
+    def handle_potentially_left_users_txn(
+        self,
+        txn: LoggingTransaction,
+        user_ids: Set[str],
+    ) -> None:
+        """Given a set of remote users check if the server still shares a room with
+        them. If not then mark those users' device cache as stale.
+        """
+
+        if not user_ids:
+            return
+
+        joined_users = self.get_users_server_still_shares_room_with_txn(txn, user_ids)
+        left_users = user_ids - joined_users
+
+        for user_id in left_users:
+            self.mark_remote_user_device_list_as_unsubscribed_txn(txn, user_id)
+
     async def mark_remote_user_device_list_as_unsubscribed(self, user_id: str) -> None:
         """Mark that we no longer track device lists for remote user."""
 
-        def _mark_remote_user_device_list_as_unsubscribed_txn(
-            txn: LoggingTransaction,
-        ) -> None:
-            self.db_pool.simple_delete_txn(
-                txn,
-                table="device_lists_remote_extremeties",
-                keyvalues={"user_id": user_id},
-            )
-            self._invalidate_cache_and_stream(
-                txn, self.get_device_list_last_stream_id_for_remote, (user_id,)
-            )
-
         await self.db_pool.runInteraction(
             "mark_remote_user_device_list_as_unsubscribed",
-            _mark_remote_user_device_list_as_unsubscribed_txn,
+            self.mark_remote_user_device_list_as_unsubscribed_txn,
+            user_id,
+        )
+
+    def mark_remote_user_device_list_as_unsubscribed_txn(
+        self,
+        txn: LoggingTransaction,
+        user_id: str,
+    ) -> None:
+        self.db_pool.simple_delete_txn(
+            txn,
+            table="device_lists_remote_extremeties",
+            keyvalues={"user_id": user_id},
+        )
+        self._invalidate_cache_and_stream(
+            txn, self.get_device_list_last_stream_id_for_remote, (user_id,)
         )
 
     async def get_dehydrated_device(
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 2e156a4a11..b59eb7478b 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1202,6 +1202,12 @@ class PersistEventsStore:
                 txn, room_id, members_changed
             )
 
+            # Check if any of the remote membership changes requires us to
+            # unsubscribe from their device lists.
+            self.store.handle_potentially_left_users_txn(
+                txn, {m for m in members_changed if not self.hs.is_mine_id(m)}
+            )
+
     def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str) -> None:
         """Update the room version in the database based off current state
         events.
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index a8d224602a..8ada3cdac3 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -662,31 +662,37 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         if not user_ids:
             return set()
 
-        def _get_users_server_still_shares_room_with_txn(
-            txn: LoggingTransaction,
-        ) -> Set[str]:
-            sql = """
-                SELECT state_key FROM current_state_events
-                WHERE
-                    type = 'm.room.member'
-                    AND membership = 'join'
-                    AND %s
-                GROUP BY state_key
-            """
-
-            clause, args = make_in_list_sql_clause(
-                self.database_engine, "state_key", user_ids
-            )
+        return await self.db_pool.runInteraction(
+            "get_users_server_still_shares_room_with",
+            self.get_users_server_still_shares_room_with_txn,
+            user_ids,
+        )
 
-            txn.execute(sql % (clause,), args)
+    def get_users_server_still_shares_room_with_txn(
+        self,
+        txn: LoggingTransaction,
+        user_ids: Collection[str],
+    ) -> Set[str]:
+        if not user_ids:
+            return set()
 
-            return {row[0] for row in txn}
+        sql = """
+            SELECT state_key FROM current_state_events
+            WHERE
+                type = 'm.room.member'
+                AND membership = 'join'
+                AND %s
+            GROUP BY state_key
+        """
 
-        return await self.db_pool.runInteraction(
-            "get_users_server_still_shares_room_with",
-            _get_users_server_still_shares_room_with_txn,
+        clause, args = make_in_list_sql_clause(
+            self.database_engine, "state_key", user_ids
         )
 
+        txn.execute(sql % (clause,), args)
+
+        return {row[0] for row in txn}
+
     @cancellable
     async def get_rooms_for_user(
         self, user_id: str, on_invalidate: Optional[Callable[[], None]] = None