summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/api/filtering.py30
-rw-r--r--synapse/handlers/account_data.py10
-rw-r--r--synapse/handlers/initial_sync.py5
-rw-r--r--synapse/handlers/room_member.py2
-rw-r--r--synapse/handlers/sync.py88
-rw-r--r--synapse/rest/admin/users.py3
-rw-r--r--synapse/storage/databases/main/account_data.py127
7 files changed, 175 insertions, 90 deletions
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index 83c42fc25a..b9f432cc23 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -219,9 +219,13 @@ class FilterCollection:
         self._room_timeline_filter = Filter(hs, room_filter_json.get("timeline", {}))
         self._room_state_filter = Filter(hs, room_filter_json.get("state", {}))
         self._room_ephemeral_filter = Filter(hs, room_filter_json.get("ephemeral", {}))
-        self._room_account_data = Filter(hs, room_filter_json.get("account_data", {}))
+        self._room_account_data_filter = Filter(
+            hs, room_filter_json.get("account_data", {})
+        )
         self._presence_filter = Filter(hs, filter_json.get("presence", {}))
-        self._account_data = Filter(hs, filter_json.get("account_data", {}))
+        self._global_account_data_filter = Filter(
+            hs, filter_json.get("account_data", {})
+        )
 
         self.include_leave = filter_json.get("room", {}).get("include_leave", False)
         self.event_fields = filter_json.get("event_fields", [])
@@ -256,8 +260,10 @@ class FilterCollection:
     ) -> List[UserPresenceState]:
         return await self._presence_filter.filter(presence_states)
 
-    async def filter_account_data(self, events: Iterable[JsonDict]) -> List[JsonDict]:
-        return await self._account_data.filter(events)
+    async def filter_global_account_data(
+        self, events: Iterable[JsonDict]
+    ) -> List[JsonDict]:
+        return await self._global_account_data_filter.filter(events)
 
     async def filter_room_state(self, events: Iterable[EventBase]) -> List[EventBase]:
         return await self._room_state_filter.filter(
@@ -279,7 +285,7 @@ class FilterCollection:
     async def filter_room_account_data(
         self, events: Iterable[JsonDict]
     ) -> List[JsonDict]:
-        return await self._room_account_data.filter(
+        return await self._room_account_data_filter.filter(
             await self._room_filter.filter(events)
         )
 
@@ -292,6 +298,13 @@ class FilterCollection:
             or self._presence_filter.filters_all_senders()
         )
 
+    def blocks_all_global_account_data(self) -> bool:
+        """True if all global acount data will be filtered out."""
+        return (
+            self._global_account_data_filter.filters_all_types()
+            or self._global_account_data_filter.filters_all_senders()
+        )
+
     def blocks_all_room_ephemeral(self) -> bool:
         return (
             self._room_ephemeral_filter.filters_all_types()
@@ -299,6 +312,13 @@ class FilterCollection:
             or self._room_ephemeral_filter.filters_all_rooms()
         )
 
+    def blocks_all_room_account_data(self) -> bool:
+        return (
+            self._room_account_data_filter.filters_all_types()
+            or self._room_account_data_filter.filters_all_senders()
+            or self._room_account_data_filter.filters_all_rooms()
+        )
+
     def blocks_all_room_timeline(self) -> bool:
         return (
             self._room_timeline_filter.filters_all_types()
diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index 67e789eef7..797de46dbc 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -343,10 +343,12 @@ class AccountDataEventSource(EventSource[int, JsonDict]):
                 }
             )
 
-        (
-            account_data,
-            room_account_data,
-        ) = await self.store.get_updated_account_data_for_user(user_id, last_stream_id)
+        account_data = await self.store.get_updated_global_account_data_for_user(
+            user_id, last_stream_id
+        )
+        room_account_data = await self.store.get_updated_room_account_data_for_user(
+            user_id, last_stream_id
+        )
 
         for account_data_type, content in account_data.items():
             results.append({"type": account_data_type, "content": content})
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 191529bd8e..1a29abde98 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -154,9 +154,8 @@ class InitialSyncHandler:
 
         tags_by_room = await self.store.get_tags_for_user(user_id)
 
