summary refs log tree commit diff
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2021-10-21 14:39:16 -0400
committerGitHub <noreply@github.com>2021-10-21 14:39:16 -0400
commitba00e20234eadae66f105f8bda64e39beed9a92d (patch)
treea828912312bbc6ecbdf9d3a5d2bfe27b7296fb3f
parentFix adding excluded users to the private room sharing tables when joining a r... (diff)
downloadsynapse-ba00e20234eadae66f105f8bda64e39beed9a92d.tar.xz
Add a thread relation type per MSC3440. (#11088)
Adds experimental support for MSC3440's `io.element.thread` relation
type (and the aggregation for it).
Diffstat (limited to '')
-rw-r--r--changelog.d/11088.feature1
-rw-r--r--synapse/api/constants.py1
-rw-r--r--synapse/config/experimental.py2
-rw-r--r--synapse/events/utils.py17
-rw-r--r--synapse/rest/client/relations.py3
-rw-r--r--synapse/storage/databases/main/events.py4
-rw-r--r--synapse/storage/databases/main/relations.py59
-rw-r--r--tests/rest/client/test_relations.py40
8 files changed, 119 insertions, 8 deletions
diff --git a/changelog.d/11088.feature b/changelog.d/11088.feature
new file mode 100644
index 0000000000..76b0d28084
--- /dev/null
+++ b/changelog.d/11088.feature
@@ -0,0 +1 @@
+Experimental support for the thread relation defined in [MSC3440](https://github.com/matrix-org/matrix-doc/pull/3440).
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index a31f037748..a33ac34161 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -176,6 +176,7 @@ class RelationTypes:
     ANNOTATION = "m.annotation"
     REPLACE = "m.replace"
     REFERENCE = "m.reference"
+    THREAD = "io.element.thread"
 
 
 class LimitBlockingTypes:
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index b013a3918c..8b098ad48d 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -26,6 +26,8 @@ class ExperimentalConfig(Config):
 
         # Whether to enable experimental MSC1849 (aka relations) support
         self.msc1849_enabled = config.get("experimental_msc1849_support_enabled", True)
+        # MSC3440 (thread relation)
+        self.msc3440_enabled: bool = experimental.get("msc3440_enabled", False)
 
         # MSC3026 (busy presence state)
         self.msc3026_enabled: bool = experimental.get("msc3026_enabled", False)
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 3f3eba86a8..6fa631aa1d 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -386,6 +386,7 @@ class EventClientSerializer:
     def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
         self._msc1849_enabled = hs.config.experimental.msc1849_enabled
+        self._msc3440_enabled = hs.config.experimental.msc3440_enabled
 
     async def serialize_event(
         self,
@@ -462,6 +463,22 @@ class EventClientSerializer:
                     "sender": edit.sender,
                 }
 
+            # If this event is the start of a thread, include a summary of the replies.
+            if self._msc3440_enabled:
+                (
+                    thread_count,
+                    latest_thread_event,
+                ) = await self.store.get_thread_summary(event_id)
+                if latest_thread_event:
+                    r = serialized_event["unsigned"].setdefault("m.relations", {})
+                    r[RelationTypes.THREAD] = {
+                        # Don't bundle aggregations as this could recurse forever.
+                        "latest_event": await self.serialize_event(
+                            latest_thread_event, time_now, bundle_aggregations=False
+                        ),
+                        "count": thread_count,
+                    }
+
         return serialized_event
 
     async def serialize_events(
diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py
index d695c18be2..58f6699073 100644
--- a/synapse/rest/client/relations.py
+++ b/synapse/rest/client/relations.py
@@ -128,9 +128,10 @@ class RelationSendServlet(RestServlet):
 
         content["m.relates_to"] = {
             "event_id": parent_id,
-            "key": aggregation_key,
             "rel_type": relation_type,
         }
+        if aggregation_key is not None:
+            content["m.relates_to"]["key"] = aggregation_key
 
         event_dict = {
             "type": event_type,
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 37439f8562..8d9086ecf0 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1710,6 +1710,7 @@ class PersistEventsStore:
             RelationTypes.ANNOTATION,
             RelationTypes.REFERENCE,
             RelationTypes.REPLACE,
+            RelationTypes.THREAD,
         ):
             # Unknown relation type
             return
@@ -1740,6 +1741,9 @@ class PersistEventsStore:
         if rel_type == RelationTypes.REPLACE:
             txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,))
 
+        if rel_type == RelationTypes.THREAD:
+            txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,))
+
     def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):
         """Handles keeping track of insertion events and edges/connections.
         Part of MSC2716.
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 2bbf6d6a95..40760fbd1b 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 
 import logging
-from typing import Optional
+from typing import Optional, Tuple
 
 import attr
 
@@ -269,6 +269,63 @@ class RelationsWorkerStore(SQLBaseStore):
 
         return await self.get_event(edit_id, allow_none=True)
 
+    @cached()
+    async def get_thread_summary(
+        self, event_id: str
+    ) -> Tuple[int, Optional[EventBase]]:
+        """Get the number of threaded replies, the senders of those replies, and
+        the latest reply (if any) for the given event.
+
+        Args:
+            event_id: The original event ID
+
+        Returns:
+            The number of items in the thread and the most recent response, if any.
+        """
+
+        def _get_thread_summary_txn(txn) -> Tuple[int, Optional[str]]:
+            # Fetch the count of threaded events and the latest event ID.
+            # TODO Should this only allow m.room.message events.
+            sql = """
+                SELECT event_id
+                FROM event_relations
+                INNER JOIN events USING (event_id)
+                WHERE
+                    relates_to_id = ?
+                    AND relation_type = ?
+                ORDER BY topological_ordering DESC, stream_ordering DESC
+                LIMIT 1
+            """
+
+            txn.execute(sql, (event_id, RelationTypes.THREAD))
+            row = txn.fetchone()
+            if row is None:
+                return 0, None
+
+            latest_event_id = row[0]
+
+            sql = """
+                SELECT COALESCE(COUNT(event_id), 0)
+                FROM event_relations
+                WHERE
+                    relates_to_id = ?
+                    AND relation_type = ?
+            """
+            txn.execute(sql, (event_id, RelationTypes.THREAD))
+            count = txn.fetchone()[0]
+
+            return count, latest_event_id
+
+        count, latest_event_id = await self.db_pool.runInteraction(
+            "get_thread_summary", _get_thread_summary_txn
+        )
+
+        latest_event = None
+        if latest_event_id:
+            latest_event = await self.get_event(latest_event_id, allow_none=True)
+
+        return count, latest_event
+
     async def has_user_annotated_event(
         self, parent_id: str, event_type: str, aggregation_key: str, sender: str
     ) -> bool:
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index 3c7d49f0b4..78c2fb86b9 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -101,10 +101,10 @@ class RelationsTestCase(unittest.HomeserverTestCase):
 
     def test_basic_paginate_relations(self):
         """Tests that calling pagination API correctly the latest relations."""
-        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction")
+        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
         self.assertEquals(200, channel.code, channel.json_body)
 
-        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction")
+        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
         self.assertEquals(200, channel.code, channel.json_body)
         annotation_id = channel.json_body["event_id"]
 
@@ -141,8 +141,10 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         """
 
         expected_event_ids = []
