summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/11391.feature1
-rw-r--r--synapse/storage/databases/main/events.py29
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py88
-rw-r--r--synapse/storage/schema/main/delta/65/07_arbitrary_relations.sql (renamed from synapse/storage/schema/main/delta/65/02_thread_relations.sql)2
-rw-r--r--tests/rest/client/test_relations.py111
-rw-r--r--tests/unittest.py7
6 files changed, 193 insertions, 45 deletions
diff --git a/changelog.d/11391.feature b/changelog.d/11391.feature
new file mode 100644
index 0000000000..4f696285a7
--- /dev/null
+++ b/changelog.d/11391.feature
@@ -0,0 +1 @@
+Store and allow querying of arbitrary event relations.
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 120e4807d1..06832221ad 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1,6 +1,6 @@
 # Copyright 2014-2016 OpenMarket Ltd
 # Copyright 2018-2019 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -1696,34 +1696,33 @@ class PersistEventsStore:
                     },
                 )
 
-    def _handle_event_relations(self, txn, event):
-        """Handles inserting relation data during peristence of events
+    def _handle_event_relations(
+        self, txn: LoggingTransaction, event: EventBase
+    ) -> None:
+        """Handles inserting relation data during persistence of events
 
         Args:
-            txn
-            event (EventBase)
+            txn: The current database transaction.
+            event: The event which might have relations.
         """
         relation = event.content.get("m.relates_to")
         if not relation:
             # No relations
             return
 
+        # Relations must have a type and parent event ID.
         rel_type = relation.get("rel_type")
-        if rel_type not in (
-            RelationTypes.ANNOTATION,
-            RelationTypes.REFERENCE,
-            RelationTypes.REPLACE,
-            RelationTypes.THREAD,
-        ):
-            # Unknown relation type
+        if not isinstance(rel_type, str):
             return
 
         parent_id = relation.get("event_id")
-        if not parent_id:
-            # Invalid relation
+        if not isinstance(parent_id, str):
             return
 
-        aggregation_key = relation.get("key")
+        # Annotations have a key field.
+        aggregation_key = None
+        if rel_type == RelationTypes.ANNOTATION:
+            aggregation_key = relation.get("key")
 
         self.db_pool.simple_insert_txn(
             txn,
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index ae3a8a63e4..c88fd35e7f 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -1,4 +1,4 @@
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -171,8 +171,14 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
             self._purged_chain_cover_index,
         )
 
+        # The event_thread_relation background update was replaced with the
+        # event_arbitrary_relations one, which handles any relation to avoid
+        # needed to potentially crawl the entire events table in the future.
+        self.db_pool.updates.register_noop_background_update("event_thread_relation")
+
         self.db_pool.updates.register_background_update_handler(
-            "event_thread_relation", self._event_thread_relation
+            "event_arbitrary_relations",
+            self._event_arbitrary_relations,
         )
 
         ################################################################################
@@ -1099,23 +1105,27 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
         return result
 
-    async def _event_thread_relation(self, progress: JsonDict, batch_size: int) -> int:
-        """Background update handler which will store thread relations for existing events."""
+    async def _event_arbitrary_relations(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
+        """Background update handler which will store previously unknown relations for existing events."""
         last_event_id = progress.get("last_event_id", "")
 
-        def _event_thread_relation_txn(txn: LoggingTransaction) -> int:
+        def _event_arbitrary_relations_txn(txn: LoggingTransaction) -> int:
+            # Fetch events and then filter based on whether the event has a
+            # relation or not.
             txn.execute(
                 """
                 SELECT event_id, json FROM event_json