-        account_data, account_data_by_room = await self.store.get_account_data_for_user(
-            user_id
-        )
+        account_data = await self.store.get_global_account_data_for_user(user_id)
+        account_data_by_room = await self.store.get_room_account_data_for_user(user_id)
 
         public_room_ids = await self.store.get_public_room_ids()
 
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index d236cc09b5..6e7141d2ef 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -484,7 +484,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             user_id: The user's ID.
         """
         # Retrieve user account data for predecessor room
-        user_account_data, _ = await self.store.get_account_data_for_user(user_id)
+        user_account_data = await self.store.get_global_account_data_for_user(user_id)
 
         # Copy direct message state if applicable
         direct_rooms = user_account_data.get(AccountDataTypes.DIRECT, {})
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 202b35eee6..399685e5b7 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -1444,9 +1444,9 @@ class SyncHandler:
 
         logger.debug("Fetching account data")
 
-        account_data_by_room = await self._generate_sync_entry_for_account_data(
-            sync_result_builder
-        )
+        # Global account data is included if it is not filtered out.
+        if not sync_config.filter_collection.blocks_all_global_account_data():
+            await self._generate_sync_entry_for_account_data(sync_result_builder)
 
         # Presence data is included if the server has it enabled and not filtered out.
         include_presence_data = bool(
@@ -1472,9 +1472,7 @@ class SyncHandler:
             (
                 newly_joined_rooms,
                 newly_left_rooms,
-            ) = await self._generate_sync_entry_for_rooms(
-                sync_result_builder, account_data_by_room
-            )
+            ) = await self._generate_sync_entry_for_rooms(sync_result_builder)
 
             # Work out which users have joined or left rooms we're in. We use this
             # to build the presence and device_list parts of the sync response in
@@ -1717,35 +1715,29 @@ class SyncHandler:
 
     async def _generate_sync_entry_for_account_data(
         self, sync_result_builder: "SyncResultBuilder"
-    ) -> Dict[str, Dict[str, JsonDict]]:
-        """Generates the account data portion of the sync response.
+    ) -> None:
+        """Generates the global account data portion of the sync response.
 
         Account data (called "Client Config" in the spec) can be set either globally
         or for a specific room. Account data consists of a list of events which
         accumulate state, much like a room.
 
-        This function retrieves global and per-room account data. The former is written
-        to the given `sync_result_builder`. The latter is returned directly, to be
-        later written to the `sync_result_builder` on a room-by-room basis.
+        This function retrieves global account data and writes it to the given
+        `sync_result_builder`. See `_generate_sync_entry_for_rooms` for handling
+         of per-room account data.
 
         Args:
             sync_result_builder
-
-        Returns:
-            A dictionary whose keys (room ids) map to the per room account data for that
-            room.
         """
         sync_config = sync_result_builder.sync_config
         user_id = sync_result_builder.sync_config.user.to_string()
         since_token = sync_result_builder.since_token
 
         if since_token and not sync_result_builder.full_state:
