diff --git a/changelog.d/11752.misc b/changelog.d/11752.misc
new file mode 100644
index 0000000000..47e085e4d9
--- /dev/null
+++ b/changelog.d/11752.misc
@@ -0,0 +1 @@
+Improve performance when fetching bundled aggregations for multiple events.
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 2e44c77715..5246fccad5 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1812,7 +1812,7 @@ class PersistEventsStore:
# potentially error-prone) so it is always invalidated.
txn.call_after(
self.store.get_thread_participated.invalidate,
- (parent_id, event.room_id, event.sender),
+ (parent_id, event.sender),
)
def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index ad79cc5610..e2c27e594b 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -20,6 +20,7 @@ from typing import (
Iterable,
List,
Optional,
+ Set,
Tuple,
Union,
cast,
@@ -454,106 +455,175 @@ class RelationsWorkerStore(SQLBaseStore):
}
@cached()
- async def get_thread_summary(
- self, event_id: str, room_id: str
- ) -> Tuple[int, Optional[EventBase]]:
+ def get_thread_summary(self, event_id: str) -> Optional[Tuple[int, EventBase]]:
+ raise NotImplementedError()
+
+ @cachedList(cached_method_name="get_thread_summary", list_name="event_ids")
+ async def _get_thread_summaries(
+ self, event_ids: Collection[str]
+ ) -> Dict[str, Optional[Tuple[int, EventBase]]]:
"""Get the number of threaded replies and the latest reply (if any) for the given event.
Args:
- event_id: Summarize the thread related to this event ID.
- room_id: The room the event belongs to.
+ event_ids: Summarize the thread related to this event ID.
Returns:
- The number of items in the thread and the most recent response, if any.
+ A map of the thread summary each event. A missing event implies there
+ are no threaded replies.
+
+ Each summary includes the number of items in the thread and the most
+ recent response.
"""
- def _get_thread_summary_txn(
+ def _get_thread_summaries_txn(
txn: LoggingTransaction,
- ) -> Tuple[int, Optional[str]]:
- # Fetch the latest event ID in the thread.
+ ) -> Tuple[Dict[str, int], Dict[str, str]]:
+ # Fetch the count of threaded events and the latest event ID.
# TODO Should this only allow m.room.message events.
- sql = """
- SELECT event_id
- FROM event_relations
- INNER JOIN events USING (event_id)
- WHERE
- relates_to_id = ?
- AND room_id = ?
- AND relation_type = ?
- ORDER BY topological_ordering DESC, stream_ordering DESC
- LIMIT 1
- """
+ if isinstance(self.database_engine, PostgresEngine):
+ # The `DISTINCT ON` clause will pick the *first* row it encounters,
+ # so ordering by topologica ordering + stream ordering desc will
+ # ensure we get the latest event in the thread.
+ sql = """
+ SELECT DISTINCT ON (parent.event_id) parent.event_id, child.event_id FROM events AS child
+ INNER JOIN event_relations USING (event_id)
+ INNER JOIN events AS parent ON
+ parent.event_id = relates_to_id
+ AND parent.room_id = child.room_id
+ WHERE
+ %s
+ AND relation_type = ?
+ ORDER BY parent.event_id, child.topological_ordering DESC, child.stream_ordering DESC
+ """
+ else:
+ # SQLite uses a simplified query which returns all entries for a
+ # thread. The first result for each thread is chosen to and subsequent
+ # results for a thread are ignored.
+ sql = """
+ SELECT parent.event_id, child.event_id FROM events AS child
+ INNER JOIN event_relations USING (event_id)
+ INNER JOIN events AS parent ON
+ parent.event_id = relates_to_id
+ AND parent.room_id = child.room_id
+ WHERE
+ %s
+ AND relation_type = ?
+ ORDER BY child.topological_ordering DESC, child.stream_ordering DESC
+ """
+
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "relates_to_id", event_ids
+ )
+ args.append(RelationTypes.THREAD)
- txn.execute(sql, (event_id, room_id, RelationTypes.THREAD))
- row = txn.fetchone()
- if row is None:
- return 0, None
+ txn.execute(sql % (clause,), args)
+ latest_event_ids = {}
+ for parent_event_id, child_event_id in txn:
+ # Only consider the latest threaded reply (by topological ordering).
+ if parent_event_id not in latest_event_ids:
+ latest_event_ids[parent_event_id] = child_event_id
- latest_event_id = row[0]
+ # If no threads were found, bail.
+ if not latest_event_ids:
+ return {}, latest_event_ids
# Fetch the number of threaded replies.
sql = """
- SELECT COUNT(event_id)
- FROM event_relations
- INNER JOIN events USING (event_id)
+ SELECT parent.event_id, COUNT(child.event_id) FROM events AS child
+ INNER JOIN event_relations USING (event_id)
+ INNER JOIN events AS parent ON
+ parent.event_id = relates_to_id
+ AND parent.room_id = child.room_id
WHERE
- relates_to_id = ?
- AND room_id = ?
+ %s
AND relation_type = ?
+ GROUP BY parent.event_id
"""
- txn.execute(sql, (event_id, room_id, RelationTypes.THREAD))
- count = cast(Tuple[int], txn.fetchone())[0]
- return count, latest_event_id
+ # Regenerate the arguments since only threads found above could
+ # possibly have any replies.
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "relates_to_id", latest_event_ids.keys()
+ )
+ args.append(RelationTypes.THREAD)
+
+ txn.execute(sql % (clause,), args)
+ counts = dict(cast(List[Tuple[str, int]], txn.fetchall()))
- count, latest_event_id = await self.db_pool.runInteraction(
- "get_thread_summary", _get_thread_summary_txn
+ return counts, latest_event_ids
+
+ counts, latest_event_ids = await self.db_pool.runInteraction(
+ "get_thread_summaries", _get_thread_summaries_txn
)
- latest_event = None
- if latest_event_id:
- latest_event = await self.get_event(latest_event_id, allow_none=True) # type: ignore[attr-defined]
+ latest_events = await self.get_events(latest_event_ids.values()) # type: ignore[attr-defined]
+
+ # Map to the event IDs to the thread summary.
+ #
+ # There might not be a summary due to there not being a thread or
+ # due to the latest event not being known, either case is treated the same.
+ summaries = {}
+ for parent_event_id, latest_event_id in latest_event_ids.items():
+ latest_event = latest_events.get(latest_event_id)
+
+ summary = None
+ if latest_event:
+ summary = (counts[parent_event_id], latest_event)
+ summaries[parent_event_id] = summary
- return count, latest_event
+ return summaries
@cached()
- async def get_thread_participated(
- self, event_id: str, room_id: str, user_id: str
- ) -> bool:
- """Get whether the requesting user participated in a thread.
+ def get_thread_participated(self, event_id: str, user_id: str) -> bool:
+ raise NotImplementedError()
- This is separate from get_thread_summary since that can be cached across
- all users while this value is specific to the requeser.
+ @cachedList(cached_method_name="get_thread_participated", list_name="event_ids")
+ async def _get_threads_participated(
+ self, event_ids: Collection[str], user_id: str
+ ) -> Dict[str, bool]:
+ """Get whether the requesting user participated in the given threads.
+
+ This is separate from get_thread_summaries since that can be cached across
+ all users while this value is specific to the requester.
Args:
- event_id: The thread related to this event ID.
- room_id: The room the event belongs to.
+ event_ids: The thread related to these event IDs.
user_id: The user requesting the summary.
Returns:
- True if the requesting user participated in the thread, otherwise false.
+ A map of event ID to a boolean which represents if the requesting
+ user participated in that event's thread, otherwise false.
"""
- def _get_thread_summary_txn(txn: LoggingTransaction) -> bool:
+ def _get_thread_summary_txn(txn: LoggingTransaction) -> Set[str]:
# Fetch whether the requester has participated or not.
sql = """
- SELECT 1
- FROM event_relations
- INNER JOIN events USING (event_id)
+ SELECT DISTINCT relates_to_id
+ FROM events AS child
+ INNER JOIN event_relations USING (event_id)
+ INNER JOIN events AS parent ON
+ parent.event_id = relates_to_id
+ AND parent.room_id = child.room_id
WHERE
- relates_to_id = ?
- AND room_id = ?
+ %s
AND relation_type = ?
- AND sender = ?
+ AND child.sender = ?
"""
- txn.execute(sql, (event_id, room_id, RelationTypes.THREAD, user_id))
- return bool(txn.fetchone())
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "relates_to_id", event_ids
+ )
+ args.extend((RelationTypes.THREAD, user_id))
- return await self.db_pool.runInteraction(
+ txn.execute(sql % (clause,), args)
+ return {row[0] for row in txn.fetchall()}
+
+ participated_threads = await self.db_pool.runInteraction(
"get_thread_summary", _get_thread_summary_txn
)
+ return {event_id: event_id in participated_threads for event_id in event_ids}
+
async def events_have_relations(
self,
parent_ids: List[str],
@@ -700,21 +770,6 @@ class RelationsWorkerStore(SQLBaseStore):
if references.chunk:
aggregations.references = await references.to_dict(cast("DataStore", self))
- # If this event is the start of a thread, include a summary of the replies.
- if self._msc3440_enabled:
- thread_count, latest_thread_event = await self.get_thread_summary(
- event_id, room_id
- )
- participated = await self.get_thread_participated(
- event_id, room_id, user_id
- )
- if latest_thread_event:
- aggregations.thread = _ThreadAggregation(
- latest_event=latest_thread_event,
- count=thread_count,
- current_user_participated=participated,
- )
-
# Store the bundled aggregations in the event metadata for later use.
return aggregations
@@ -763,6 +818,27 @@ class RelationsWorkerStore(SQLBaseStore):
for event_id, edit in edits.items():
results.setdefault(event_id, BundledAggregations()).replace = edit
+ # Fetch thread summaries.
+ if self._msc3440_enabled:
+ summaries = await self._get_thread_summaries(seen_event_ids)
+ # Only fetch participated for a limited selection based on what had
+ # summaries.
+ participated = await self._get_threads_participated(
+ summaries.keys(), user_id
+ )
+ for event_id, summary in summaries.items():
+ if summary:
+ thread_count, latest_thread_event = summary
+ results.setdefault(
+ event_id, BundledAggregations()
+ ).thread = _ThreadAggregation(
+ latest_event=latest_thread_event,
+ count=thread_count,
+ # If there's a thread summary it must also exist in the
+ # participated dictionary.
+ current_user_participated=participated[event_id],
+ )
+
return results
|