-        for _ in range(10):
-            channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction")
+        for idx in range(10):
+            channel = self._send_relation(
+                RelationTypes.ANNOTATION, "m.reaction", chr(ord("a") + idx)
+            )
             self.assertEquals(200, channel.code, channel.json_body)
             expected_event_ids.append(channel.json_body["event_id"])
 
@@ -386,8 +388,9 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         )
         self.assertEquals(400, channel.code, channel.json_body)
 
+    @unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
     def test_aggregation_get_event(self):
-        """Test that annotations and references get correctly bundled when
+        """Test that annotations, references, and threads get correctly bundled when
         getting the parent event.
         """
 
@@ -410,6 +413,13 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         self.assertEquals(200, channel.code, channel.json_body)
         reply_2 = channel.json_body["event_id"]
 
+        channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
+        self.assertEquals(200, channel.code, channel.json_body)
+
+        channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
+        self.assertEquals(200, channel.code, channel.json_body)
+        thread_2 = channel.json_body["event_id"]
+
         channel = self.make_request(
             "GET",
             "/rooms/%s/event/%s" % (self.room, self.parent_id),
@@ -429,6 +439,25 @@ class RelationsTestCase(unittest.HomeserverTestCase):
                 RelationTypes.REFERENCE: {
                     "chunk": [{"event_id": reply_1}, {"event_id": reply_2}]
                 },
+                RelationTypes.THREAD: {
+                    "count": 2,
+                    "latest_event": {
+                        "age": 100,
+                        "content": {
+                            "m.relates_to": {
+                                "event_id": self.parent_id,
+                                "rel_type": RelationTypes.THREAD,
+                            }
+                        },
+                        "event_id": thread_2,
+                        "origin_server_ts": 1600,
+                        "room_id": self.room,
+                        "sender": self.user_id,
+                        "type": "m.room.test",
+                        "unsigned": {"age": 100},
+                        "user_id": self.user_id,
+                    },
+                },
             },
         )
 
@@ -559,7 +588,6 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             {
                 "m.relates_to": {
                     "event_id": self.parent_id,
-                    "key": None,
                     "rel_type": "m.reference",
                 }
             },