summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/11161.feature1
-rw-r--r--synapse/handlers/message.py54
-rw-r--r--synapse/storage/databases/main/relations.py67
-rw-r--r--tests/rest/client/test_relations.py62
4 files changed, 176 insertions, 8 deletions
diff --git a/changelog.d/11161.feature b/changelog.d/11161.feature
new file mode 100644
index 0000000000..76b0d28084
--- /dev/null
+++ b/changelog.d/11161.feature
@@ -0,0 +1 @@
+Experimental support for the thread relation defined in [MSC3440](https://github.com/matrix-org/matrix-doc/pull/3440).
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index d4c2a6ab7a..22dd4cf5fd 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -1001,13 +1001,52 @@ class EventCreationHandler:
             )
 
         self.validator.validate_new(event, self.config)
+        await self._validate_event_relation(event)
+        logger.debug("Created event %s", event.event_id)
+
+        return event, context
+
+    async def _validate_event_relation(self, event: EventBase) -> None:
+        """
+        Ensure the relation data on a new event is not bogus.
+
+        Args:
+            event: The event being created.
+
+        Raises:
+            SynapseError if the event is invalid.
+        """
+
+        relation = event.content.get("m.relates_to")
+        if not relation:
+            return
+
+        relation_type = relation.get("rel_type")
+        if not relation_type:
+            return
+
+        # Ensure the parent is real.
+        relates_to = relation.get("event_id")
+        if not relates_to:
+            return
+
+        parent_event = await self.store.get_event(relates_to, allow_none=True)
+        if parent_event:
+            # And in the same room.
+            if parent_event.room_id != event.room_id:
+                raise SynapseError(400, "Relations must be in the same room")
+
+        else:
+            # There must be some reason that the client knows the event exists,
+            # see if there are existing relations. If so, assume everything is fine.
+            if not await self.store.event_is_target_of_relation(relates_to):
+                # Otherwise, the client can't know about the parent event!
+                raise SynapseError(400, "Can't send relation to unknown event")
 
         # If this event is an annotation then we check that that the sender
         # can't annotate the same way twice (e.g. stops users from liking an
         # event multiple times).
