summary refs log tree commit diff
diff options
context:
space:
mode:
authorEric Eastwood <eric.eastwood@beta.gouv.fr>2024-07-24 20:56:50 -0500
committerEric Eastwood <eric.eastwood@beta.gouv.fr>2024-07-24 20:56:50 -0500
commit00404b3ab43582c2c589f3f0c5aa46377796b78f (patch)
treed8eff13f943dde8bb92f6c07a78ea8b478bd3aca
parentSliding Sync: Add Account Data extension (MSC3959) (#17477) (diff)
downloadsynapse-00404b3ab43582c2c589f3f0c5aa46377796b78f.tar.xz
Better standardize `find_relevant_room_ids_for_extension(...)`
-rw-r--r--synapse/handlers/sliding_sync.py193
-rw-r--r--synapse/types/handlers/__init__.py18
-rw-r--r--synapse/types/rest/client/__init__.py18
-rw-r--r--tests/rest/client/test_sync.py100
4 files changed, 213 insertions, 116 deletions
diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py
index 3231574402..103ebfbbdf 100644
--- a/synapse/handlers/sliding_sync.py
+++ b/synapse/handlers/sliding_sync.py
@@ -630,7 +630,8 @@ class SlidingSyncHandler:
 
         extensions = await self.get_extensions_response(
             sync_config=sync_config,
-            lists=lists,
+            actual_lists=lists,
+            actual_room_ids=rooms.keys(),
             from_token=from_token,
             to_token=to_token,
         )
@@ -1800,7 +1801,8 @@ class SlidingSyncHandler:
     async def get_extensions_response(
         self,
         sync_config: SlidingSyncConfig,
-        lists: Dict[str, SlidingSyncResult.SlidingWindowList],
+        actual_lists: Dict[str, SlidingSyncResult.SlidingWindowList],
+        actual_room_ids: Set[str],
         to_token: StreamToken,
         from_token: Optional[SlidingSyncStreamToken],
     ) -> SlidingSyncResult.Extensions:
@@ -1808,7 +1810,9 @@ class SlidingSyncHandler:
 
         Args:
             sync_config: Sync configuration
-            lists: Sliding window API. A map of list key to list results.
+            actual_lists: Sliding window API. A map of list key to list results in the
+                Sliding Sync response.
+            actual_room_ids: The actual room IDs in the the Sliding Sync response.
             to_token: The point in the stream to sync up to.
             from_token: The point in the stream to sync from.
         """
@@ -1837,18 +1841,100 @@ class SlidingSyncHandler:
         if sync_config.extensions.account_data is not None:
             account_data_response = await self.get_account_data_extension_response(
                 sync_config=sync_config,
-                lists=lists,
+                actual_lists=actual_lists,
+                actual_room_ids=actual_room_ids,
                 account_data_request=sync_config.extensions.account_data,
                 to_token=to_token,
                 from_token=from_token,
             )
 
+        receipts_response = None
+        if sync_config.extensions.receipts is not None:
+            receipts_response = await self.get_receipts_extension_response(
+                sync_config=sync_config,
+                actual_lists=actual_lists,
+                actual_room_ids=actual_room_ids,
+                receipts_request=sync_config.extensions.receipts,
+                to_token=to_token,
+                from_token=from_token,
+            )
+
         return SlidingSyncResult.Extensions(
             to_device=to_device_response,
             e2ee=e2ee_response,
             account_data=account_data_response,
+            receipts=receipts_response,
         )
 
+    def find_relevant_room_ids_for_extension(
+        self,
+        requested_lists: Optional[List[str]],
+        requested_rooms: Optional[List[str]],
+        actual_lists: Dict[str, SlidingSyncResult.SlidingWindowList],
+        actual_room_ids: Set[str],
+    ) -> Set[str]:
+        """
+        Handle the reserved `lists`/`rooms` keys for extensions.
+
+        {"lists": []}                    // Do not process any lists.
+        {"lists": ["rooms", "dms"]}      // Process only a subset of lists.
+        {"lists": ["*"]}                 // Process all lists defined in the Sliding Window API. (This is the default.)
+
+        {"rooms": []}                    // Do not process any specific rooms.
+        {"rooms": ["!a:b", "!c:d"]}      // Process only a subset of room subscriptions.
+        {"rooms": ["*"]}                 // Process all room subscriptions defined in the Room Subscription API. (This is the default.)
+
+        Args:
+            requested_lists: The `lists` from the extension request.
+            requested_rooms: The `rooms` from the extension request.
+            actual_lists: The actual lists from the Sliding Sync response.
+            actual_room_subscriptions: The actual room subscriptions from the Sliding Sync request.
+        """
+
+        # We only want to include account data for rooms that are already in the sliding
+        # sync response AND that were requested in the account data request.
+        relevant_room_ids: Set[str] = set()
+
+        # See what rooms from the room subscriptions we should get account data for
+        if requested_rooms is not None:
+            for room_id in requested_rooms:
+                # A wildcard means we process all rooms from the room subscriptions
+                if room_id == "*":
+                    relevant_room_ids.update(actual_room_ids)
+                    break
+
+                if room_id in actual_room_ids:
+                    relevant_room_ids.add(room_id)
+
+        # See what rooms from the sliding window lists we should get account data for
+        if requested_lists is not None:
+            for list_key in requested_lists:
+                # Just some typing because we share the variable name in multiple places
+                actual_list: Optional[SlidingSyncResult.SlidingWindowList] = None
+
+                # A wildcard means we process rooms from all lists
+                if list_key == "*":
+                    for actual_list in actual_lists.values():
+                        # We only expect a single SYNC operation for any list
+                        assert len(actual_list.ops) == 1
+                        sync_op = actual_list.ops[0]
+                        assert sync_op.op == OperationType.SYNC
+
+                        relevant_room_ids.update(sync_op.room_ids)
+
+                    break
+
+                actual_list = actual_lists.get(list_key)
+                if actual_list is not None:
+                    # We only expect a single SYNC operation for any list
+                    assert len(actual_list.ops) == 1
+                    sync_op = actual_list.ops[0]
+                    assert sync_op.op == OperationType.SYNC
+
+                    relevant_room_ids.update(sync_op.room_ids)
+
+        return relevant_room_ids
+
     async def get_to_device_extension_response(
         self,
         sync_config: SlidingSyncConfig,
@@ -1976,7 +2062,8 @@ class SlidingSyncHandler:
     async def get_account_data_extension_response(
         self,
         sync_config: SlidingSyncConfig,
-        lists: Dict[str, SlidingSyncResult.SlidingWindowList],
+        actual_lists: Dict[str, SlidingSyncResult.SlidingWindowList],
+        actual_room_ids: Set[str],
         account_data_request: SlidingSyncConfig.Extensions.AccountDataExtension,
         to_token: StreamToken,
         from_token: Optional[SlidingSyncStreamToken],
@@ -1985,7 +2072,9 @@ class SlidingSyncHandler:
 
         Args:
             sync_config: Sync configuration
-            lists: Sliding window API. A map of list key to list results.
+            actual_lists: Sliding window API. A map of list key to list results in the
+                Sliding Sync response.
+            actual_room_ids: The actual room IDs in the the Sliding Sync response.
             account_data_request: The account_data extension from the request
             to_token: The point in the stream to sync up to.
             from_token: The point in the stream to sync from.
@@ -2022,55 +2111,14 @@ class SlidingSyncHandler:
                 await self.push_rules_handler.push_rules_for_user(sync_config.user)
             )
 
-        # We only want to include account data for rooms that are already in the sliding
-        # sync response AND that were requested in the account data request.
-        relevant_room_ids: Set[str] = set()
-
-        # See what rooms from the room subscriptions we should get account data for
-        if (
-            account_data_request.rooms is not None
-            and sync_config.room_subscriptions is not None
-        ):
-            actual_room_ids = sync_config.room_subscriptions.keys()
-
-            for room_id in account_data_request.rooms:
-                # A wildcard means we process all rooms from the room subscriptions
-                if room_id == "*":
-                    relevant_room_ids.update(sync_config.room_subscriptions.keys())
-                    break
-
-                if room_id in actual_room_ids:
-                    relevant_room_ids.add(room_id)
-
-        # See what rooms from the sliding window lists we should get account data for
-        if account_data_request.lists is not None:
-            for list_key in account_data_request.lists:
-                # Just some typing because we share the variable name in multiple places
-                actual_list: Optional[SlidingSyncResult.SlidingWindowList] = None
-
-                # A wildcard means we process rooms from all lists
-                if list_key == "*":
-                    for actual_list in lists.values():
-                        # We only expect a single SYNC operation for any list
-                        assert len(actual_list.ops) == 1
-                        sync_op = actual_list.ops[0]
-                        assert sync_op.op == OperationType.SYNC
-
-                        relevant_room_ids.update(sync_op.room_ids)
-
-                    break
-
-                actual_list = lists.get(list_key)
-                if actual_list is not None:
-                    # We only expect a single SYNC operation for any list
-                    assert len(actual_list.ops) == 1
-                    sync_op = actual_list.ops[0]
-                    assert sync_op.op == OperationType.SYNC
-
-                    relevant_room_ids.update(sync_op.room_ids)
-
         # Fetch room account data
         account_data_by_room_map: Mapping[str, Mapping[str, JsonMapping]] = {}
+        relevant_room_ids = self.find_relevant_room_ids_for_extension(
+            requested_lists=account_data_request.lists,
+            requested_rooms=account_data_request.rooms,
+            actual_lists=actual_lists,
+            actual_room_ids=actual_room_ids,
+        )
         if len(relevant_room_ids) > 0:
             if from_token is not None:
                 account_data_by_room_map = (
@@ -2094,3 +2142,42 @@ class SlidingSyncHandler:
             global_account_data_map=global_account_data_map,
             account_data_by_room_map=account_data_by_room_map,
         )
+
+    async def get_receipts_extension_response(
+        self,
+        sync_config: SlidingSyncConfig,
+        actual_lists: Dict[str, SlidingSyncResult.SlidingWindowList],
+        actual_room_ids: Set[str],
+        receipts_request: SlidingSyncConfig.Extensions.ReceiptsExtension,
+        to_token: StreamToken,
+        from_token: Optional[SlidingSyncStreamToken],
+    ) -> Optional[SlidingSyncResult.Extensions.ReceiptsExtension]:
+        """Handle Receipts extension (MSC3960)
+
+        Args:
+            sync_config: Sync configuration
+            actual_lists: Sliding window API. A map of list key to list results in the
+                Sliding Sync response.
+            actual_room_ids: The actual room IDs in the the Sliding Sync response.
+            account_data_request: The account_data extension from the request
+            to_token: The point in the stream to sync up to.
+            from_token: The point in the stream to sync from.
+        """
+        user_id = sync_config.user.to_string()
+
+        # Skip if the extension is not enabled
+        if not receipts_request.enabled:
+            return None
+
+        # receipt_source = self.event_sources.sources.receipt
+        # receipts, receipt_key = await receipt_source.get_new_events(
+        #     user=sync_config.user,
+        #     from_key=receipt_key,
+        #     limit=sync_config.filter_collection.ephemeral_limit(),
+        #     room_ids=room_ids,
+        #     is_guest=sync_config.is_guest,
+        # )
+
+        return SlidingSyncResult.Extensions.ReceiptsExtension(
+            room_id_to_receipt_map=TODO,
+        )
diff --git a/synapse/types/handlers/__init__.py b/synapse/types/handlers/__init__.py
index 479222a18d..654c50ef61 100644
--- a/synapse/types/handlers/__init__.py
+++ b/synapse/types/handlers/__init__.py
@@ -349,12 +349,28 @@ class SlidingSyncResult:
                     self.global_account_data_map or self.account_data_by_room_map
                 )
 
+        @attr.s(slots=True, frozen=True, auto_attribs=True)
+        class ReceiptsExtension:
+            """The Receipts extension (MSC3960)
+
+            Attributes:
+                room_id_to_receipt_map: Mapping from room_id to `m.receipt` event (type, content)
+            """
+
+            room_id_to_receipt_map: Mapping[str, Mapping[str, JsonMapping]]
+
+            def __bool__(self) -> bool:
+                return bool(self.room_id_to_receipt_map)
+
         to_device: Optional[ToDeviceExtension] = None
         e2ee: Optional[E2eeExtension] = None
         account_data: Optional[AccountDataExtension] = None
+        receipts: Optional[ReceiptsExtension] = None
 
         def __bool__(self) -> bool:
-            return bool(self.to_device or self.e2ee or self.account_data)
+            return bool(
+                self.to_device or self.e2ee or self.account_data or self.receipts
+            )
 
     next_pos: SlidingSyncStreamToken
     lists: Dict[str, SlidingWindowList]
diff --git a/synapse/types/rest/client/__init__.py b/synapse/types/rest/client/__init__.py
index 34e07ddac5..ae8b7d8144 100644
--- a/synapse/types/rest/client/__init__.py
+++ b/synapse/types/rest/client/__init__.py
@@ -339,9 +339,27 @@ class SlidingSyncBody(RequestBodyModel):
             # Process all room subscriptions defined in the Room Subscription API. (This is the default.)
             rooms: Optional[List[StrictStr]] = ["*"]
 
+        class ReceiptsExtension(RequestBodyModel):
+            """The Receipts extension (MSC3960)
+
+            Attributes:
+                enabled
+                lists: List of list keys (from the Sliding Window API) to apply this
+                    extension to.
+                rooms: List of room IDs (from the Room Subscription API) to apply this
+                    extension to.
+            """
+
+            enabled: Optional[StrictBool] = False
+            # Process all lists defined in the Sliding Window API. (This is the default.)
+            lists: Optional[List[StrictStr]] = ["*"]
+            # Process all room subscriptions defined in the Room Subscription API. (This is the default.)
+            rooms: Optional[List[StrictStr]] = ["*"]
+
         to_device: Optional[ToDeviceExtension] = None
         e2ee: Optional[E2eeExtension] = None
         account_data: Optional[AccountDataExtension] = None
+        receipts: Optional[ReceiptsExtension] = None
 
     # mypy workaround via https://github.com/pydantic/pydantic/issues/156#issuecomment-1130883884
     if TYPE_CHECKING:
diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index 135b677bad..276588ad2f 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -5941,8 +5941,7 @@ class SlidingSyncAccountDataExtensionTestCase(SlidingSyncBase):
             room_id5: "room5",
         }
 
-        # Mix lists and rooms
-        sync_body = {
+        main_sync_body = {
             "lists": {
                 # We expect this list range to include room5 and room4
                 "foo-list": {
@@ -5963,6 +5962,11 @@ class SlidingSyncAccountDataExtensionTestCase(SlidingSyncBase):
                     "timeline_limit": 0,
                 }
             },
+        }
+
+        # Mix lists and rooms
+        sync_body = {
+            **main_sync_body,
             "extensions": {
                 "account_data": {
                     "enabled": True,
@@ -5991,26 +5995,7 @@ class SlidingSyncAccountDataExtensionTestCase(SlidingSyncBase):
 
         # Try wildcards (this is the default)
         sync_body = {
-            "lists": {
-                # We expect this list range to include room5 and room4
-                "foo-list": {
-                    "ranges": [[0, 1]],
-                    "required_state": [],
-                    "timeline_limit": 0,
-                },
-                # We expect this list range to include room5, room4, room3
-                "bar-list": {
-                    "ranges": [[0, 2]],
-                    "required_state": [],
-                    "timeline_limit": 0,
-                },
-            },
-            "room_subscriptions": {
-                room_id1: {
-                    "required_state": [],
-                    "timeline_limit": 0,
-                }
-            },
+            **main_sync_body,
             "extensions": {
                 "account_data": {
                     "enabled": True,
@@ -6039,26 +6024,7 @@ class SlidingSyncAccountDataExtensionTestCase(SlidingSyncBase):
 
         # Empty list will return nothing
         sync_body = {
-            "lists": {
-                # We expect this list range to include room5 and room4
-                "foo-list": {
-                    "ranges": [[0, 1]],
-                    "required_state": [],
-                    "timeline_limit": 0,
-                },
-                # We expect this list range to include room5, room4, room3
-                "bar-list": {
-                    "ranges": [[0, 2]],
-                    "required_state": [],
-                    "timeline_limit": 0,
-                },
-            },
-            "room_subscriptions": {
-                room_id1: {
-                    "required_state": [],
-                    "timeline_limit": 0,
-                }
-            },
+            **main_sync_body,
             "extensions": {
                 "account_data": {
                     "enabled": True,
@@ -6087,26 +6053,7 @@ class SlidingSyncAccountDataExtensionTestCase(SlidingSyncBase):
 
         # Try wildcard and none
         sync_body = {
-            "lists": {
-                # We expect this list range to include room5 and room4
-                "foo-list": {
-                    "ranges": [[0, 1]],
-                    "required_state": [],
-                    "timeline_limit": 0,
-                },
-                # We expect this list range to include room5, room4, room3
-                "bar-list": {
-                    "ranges": [[0, 2]],
-                    "required_state": [],
-                    "timeline_limit": 0,
-                },
-            },
-            "room_subscriptions": {
-                room_id1: {
-                    "required_state": [],
-                    "timeline_limit": 0,
-                }
-            },
+            **main_sync_body,
             "extensions": {
                 "account_data": {
                     "enabled": True,
@@ -6133,6 +6080,35 @@ class SlidingSyncAccountDataExtensionTestCase(SlidingSyncBase):
             exact=True,
         )
 
+        # Try requesting a room that is only in a list
+        sync_body = {
+            **main_sync_body,
+            "extensions": {
+                "account_data": {
+                    "enabled": True,
+                    "lists": [],
+                    "rooms": [room_id5],
+                }
+            },
+        }
+        response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+
+        # room1: ❌ Not requested
+        # room2: ❌ Not requested
+        # room3: ❌ Not requested
+        # room4: ❌ Not requested
+        # room5: ✅ Requested via `rooms` and is in a list
+        self.assertIncludes(
+            {
+                room_id_to_human_name_map[room_id]
+                for room_id in response_body["extensions"]["account_data"]
+                .get("rooms")
+                .keys()
+            },
+            {"room5"},
+            exact=True,
+        )
+
     def test_wait_for_new_data(self) -> None:
         """
         Test to make sure that the Sliding Sync request waits for new data to arrive.