summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/12766.bugfix1
-rw-r--r--synapse/handlers/relations.py58
-rw-r--r--tests/rest/client/test_relations.py85
3 files changed, 97 insertions, 47 deletions
diff --git a/changelog.d/12766.bugfix b/changelog.d/12766.bugfix
new file mode 100644
index 0000000000..912c3deb70
--- /dev/null
+++ b/changelog.d/12766.bugfix
@@ -0,0 +1 @@
+Implement [MSC3816](https://github.com/matrix-org/matrix-spec-proposals/pull/3816): sending the root event in a thread should count as "participated" in it.
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index 9a1cc11bb3..0b63cd2186 100644
--- a/synapse/handlers/relations.py
+++ b/synapse/handlers/relations.py
@@ -12,16 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import (
-    TYPE_CHECKING,
-    Collection,
-    Dict,
-    FrozenSet,
-    Iterable,
-    List,
-    Optional,
-    Tuple,
-)
+from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tuple
 
 import attr
 
@@ -256,13 +247,19 @@ class RelationsHandler:
 
         return filtered_results
 
-    async def get_threads_for_events(
-        self, event_ids: Collection[str], user_id: str, ignored_users: FrozenSet[str]
+    async def _get_threads_for_events(
+        self,
+        events_by_id: Dict[str, EventBase],
+        relations_by_id: Dict[str, str],
+        user_id: str,
+        ignored_users: FrozenSet[str],
     ) -> Dict[str, _ThreadAggregation]:
         """Get the bundled aggregations for threads for the requested events.
 
         Args:
-            event_ids: Events to get aggregations for threads.
+            events_by_id: A map of event_id to events to get aggregations for threads.
+            relations_by_id: A map of event_id to the relation type, if one exists
+                for that event.
             user_id: The user requesting the bundled aggregations.
             ignored_users: The users ignored by the requesting user.
 
@@ -273,16 +270,34 @@ class RelationsHandler:
         """
         user = UserID.from_string(user_id)
 
+        # It is not valid to start a thread on an event which itself relates to another event.
+        event_ids = [eid for eid in events_by_id.keys() if eid not in relations_by_id]
+
         # Fetch thread summaries.
         summaries = await self._main_store.get_thread_summaries(event_ids)
 
-        # Only fetch participated for a limited selection based on what had
-        # summaries.
+        # Limit fetching whether the requester has participated in a thread to
+        # events which are thread roots.
         thread_event_ids = [
             event_id for event_id, summary in summaries.items() if summary
         ]
-        participated = await self._main_store.get_threads_participated(
-            thread_event_ids, user_id
+
+        # Pre-seed thread participation with whether the requester sent the event.
+        participated = {
+            event_id: events_by_id[event_id].sender == user_id
+            for event_id in thread_event_ids
+        }
+        # For events the requester did not send, check the database for whether
+        # the requester sent a threaded reply.
+        participated.update(
+            await self._main_store.get_threads_participated(
+                [
+                    event_id
+                    for event_id in thread_event_ids
+                    if not participated[event_id]
+                ],
+                user_id,
+            )
         )
 
         # Then subtract off the results for any ignored users.
@@ -343,7 +358,8 @@ class RelationsHandler:
                     count=thread_count,
                     # If there's a thread summary it must also exist in the
                     # participated dictionary.
-                    current_user_participated=participated[event_id],
+                    current_user_participated=events_by_id[event_id].sender == user_id
+                    or participated[event_id],
                 )
 
         return results
@@ -401,9 +417,9 @@ class RelationsHandler:
         # events to be fetched. Thus, we check those first!
 
         # Fetch thread summaries (but only for the directly requested events).
-        threads = await self.get_threads_for_events(
-            # It is not valid to start a thread on an event which itself relates to another event.
-            [eid for eid in events_by_id.keys() if eid not in relations_by_id],
+        threads = await self._get_threads_for_events(
+            events_by_id,
+            relations_by_id,
             user_id,
             ignored_users,
         )
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index bc9cc51b92..62e4db23ef 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -896,6 +896,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
         relation_type: str,
         assertion_callable: Callable[[JsonDict], None],
         expected_db_txn_for_event: int,
+        access_token: Optional[str] = None,
     ) -> None:
         """
         Makes requests to various endpoints which should include bundled aggregations
@@ -907,7 +908,9 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
                 for relation-specific assertions.
             expected_db_txn_for_event: The number of database transactions which
                 are expected for a call to /event/.
+            access_token: The access token to user, defaults to self.user_token.
         """
+        access_token = access_token or self.user_token
 
         def assert_bundle(event_json: JsonDict) -> None:
             """Assert the expected values of the bundled aggregations."""
@@ -921,7 +924,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
         channel = self.make_request(
             "GET",
             f"/rooms/{self.room}/event/{self.parent_id}",
-            access_token=self.user_token,
+            access_token=access_token,
         )
         self.assertEqual(200, channel.code, channel.json_body)
         assert_bundle(channel.json_body)
@@ -932,7 +935,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
         channel = self.make_request(
             "GET",
             f"/rooms/{self.room}/messages?dir=b",
-            access_token=self.user_token,
+            access_token=access_token,
         )
         self.assertEqual(200, channel.code, channel.json_body)
         assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"]))
@@ -941,7 +944,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
         channel = self.make_request(
             "GET",
             f"/rooms/{self.room}/context/{self.parent_id}",
-            access_token=self.user_token,
+            access_token=access_token,
         )
         self.assertEqual(200, channel.code, channel.json_body)
         assert_bundle(channel.json_body["event"])
@@ -949,7 +952,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
         # Request sync.
         filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 4}}}')
         channel = self.make_request(
-            "GET", f"/sync?filter={filter}", access_token=self.user_token
+            "GET", f"/sync?filter={filter}", access_token=access_token
         )
         self.assertEqual(200, channel.code, channel.json_body)
         room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