-        relation = event.content.get("m.relates_to", {})
-        if relation.get("rel_type") == RelationTypes.ANNOTATION:
-            relates_to = relation["event_id"]
+        if relation_type == RelationTypes.ANNOTATION:
             aggregation_key = relation["key"]
 
             already_exists = await self.store.has_user_annotated_event(
@@ -1016,9 +1055,12 @@ class EventCreationHandler:
             if already_exists:
                 raise SynapseError(400, "Can't send same reaction twice")
 
-        logger.debug("Created event %s", event.event_id)
-
-        return event, context
+        # Don't attempt to start a thread if the parent event is a relation.
+        elif relation_type == RelationTypes.THREAD:
+            if await self.store.event_includes_relation(relates_to):
+                raise SynapseError(
+                    400, "Cannot start threads from an event with a relation"
+                )
 
     @measure_func("handle_new_client_event")
     async def handle_new_client_event(
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 907af10995..0a43acda07 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -132,6 +132,69 @@ class RelationsWorkerStore(SQLBaseStore):
             "get_recent_references_for_event", _get_recent_references_for_event_txn
         )
 
+    async def event_includes_relation(self, event_id: str) -> bool:
+        """Check if the given event relates to another event.
+
+        An event has a relation if it has a valid m.relates_to with a rel_type
+        and event_id in the content:
+
+        {
+            "content": {
+                "m.relates_to": {
+                    "rel_type": "m.replace",
+                    "event_id": "$other_event_id"
+                }
+            }
+        }
+
+        Args:
+            event_id: The event to check.
+
+        Returns:
+            True if the event includes a valid relation.
+        """
+
+        result = await self.db_pool.simple_select_one_onecol(
+            table="event_relations",
+            keyvalues={"event_id": event_id},
+            retcol="event_id",
+            allow_none=True,
+            desc="event_includes_relation",
+        )
+        return result is not None
+
+    async def event_is_target_of_relation(self, parent_id: str) -> bool:
+        """Check if the given event is the target of another event's relation.
+
+        An event is the target of an event relation if it has a valid
+        m.relates_to with a rel_type and event_id pointing to parent_id in the
+        content:
+
+        {
+            "content": {
+                "m.relates_to": {
+                    "rel_type": "m.replace",
+                    "event_id": "$parent_id"
+                }
+            }
+        }
+
+        Args:
+            parent_id: The event to check.
+
+        Returns:
+            True if the event is the target of another event's relation.
+        """
+
+        result = await self.db_pool.simple_select_one_onecol(
+            table="event_relations",
+            keyvalues={"relates_to_id": parent_id},
+            retcol="event_id",
+            allow_none=True,
+            desc="event_is_target_of_relation",
+        )
+        return result is not None
+
     @cached(tree=True)
     async def get_aggregation_groups_for_event(
         self,
@@ -362,7 +425,7 @@ class RelationsWorkerStore(SQLBaseStore):
                 %s;
         """
 
-        def _get_if_event_has_relations(txn) -> List[str]:
+        def _get_if_events_have_relations(txn) -> List[str]:
             clauses: List[str] = []
             clause, args = make_in_list_sql_clause(
                 txn.database_engine, "relates_to_id", parent_ids
@@ -387,7 +450,7 @@ class RelationsWorkerStore(SQLBaseStore):
             return [row[0] for row in txn]
 
         return await self.db_pool.runInteraction(
-            "get_if_event_has_relations", _get_if_event_has_relations
+            "get_if_events_have_relations", _get_if_events_have_relations
         )
 
     async def has_user_annotated_event(
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index 78c2fb86b9..b8a1b92a89 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -91,6 +91,49 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         channel = self._send_relation(RelationTypes.ANNOTATION, EventTypes.Member)
         self.assertEquals(400, channel.code, channel.json_body)
 
+    def test_deny_invalid_event(self):
+        """Test that we deny relations on non-existant events"""
+        channel = self._send_relation(
+            RelationTypes.ANNOTATION,
+            EventTypes.Message,
+            parent_id="foo",
+            content={"body": "foo", "msgtype": "m.text"},
+        )
+        self.assertEquals(400, channel.code, channel.json_body)
+
+        # Unless that event is referenced from another event!
+        self.get_success(
+            self.hs.get_datastore().db_pool.simple_insert(
+                table="event_relations",
+                values={
+                    "event_id": "bar",
+                    "relates_to_id": "foo",
+                    "relation_type": RelationTypes.THREAD,
+                },
+                desc="test_deny_invalid_event",
+            )
+        )
+        channel = self._send_relation(
+            RelationTypes.THREAD,
+            EventTypes.Message,
+            parent_id="foo",
+            content={"body": "foo", "msgtype": "m.text"},
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+
+    def test_deny_invalid_room(self):
+        """Test that we deny relations on non-existant events"""
+        # Create another room and send a message in it.
+        room2 = self.helper.create_room_as(self.user_id, tok=self.user_token)
+        res = self.helper.send(room2, body="Hi!", tok=self.user_token)
+        parent_id = res["event_id"]
+
+        # Attempt to send an annotation to that event.
+        channel = self._send_relation(
+            RelationTypes.ANNOTATION, "m.reaction", parent_id=parent_id, key="A"
+        )
+        self.assertEquals(400, channel.code, channel.json_body)
+
     def test_deny_double_react(self):
         """Test that we deny relations on membership events"""
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
@@ -99,6 +142,25 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
         self.assertEquals(400, channel.code, channel.json_body)
 
+    def test_deny_forked_thread(self):
+        """It is invalid to start a thread off a thread."""
+        channel = self._send_relation(
+            RelationTypes.THREAD,
+            "m.room.message",
+            content={"msgtype": "m.text", "body": "foo"},
+            parent_id=self.parent_id,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        parent_id = channel.json_body["event_id"]
+
+        channel = self._send_relation(
+            RelationTypes.THREAD,
+            "m.room.message",
+            content={"msgtype": "m.text", "body": "foo"},
+            parent_id=parent_id,
+        )
+        self.assertEquals(400, channel.code, channel.json_body)
+
     def test_basic_paginate_relations(self):
         """Tests that calling pagination API correctly the latest relations."""
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")