summary refs log tree commit diff
path: root/synapse/handlers/sliding_sync.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/sliding_sync.py')
-rw-r--r--synapse/handlers/sliding_sync.py138
1 files changed, 138 insertions, 0 deletions
diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py
index f1f6f30b95..3231574402 100644
--- a/synapse/handlers/sliding_sync.py
+++ b/synapse/handlers/sliding_sync.py
@@ -46,6 +46,7 @@ from synapse.storage.roommember import MemberSummary
 from synapse.types import (
     DeviceListUpdates,
     JsonDict,
+    JsonMapping,
     PersistedEventPosition,
     Requester,
     RoomStreamToken,
@@ -357,6 +358,7 @@ class SlidingSyncHandler:
         self.event_sources = hs.get_event_sources()
         self.relations_handler = hs.get_relations_handler()
         self.device_handler = hs.get_device_handler()
+        self.push_rules_handler = hs.get_push_rules_handler()
         self.rooms_to_exclude_globally = hs.config.server.rooms_to_exclude_from_sync
 
     async def wait_for_sync_for_user(
@@ -628,6 +630,7 @@ class SlidingSyncHandler:
 
         extensions = await self.get_extensions_response(
             sync_config=sync_config,
+            lists=lists,
             from_token=from_token,
             to_token=to_token,
         )
@@ -1797,6 +1800,7 @@ class SlidingSyncHandler:
     async def get_extensions_response(
         self,
         sync_config: SlidingSyncConfig,
+        lists: Dict[str, SlidingSyncResult.SlidingWindowList],
         to_token: StreamToken,
         from_token: Optional[SlidingSyncStreamToken],
     ) -> SlidingSyncResult.Extensions:
@@ -1804,6 +1808,7 @@ class SlidingSyncHandler:
 
         Args:
             sync_config: Sync configuration
+            lists: Sliding window API. A map of list key to list results.
             to_token: The point in the stream to sync up to.
             from_token: The point in the stream to sync from.
         """
@@ -1828,9 +1833,20 @@ class SlidingSyncHandler:
                 from_token=from_token,
             )
 
+        account_data_response = None
+        if sync_config.extensions.account_data is not None:
+            account_data_response = await self.get_account_data_extension_response(
+                sync_config=sync_config,
+                lists=lists,
+                account_data_request=sync_config.extensions.account_data,
+                to_token=to_token,
+                from_token=from_token,
+            )
+
         return SlidingSyncResult.Extensions(
             to_device=to_device_response,
             e2ee=e2ee_response,
+            account_data=account_data_response,
         )
 
     async def get_to_device_extension_response(
@@ -1956,3 +1972,125 @@ class SlidingSyncHandler:
             device_one_time_keys_count=device_one_time_keys_count,
             device_unused_fallback_key_types=device_unused_fallback_key_types,
         )
+
+    async def get_account_data_extension_response(
+        self,
+        sync_config: SlidingSyncConfig,
+        lists: Dict[str, SlidingSyncResult.SlidingWindowList],
+        account_data_request: SlidingSyncConfig.Extensions.AccountDataExtension,
+        to_token: StreamToken,
+        from_token: Optional[SlidingSyncStreamToken],
+    ) -> Optional[SlidingSyncResult.Extensions.AccountDataExtension]:
+        """Handle Account Data extension (MSC3959)
+
+        Args:
+            sync_config: Sync configuration
+            lists: Sliding window API. A map of list key to list results.
+            account_data_request: The account_data extension from the request
+            to_token: The point in the stream to sync up to.
+            from_token: The point in the stream to sync from.
+        """
+        user_id = sync_config.user.to_string()
+
+        # Skip if the extension is not enabled
+        if not account_data_request.enabled:
+            return None
+
+        global_account_data_map: Mapping[str, JsonMapping] = {}
+        if from_token is not None:
+            global_account_data_map = (
+                await self.store.get_updated_global_account_data_for_user(
+                    user_id, from_token.stream_token.account_data_key
+                )
+            )
+
+            have_push_rules_changed = await self.store.have_push_rules_changed_for_user(
+                user_id, from_token.stream_token.push_rules_key
+            )
+            if have_push_rules_changed:
+                global_account_data_map = dict(global_account_data_map)
+                global_account_data_map[AccountDataTypes.PUSH_RULES] = (
+                    await self.push_rules_handler.push_rules_for_user(sync_config.user)
+                )
+        else:
+            all_global_account_data = await self.store.get_global_account_data_for_user(
+                user_id
+            )
+
+            global_account_data_map = dict(all_global_account_data)
+            global_account_data_map[AccountDataTypes.PUSH_RULES] = (
+                await self.push_rules_handler.push_rules_for_user(sync_config.user)
+            )
+
+        # We only want to include account data for rooms that are already in the sliding
+        # sync response AND that were requested in the account data request.
+        relevant_room_ids: Set[str] = set()
+
+        # See what rooms from the room subscriptions we should get account data for
+        if (
+            account_data_request.rooms is not None
+            and sync_config.room_subscriptions is not None
+        ):
+            actual_room_ids = sync_config.room_subscriptions.keys()
+
+            for room_id in account_data_request.rooms:
+                # A wildcard means we process all rooms from the room subscriptions
+                if room_id == "*":
+                    relevant_room_ids.update(sync_config.room_subscriptions.keys())
+                    break
+
+                if room_id in actual_room_ids:
+                    relevant_room_ids.add(room_id)
+
+        # See what rooms from the sliding window lists we should get account data for
+        if account_data_request.lists is not None:
+            for list_key in account_data_request.lists:
+                # Just some typing because we share the variable name in multiple places
+                actual_list: Optional[SlidingSyncResult.SlidingWindowList] = None
+
+                # A wildcard means we process rooms from all lists
+                if list_key == "*":
+                    for actual_list in lists.values():
+                        # We only expect a single SYNC operation for any list
+                        assert len(actual_list.ops) == 1
+                        sync_op = actual_list.ops[0]
+                        assert sync_op.op == OperationType.SYNC
+
+                        relevant_room_ids.update(sync_op.room_ids)
+
+                    break
+
+                actual_list = lists.get(list_key)
+                if actual_list is not None:
+                    # We only expect a single SYNC operation for any list
+                    assert len(actual_list.ops) == 1
+                    sync_op = actual_list.ops[0]
+                    assert sync_op.op == OperationType.SYNC
+
+                    relevant_room_ids.update(sync_op.room_ids)
+
+        # Fetch room account data
+        account_data_by_room_map: Mapping[str, Mapping[str, JsonMapping]] = {}
+        if len(relevant_room_ids) > 0:
+            if from_token is not None:
+                account_data_by_room_map = (
+                    await self.store.get_updated_room_account_data_for_user(
+                        user_id, from_token.stream_token.account_data_key
+                    )
+                )
+            else:
+                account_data_by_room_map = (
+                    await self.store.get_room_account_data_for_user(user_id)
+                )
+
+        # Filter down to the relevant rooms
+        account_data_by_room_map = {
+            room_id: account_data_map
+            for room_id, account_data_map in account_data_by_room_map.items()
+            if room_id in relevant_room_ids
+        }
+
+        return SlidingSyncResult.Extensions.AccountDataExtension(
+            global_account_data_map=global_account_data_map,
+            account_data_by_room_map=account_data_by_room_map,
+        )