summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/account_data.py127
1 files changed, 91 insertions, 36 deletions
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 8a359d7eb8..2d6f02c14f 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -21,6 +21,7 @@ from typing import (
     FrozenSet,
     Iterable,
     List,
+    Mapping,
     Optional,
     Tuple,
     cast,
@@ -122,25 +123,25 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
         return self._account_data_id_gen.get_current_token()
 
     @cached()
-    async def get_account_data_for_user(
+    async def get_global_account_data_for_user(
         self, user_id: str
-    ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
+    ) -> Mapping[str, JsonDict]:
         """
-        Get all the client account_data for a user.
+        Get all the global client account_data for a user.
 
         If experimental MSC3391 support is enabled, any entries with an empty
         content body are excluded; as this means they have been deleted.
 
         Args:
             user_id: The user to get the account_data for.
+
         Returns:
-            A 2-tuple of a dict of global account_data and a dict mapping from
-            room_id string to per room account_data dicts.
+            The global account_data.
         """
 
-        def get_account_data_for_user_txn(
+        def get_global_account_data_for_user(
             txn: LoggingTransaction,
-        ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
+        ) -> Dict[str, JsonDict]:
             # The 'content != '{}' condition below prevents us from using
             # `simple_select_list_txn` here, as it doesn't support conditions
             # other than 'equals'.
@@ -158,10 +159,34 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
             txn.execute(sql, (user_id,))
             rows = self.db_pool.cursor_to_dict(txn)
 
-            global_account_data = {
+            return {
                 row["account_data_type"]: db_to_json(row["content"]) for row in rows
             }
 
+        return await self.db_pool.runInteraction(
+            "get_global_account_data_for_user", get_global_account_data_for_user
+        )
+
+    @cached()
+    async def get_room_account_data_for_user(
+        self, user_id: str
+    ) -> Mapping[str, Mapping[str, JsonDict]]:
+        """
+        Get all of the per-room client account_data for a user.
+
+        If experimental MSC3391 support is enabled, any entries with an empty
+        content body are excluded; as this means they have been deleted.
+
+        Args:
+            user_id: The user to get the account_data for.
+
+        Returns:
+            A dict mapping from room_id string to per-room account_data dicts.
+        """
+
+        def get_room_account_data_for_user_txn(
+            txn: LoggingTransaction,
+        ) -> Dict[str, Dict[str, JsonDict]]:
             # The 'content != '{}' condition below prevents us from using
             # `simple_select_list_txn` here, as it doesn't support conditions
             # other than 'equals'.
@@ -185,10 +210,10 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
 
                 room_data[row["account_data_type"]] = db_to_json(row["content"])
 
-            return global_account_data, by_room
+            return by_room
 
         return await self.db_pool.runInteraction(
-            "get_account_data_for_user", get_account_data_for_user_txn
+            "get_room_account_data_for_user_txn", get_room_account_data_for_user_txn
         )
 
     @cached(num_args=2, max_entries=5000, tree=True)
@@ -342,36 +367,61 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
             "get_updated_room_account_data", get_updated_room_account_data_txn
         )
 
-    async def get_updated_account_data_for_user(
+    async def get_updated_global_account_data_for_user(
         self, user_id: str, stream_id: int
-    ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
-        """Get all the client account_data for a that's changed for a user
+    ) -> Dict[str, JsonDict]:
+        """Get all the global account_data that's changed for a user.
 
         Args:
             user_id: The user to get the account_data for.
             stream_id: The point in the stream since which to get updates
+
         Returns:
-            A deferred pair of a dict of global account_data and a dict
-            mapping from room_id string to per room account_data dicts.
+            A dict of global account_data.
         """
 
-        def get_updated_account_data_for_user_txn(
+        def get_updated_global_account_data_for_user(
             txn: LoggingTransaction,
-        ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
-            sql = (
-                "SELECT account_data_type, content FROM account_data"
-                " WHERE user_id = ? AND stream_id > ?"
-            )
-
+        ) -> Dict[str, JsonDict]:
+            sql = """
+                SELECT account_data_type, content FROM account_data
+                WHERE user_id = ? AND stream_id > ?
+            """
             txn.execute(sql, (user_id, stream_id))
 
-            global_account_data = {row[0]: db_to_json(row[1]) for row in txn}
+            return {row[0]: db_to_json(row[1]) for row in txn}
 
-            sql = (
-                "SELECT room_id, account_data_type, content FROM room_account_data"
-                " WHERE user_id = ? AND stream_id > ?"
-            )
+        changed = self._account_data_stream_cache.has_entity_changed(
+            user_id, int(stream_id)
+        )
+        if not changed:
+            return {}
+
+        return await self.db_pool.runInteraction(
+            "get_updated_global_account_data_for_user",
+            get_updated_global_account_data_for_user,
+        )
+
+    async def get_updated_room_account_data_for_user(
+        self, user_id: str, stream_id: int
+    ) -> Dict[str, Dict[str, JsonDict]]:
+        """Get all the room account_data that's changed for a user.
 
+        Args:
+            user_id: The user to get the account_data for.
+            stream_id: The point in the stream since which to get updates
+
+        Returns:
+            A dict mapping from room_id string to per room account_data dicts.
+        """
+
+        def get_updated_room_account_data_for_user_txn(
+            txn: LoggingTransaction,
+        ) -> Dict[str, Dict[str, JsonDict]]:
+            sql = """
+                SELECT room_id, account_data_type, content FROM room_account_data
+                WHERE user_id = ? AND stream_id > ?
+            """
             txn.execute(sql, (user_id, stream_id))
 
             account_data_by_room: Dict[str, Dict[str, JsonDict]] = {}
@@ -379,16 +429,17 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
                 room_account_data = account_data_by_room.setdefault(row[0], {})
                 room_account_data[row[1]] = db_to_json(row[2])
 
-            return global_account_data, account_data_by_room
+            return account_data_by_room
 
         changed = self._account_data_stream_cache.has_entity_changed(
             user_id, int(stream_id)
         )
         if not changed:
-            return {}, {}
+            return {}
 
         return await self.db_pool.runInteraction(
-            "get_updated_account_data_for_user", get_updated_account_data_for_user_txn
+            "get_updated_room_account_data_for_user",
+            get_updated_room_account_data_for_user_txn,
         )
 
     @cached(max_entries=5000, iterable=True)
@@ -444,7 +495,8 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
                     self.get_global_account_data_by_type_for_user.invalidate(
                         (row.user_id, row.data_type)
                     )
-                self.get_account_data_for_user.invalidate((row.user_id,))
+                self.get_global_account_data_for_user.invalidate((row.user_id,))
+                self.get_room_account_data_for_user.invalidate((row.user_id,))
                 self.get_account_data_for_room.invalidate((row.user_id, row.room_id))
                 self.get_account_data_for_room_and_type.invalidate(
                     (row.user_id, row.room_id, row.data_type)
@@ -492,7 +544,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
             )
 
             self._account_data_stream_cache.entity_has_changed(user_id, next_id)
-            self.get_account_data_for_user.invalidate((user_id,))
+            self.get_room_account_data_for_user.invalidate((user_id,))
             self.get_account_data_for_room.invalidate((user_id, room_id))
             self.get_account_data_for_room_and_type.prefill(
                 (user_id, room_id, account_data_type), content
@@ -558,7 +610,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
                 return None
 
             self._account_data_stream_cache.entity_has_changed(user_id, next_id)
-            self.get_account_data_for_user.invalidate((user_id,))
+            self.get_room_account_data_for_user.invalidate((user_id,))
             self.get_account_data_for_room.invalidate((user_id, room_id))
             self.get_account_data_for_room_and_type.prefill(
                 (user_id, room_id, account_data_type), {}
@@ -593,7 +645,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
             )
 
             self._account_data_stream_cache.entity_has_changed(user_id, next_id)
-            self.get_account_data_for_user.invalidate((user_id,))
+            self.get_global_account_data_for_user.invalidate((user_id,))
             self.get_global_account_data_by_type_for_user.invalidate(
                 (user_id, account_data_type)
             )
@@ -761,7 +813,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
                 return None
 
             self._account_data_stream_cache.entity_has_changed(user_id, next_id)
-            self.get_account_data_for_user.invalidate((user_id,))
+            self.get_global_account_data_for_user.invalidate((user_id,))
             self.get_global_account_data_by_type_for_user.prefill(
                 (user_id, account_data_type), {}
             )
@@ -822,7 +874,10 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
             txn, self.get_account_data_for_room_and_type, (user_id,)
         )
         self._invalidate_cache_and_stream(
-            txn, self.get_account_data_for_user, (user_id,)
+            txn, self.get_global_account_data_for_user, (user_id,)
+        )
+        self._invalidate_cache_and_stream(
+            txn, self.get_room_account_data_for_user, (user_id,)
         )
         self._invalidate_cache_and_stream(
             txn, self.get_global_account_data_by_type_for_user, (user_id,)