diff options
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r-- | synapse/storage/databases/main/relations.py | 152 |
1 files changed, 143 insertions, 9 deletions
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 64a7808140..db929ef523 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -17,6 +17,7 @@ from typing import ( TYPE_CHECKING, Collection, Dict, + FrozenSet, Iterable, List, Optional, @@ -26,6 +27,8 @@ from typing import ( cast, ) +import attr + from synapse.api.constants import RelationTypes from synapse.events import EventBase from synapse.storage._base import SQLBaseStore @@ -46,6 +49,19 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _RelatedEvent: + """ + Contains enough information about a related event in order to properly filter + events from ignored users. + """ + + # The event ID of the related event. + event_id: str + # The sender of the related event. + sender: str + + class RelationsWorkerStore(SQLBaseStore): def __init__( self, @@ -70,7 +86,7 @@ class RelationsWorkerStore(SQLBaseStore): direction: str = "b", from_token: Optional[StreamToken] = None, to_token: Optional[StreamToken] = None, - ) -> Tuple[List[str], Optional[StreamToken]]: + ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]: """Get a list of relations for an event, ordered by topological ordering. Args: @@ -88,7 +104,7 @@ class RelationsWorkerStore(SQLBaseStore): Returns: A tuple of: - A list of related event IDs + A list of related event IDs & their senders. The next stream token, if one exists. """ @@ -131,7 +147,7 @@ class RelationsWorkerStore(SQLBaseStore): order = "ASC" sql = """ - SELECT event_id, relation_type, topological_ordering, stream_ordering + SELECT event_id, relation_type, sender, topological_ordering, stream_ordering FROM event_relations INNER JOIN events USING (event_id) WHERE %s @@ -145,7 +161,7 @@ class RelationsWorkerStore(SQLBaseStore): def _get_recent_references_for_event_txn( txn: LoggingTransaction, - ) -> Tuple[List[str], Optional[StreamToken]]: + ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]: txn.execute(sql, where_args + [limit + 1]) last_topo_id = None @@ -155,9 +171,9 @@ class RelationsWorkerStore(SQLBaseStore): # Do not include edits for redacted events as they leak event # content. if not is_redacted or row[1] != RelationTypes.REPLACE: - events.append(row[0]) - last_topo_id = row[2] - last_stream_id = row[3] + events.append(_RelatedEvent(row[0], row[2])) + last_topo_id = row[3] + last_stream_id = row[4] # If there are more events, generate the next pagination key. next_token = None @@ -267,7 +283,7 @@ class RelationsWorkerStore(SQLBaseStore): `type`, `key` and `count` fields. """ - where_args = [ + args = [ event_id, room_id, RelationTypes.ANNOTATION, @@ -287,7 +303,7 @@ class RelationsWorkerStore(SQLBaseStore): def _get_aggregation_groups_for_event_txn( txn: LoggingTransaction, ) -> List[JsonDict]: - txn.execute(sql, where_args) + txn.execute(sql, args) return [{"type": row[0], "key": row[1], "count": row[2]} for row in txn] @@ -295,6 +311,63 @@ class RelationsWorkerStore(SQLBaseStore): "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn ) + async def get_aggregation_groups_for_users( + self, + event_id: str, + room_id: str, + limit: int, + users: FrozenSet[str] = frozenset(), + ) -> Dict[Tuple[str, str], int]: + """Fetch the partial aggregations for an event for specific users. + + This is used, in conjunction with get_aggregation_groups_for_event, to + remove information from the results for ignored users. + + Args: + event_id: Fetch events that relate to this event ID. + room_id: The room the event belongs to. + limit: Only fetch the `limit` groups. + users: The users to fetch information for. + + Returns: + A map of (event type, aggregation key) to a count of users. + """ + + if not users: + return {} + + args: List[Union[str, int]] = [ + event_id, + room_id, + RelationTypes.ANNOTATION, + ] + + users_sql, users_args = make_in_list_sql_clause( + self.database_engine, "sender", users + ) + args.extend(users_args) + + sql = f""" + SELECT type, aggregation_key, COUNT(DISTINCT sender) + FROM event_relations + INNER JOIN events USING (event_id) + WHERE relates_to_id = ? AND room_id = ? AND relation_type = ? AND {users_sql} + GROUP BY relation_type, type, aggregation_key + ORDER BY COUNT(*) DESC + LIMIT ? + """ + + def _get_aggregation_groups_for_users_txn( + txn: LoggingTransaction, + ) -> Dict[Tuple[str, str], int]: + txn.execute(sql, args + [limit]) + + return {(row[0], row[1]): row[2] for row in txn} + + return await self.db_pool.runInteraction( + "get_aggregation_groups_for_users", _get_aggregation_groups_for_users_txn + ) + @cached() def get_applicable_edit(self, event_id: str) -> Optional[EventBase]: raise NotImplementedError() @@ -521,6 +594,67 @@ class RelationsWorkerStore(SQLBaseStore): return summaries + async def get_threaded_messages_per_user( + self, + event_ids: Collection[str], + users: FrozenSet[str] = frozenset(), + ) -> Dict[Tuple[str, str], int]: + """Get the number of threaded replies for a set of users. + + This is used, in conjunction with get_thread_summaries, to calculate an + accurate count of the replies to a thread by subtracting ignored users. + + Args: + event_ids: The events to check for threaded replies. + users: The user to calculate the count of their replies. + + Returns: + A map of the (event_id, sender) to the count of their replies. + """ + if not users: + return {} + + # Fetch the number of threaded replies. + sql = """ + SELECT parent.event_id, child.sender, COUNT(child.event_id) FROM events AS child + INNER JOIN event_relations USING (event_id) + INNER JOIN events AS parent ON + parent.event_id = relates_to_id + AND parent.room_id = child.room_id + WHERE + %s + AND %s + AND %s + GROUP BY parent.event_id, child.sender + """ + + def _get_threaded_messages_per_user_txn( + txn: LoggingTransaction, + ) -> Dict[Tuple[str, str], int]: + users_sql, users_args = make_in_list_sql_clause( + self.database_engine, "child.sender", users + ) + events_clause, events_args = make_in_list_sql_clause( + txn.database_engine, "relates_to_id", event_ids + ) + + if self._msc3440_enabled: + relations_clause = "(relation_type = ? OR relation_type = ?)" + relations_args = [RelationTypes.THREAD, RelationTypes.UNSTABLE_THREAD] + else: + relations_clause = "relation_type = ?" + relations_args = [RelationTypes.THREAD] + + txn.execute( + sql % (users_sql, events_clause, relations_clause), + users_args + events_args + relations_args, + ) + return {(row[0], row[1]): row[2] for row in txn} + + return await self.db_pool.runInteraction( + "get_threaded_messages_per_user", _get_threaded_messages_per_user_txn + ) + @cached() def get_thread_participated(self, event_id: str, user_id: str) -> bool: raise NotImplementedError() |