diff options
-rw-r--r-- | synapse/handlers/relations.py | 18 | ||||
-rw-r--r-- | synapse/rest/client/relations.py | 4 | ||||
-rw-r--r-- | tests/rest/client/test_relations.py | 46 |
3 files changed, 68 insertions, 0 deletions
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index af83ac1a7b..8f17ee4290 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -485,6 +485,7 @@ class RelationsHandler: self, requester: Requester, room_id: str, + include: str, limit: int = 5, from_token: Optional[StreamToken] = None, to_token: Optional[StreamToken] = None, @@ -494,6 +495,8 @@ class RelationsHandler: Args: requester: The user requesting the relations. room_id: The room the event belongs to. + include: One of "all" or "participated" to indicate which threads should + be returned. limit: Only fetch the most recent `limit` events. from_token: Fetch rows from the given token, or from the start if None. to_token: Fetch rows up to the given token, or up to the end if None. @@ -518,6 +521,21 @@ class RelationsHandler: events = await self._main_store.get_events_as_list(thread_roots) + if include == "participated": + # Pre-seed thread participation with whether the requester sent the event. + participated = {event.event_id: event.sender == user_id for event in events} + # For events the requester did not send, check the database for whether + # the requester sent a threaded reply. + participated.update( + await self._main_store.get_threads_participated( + [eid for eid, p in participated.items() if not p], + user_id, + ) + ) + + # Limit the returned threads to those the user has participated in. + events = [event for event in events if participated[event.event_id]] + events = await filter_events_for_client( self._storage_controllers, user_id, diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index faa962e3a8..8d1fd4fb98 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -113,6 +113,9 @@ class ThreadsServlet(RestServlet): limit = parse_integer(request, "limit", default=5) from_token_str = parse_string(request, "from") to_token_str = parse_string(request, "to") + include = parse_string( + request, "include", default="all", allowed_values=["all", "participated"] + ) # Return the relations from_token = None @@ -125,6 +128,7 @@ class ThreadsServlet(RestServlet): result = await self._relations_handler.get_threads( requester=requester, room_id=room_id, + include=include, limit=limit, from_token=from_token, to_token=to_token, diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 0666bec479..6b302d90bf 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -1750,4 +1750,50 @@ class ThreadsTestCase(BaseRelationsTestCase): self.assertNotIn("next_batch", channel.json_body, channel.json_body) + @unittest.override_config({"experimental_features": {"msc3856_enabled": True}}) + def test_include(self) -> None: + """Filtering threads to all or participated in should work.""" + # Thread 1 has the user as the root event. + thread_1 = self.parent_id + self._send_relation( + RelationTypes.THREAD, "m.room.test", access_token=self.user2_token + ) + + # Thread 2 has the user replying. + res = self.helper.send(self.room, body="Thread Root!", tok=self.user2_token) + thread_2 = res["event_id"] + self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2) + + # Thread 3 has the user not participating in. + res = self.helper.send(self.room, body="Another thread!", tok=self.user2_token) + thread_3 = res["event_id"] + self._send_relation( + RelationTypes.THREAD, + "m.room.test", + access_token=self.user2_token, + parent_id=thread_3, + ) + + # All threads in the room. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] + self.assertEqual( + thread_roots, [thread_3, thread_2, thread_1], channel.json_body + ) + + # Only participated threads. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads?include=participated", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] + self.assertEqual(thread_roots, [thread_2, thread_1], channel.json_body) + # XXX Test ignoring users. |