diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 64a7808140..db929ef523 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -17,6 +17,7 @@ from typing import (
TYPE_CHECKING,
Collection,
Dict,
+ FrozenSet,
Iterable,
List,
Optional,
@@ -26,6 +27,8 @@ from typing import (
cast,
)
+import attr
+
from synapse.api.constants import RelationTypes
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore
@@ -46,6 +49,19 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _RelatedEvent:
+ """
+ Contains enough information about a related event in order to properly filter
+ events from ignored users.
+ """
+
+ # The event ID of the related event.
+ event_id: str
+ # The sender of the related event.
+ sender: str
+
+
class RelationsWorkerStore(SQLBaseStore):
def __init__(
self,
@@ -70,7 +86,7 @@ class RelationsWorkerStore(SQLBaseStore):
direction: str = "b",
from_token: Optional[StreamToken] = None,
to_token: Optional[StreamToken] = None,
- ) -> Tuple[List[str], Optional[StreamToken]]:
+ ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]:
"""Get a list of relations for an event, ordered by topological ordering.
Args:
@@ -88,7 +104,7 @@ class RelationsWorkerStore(SQLBaseStore):
Returns:
A tuple of:
- A list of related event IDs
+ A list of related event IDs & their senders.
The next stream token, if one exists.
"""
@@ -131,7 +147,7 @@ class RelationsWorkerStore(SQLBaseStore):
order = "ASC"
sql = """
- SELECT event_id, relation_type, topological_ordering, stream_ordering
+ SELECT event_id, relation_type, sender, topological_ordering, stream_ordering
FROM event_relations
INNER JOIN events USING (event_id)
WHERE %s
@@ -145,7 +161,7 @@ class RelationsWorkerStore(SQLBaseStore):
def _get_recent_references_for_event_txn(
txn: LoggingTransaction,
- ) -> Tuple[List[str], Optional[StreamToken]]:
+ ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]:
txn.execute(sql, where_args + [limit + 1])
last_topo_id = None
@@ -155,9 +171,9 @@ class RelationsWorkerStore(SQLBaseStore):
# Do not include edits for redacted events as they leak event
# content.
if not is_redacted or row[1] != RelationTypes.REPLACE:
- events.append(row[0])
- last_topo_id = row[2]
- last_stream_id = row[3]
+ events.append(_RelatedEvent(row[0], row[2]))
+ last_topo_id = row[3]
+ last_stream_id = row[4]
# If there are more events, generate the next pagination key.
next_token = None
@@ -267,7 +283,7 @@ class RelationsWorkerStore(SQLBaseStore):
`type`, `key` and `count` fields.
"""
- where_args = [
+ args = [
event_id,
room_id,
RelationTypes.ANNOTATION,
@@ -287,7 +303,7 @@ class RelationsWorkerStore(SQLBaseStore):
def _get_aggregation_groups_for_event_txn(
txn: LoggingTransaction,
) -> List[JsonDict]:
- txn.execute(sql, where_args)
+ txn.execute(sql, args)
return [{"type": row[0], "key": row[1], "count": row[2]} for row in txn]
@@ -295,6 +311,63 @@ class RelationsWorkerStore(SQLBaseStore):
"get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
)
+ async def get_aggregation_groups_for_users(
+ self,
+ event_id: str,
+ room_id: str,
+ limit: int,
+ users: FrozenSet[str] = frozenset(),
+ ) -> Dict[Tuple[str, str], int]:
+ """Fetch the partial aggregations for an event for specific users.
+
+ This is used, in conjunction with get_aggregation_groups_for_event, to
+ remove information from the results for ignored users.
+
+ Args:
+ event_id: Fetch events that relate to this event ID.
+ room_id: The room the event belongs to.
+ limit: Only fetch the `limit` groups.
+ users: The users to fetch information for.
+
+ Returns:
+ A map of (event type, aggregation key) to a count of users.
+ """
+
+ if not users:
+ return {}
+
+ args: List[Union[str, int]] = [
+ event_id,
+ room_id,
+ RelationTypes.ANNOTATION,
+ ]
+
+ users_sql, users_args = make_in_list_sql_clause(
+ self.database_engine, "sender", users
+ )
+ args.extend(users_args)
+
+ sql = f"""
+ SELECT type, aggregation_key, COUNT(DISTINCT sender)
+ FROM event_relations
+ INNER JOIN events USING (event_id)
+ WHERE relates_to_id = ? AND room_id = ? AND relation_type = ? AND {users_sql}
+ GROUP BY relation_type, type, aggregation_key
+ ORDER BY COUNT(*) DESC
+ LIMIT ?
+ """
+
+ def _get_aggregation_groups_for_users_txn(
+ txn: LoggingTransaction,
+ ) -> Dict[Tuple[str, str], int]:
+ txn.execute(sql, args + [limit])
+
+ return {(row[0], row[1]): row[2] for row in txn}
+
+ return await self.db_pool.runInteraction(
+ "get_aggregation_groups_for_users", _get_aggregation_groups_for_users_txn
+ )
+
@cached()
def get_applicable_edit(self, event_id: str) -> Optional[EventBase]:
raise NotImplementedError()
@@ -521,6 +594,67 @@ class RelationsWorkerStore(SQLBaseStore):
return summaries
+ async def get_threaded_messages_per_user(
+ self,
+ event_ids: Collection[str],
+ users: FrozenSet[str] = frozenset(),
+ ) -> Dict[Tuple[str, str], int]:
+ """Get the number of threaded replies for a set of users.
+
+ This is used, in conjunction with get_thread_summaries, to calculate an
+ accurate count of the replies to a thread by subtracting ignored users.
+
+ Args:
+ event_ids: The events to check for threaded replies.
+ users: The user to calculate the count of their replies.
+
+ Returns:
+ A map of the (event_id, sender) to the count of their replies.
+ """
+ if not users:
+ return {}
+
+ # Fetch the number of threaded replies.
+ sql = """
+ SELECT parent.event_id, child.sender, COUNT(child.event_id) FROM events AS child
+ INNER JOIN event_relations USING (event_id)
+ INNER JOIN events AS parent ON
+ parent.event_id = relates_to_id
+ AND parent.room_id = child.room_id
+ WHERE
+ %s
+ AND %s
+ AND %s
+ GROUP BY parent.event_id, child.sender
+ """
+
+ def _get_threaded_messages_per_user_txn(
+ txn: LoggingTransaction,
+ ) -> Dict[Tuple[str, str], int]:
+ users_sql, users_args = make_in_list_sql_clause(
+ self.database_engine, "child.sender", users
+ )
+ events_clause, events_args = make_in_list_sql_clause(
+ txn.database_engine, "relates_to_id", event_ids
+ )
+
+ if self._msc3440_enabled:
+ relations_clause = "(relation_type = ? OR relation_type = ?)"
+ relations_args = [RelationTypes.THREAD, RelationTypes.UNSTABLE_THREAD]
+ else:
+ relations_clause = "relation_type = ?"
+ relations_args = [RelationTypes.THREAD]
+
+ txn.execute(
+ sql % (users_sql, events_clause, relations_clause),
+ users_args + events_args + relations_args,
+ )
+ return {(row[0], row[1]): row[2] for row in txn}
+
+ return await self.db_pool.runInteraction(
+ "get_threaded_messages_per_user", _get_threaded_messages_per_user_txn
+ )
+
@cached()
def get_thread_participated(self, event_id: str, user_id: str) -> bool:
raise NotImplementedError()
|