summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tests/rest/client/test_relations.py111
-rw-r--r--tests/unittest.py7
2 files changed, 117 insertions, 1 deletions
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index b8a1b92a89..eb10d43217 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -1,4 +1,5 @@
 # Copyright 2019 New Vector Ltd
+# Copyright 2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -46,6 +47,8 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         return config
 
     def prepare(self, reactor, clock, hs):
+        self.store = hs.get_datastore()
+
         self.user_id, self.user_token = self._create_user("alice")
         self.user2_id, self.user2_token = self._create_user("bob")
 
@@ -765,6 +768,52 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         self.assertIn("chunk", channel.json_body)
         self.assertEquals(channel.json_body["chunk"], [])
 
+    def test_unknown_relations(self):
+        """Unknown relations should be accepted."""
+        channel = self._send_relation("m.relation.test", "m.room.test")
+        self.assertEquals(200, channel.code, channel.json_body)
+        event_id = channel.json_body["event_id"]
+
+        channel = self.make_request(
+            "GET",
+            "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1"
+            % (self.room, self.parent_id),
+            access_token=self.user_token,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+
+        # We expect to get back a single pagination result, which is the full
+        # relation event we sent above.
+        self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body)
+        self.assert_dict(
+            {"event_id": event_id, "sender": self.user_id, "type": "m.room.test"},
+            channel.json_body["chunk"][0],
+        )
+
+        # We also expect to get the original event (the id of which is self.parent_id)
+        self.assertEquals(
+            channel.json_body["original_event"]["event_id"], self.parent_id
+        )
+
+        # When bundling the unknown relation is not included.
+        channel = self.make_request(
+            "GET",
+            "/rooms/%s/event/%s" % (self.room, self.parent_id),
+            access_token=self.user_token,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertNotIn("m.relations", channel.json_body["unsigned"])
+
+        # But unknown relations can be directly queried.
+        channel = self.make_request(
+            "GET",
+            "/_matrix/client/unstable/rooms/%s/aggregations/%s?limit=1"
+            % (self.room, self.parent_id),
+            access_token=self.user_token,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEquals(channel.json_body["chunk"], [])
+
     def _send_relation(
         self,
         relation_type: str,
@@ -811,3 +860,65 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         access_token = self.login(localpart, "abc123")
 
         return user_id, access_token
+
+    def test_background_update(self):
+        """Test the event_arbitrary_relations background update."""
+        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
+        self.assertEquals(200, channel.code, channel.json_body)
+        annotation_event_id_good = channel.json_body["event_id"]
+
+        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="A")
+        self.assertEquals(200, channel.code, channel.json_body)
+        annotation_event_id_bad = channel.json_body["event_id"]
+
+        channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
+        self.assertEquals(200, channel.code, channel.json_body)
+        thread_event_id = channel.json_body["event_id"]
+
+        # Clean-up the table as if the inserts did not happen during event creation.
+        self.get_success(
+            self.store.db_pool.simple_delete_many(
+                table="event_relations",
+                column="event_id",
+                iterable=(annotation_event_id_bad, thread_event_id),
+                keyvalues={},
+                desc="RelationsTestCase.test_background_update",
+            )
+        )
+
+        # Only the "good" annotation should be found.
+        channel = self.make_request(
+            "GET",
+            f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=10",
+            access_token=self.user_token,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEquals(
+            [ev["event_id"] for ev in channel.json_body["chunk"]],
+            [annotation_event_id_good],
+        )
+
+        # Insert and run the background update.
+        self.get_success(
+            self.store.db_pool.simple_insert(
+                "background_updates",
+                {"update_name": "event_arbitrary_relations", "progress_json": "{}"},
+            )
+        )
+
+        # Ugh, have to reset this flag
+        self.store.db_pool.updates._all_done = False
+        self.wait_for_background_updates()
+
+        # The "good" annotation and the thread should be found, but not the "bad"
+        # annotation.
+        channel = self.make_request(
+            "GET",
+            f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=10",
+            access_token=self.user_token,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertCountEqual(
+            [ev["event_id"] for ev in channel.json_body["chunk"]],
+            [annotation_event_id_good, thread_event_id],
+        )
diff --git a/tests/unittest.py b/tests/unittest.py
index c9a08a3420..165aafc574 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -331,7 +331,12 @@ class HomeserverTestCase(TestCase):
             time.sleep(0.01)
 
     def wait_for_background_updates(self) -> None:
-        """Block until all background database updates have completed."""
+        """
+        Block until all background database updates have completed.
+
+        Note that callers must ensure that's a store property created on the
+        testcase.
+        """
         while not self.get_success(
             self.store.db_pool.updates.has_completed_background_updates()
         ):