-                LEFT JOIN event_relations USING (event_id)
-                WHERE event_id > ? AND event_relations.event_id IS NULL
+                WHERE event_id > ?
                 ORDER BY event_id LIMIT ?
                 """,
                 (last_event_id, batch_size),
             )
 
             results = list(txn)
-            missing_thread_relations = []
+            # (event_id, parent_id, rel_type) for each relation
+            relations_to_insert: List[Tuple[str, str, str]] = []
             for (event_id, event_json_raw) in results:
                 try:
                     event_json = db_to_json(event_json_raw)
@@ -1127,48 +1137,70 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
                     )
                     continue
 
-                # If there's no relation (or it is not a thread), skip!
+                # If there's no relation, skip!
                 relates_to = event_json["content"].get("m.relates_to")
                 if not relates_to or not isinstance(relates_to, dict):
                     continue
-                if relates_to.get("rel_type") != RelationTypes.THREAD:
+
+                # If the relation type or parent event ID is not a string, skip it.
+                #
+                # Do not consider relation types that have existed for a long time,
+                # since they will already be listed in the `event_relations` table.
+                rel_type = relates_to.get("rel_type")
+                if not isinstance(rel_type, str) or rel_type in (
+                    RelationTypes.ANNOTATION,
+                    RelationTypes.REFERENCE,
+                    RelationTypes.REPLACE,
+                ):
                     continue
 
-                # Get the parent ID.
                 parent_id = relates_to.get("event_id")
                 if not isinstance(parent_id, str):
                     continue
 
-                missing_thread_relations.append((event_id, parent_id))
+                relations_to_insert.append((event_id, parent_id, rel_type))
+
+            # Insert the missing data, note that we upsert here in case the event
+            # has already been processed.
+            if relations_to_insert:
+                self.db_pool.simple_upsert_many_txn(
+                    txn=txn,
+                    table="event_relations",
+                    key_names=("event_id",),
+                    key_values=[(r[0],) for r in relations_to_insert],
+                    value_names=("relates_to_id", "relation_type"),
+                    value_values=[r[1:] for r in relations_to_insert],
+                )
 
-            # Insert the missing data.
-            self.db_pool.simple_insert_many_txn(
-                txn=txn,
-                table="event_relations",
-                values=[
-                    {
-                        "event_id": event_id,
-                        "relates_to_Id": parent_id,
-                        "relation_type": RelationTypes.THREAD,
-                    }
-                    for event_id, parent_id in missing_thread_relations
-                ],
-            )
+                # Iterate the parent IDs and invalidate caches.
+                for parent_id in {r[1] for r in relations_to_insert}:
+                    cache_tuple = (parent_id,)
+                    self._invalidate_cache_and_stream(
+                        txn, self.get_relations_for_event, cache_tuple
+                    )
+                    self._invalidate_cache_and_stream(
+                        txn, self.get_aggregation_groups_for_event, cache_tuple
+                    )
+                    self._invalidate_cache_and_stream(
+                        txn, self.get_thread_summary, cache_tuple
+                    )
 
             if results:
                 latest_event_id = results[-1][0]
                 self.db_pool.updates._background_update_progress_txn(
-                    txn, "event_thread_relation", {"last_event_id": latest_event_id}
+                    txn, "event_arbitrary_relations", {"last_event_id": latest_event_id}
                 )
 
             return len(results)
 
         num_rows = await self.db_pool.runInteraction(
-            desc="event_thread_relation", func=_event_thread_relation_txn
+            desc="event_arbitrary_relations", func=_event_arbitrary_relations_txn
         )
 
         if not num_rows:
-            await self.db_pool.updates._end_background_update("event_thread_relation")
+            await self.db_pool.updates._end_background_update(
+                "event_arbitrary_relations"
+            )
 
         return num_rows
 
diff --git a/synapse/storage/schema/main/delta/65/02_thread_relations.sql b/synapse/storage/schema/main/delta/65/07_arbitrary_relations.sql
index d60517f7b4..267b2cb539 100644
--- a/synapse/storage/schema/main/delta/65/02_thread_relations.sql
+++ b/synapse/storage/schema/main/delta/65/07_arbitrary_relations.sql
@@ -15,4 +15,4 @@
 
 -- Check old events for thread relations.
 INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
-  (6502, 'event_thread_relation', '{}');
+  (6507, 'event_arbitrary_relations', '{}');
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index b8a1b92a89..eb10d43217 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -1,4 +1,5 @@
 # Copyright 2019 New Vector Ltd
+# Copyright 2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -46,6 +47,8 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         return config
 
     def prepare(self, reactor, clock, hs):
+        self.store = hs.get_datastore()
+
         self.user_id, self.user_token = self._create_user("alice")
         self.user2_id, self.user2_token = self._create_user("bob")
 
@@ -765,6 +768,52 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         self.assertIn("chunk", channel.json_body)
         self.assertEquals(channel.json_body["chunk"], [])
 
+    def test_unknown_relations(self):
+        """Unknown relations should be accepted."""
+        channel = self._send_relation("m.relation.test", "m.room.test")
+        self.assertEquals(200, channel.code, channel.json_body)
+        event_id = channel.json_body["event_id"]
+
+        channel = self.make_request(
+            "GET",
+            "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1"
+            % (self.room, self.parent_id),
+            access_token=self.user_token,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+
+        # We expect to get back a single pagination result, which is the full
+        # relation event we sent above.
+        self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body)
+        self.assert_dict(
+            {"event_id": event_id, "sender": self.user_id, "type": "m.room.test"},
+            channel.json_body["chunk"][0],
+        )
+
+        # We also expect to get the original event (the id of which is self.parent_id)
+        self.assertEquals(
+            channel.json_body["original_event"]["event_id"], self.parent_id
+        )
+
+        # When bundling the unknown relation is not included.
+        channel = self.make_request(
+            "GET",
+            "/rooms/%s/event/%s" % (self.room, self.parent_id),
+            access_token=self.user_token,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertNotIn("m.relations", channel.json_body["unsigned"])
+
+        # But unknown relations can be directly queried.
+        channel = self.make_request(
+            "GET",
+            "/_matrix/client/unstable/rooms/%s/aggregations/%s?limit=1"
+            % (self.room, self.parent_id),
+            access_token=self.user_token,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEquals(channel.json_body["chunk"], [])
+
     def _send_relation(
         self,
         relation_type: str,
@@ -811,3 +860,65 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         access_token = self.login(localpart, "abc123")
 
         return user_id, access_token
+
+    def test_background_update(self):
+        """Test the event_arbitrary_relations background update."""
+        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
+        self.assertEquals(200, channel.code, channel.json_body)
+        annotation_event_id_good = channel.json_body["event_id"]
+
+        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="A")
+        self.assertEquals(200, channel.code, channel.json_body)
+        annotation_event_id_bad = channel.json_body["event_id"]
+
+        channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
+        self.assertEquals(200, channel.code, channel.json_body)
+        thread_event_id = channel.json_body["event_id"]
+
+        # Clean-up the table as if the inserts did not happen during event creation.
+        self.get_success(
+            self.store.db_pool.simple_delete_many(
+                table="event_relations",
+                column="event_id",
+                iterable=(annotation_event_id_bad, thread_event_id),
+                keyvalues={},
+                desc="RelationsTestCase.test_background_update",
+            )
+        )
+
+        # Only the "good" annotation should be found.
+        channel = self.make_request(
+            "GET",
+            f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=10",
+            access_token=self.user_token,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertEquals(
+            [ev["event_id"] for ev in channel.json_body["chunk"]],
+            [annotation_event_id_good],
+        )
+
+        # Insert and run the background update.
+        self.get_success(
+            self.store.db_pool.simple_insert(
+                "background_updates",
+                {"update_name": "event_arbitrary_relations", "progress_json": "{}"},
+            )
+        )
+
+        # Ugh, have to reset this flag
+        self.store.db_pool.updates._all_done = False
+        self.wait_for_background_updates()
+
+        # The "good" annotation and the thread should be found, but not the "bad"
+        # annotation.
+        channel = self.make_request(
+            "GET",
+            f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=10",
+            access_token=self.user_token,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        self.assertCountEqual(
+            [ev["event_id"] for ev in channel.json_body["chunk"]],
+            [annotation_event_id_good, thread_event_id],
+        )
diff --git a/tests/unittest.py b/tests/unittest.py
index c9a08a3420..165aafc574 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -331,7 +331,12 @@ class HomeserverTestCase(TestCase):
             time.sleep(0.01)
 
     def wait_for_background_updates(self) -> None:
-        """Block until all background database updates have completed."""
+        """
+        Block until all background database updates have completed.
+
+        Note that callers must ensure that's a store property created on the
+        testcase.
+        """
         while not self.get_success(
             self.store.db_pool.updates.has_completed_background_updates()
         ):