diff options
-rw-r--r-- | changelog.d/12235.bugfix | 1 | ||||
-rw-r--r-- | changelog.d/12338.bugfix | 1 | ||||
-rw-r--r-- | changelog.d/12338.misc | 1 | ||||
-rw-r--r-- | synapse/handlers/relations.py | 244 | ||||
-rw-r--r-- | synapse/storage/databases/main/relations.py | 152 | ||||
-rw-r--r-- | tests/rest/client/test_relations.py | 73 |
6 files changed, 427 insertions, 45 deletions
diff --git a/changelog.d/12235.bugfix b/changelog.d/12235.bugfix new file mode 100644 index 0000000000..b5d2bede67 --- /dev/null +++ b/changelog.d/12235.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where events from ignored users were still considered for bundled aggregations. diff --git a/changelog.d/12338.bugfix b/changelog.d/12338.bugfix new file mode 100644 index 0000000000..b5d2bede67 --- /dev/null +++ b/changelog.d/12338.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where events from ignored users were still considered for bundled aggregations. diff --git a/changelog.d/12338.misc b/changelog.d/12338.misc deleted file mode 100644 index 376089f327..0000000000 --- a/changelog.d/12338.misc +++ /dev/null @@ -1 +0,0 @@ -Refactor relations code to remove an unnecessary class. diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index a36936b520..0be2319577 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -12,7 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, Iterable, Optional +from typing import ( + TYPE_CHECKING, + Collection, + Dict, + FrozenSet, + Iterable, + List, + Optional, + Tuple, +) import attr from frozendict import frozendict @@ -20,7 +29,8 @@ from frozendict import frozendict from synapse.api.constants import RelationTypes from synapse.api.errors import SynapseError from synapse.events import EventBase -from synapse.types import JsonDict, Requester, StreamToken +from synapse.storage.databases.main.relations import _RelatedEvent +from synapse.types import JsonDict, Requester, StreamToken, UserID from synapse.visibility import filter_events_for_client if TYPE_CHECKING: @@ -115,6 +125,9 @@ class RelationsHandler: if event is None: raise SynapseError(404, "Unknown parent event.") + # Note that ignored users are not passed into get_relations_for_event + # below. Ignored users are handled in filter_events_for_client (and by + # not passing them in here we should get a better cache hit rate). related_events, next_token = await self._main_store.get_relations_for_event( event_id=event_id, event=event, @@ -128,7 +141,9 @@ class RelationsHandler: to_token=to_token, ) - events = await self._main_store.get_events_as_list(related_events) + events = await self._main_store.get_events_as_list( + [e.event_id for e in related_events] + ) events = await filter_events_for_client( self._storage, user_id, events, is_peeking=(member_event_id is None) @@ -162,8 +177,87 @@ class RelationsHandler: return return_value + async def get_relations_for_event( + self, + event_id: str, + event: EventBase, + room_id: str, + relation_type: str, + ignored_users: FrozenSet[str] = frozenset(), + ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]: + """Get a list of events which relate to an event, ordered by topological ordering. + + Args: + event_id: Fetch events that relate to this event ID. + event: The matching EventBase to event_id. + room_id: The room the event belongs to. + relation_type: The type of relation. + ignored_users: The users ignored by the requesting user. + + Returns: + List of event IDs that match relations requested. The rows are of + the form `{"event_id": "..."}`. + """ + + # Call the underlying storage method, which is cached. + related_events, next_token = await self._main_store.get_relations_for_event( + event_id, event, room_id, relation_type, direction="f" + ) + + # Filter out ignored users and convert to the expected format. + related_events = [ + event for event in related_events if event.sender not in ignored_users + ] + + return related_events, next_token + + async def get_annotations_for_event( + self, + event_id: str, + room_id: str, + limit: int = 5, + ignored_users: FrozenSet[str] = frozenset(), + ) -> List[JsonDict]: + """Get a list of annotations on the event, grouped by event type and + aggregation key, sorted by count. + + This is used e.g. to get the what and how many reactions have happend + on an event. + + Args: + event_id: Fetch events that relate to this event ID. + room_id: The room the event belongs to. + limit: Only fetch the `limit` groups. + ignored_users: The users ignored by the requesting user. + + Returns: + List of groups of annotations that match. Each row is a dict with + `type`, `key` and `count` fields. + """ + # Get the base results for all users. + full_results = await self._main_store.get_aggregation_groups_for_event( + event_id, room_id, limit + ) + + # Then subtract off the results for any ignored users. + ignored_results = await self._main_store.get_aggregation_groups_for_users( + event_id, room_id, limit, ignored_users + ) + + filtered_results = [] + for result in full_results: + key = (result["type"], result["key"]) + if key in ignored_results: + result = result.copy() + result["count"] -= ignored_results[key] + if result["count"] <= 0: + continue + filtered_results.append(result) + + return filtered_results + async def _get_bundled_aggregation_for_event( - self, event: EventBase, user_id: str + self, event: EventBase, ignored_users: FrozenSet[str] ) -> Optional[BundledAggregations]: """Generate bundled aggregations for an event. @@ -171,7 +265,7 @@ class RelationsHandler: Args: event: The event to calculate bundled aggregations for. - user_id: The user requesting the bundled aggregations. + ignored_users: The users ignored by the requesting user. Returns: The bundled aggregations for an event, if bundled aggregations are @@ -194,18 +288,22 @@ class RelationsHandler: # while others need more processing during serialization. aggregations = BundledAggregations() - annotations = await self._main_store.get_aggregation_groups_for_event( - event_id, room_id + annotations = await self.get_annotations_for_event( + event_id, room_id, ignored_users=ignored_users ) if annotations: aggregations.annotations = {"chunk": annotations} - references, next_token = await self._main_store.get_relations_for_event( - event_id, event, room_id, RelationTypes.REFERENCE, direction="f" + references, next_token = await self.get_relations_for_event( + event_id, + event, + room_id, + RelationTypes.REFERENCE, + ignored_users=ignored_users, ) if references: aggregations.references = { - "chunk": [{"event_id": event_id} for event_id in references] + "chunk": [{"event_id": event.event_id} for event in references] } if next_token: @@ -216,6 +314,99 @@ class RelationsHandler: # Store the bundled aggregations in the event metadata for later use. return aggregations + async def get_threads_for_events( + self, event_ids: Collection[str], user_id: str, ignored_users: FrozenSet[str] + ) -> Dict[str, _ThreadAggregation]: + """Get the bundled aggregations for threads for the requested events. + + Args: + event_ids: Events to get aggregations for threads. + user_id: The user requesting the bundled aggregations. + ignored_users: The users ignored by the requesting user. + + Returns: + A dictionary mapping event ID to the thread information. + + May not contain a value for all requested event IDs. + """ + user = UserID.from_string(user_id) + + # Fetch thread summaries. + summaries = await self._main_store.get_thread_summaries(event_ids) + + # Only fetch participated for a limited selection based on what had + # summaries. + thread_event_ids = [ + event_id for event_id, summary in summaries.items() if summary + ] + participated = await self._main_store.get_threads_participated( + thread_event_ids, user_id + ) + + # Then subtract off the results for any ignored users. + ignored_results = await self._main_store.get_threaded_messages_per_user( + thread_event_ids, ignored_users + ) + + # A map of event ID to the thread aggregation. + results = {} + + for event_id, summary in summaries.items(): + if summary: + thread_count, latest_thread_event, edit = summary + + # Subtract off the count of any ignored users. + for ignored_user in ignored_users: + thread_count -= ignored_results.get((event_id, ignored_user), 0) + + # This is gnarly, but if the latest event is from an ignored user, + # attempt to find one that isn't from an ignored user. + if latest_thread_event.sender in ignored_users: + room_id = latest_thread_event.room_id + + # If the root event is not found, something went wrong, do + # not include a summary of the thread. + event = await self._event_handler.get_event(user, room_id, event_id) + if event is None: + continue + + potential_events, _ = await self.get_relations_for_event( + event_id, + event, + room_id, + RelationTypes.THREAD, + ignored_users, + ) + + # If all found events are from ignored users, do not include + # a summary of the thread. + if not potential_events: + continue + + # The *last* event returned is the one that is cared about. + event = await self._event_handler.get_event( + user, room_id, potential_events[-1].event_id + ) + # It is unexpected that the event will not exist. + if event is None: + logger.warning( + "Unable to fetch latest event in a thread with event ID: %s", + potential_events[-1].event_id, + ) + continue + latest_thread_event = event + + results[event_id] = _ThreadAggregation( + latest_event=latest_thread_event, + latest_edit=edit, + count=thread_count, + # If there's a thread summary it must also exist in the + # participated dictionary. + current_user_participated=participated[event_id], + ) + + return results + async def get_bundled_aggregations( self, events: Iterable[EventBase], user_id: str ) -> Dict[str, BundledAggregations]: @@ -239,13 +430,21 @@ class RelationsHandler: # event ID -> bundled aggregation in non-serialized form. results: Dict[str, BundledAggregations] = {} + # Fetch any ignored users of the requesting user. + ignored_users = await self._main_store.ignored_users(user_id) + # Fetch other relations per event. for event in events_by_id.values(): - event_result = await self._get_bundled_aggregation_for_event(event, user_id) + event_result = await self._get_bundled_aggregation_for_event( + event, ignored_users + ) if event_result: results[event.event_id] = event_result # Fetch any edits (but not for redacted events). + # + # Note that there is no use in limiting edits by ignored users since the + # parent event should be ignored in the first place if the user is ignored. edits = await self._main_store.get_applicable_edits( [ event_id @@ -256,25 +455,10 @@ class RelationsHandler: for event_id, edit in edits.items(): results.setdefault(event_id, BundledAggregations()).replace = edit - # Fetch thread summaries. - summaries = await self._main_store.get_thread_summaries(events_by_id.keys()) - # Only fetch participated for a limited selection based on what had - # summaries. - participated = await self._main_store.get_threads_participated( - [event_id for event_id, summary in summaries.items() if summary], user_id + threads = await self.get_threads_for_events( + events_by_id.keys(), user_id, ignored_users ) - for event_id, summary in summaries.items(): - if summary: - thread_count, latest_thread_event, edit = summary - results.setdefault( - event_id, BundledAggregations() - ).thread = _ThreadAggregation( - latest_event=latest_thread_event, - latest_edit=edit, - count=thread_count, - # If there's a thread summary it must also exist in the - # participated dictionary. - current_user_participated=participated[event_id], - ) + for event_id, thread in threads.items(): + results.setdefault(event_id, BundledAggregations()).thread = thread return results diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 64a7808140..db929ef523 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -17,6 +17,7 @@ from typing import ( TYPE_CHECKING, Collection, Dict, + FrozenSet, Iterable, List, Optional, @@ -26,6 +27,8 @@ from typing import ( cast, ) +import attr + from synapse.api.constants import RelationTypes from synapse.events import EventBase from synapse.storage._base import SQLBaseStore @@ -46,6 +49,19 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _RelatedEvent: + """ + Contains enough information about a related event in order to properly filter + events from ignored users. + """ + + # The event ID of the related event. + event_id: str + # The sender of the related event. + sender: str + + class RelationsWorkerStore(SQLBaseStore): def __init__( self, @@ -70,7 +86,7 @@ class RelationsWorkerStore(SQLBaseStore): direction: str = "b", from_token: Optional[StreamToken] = None, to_token: Optional[StreamToken] = None, - ) -> Tuple[List[str], Optional[StreamToken]]: + ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]: """Get a list of relations for an event, ordered by topological ordering. Args: @@ -88,7 +104,7 @@ class RelationsWorkerStore(SQLBaseStore): Returns: A tuple of: - A list of related event IDs + A list of related event IDs & their senders. The next stream token, if one exists. """ @@ -131,7 +147,7 @@ class RelationsWorkerStore(SQLBaseStore): order = "ASC" sql = """ - SELECT event_id, relation_type, topological_ordering, stream_ordering + SELECT event_id, relation_type, sender, topological_ordering, stream_ordering FROM event_relations INNER JOIN events USING (event_id) WHERE %s @@ -145,7 +161,7 @@ class RelationsWorkerStore(SQLBaseStore): def _get_recent_references_for_event_txn( txn: LoggingTransaction, - ) -> Tuple[List[str], Optional[StreamToken]]: + ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]: txn.execute(sql, where_args + [limit + 1]) last_topo_id = None @@ -155,9 +171,9 @@ class RelationsWorkerStore(SQLBaseStore): # Do not include edits for redacted events as they leak event # content. if not is_redacted or row[1] != RelationTypes.REPLACE: - events.append(row[0]) - last_topo_id = row[2] - last_stream_id = row[3] + events.append(_RelatedEvent(row[0], row[2])) + last_topo_id = row[3] + last_stream_id = row[4] # If there are more events, generate the next pagination key. next_token = None @@ -267,7 +283,7 @@ class RelationsWorkerStore(SQLBaseStore): `type`, `key` and `count` fields. """ - where_args = [ + args = [ event_id, room_id, RelationTypes.ANNOTATION, @@ -287,7 +303,7 @@ class RelationsWorkerStore(SQLBaseStore): def _get_aggregation_groups_for_event_txn( txn: LoggingTransaction, ) -> List[JsonDict]: - txn.execute(sql, where_args) + txn.execute(sql, args) return [{"type": row[0], "key": row[1], "count": row[2]} for row in txn] @@ -295,6 +311,63 @@ class RelationsWorkerStore(SQLBaseStore): "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn ) + async def get_aggregation_groups_for_users( + self, + event_id: str, + room_id: str, + limit: int, + users: FrozenSet[str] = frozenset(), + ) -> Dict[Tuple[str, str], int]: + """Fetch the partial aggregations for an event for specific users. + + This is used, in conjunction with get_aggregation_groups_for_event, to + remove information from the results for ignored users. + + Args: + event_id: Fetch events that relate to this event ID. + room_id: The room the event belongs to. + limit: Only fetch the `limit` groups. + users: The users to fetch information for. + + Returns: + A map of (event type, aggregation key) to a count of users. + """ + + if not users: + return {} + + args: List[Union[str, int]] = [ + event_id, + room_id, + RelationTypes.ANNOTATION, + ] + + users_sql, users_args = make_in_list_sql_clause( + self.database_engine, "sender", users + ) + args.extend(users_args) + + sql = f""" + SELECT type, aggregation_key, COUNT(DISTINCT sender) + FROM event_relations + INNER JOIN events USING (event_id) + WHERE relates_to_id = ? AND room_id = ? AND relation_type = ? AND {users_sql} + GROUP BY relation_type, type, aggregation_key + ORDER BY COUNT(*) DESC + LIMIT ? + """ + + def _get_aggregation_groups_for_users_txn( + txn: LoggingTransaction, + ) -> Dict[Tuple[str, str], int]: + txn.execute(sql, args + [limit]) + + return {(row[0], row[1]): row[2] for row in txn} + + return await self.db_pool.runInteraction( + "get_aggregation_groups_for_users", _get_aggregation_groups_for_users_txn + ) + @cached() def get_applicable_edit(self, event_id: str) -> Optional[EventBase]: raise NotImplementedError() @@ -521,6 +594,67 @@ class RelationsWorkerStore(SQLBaseStore): return summaries + async def get_threaded_messages_per_user( + self, + event_ids: Collection[str], + users: FrozenSet[str] = frozenset(), + ) -> Dict[Tuple[str, str], int]: + """Get the number of threaded replies for a set of users. + + This is used, in conjunction with get_thread_summaries, to calculate an + accurate count of the replies to a thread by subtracting ignored users. + + Args: + event_ids: The events to check for threaded replies. + users: The user to calculate the count of their replies. + + Returns: + A map of the (event_id, sender) to the count of their replies. + """ + if not users: + return {} + + # Fetch the number of threaded replies. + sql = """ + SELECT parent.event_id, child.sender, COUNT(child.event_id) FROM events AS child + INNER JOIN event_relations USING (event_id) + INNER JOIN events AS parent ON + parent.event_id = relates_to_id + AND parent.room_id = child.room_id + WHERE + %s + AND %s + AND %s + GROUP BY parent.event_id, child.sender + """ + + def _get_threaded_messages_per_user_txn( + txn: LoggingTransaction, + ) -> Dict[Tuple[str, str], int]: + users_sql, users_args = make_in_list_sql_clause( + self.database_engine, "child.sender", users + ) + events_clause, events_args = make_in_list_sql_clause( + txn.database_engine, "relates_to_id", event_ids + ) + + if self._msc3440_enabled: + relations_clause = "(relation_type = ? OR relation_type = ?)" + relations_args = [RelationTypes.THREAD, RelationTypes.UNSTABLE_THREAD] + else: + relations_clause = "relation_type = ?" + relations_args = [RelationTypes.THREAD] + + txn.execute( + sql % (users_sql, events_clause, relations_clause), + users_args + events_args + relations_args, + ) + return {(row[0], row[1]): row[2] for row in txn} + + return await self.db_pool.runInteraction( + "get_threaded_messages_per_user", _get_threaded_messages_per_user_txn + ) + @cached() def get_thread_participated(self, event_id: str, user_id: str) -> bool: raise NotImplementedError() diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 2f2ec3a685..6fabada8b3 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -1137,16 +1137,27 @@ class RelationIgnoredUserTestCase(BaseRelationsTestCase): """Relations sent from an ignored user should be ignored.""" def _test_ignored_user( - self, allowed_event_ids: List[str], ignored_event_ids: List[str] - ) -> None: + self, + relation_type: str, + allowed_event_ids: List[str], + ignored_event_ids: List[str], + ) -> Tuple[JsonDict, JsonDict]: """ Fetch the relations and ensure they're all there, then ignore user2, and repeat. + + Returns: + A tuple of two JSON dictionaries, each are bundled aggregations, the + first is from before the user is ignored, and the second is after. """ # Get the relations. event_ids = self._get_related_events() self.assertCountEqual(event_ids, allowed_event_ids + ignored_event_ids) + # And the bundled aggregations. + before_aggregations = self._get_bundled_aggregations() + self.assertIn(relation_type, before_aggregations) + # Ignore user2 and re-do the requests. self.get_success( self.store.add_account_data_for_user( @@ -1160,6 +1171,12 @@ class RelationIgnoredUserTestCase(BaseRelationsTestCase): event_ids = self._get_related_events() self.assertCountEqual(event_ids, allowed_event_ids) + # And the bundled aggregations. + after_aggregations = self._get_bundled_aggregations() + self.assertIn(relation_type, after_aggregations) + + return before_aggregations[relation_type], after_aggregations[relation_type] + def test_annotation(self) -> None: """Annotations should ignore""" # Send 2 from us, 2 from the to be ignored user. @@ -1184,7 +1201,26 @@ class RelationIgnoredUserTestCase(BaseRelationsTestCase): ) ignored_event_ids.append(channel.json_body["event_id"]) - self._test_ignored_user(allowed_event_ids, ignored_event_ids) + before_aggregations, after_aggregations = self._test_ignored_user( + RelationTypes.ANNOTATION, allowed_event_ids, ignored_event_ids + ) + + self.assertCountEqual( + before_aggregations["chunk"], + [ + {"type": "m.reaction", "key": "a", "count": 2}, + {"type": "m.reaction", "key": "b", "count": 1}, + {"type": "m.reaction", "key": "c", "count": 1}, + ], + ) + + self.assertCountEqual( + after_aggregations["chunk"], + [ + {"type": "m.reaction", "key": "a", "count": 1}, + {"type": "m.reaction", "key": "b", "count": 1}, + ], + ) def test_reference(self) -> None: """Annotations should ignore""" @@ -1196,7 +1232,18 @@ class RelationIgnoredUserTestCase(BaseRelationsTestCase): ) ignored_event_ids = [channel.json_body["event_id"]] - self._test_ignored_user(allowed_event_ids, ignored_event_ids) + before_aggregations, after_aggregations = self._test_ignored_user( + RelationTypes.REFERENCE, allowed_event_ids, ignored_event_ids + ) + + self.assertCountEqual( + [e["event_id"] for e in before_aggregations["chunk"]], + allowed_event_ids + ignored_event_ids, + ) + + self.assertCountEqual( + [e["event_id"] for e in after_aggregations["chunk"]], allowed_event_ids + ) def test_thread(self) -> None: """Annotations should ignore""" @@ -1208,7 +1255,23 @@ class RelationIgnoredUserTestCase(BaseRelationsTestCase): ) ignored_event_ids = [channel.json_body["event_id"]] - self._test_ignored_user(allowed_event_ids, ignored_event_ids) + before_aggregations, after_aggregations = self._test_ignored_user( + RelationTypes.THREAD, allowed_event_ids, ignored_event_ids + ) + + self.assertEqual(before_aggregations["count"], 2) + self.assertTrue(before_aggregations["current_user_participated"]) + # The latest thread event has some fields that don't matter. + self.assertEqual( + before_aggregations["latest_event"]["event_id"], ignored_event_ids[0] + ) + + self.assertEqual(after_aggregations["count"], 1) + self.assertTrue(after_aggregations["current_user_participated"]) + # The latest thread event has some fields that don't matter. + self.assertEqual( + after_aggregations["latest_event"]["event_id"], allowed_event_ids[0] + ) class RelationRedactionTestCase(BaseRelationsTestCase): |