summary refs log tree commit diff
diff options
context:
space:
mode:
authorEric Eastwood <eric.eastwood@beta.gouv.fr>2024-07-25 00:44:18 -0500
committerEric Eastwood <eric.eastwood@beta.gouv.fr>2024-07-25 00:44:18 -0500
commita294b4196ac97662adb8ecd2021449ba653889d6 (patch)
treeb576477120851586e67e89670281de5f34a3e130
parentIterate more (diff)
downloadsynapse-a294b4196ac97662adb8ecd2021449ba653889d6.tar.xz
Generalize extension test
-rw-r--r--synapse/handlers/sliding_sync.py2
-rw-r--r--synapse/rest/client/sync.py6
-rw-r--r--tests/rest/client/test_sync.py82
3 files changed, 72 insertions, 18 deletions
diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py
index eb47569402..0c7299137d 100644
--- a/synapse/handlers/sliding_sync.py
+++ b/synapse/handlers/sliding_sync.py
@@ -2205,7 +2205,7 @@ class SlidingSyncHandler:
                 room_id = receipt["room_id"]
                 type = receipt["type"]
                 content = receipt["content"]
-                room_id_to_receipt_map[room_id] = {type: type, content: content}
+                room_id_to_receipt_map[room_id] = {"type": type, "content": content}
 
         return SlidingSyncResult.Extensions.ReceiptsExtension(
             room_id_to_receipt_map=room_id_to_receipt_map,
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index 7cf1f56435..268c6521e0 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -1134,6 +1134,12 @@ class SlidingSyncRestServlet(RestServlet):
                 },
             }
 
+        if extensions.receipts is not None:
+            serialized_extensions["receipts"] = {
+                # Same as the the top-level `account_data.events` field in Sync v2.
+                "rooms": extensions.receipts.room_id_to_receipt_map,
+            }
+
         return serialized_extensions
 
 
diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index 5047313941..2581b58b5c 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -1280,6 +1280,7 @@ class SlidingSyncTestCase(SlidingSyncBase):
         room.register_servlets,
         sync.register_servlets,
         devices.register_servlets,
+        receipts.register_servlets,
     ]
 
     def default_config(self) -> JsonDict:
@@ -4625,10 +4626,12 @@ class SlidingSyncTestCase(SlidingSyncBase):
             channel.json_body["rooms"].get(room_id1), channel.json_body["rooms"]
         )
 
+    # Any extensions that use `lists`/`rooms` should be tested here
     @parameterized.expand([("account_data",), ("receipts",)])
     def test_extensions_lists_rooms_relevant_rooms(self, extension_name: str) -> None:
         """
-        Test out different variations of `lists`/`rooms` we are requesting extensions for.
+        With various extensions, test out requesting different variations of
+        `lists`/`rooms`.
         """
         user1_id = self.register_user("user1", "pass")
         user1_tok = self.login(user1_id, "pass")
@@ -4649,15 +4652,30 @@ class SlidingSyncTestCase(SlidingSyncBase):
         }
 
         for room_id in room_id_to_human_name_map.keys():
