summary refs log tree commit diff
path: root/tests/storage
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2022-10-17 11:32:11 -0400
committerGitHub <noreply@github.com>2022-10-17 11:32:11 -0400
commit4283bd1cf9c3da2157c3642a7c4f105e9fac2636 (patch)
treef13dd6d01fafa2170d808f2060dee6a9ab5fca56 /tests/storage
parentBump psycopg2 from 2.9.3 to 2.9.4 (#14200) (diff)
downloadsynapse-4283bd1cf9c3da2157c3642a7c4f105e9fac2636.tar.xz
Support filtering the /messages API by relation type (MSC3874). (#14148)
Gated behind an experimental configuration flag.
Diffstat (limited to 'tests/storage')
-rw-r--r--tests/storage/test_stream.py118
1 files changed, 85 insertions, 33 deletions
diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py
index 78663a53fe..34fa810cf6 100644
--- a/tests/storage/test_stream.py
+++ b/tests/storage/test_stream.py
@@ -16,7 +16,6 @@ from typing import List
 
 from synapse.api.constants import EventTypes, RelationTypes
 from synapse.api.filtering import Filter
-from synapse.events import EventBase
 from synapse.rest import admin
 from synapse.rest.client import login, room
 from synapse.types import JsonDict
@@ -40,7 +39,7 @@ class PaginationTestCase(HomeserverTestCase):
 
     def default_config(self):
         config = super().default_config()
-        config["experimental_features"] = {"msc3440_enabled": True}
+        config["experimental_features"] = {"msc3874_enabled": True}
         return config
 
     def prepare(self, reactor, clock, homeserver):
@@ -58,6 +57,11 @@ class PaginationTestCase(HomeserverTestCase):
         self.third_tok = self.login("third", "test")
         self.helper.join(room=self.room_id, user=self.third_user_id, tok=self.third_tok)
 
+        # Store a token which is after all the room creation events.
+        self.from_token = self.get_success(
+            self.hs.get_event_sources().get_current_token_for_pagination(self.room_id)
+        )
+
         # An initial event with a relation from second user.
         res = self.helper.send_event(
             room_id=self.room_id,
@@ -66,7 +70,7 @@ class PaginationTestCase(HomeserverTestCase):
             tok=self.tok,
         )
         self.event_id_1 = res["event_id"]
-        self.helper.send_event(
+        res = self.helper.send_event(
             room_id=self.room_id,
             type="m.reaction",
             content={
@@ -78,6 +82,7 @@ class PaginationTestCase(HomeserverTestCase):
             },
             tok=self.second_tok,
         )
+        self.event_id_annotation = res["event_id"]
 
         # Another event with a relation from third user.
         res = self.helper.send_event(
@@ -87,7 +92,7 @@ class PaginationTestCase(HomeserverTestCase):
             tok=self.tok,
         )
         self.event_id_2 = res["event_id"]
-        self.helper.send_event(
+        res = self.helper.send_event(
             room_id=self.room_id,
             type="m.reaction",
             content={
@@ -98,68 +103,59 @@ class PaginationTestCase(HomeserverTestCase):
             },
             tok=self.third_tok,
         )
+        self.event_id_reference = res["event_id"]
 
         # An event with no relations.
-        self.helper.send_event(
+        res = self.helper.send_event(
             room_id=self.room_id,
             type=EventTypes.Message,
             content={"msgtype": "m.text", "body": "No relations"},
             tok=self.tok,
         )
+        self.event_id_none = res["event_id"]
 
-    def _filter_messages(self, filter: JsonDict) -> List[EventBase]:
+    def _filter_messages(self, filter: JsonDict) -> List[str]:
         """Make a request to /messages with a filter, returns the chunk of events."""
 
-        from_token = self.get_success(
-            self.hs.get_event_sources().get_current_token_for_pagination(self.room_id)
-        )
-
         events, next_key = self.get_success(
             self.hs.get_datastores().main.paginate_room_events(
                 room_id=self.room_id,
-                from_key=from_token.room_key,
+                from_key=self.from_token.room_key,
                 to_key=None,
-                direction="b",
+                direction="f",
                 limit=10,
                 event_filter=Filter(self.hs, filter),
             )
         )
 
-        return events
+        return [ev.event_id for ev in events]
 
     def test_filter_relation_senders(self):
         # Messages which second user reacted to.
         filter = {"related_by_senders": [self.second_user_id]}
         chunk = self._filter_messages(filter)
-        self.assertEqual(len(chunk), 1, chunk)
-        self.assertEqual(chunk[0].event_id, self.event_id_1)
+        self.assertEqual(chunk, [self.event_id_1])
 
         # Messages which third user reacted to.
         filter = {"related_by_senders": [self.third_user_id]}
         chunk = self._filter_messages(filter)
-        self.assertEqual(len(chunk), 1, chunk)
-        self.assertEqual(chunk[0].event_id, self.event_id_2)
+        self.assertEqual(chunk, [self.event_id_2])
 
         # Messages which either user reacted to.
         filter = {"related_by_senders": [self.second_user_id, self.third_user_id]}
         chunk = self._filter_messages(filter)
-        self.assertEqual(len(chunk), 2, chunk)
-        self.assertCountEqual(
-            [c.event_id for c in chunk], [self.event_id_1, self.event_id_2]
-        )
+        self.assertCountEqual(chunk, [self.event_id_1, self.event_id_2])
 
     def test_filter_relation_type(self):
         # Messages which have annotations.
         filter = {"related_by_rel_types": [RelationTypes.ANNOTATION]}
         chunk = self._filter_messages(filter)
-        self.assertEqual(len(chunk), 1, chunk)
-        self.assertEqual(chunk[0].event_id, self.event_id_1)
+        self.assertEqual(chunk, [self.event_id_1])
 
         # Messages which have references.
         filter = {"related_by_rel_types": [RelationTypes.REFERENCE]}
         chunk = self._filter_messages(filter)
-        self.assertEqual(len(chunk), 1, chunk)
-        self.assertEqual(chunk[0].event_id, self.event_id_2)
+        self.assertEqual(chunk, [self.event_id_2])
 
         # Messages which have either annotations or references.
         filter = {
@@ -169,10 +165,7 @@ class PaginationTestCase(HomeserverTestCase):
             ]
         }
         chunk = self._filter_messages(filter)
-        self.assertEqual(len(chunk), 2, chunk)
-        self.assertCountEqual(
-            [c.event_id for c in chunk], [self.event_id_1, self.event_id_2]
-        )
+        self.assertCountEqual(chunk, [self.event_id_1, self.event_id_2])
 
     def test_filter_relation_senders_and_type(self):
         # Messages which second user reacted to.
@@ -181,8 +174,7 @@ class PaginationTestCase(HomeserverTestCase):
             "related_by_rel_types": [RelationTypes.ANNOTATION],
         }
         chunk = self._filter_messages(filter)
-        self.assertEqual(len(chunk), 1, chunk)
-        self.assertEqual(chunk[0].event_id, self.event_id_1)
+        self.assertEqual(chunk, [self.event_id_1])
 
     def test_duplicate_relation(self):
         """An event should only be returned once if there are multiple relations to it."""
@@ -201,5 +193,65 @@ class PaginationTestCase(HomeserverTestCase):
 
         filter = {"related_by_senders": [self.second_user_id]}
         chunk = self._filter_messages(filter)
-        self.assertEqual(len(chunk), 1, chunk)
-        self.assertEqual(chunk[0].event_id, self.event_id_1)
+        self.assertEqual(chunk, [self.event_id_1])
+
+    def test_filter_rel_types(self) -> None:
+        # Messages which are annotations.
+        filter = {"org.matrix.msc3874.rel_types": [RelationTypes.ANNOTATION]}
+        chunk = self._filter_messages(filter)
+        self.assertEqual(chunk, [self.event_id_annotation])
+
+        # Messages which are references.
+        filter = {"org.matrix.msc3874.rel_types": [RelationTypes.REFERENCE]}
+        chunk = self._filter_messages(filter)
+        self.assertEqual(chunk, [self.event_id_reference])
+
+        # Messages which are either annotations or references.
+        filter = {
+            "org.matrix.msc3874.rel_types": [
+                RelationTypes.ANNOTATION,
+                RelationTypes.REFERENCE,
+            ]
+        }
+        chunk = self._filter_messages(filter)
+        self.assertCountEqual(
+            chunk,
+            [self.event_id_annotation, self.event_id_reference],
+        )
+
+    def test_filter_not_rel_types(self) -> None:
+        # Messages which are not annotations.
+        filter = {"org.matrix.msc3874.not_rel_types": [RelationTypes.ANNOTATION]}
+        chunk = self._filter_messages(filter)
+        self.assertEqual(
+            chunk,
+            [
+                self.event_id_1,
+                self.event_id_2,
+                self.event_id_reference,
+                self.event_id_none,
+            ],
+        )
+
+        # Messages which are not references.
+        filter = {"org.matrix.msc3874.not_rel_types": [RelationTypes.REFERENCE]}
+        chunk = self._filter_messages(filter)
+        self.assertEqual(
+            chunk,
+            [
+                self.event_id_1,
+                self.event_id_annotation,
+                self.event_id_2,
+                self.event_id_none,
+            ],
+        )
+
+        # Messages which are neither annotations or references.
+        filter = {
+            "org.matrix.msc3874.not_rel_types": [
+                RelationTypes.ANNOTATION,
+                RelationTypes.REFERENCE,
+            ]
+        }
+        chunk = self._filter_messages(filter)
+        self.assertEqual(chunk, [self.event_id_1, self.event_id_2, self.event_id_none])