diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 8a359d7eb8..2d6f02c14f 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -21,6 +21,7 @@ from typing import (
FrozenSet,
Iterable,
List,
+ Mapping,
Optional,
Tuple,
cast,
@@ -122,25 +123,25 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
return self._account_data_id_gen.get_current_token()
@cached()
- async def get_account_data_for_user(
+ async def get_global_account_data_for_user(
self, user_id: str
- ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
+ ) -> Mapping[str, JsonDict]:
"""
- Get all the client account_data for a user.
+ Get all the global client account_data for a user.
If experimental MSC3391 support is enabled, any entries with an empty
content body are excluded; as this means they have been deleted.
Args:
user_id: The user to get the account_data for.
+
Returns:
- A 2-tuple of a dict of global account_data and a dict mapping from
- room_id string to per room account_data dicts.
+ The global account_data.
"""
- def get_account_data_for_user_txn(
+ def get_global_account_data_for_user(
txn: LoggingTransaction,
- ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
+ ) -> Dict[str, JsonDict]:
# The 'content != '{}' condition below prevents us from using
# `simple_select_list_txn` here, as it doesn't support conditions
# other than 'equals'.
@@ -158,10 +159,34 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
txn.execute(sql, (user_id,))
rows = self.db_pool.cursor_to_dict(txn)
- global_account_data = {
+ return {
row["account_data_type"]: db_to_json(row["content"]) for row in rows
}
+ return await self.db_pool.runInteraction(
+ "get_global_account_data_for_user", get_global_account_data_for_user
+ )
+
+ @cached()
+ async def get_room_account_data_for_user(
+ self, user_id: str
+ ) -> Mapping[str, Mapping[str, JsonDict]]:
+ """
+ Get all of the per-room client account_data for a user.
+
+ If experimental MSC3391 support is enabled, any entries with an empty
+ content body are excluded; as this means they have been deleted.
+
+ Args:
+ user_id: The user to get the account_data for.
+
+ Returns:
+ A dict mapping from room_id string to per-room account_data dicts.
+ """
+
+ def get_room_account_data_for_user_txn(
+ txn: LoggingTransaction,
+ ) -> Dict[str, Dict[str, JsonDict]]:
# The 'content != '{}' condition below prevents us from using
# `simple_select_list_txn` here, as it doesn't support conditions
# other than 'equals'.
@@ -185,10 +210,10 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
room_data[row["account_data_type"]] = db_to_json(row["content"])
- return global_account_data, by_room
+ return by_room
return await self.db_pool.runInteraction(
- "get_account_data_for_user", get_account_data_for_user_txn
+ "get_room_account_data_for_user_txn", get_room_account_data_for_user_txn
)
@cached(num_args=2, max_entries=5000, tree=True)
@@ -342,36 +367,61 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
"get_updated_room_account_data", get_updated_room_account_data_txn
)
- async def get_updated_account_data_for_user(
+ async def get_updated_global_account_data_for_user(
self, user_id: str, stream_id: int
- ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
- """Get all the client account_data for a that's changed for a user
+ ) -> Dict[str, JsonDict]:
+ """Get all the global account_data that's changed for a user.
Args:
user_id: The user to get the account_data for.
stream_id: The point in the stream since which to get updates
+
Returns:
- A deferred pair of a dict of global account_data and a dict
- mapping from room_id string to per room account_data dicts.
+ A dict of global account_data.
"""
- def get_updated_account_data_for_user_txn(
+ def get_updated_global_account_data_for_user(
txn: LoggingTransaction,
- ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
- sql = (
- "SELECT account_data_type, content FROM account_data"
- " WHERE user_id = ? AND stream_id > ?"
- )
-
+ ) -> Dict[str, JsonDict]:
+ sql = """
+ SELECT account_data_type, content FROM account_data
+ WHERE user_id = ? AND stream_id > ?
+ """
txn.execute(sql, (user_id, stream_id))
- global_account_data = {row[0]: db_to_json(row[1]) for row in txn}
+ return {row[0]: db_to_json(row[1]) for row in txn}
- sql = (
- "SELECT room_id, account_data_type, content FROM room_account_data"
- " WHERE user_id = ? AND stream_id > ?"
- )
+ changed = self._account_data_stream_cache.has_entity_changed(
+ user_id, int(stream_id)
+ )
+ if not changed:
+ return {}
+
+ return await self.db_pool.runInteraction(
+ "get_updated_global_account_data_for_user",
+ get_updated_global_account_data_for_user,
+ )
+
+ async def get_updated_room_account_data_for_user(
+ self, user_id: str, stream_id: int
+ ) -> Dict[str, Dict[str, JsonDict]]:
+ """Get all the room account_data that's changed for a user.
+ Args:
+ user_id: The user to get the account_data for.
+ stream_id: The point in the stream since which to get updates
+
+ Returns:
+ A dict mapping from room_id string to per room account_data dicts.
+ """
+
+ def get_updated_room_account_data_for_user_txn(
+ txn: LoggingTransaction,
+ ) -> Dict[str, Dict[str, JsonDict]]:
+ sql = """
+ SELECT room_id, account_data_type, content FROM room_account_data
+ WHERE user_id = ? AND stream_id > ?
+ """
txn.execute(sql, (user_id, stream_id))
account_data_by_room: Dict[str, Dict[str, JsonDict]] = {}
@@ -379,16 +429,17 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
room_account_data = account_data_by_room.setdefault(row[0], {})
room_account_data[row[1]] = db_to_json(row[2])
- return global_account_data, account_data_by_room
+ return account_data_by_room
changed = self._account_data_stream_cache.has_entity_changed(
user_id, int(stream_id)
)
if not changed:
- return {}, {}
+ return {}
return await self.db_pool.runInteraction(
- "get_updated_account_data_for_user", get_updated_account_data_for_user_txn
+ "get_updated_room_account_data_for_user",
+ get_updated_room_account_data_for_user_txn,
)
@cached(max_entries=5000, iterable=True)
@@ -444,7 +495,8 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
self.get_global_account_data_by_type_for_user.invalidate(
(row.user_id, row.data_type)
)
- self.get_account_data_for_user.invalidate((row.user_id,))
+ self.get_global_account_data_for_user.invalidate((row.user_id,))
+ self.get_room_account_data_for_user.invalidate((row.user_id,))
self.get_account_data_for_room.invalidate((row.user_id, row.room_id))
self.get_account_data_for_room_and_type.invalidate(
(row.user_id, row.room_id, row.data_type)
@@ -492,7 +544,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
)
self._account_data_stream_cache.entity_has_changed(user_id, next_id)
- self.get_account_data_for_user.invalidate((user_id,))
+ self.get_room_account_data_for_user.invalidate((user_id,))
self.get_account_data_for_room.invalidate((user_id, room_id))
self.get_account_data_for_room_and_type.prefill(
(user_id, room_id, account_data_type), content
@@ -558,7 +610,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
return None
self._account_data_stream_cache.entity_has_changed(user_id, next_id)
- self.get_account_data_for_user.invalidate((user_id,))
+ self.get_room_account_data_for_user.invalidate((user_id,))
self.get_account_data_for_room.invalidate((user_id, room_id))
self.get_account_data_for_room_and_type.prefill(
(user_id, room_id, account_data_type), {}
@@ -593,7 +645,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
)
self._account_data_stream_cache.entity_has_changed(user_id, next_id)
- self.get_account_data_for_user.invalidate((user_id,))
+ self.get_global_account_data_for_user.invalidate((user_id,))
self.get_global_account_data_by_type_for_user.invalidate(
(user_id, account_data_type)
)
@@ -761,7 +813,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
return None
self._account_data_stream_cache.entity_has_changed(user_id, next_id)
- self.get_account_data_for_user.invalidate((user_id,))
+ self.get_global_account_data_for_user.invalidate((user_id,))
self.get_global_account_data_by_type_for_user.prefill(
(user_id, account_data_type), {}
)
@@ -822,7 +874,10 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
txn, self.get_account_data_for_room_and_type, (user_id,)
)
self._invalidate_cache_and_stream(
- txn, self.get_account_data_for_user, (user_id,)
+ txn, self.get_global_account_data_for_user, (user_id,)
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_room_account_data_for_user, (user_id,)
)
self._invalidate_cache_and_stream(
txn, self.get_global_account_data_by_type_for_user, (user_id,)
|