diff options
-rw-r--r-- | changelog.d/12766.bugfix | 1 | ||||
-rw-r--r-- | synapse/handlers/relations.py | 58 | ||||
-rw-r--r-- | tests/rest/client/test_relations.py | 85 |
3 files changed, 97 insertions, 47 deletions
diff --git a/changelog.d/12766.bugfix b/changelog.d/12766.bugfix new file mode 100644 index 0000000000..912c3deb70 --- /dev/null +++ b/changelog.d/12766.bugfix @@ -0,0 +1 @@ +Implement [MSC3816](https://github.com/matrix-org/matrix-spec-proposals/pull/3816): sending the root event in a thread should count as "participated" in it. diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index 9a1cc11bb3..0b63cd2186 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -12,16 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import ( - TYPE_CHECKING, - Collection, - Dict, - FrozenSet, - Iterable, - List, - Optional, - Tuple, -) +from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tuple import attr @@ -256,13 +247,19 @@ class RelationsHandler: return filtered_results - async def get_threads_for_events( - self, event_ids: Collection[str], user_id: str, ignored_users: FrozenSet[str] + async def _get_threads_for_events( + self, + events_by_id: Dict[str, EventBase], + relations_by_id: Dict[str, str], + user_id: str, + ignored_users: FrozenSet[str], ) -> Dict[str, _ThreadAggregation]: """Get the bundled aggregations for threads for the requested events. Args: - event_ids: Events to get aggregations for threads. + events_by_id: A map of event_id to events to get aggregations for threads. + relations_by_id: A map of event_id to the relation type, if one exists + for that event. user_id: The user requesting the bundled aggregations. ignored_users: The users ignored by the requesting user. @@ -273,16 +270,34 @@ class RelationsHandler: """ user = UserID.from_string(user_id) + # It is not valid to start a thread on an event which itself relates to another event. + event_ids = [eid for eid in events_by_id.keys() if eid not in relations_by_id] + # Fetch thread summaries. summaries = await self._main_store.get_thread_summaries(event_ids) - # Only fetch participated for a limited selection based on what had - # summaries. + # Limit fetching whether the requester has participated in a thread to + # events which are thread roots. thread_event_ids = [ event_id for event_id, summary in summaries.items() if summary ] - participated = await self._main_store.get_threads_participated( - thread_event_ids, user_id + + # Pre-seed thread participation with whether the requester sent the event. + participated = { + event_id: events_by_id[event_id].sender == user_id + for event_id in thread_event_ids + } + # For events the requester did not send, check the database for whether + # the requester sent a threaded reply. + participated.update( + await self._main_store.get_threads_participated( + [ + event_id + for event_id in thread_event_ids + if not participated[event_id] + ], + user_id, + ) ) # Then subtract off the results for any ignored users. @@ -343,7 +358,8 @@ class RelationsHandler: count=thread_count, # If there's a thread summary it must also exist in the # participated dictionary. - current_user_participated=participated[event_id], + current_user_participated=events_by_id[event_id].sender == user_id + or participated[event_id], ) return results @@ -401,9 +417,9 @@ class RelationsHandler: # events to be fetched. Thus, we check those first! # Fetch thread summaries (but only for the directly requested events). - threads = await self.get_threads_for_events( - # It is not valid to start a thread on an event which itself relates to another event. - [eid for eid in events_by_id.keys() if eid not in relations_by_id], + threads = await self._get_threads_for_events( + events_by_id, + relations_by_id, user_id, ignored_users, ) diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index bc9cc51b92..62e4db23ef 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -896,6 +896,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): relation_type: str, assertion_callable: Callable[[JsonDict], None], expected_db_txn_for_event: int, + access_token: Optional[str] = None, ) -> None: """ Makes requests to various endpoints which should include bundled aggregations @@ -907,7 +908,9 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): for relation-specific assertions. expected_db_txn_for_event: The number of database transactions which are expected for a call to /event/. + access_token: The access token to user, defaults to self.user_token. """ + access_token = access_token or self.user_token def assert_bundle(event_json: JsonDict) -> None: """Assert the expected values of the bundled aggregations.""" @@ -921,7 +924,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): channel = self.make_request( "GET", f"/rooms/{self.room}/event/{self.parent_id}", - access_token=self.user_token, + access_token=access_token, ) self.assertEqual(200, channel.code, channel.json_body) assert_bundle(channel.json_body) @@ -932,7 +935,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): channel = self.make_request( "GET", f"/rooms/{self.room}/messages?dir=b", - access_token=self.user_token, + access_token=access_token, ) self.assertEqual(200, channel.code, channel.json_body) assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"])) @@ -941,7 +944,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): channel = self.make_request( "GET", f"/rooms/{self.room}/context/{self.parent_id}", - access_token=self.user_token, + access_token=access_token, ) self.assertEqual(200, channel.code, channel.json_body) assert_bundle(channel.json_body["event"]) @@ -949,7 +952,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): # Request sync. filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 4}}}') channel = self.make_request( - "GET", f"/sync?filter={filter}", access_token=self.user_token + "GET", f"/sync?filter={filter}", access_token=access_token ) self.assertEqual(200, channel.code, channel.json_body) room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] @@ -962,7 +965,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): "/search", # Search term matches the parent message. content={"search_categories": {"room_events": {"search_term": "Hi"}}}, - access_token=self.user_token, + access_token=access_token, ) self.assertEqual(200, channel.code, channel.json_body) chunk = [ @@ -1037,30 +1040,60 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): """ Test that threads get correctly bundled. """ - self._send_relation(RelationTypes.THREAD, "m.room.test") - channel = self._send_relation(RelationTypes.THREAD, "m.room.test") + # The root message is from "user", send replies as "user2". + self._send_relation( + RelationTypes.THREAD, "m.room.test", access_token=self.user2_token + ) + channel = self._send_relation( + RelationTypes.THREAD, "m.room.test", access_token=self.user2_token + ) thread_2 = channel.json_body["event_id"] - def assert_thread(bundled_aggregations: JsonDict) -> None: - self.assertEqual(2, bundled_aggregations.get("count")) - self.assertTrue(bundled_aggregations.get("current_user_participated")) - # The latest thread event has some fields that don't matter. - self.assert_dict( - { - "content": { - "m.relates_to": { - "event_id": self.parent_id, - "rel_type": RelationTypes.THREAD, - } + # This needs two assertion functions which are identical except for whether + # the current_user_participated flag is True, create a factory for the + # two versions. + def _gen_assert(participated: bool) -> Callable[[JsonDict], None]: + def assert_thread(bundled_aggregations: JsonDict) -> None: + self.assertEqual(2, bundled_aggregations.get("count")) + self.assertEqual( + participated, bundled_aggregations.get("current_user_participated") + ) + # The latest thread event has some fields that don't matter. + self.assert_dict( + { + "content": { + "m.relates_to": { + "event_id": self.parent_id, + "rel_type": RelationTypes.THREAD, + } + }, + "event_id": thread_2, + "sender": self.user2_id, + "type": "m.room.test", }, - "event_id": thread_2, - "sender": self.user_id, - "type": "m.room.test", - }, - bundled_aggregations.get("latest_event"), - ) + bundled_aggregations.get("latest_event"), + ) - self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 9) + return assert_thread + + # The "user" sent the root event and is making queries for the bundled + # aggregations: they have participated. + self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 8) + # The "user2" sent replies in the thread and is making queries for the + # bundled aggregations: they have participated. + # + # Note that this re-uses some cached values, so the total number of + # queries is much smaller. + self._test_bundled_aggregations( + RelationTypes.THREAD, _gen_assert(True), 2, access_token=self.user2_token + ) + + # A user with no interactions with the thread: they have not participated. + user3_id, user3_token = self._create_user("charlie") + self.helper.join(self.room, user=user3_id, tok=user3_token) + self._test_bundled_aggregations( + RelationTypes.THREAD, _gen_assert(False), 2, access_token=user3_token + ) def test_thread_with_bundled_aggregations_for_latest(self) -> None: """ @@ -1106,7 +1139,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): bundled_aggregations["latest_event"].get("unsigned"), ) - self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 9) + self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 8) def test_nested_thread(self) -> None: """ |