-            # Add some account data to each room
-            self.get_success(
-                self.account_data_handler.add_account_data_to_room(
-                    user_id=user1_id,
-                    room_id=room_id,
-                    account_data_type="org.matrix.roorarraz",
-                    content={"roo": "rar"},
+            if extension_name == "account_data":
+                # Add some account data to each room
+                self.get_success(
+                    self.account_data_handler.add_account_data_to_room(
+                        user_id=user1_id,
+                        room_id=room_id,
+                        account_data_type="org.matrix.roorarraz",
+                        content={"roo": "rar"},
+                    )
                 )
-            )
+            elif extension_name == "receipts":
+                event_response = self.helper.send(
+                    room_id, body="new event", tok=user1_tok
+                )
+                # Read last event
+                channel = self.make_request(
+                    "POST",
+                    f"/rooms/{room_id}/receipt/{ReceiptTypes.READ}/{event_response["event_id"]}",
+                    {},
+                    access_token=user1_tok,
+                )
+                self.assertEqual(channel.code, 200, channel.json_body)
+            else:
+                raise AssertionError(f"Unknown extension name: {extension_name}")
 
         main_sync_body = {
             "lists": {
@@ -4686,7 +4704,7 @@ class SlidingSyncTestCase(SlidingSyncBase):
         sync_body = {
             **main_sync_body,
             "extensions": {
-                "account_data": {
+                extension_name: {
                     "enabled": True,
                     "lists": ["foo-list", "non-existent-list"],
                     "rooms": [room_id1, room_id2, "!non-existent-room"],
@@ -4703,7 +4721,7 @@ class SlidingSyncTestCase(SlidingSyncBase):
         self.assertIncludes(
             {
                 room_id_to_human_name_map[room_id]
-                for room_id in response_body["extensions"]["account_data"]
+                for room_id in response_body["extensions"][extension_name]
                 .get("rooms")
                 .keys()
             },
@@ -4715,7 +4733,7 @@ class SlidingSyncTestCase(SlidingSyncBase):
         sync_body = {
             **main_sync_body,
             "extensions": {
-                "account_data": {
+                extension_name: {
                     "enabled": True,
                     # "lists": ["*"],
                     # "rooms": ["*"],
@@ -4732,7 +4750,7 @@ class SlidingSyncTestCase(SlidingSyncBase):
         self.assertIncludes(
             {
                 room_id_to_human_name_map[room_id]
-                for room_id in response_body["extensions"]["account_data"]
+                for room_id in response_body["extensions"][extension_name]
                 .get("rooms")
                 .keys()
             },
@@ -4744,7 +4762,7 @@ class SlidingSyncTestCase(SlidingSyncBase):
         sync_body = {
             **main_sync_body,
             "extensions": {
-                "account_data": {
+                extension_name: {
                     "enabled": True,
                     "lists": [],
                     "rooms": [],
@@ -4761,7 +4779,7 @@ class SlidingSyncTestCase(SlidingSyncBase):
         self.assertIncludes(
             {
                 room_id_to_human_name_map[room_id]
-                for room_id in response_body["extensions"]["account_data"]
+                for room_id in response_body["extensions"][extension_name]
                 .get("rooms")
                 .keys()
             },
@@ -4773,7 +4791,7 @@ class SlidingSyncTestCase(SlidingSyncBase):
         sync_body = {
             **main_sync_body,
             "extensions": {
-                "account_data": {
+                extension_name: {
                     "enabled": True,
                     "lists": ["*"],
                     "rooms": [],
@@ -4790,7 +4808,7 @@ class SlidingSyncTestCase(SlidingSyncBase):
         self.assertIncludes(
             {
                 room_id_to_human_name_map[room_id]
-                for room_id in response_body["extensions"]["account_data"]
+                for room_id in response_body["extensions"][extension_name]
                 .get("rooms")
                 .keys()
             },
@@ -4798,6 +4816,35 @@ class SlidingSyncTestCase(SlidingSyncBase):
             exact=True,
         )
 
+        # Try requesting a room that is only in a list
+        sync_body = {
+            **main_sync_body,
+            "extensions": {
+                extension_name: {
+                    "enabled": True,
+                    "lists": [],
+                    "rooms": [room_id5],
+                }
+            },
+        }
+        response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+
+        # room1: ❌ Not requested
+        # room2: ❌ Not requested
+        # room3: ❌ Not requested
+        # room4: ❌ Not requested
+        # room5: ✅ Requested via `rooms` and is in a list
+        self.assertIncludes(
+            {
+                room_id_to_human_name_map[room_id]
+                for room_id in response_body["extensions"][extension_name]
+                .get("rooms")
+                .keys()
+            },
+            {"room5"},
+            exact=True,
+        )
+
 
 class SlidingSyncToDeviceExtensionTestCase(SlidingSyncBase):
     """Tests for the to-device sliding sync extension"""
@@ -6156,6 +6203,7 @@ class SlidingSyncAccountDataExtensionTestCase(SlidingSyncBase):
             channel.json_body["extensions"]["account_data"].get("rooms")
         )
 
+
 class SlidingSyncReceiptsExtensionTestCase(unittest.HomeserverTestCase):
     """Tests for the receipts sliding sync extension"""