-            # TODO Do not fetch room account data if it will be unused.
-            (
-                global_account_data,
-                account_data_by_room,
-            ) = await self.store.get_updated_account_data_for_user(
-                user_id, since_token.account_data_key
+            global_account_data = (
+                await self.store.get_updated_global_account_data_for_user(
+                    user_id, since_token.account_data_key
+                )
             )
 
             push_rules_changed = await self.store.have_push_rules_changed_for_user(
@@ -1758,28 +1750,26 @@ class SyncHandler:
                     sync_config.user
                 )
         else:
-            # TODO Do not fetch room account data if it will be unused.
-            (
-                global_account_data,
-                account_data_by_room,
-            ) = await self.store.get_account_data_for_user(sync_config.user.to_string())
+            all_global_account_data = await self.store.get_global_account_data_for_user(
+                user_id
+            )
 
-            global_account_data = dict(global_account_data)
+            global_account_data = dict(all_global_account_data)
             global_account_data["m.push_rules"] = await self.push_rules_for_user(
                 sync_config.user
             )
 
-        account_data_for_user = await sync_config.filter_collection.filter_account_data(
-            [
-                {"type": account_data_type, "content": content}
-                for account_data_type, content in global_account_data.items()
-            ]
+        account_data_for_user = (
+            await sync_config.filter_collection.filter_global_account_data(
+                [
+                    {"type": account_data_type, "content": content}
+                    for account_data_type, content in global_account_data.items()
+                ]
+            )
         )
 
         sync_result_builder.account_data = account_data_for_user
 
-        return account_data_by_room
-
     async def _generate_sync_entry_for_presence(
         self,
         sync_result_builder: "SyncResultBuilder",
@@ -1839,9 +1829,7 @@ class SyncHandler:
         sync_result_builder.presence = presence
 
     async def _generate_sync_entry_for_rooms(
-        self,
-        sync_result_builder: "SyncResultBuilder",
-        account_data_by_room: Dict[str, Dict[str, JsonDict]],
+        self, sync_result_builder: "SyncResultBuilder"
     ) -> Tuple[AbstractSet[str], AbstractSet[str]]:
         """Generates the rooms portion of the sync response. Populates the
         `sync_result_builder` with the result.
@@ -1852,7 +1840,6 @@ class SyncHandler:
 
         Args:
             sync_result_builder
-            account_data_by_room: Dictionary of per room account data
 
         Returns:
             Returns a 2-tuple describing rooms the user has joined or left.
@@ -1865,9 +1852,30 @@ class SyncHandler:
         since_token = sync_result_builder.since_token
         user_id = sync_result_builder.sync_config.user.to_string()
 
+        blocks_all_rooms = (
+            sync_result_builder.sync_config.filter_collection.blocks_all_rooms()
+        )
+
+        # 0. Start by fetching room account data (if required).
+        if (
+            blocks_all_rooms
+            or sync_result_builder.sync_config.filter_collection.blocks_all_room_account_data()
+        ):
+            account_data_by_room: Mapping[str, Mapping[str, JsonDict]] = {}
+        elif since_token and not sync_result_builder.full_state:
+            account_data_by_room = (
+                await self.store.get_updated_room_account_data_for_user(
+                    user_id, since_token.account_data_key
+                )
+            )
+        else:
+            account_data_by_room = await self.store.get_room_account_data_for_user(
+                user_id
+            )
+
         # 1. Start by fetching all ephemeral events in rooms we've joined (if required).
         block_all_room_ephemeral = (
-            sync_result_builder.sync_config.filter_collection.blocks_all_rooms()
+            blocks_all_rooms
             or sync_result_builder.sync_config.filter_collection.blocks_all_room_ephemeral()
         )
         if block_all_room_ephemeral:
@@ -2294,7 +2302,7 @@ class SyncHandler:
         room_builder: "RoomSyncResultBuilder",
         ephemeral: List[JsonDict],
         tags: Optional[Dict[str, Dict[str, Any]]],
-        account_data: Dict[str, JsonDict],
+        account_data: Mapping[str, JsonDict],
         always_include: bool = False,
     ) -> None:
         """Populates the `joined` and `archived` section of `sync_result_builder`
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index b9dca8ef3a..0c0bf540b9 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -1192,7 +1192,8 @@ class AccountDataRestServlet(RestServlet):
         if not await self._store.get_user_by_id(user_id):
             raise NotFoundError("User not found")
 
-        global_data, by_room_data = await self._store.get_account_data_for_user(user_id)
+        global_data = await self._store.get_global_account_data_for_user(user_id)
+        by_room_data = await self._store.get_room_account_data_for_user(user_id)
         return HTTPStatus.OK, {
             "account_data": {
                 "global": global_data,
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 8a359d7eb8..2d6f02c14f 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -21,6 +21,7 @@ from typing import (
     FrozenSet,
     Iterable,
     List,
+    Mapping,
     Optional,
     Tuple,
     cast,
@@ -122,25 +123,25 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
         return self._account_data_id_gen.get_current_token()
 
     @cached()
-    async def get_account_data_for_user(
+    async def get_global_account_data_for_user(
         self, user_id: str
-    ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
+    ) -> Mapping[str, JsonDict]:
         """
-        Get all the client account_data for a user.
+        Get all the global client account_data for a user.
 
         If experimental MSC3391 support is enabled, any entries with an empty
         content body are excluded; as this means they have been deleted.
 
         Args:
             user_id: The user to get the account_data for.
+
         Returns:
-            A 2-tuple of a dict of global account_data and a dict mapping from
-            room_id string to per room account_data dicts.
+            The global account_data.
         """
 
-        def get_account_data_for_user_txn(
+        def get_global_account_data_for_user(
             txn: LoggingTransaction,
-        ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
+        ) -> Dict[str, JsonDict]:
             # The 'content != '{}' condition below prevents us from using
             # `simple_select_list_txn` here, as it doesn't support conditions
             # other than 'equals'.
@@ -158,10 +159,34 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
             txn.execute(sql, (user_id,))
             rows = self.db_pool.cursor_to_dict(txn)
 
-            global_account_data = {
+            return {
                 row["account_data_type"]: db_to_json(row["content"]) for row in rows
             }
 
+        return await self.db_pool.runInteraction(
+            "get_global_account_data_for_user", get_global_account_data_for_user
+        )
+
+    @cached()
+    async def get_room_account_data_for_user(
+        self, user_id: str
+    ) -> Mapping[str, Mapping[str, JsonDict]]:
+        """
+        Get all of the per-room client account_data for a user.
+
+        If experimental MSC3391 support is enabled, any entries with an empty
+        content body are excluded; as this means they have been deleted.
+
+        Args:
+            user_id: The user to get the account_data for.
+
+        Returns:
+            A dict mapping from room_id string to per-room account_data dicts.
+        """
+
+        def get_room_account_data_for_user_txn(
+            txn: LoggingTransaction,
+        ) -> Dict[str, Dict[str, JsonDict]]:
             # The 'content != '{}' condition below prevents us from using
             # `simple_select_list_txn` here, as it doesn't support conditions
             # other than 'equals'.
@@ -185,10 +210,10 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
 
                 room_data[row["account_data_type"]] = db_to_json(row["content"])
 
-            return global_account_data, by_room
+            return by_room
 
         return await self.db_pool.runInteraction(
-            "get_account_data_for_user", get_account_data_for_user_txn
+            "get_room_account_data_for_user_txn", get_room_account_data_for_user_txn
         )
 
     @cached(num_args=2, max_entries=5000, tree=True)
@@ -342,36 +367,61 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
             "get_updated_room_account_data", get_updated_room_account_data_txn
         )
 
-    async def get_updated_account_data_for_user(
+    async def get_updated_global_account_data_for_user(
         self, user_id: str, stream_id: int
-    ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
-        """Get all the client account_data for a that's changed for a user
+    ) -> Dict[str, JsonDict]:
+        """Get all the global account_data that's changed for a user.
 
         Args:
             user_id: The user to get the account_data for.
             stream_id: The point in the stream since which to get updates
+
         Returns:
-            A deferred pair of a dict of global account_data and a dict
-            mapping from room_id string to per room account_data dicts.
+            A dict of global account_data.
         """
 
-        def get_updated_account_data_for_user_txn(
+        def get_updated_global_account_data_for_user(
             txn: LoggingTransaction,
-        ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
-            sql = (
-                "SELECT account_data_type, content FROM account_data"
-                " WHERE user_id = ? AND stream_id > ?"
-            )
-
+        ) -> Dict[str, JsonDict]:
+            sql = """
+                SELECT account_data_type, content FROM account_data
+                WHERE user_id = ? AND stream_id > ?
+            """
             txn.execute(sql, (user_id, stream_id))
 
-            global_account_data = {row[0]: db_to_json(row[1]) for row in txn}
+            return {row[0]: db_to_json(row[1]) for row in txn}
 
-            sql = (
-                "SELECT room_id, account_data_type, content FROM room_account_data"
-                " WHERE user_id = ? AND stream_id > ?"
-            )
+        changed = self._account_data_stream_cache.has_entity_changed(
+            user_id, int(stream_id)
+        )
+        if not changed:
+            return {}
+
+        return await self.db_pool.runInteraction(
+            "get_updated_global_account_data_for_user",
+            get_updated_global_account_data_for_user,
+        )
+
+    async def get_updated_room_account_data_for_user(
+        self, user_id: str, stream_id: int
+    ) -> Dict[str, Dict[str, JsonDict]]:
+        """Get all the room account_data that's changed for a user.
 
+        Args:
+            user_id: The user to get the account_data for.
+            stream_id: The point in the stream since which to get updates
+
+        Returns:
+            A dict mapping from room_id string to per room account_data dicts.
+        """
+
+        def get_updated_room_account_data_for_user_txn(
+            txn: LoggingTransaction,
+        ) -> Dict[str, Dict[str, JsonDict]]:
+            sql = """
+                SELECT room_id, account_data_type, content FROM room_account_data
+                WHERE user_id = ? AND stream_id > ?
+            """
             txn.execute(sql, (user_id, stream_id))
 
             account_data_by_room: Dict[str, Dict[str, JsonDict]] = {}
@@ -379,16 +429,17 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
                 room_account_data = account_data_by_room.setdefault(row[0], {})
                 room_account_data[row[1]] = db_to_json(row[2])
 
-            return global_account_data, account_data_by_room
+            return account_data_by_room
 
         changed = self._account_data_stream_cache.has_entity_changed(
             user_id, int(stream_id)
         )
         if not changed:
-            return {}, {}
+            return {}
 
         return await self.db_pool.runInteraction(
-            "get_updated_account_data_for_user", get_updated_account_data_for_user_txn
+            "get_updated_room_account_data_for_user",
+            get_updated_room_account_data_for_user_txn,
         )
 
     @cached(max_entries=5000, iterable=True)
@@ -444,7 +495,8 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
                     self.get_global_account_data_by_type_for_user.invalidate(
                         (row.user_id, row.data_type)
                     )
-                self.get_account_data_for_user.invalidate((row.user_id,))
+                self.get_global_account_data_for_user.invalidate((row.user_id,))
+                self.get_room_account_data_for_user.invalidate((row.user_id,))
                 self.get_account_data_for_room.invalidate((row.user_id, row.room_id))
                 self.get_account_data_for_room_and_type.invalidate(
                     (row.user_id, row.room_id, row.data_type)
@@ -492,7 +544,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
             )
 
             self._account_data_stream_cache.entity_has_changed(user_id, next_id)
-            self.get_account_data_for_user.invalidate((user_id,))
+            self.get_room_account_data_for_user.invalidate((user_id,))
             self.get_account_data_for_room.invalidate((user_id, room_id))
             self.get_account_data_for_room_and_type.prefill(
                 (user_id, room_id, account_data_type), content
@@ -558,7 +610,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
                 return None
 
             self._account_data_stream_cache.entity_has_changed(user_id, next_id)
-            self.get_account_data_for_user.invalidate((user_id,))
+            self.get_room_account_data_for_user.invalidate((user_id,))
             self.get_account_data_for_room.invalidate((user_id, room_id))
             self.get_account_data_for_room_and_type.prefill(
                 (user_id, room_id, account_data_type), {}
@@ -593,7 +645,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
             )
 
             self._account_data_stream_cache.entity_has_changed(user_id, next_id)
-            self.get_account_data_for_user.invalidate((user_id,))
+            self.get_global_account_data_for_user.invalidate((user_id,))
             self.get_global_account_data_by_type_for_user.invalidate(
                 (user_id, account_data_type)
             )
@@ -761,7 +813,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
                 return None
 
             self._account_data_stream_cache.entity_has_changed(user_id, next_id)
-            self.get_account_data_for_user.invalidate((user_id,))
+            self.get_global_account_data_for_user.invalidate((user_id,))
             self.get_global_account_data_by_type_for_user.prefill(
                 (user_id, account_data_type), {}
             )
@@ -822,7 +874,10 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
             txn, self.get_account_data_for_room_and_type, (user_id,)
         )
         self._invalidate_cache_and_stream(
-            txn, self.get_account_data_for_user, (user_id,)
+            txn, self.get_global_account_data_for_user, (user_id,)
+        )
+        self._invalidate_cache_and_stream(
+            txn, self.get_room_account_data_for_user, (user_id,)
         )
         self._invalidate_cache_and_stream(
             txn, self.get_global_account_data_by_type_for_user, (user_id,)