@@ -962,7 +965,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
             "/search",
             # Search term matches the parent message.
             content={"search_categories": {"room_events": {"search_term": "Hi"}}},
-            access_token=self.user_token,
+            access_token=access_token,
         )
         self.assertEqual(200, channel.code, channel.json_body)
         chunk = [
@@ -1037,30 +1040,60 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
         """
         Test that threads get correctly bundled.
         """
-        self._send_relation(RelationTypes.THREAD, "m.room.test")
-        channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
+        # The root message is from "user", send replies as "user2".
+        self._send_relation(
+            RelationTypes.THREAD, "m.room.test", access_token=self.user2_token
+        )
+        channel = self._send_relation(
+            RelationTypes.THREAD, "m.room.test", access_token=self.user2_token
+        )
         thread_2 = channel.json_body["event_id"]
 
-        def assert_thread(bundled_aggregations: JsonDict) -> None:
-            self.assertEqual(2, bundled_aggregations.get("count"))
-            self.assertTrue(bundled_aggregations.get("current_user_participated"))
-            # The latest thread event has some fields that don't matter.
-            self.assert_dict(
-                {
-                    "content": {
-                        "m.relates_to": {
-                            "event_id": self.parent_id,
-                            "rel_type": RelationTypes.THREAD,
-                        }
+        # This needs two assertion functions which are identical except for whether
+        # the current_user_participated flag is True, create a factory for the
+        # two versions.
+        def _gen_assert(participated: bool) -> Callable[[JsonDict], None]:
+            def assert_thread(bundled_aggregations: JsonDict) -> None:
+                self.assertEqual(2, bundled_aggregations.get("count"))
+                self.assertEqual(
+                    participated, bundled_aggregations.get("current_user_participated")
+                )
+                # The latest thread event has some fields that don't matter.
+                self.assert_dict(
+                    {
+                        "content": {
+                            "m.relates_to": {
+                                "event_id": self.parent_id,
+                                "rel_type": RelationTypes.THREAD,
+                            }
+                        },
+                        "event_id": thread_2,
+                        "sender": self.user2_id,
+                        "type": "m.room.test",
                     },
-                    "event_id": thread_2,
-                    "sender": self.user_id,
-                    "type": "m.room.test",
-                },
-                bundled_aggregations.get("latest_event"),
-            )
+                    bundled_aggregations.get("latest_event"),
+                )
 
-        self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 9)
+            return assert_thread
+
+        # The "user" sent the root event and is making queries for the bundled
+        # aggregations: they have participated.
+        self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 8)
+        # The "user2" sent replies in the thread and is making queries for the
+        # bundled aggregations: they have participated.
+        #
+        # Note that this re-uses some cached values, so the total number of
+        # queries is much smaller.
+        self._test_bundled_aggregations(
+            RelationTypes.THREAD, _gen_assert(True), 2, access_token=self.user2_token
+        )
+
+        # A user with no interactions with the thread: they have not participated.
+        user3_id, user3_token = self._create_user("charlie")
+        self.helper.join(self.room, user=user3_id, tok=user3_token)
+        self._test_bundled_aggregations(
+            RelationTypes.THREAD, _gen_assert(False), 2, access_token=user3_token
+        )
 
     def test_thread_with_bundled_aggregations_for_latest(self) -> None:
         """
@@ -1106,7 +1139,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
                 bundled_aggregations["latest_event"].get("unsigned"),
             )
 
-        self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 9)
+        self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 8)
 
     def test_nested_thread(self) -> None:
         """