diff --git a/changelog.d/11660.misc b/changelog.d/11660.misc
new file mode 100644
index 0000000000..47e085e4d9
--- /dev/null
+++ b/changelog.d/11660.misc
@@ -0,0 +1 @@
+Improve performance when fetching bundled aggregations for multiple events.
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index b804185c40..2e44c77715 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1801,9 +1801,7 @@ class PersistEventsStore:
)
if rel_type == RelationTypes.REPLACE:
- txn.call_after(
- self.store.get_applicable_edit.invalidate, (parent_id, event.room_id)
- )
+ txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,))
if rel_type == RelationTypes.THREAD:
txn.call_after(
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 37468a5183..6180b17296 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -13,12 +13,22 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union, cast
+from typing import (
+ TYPE_CHECKING,
+ Collection,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Tuple,
+ Union,
+ cast,
+)
import attr
from frozendict import frozendict
-from synapse.api.constants import EventTypes, RelationTypes
+from synapse.api.constants import RelationTypes
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
@@ -28,13 +38,14 @@ from synapse.storage.database import (
make_in_list_sql_clause,
)
from synapse.storage.databases.main.stream import generate_pagination_where_clause
+from synapse.storage.engines import PostgresEngine
from synapse.storage.relations import (
AggregationPaginationToken,
PaginationChunk,
RelationPaginationToken,
)
from synapse.types import JsonDict
-from synapse.util.caches.descriptors import cached
+from synapse.util.caches.descriptors import cached, cachedList
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -340,20 +351,24 @@ class RelationsWorkerStore(SQLBaseStore):
)
@cached()
- async def get_applicable_edit(
- self, event_id: str, room_id: str
- ) -> Optional[EventBase]:
+ def get_applicable_edit(self, event_id: str) -> Optional[EventBase]:
+ raise NotImplementedError()
+
+ @cachedList(cached_method_name="get_applicable_edit", list_name="event_ids")
+ async def _get_applicable_edits(
+ self, event_ids: Collection[str]
+ ) -> Dict[str, Optional[EventBase]]:
"""Get the most recent edit (if any) that has happened for the given
- event.
+ events.
Correctly handles checking whether edits were allowed to happen.
Args:
- event_id: The original event ID
- room_id: The original event's room ID
+ event_ids: The original event IDs
Returns:
- The most recent edit, if any.
+ A map of the most recent edit for each event. If there are no edits,
+ the event will map to None.
"""
# We only allow edits for `m.room.message` events that have the same sender
@@ -362,37 +377,67 @@ class RelationsWorkerStore(SQLBaseStore):
# Fetches latest edit that has the same type and sender as the
# original, and is an `m.room.message`.
- sql = """
- SELECT edit.event_id FROM events AS edit
- INNER JOIN event_relations USING (event_id)
- INNER JOIN events AS original ON
- original.event_id = relates_to_id
- AND edit.type = original.type
- AND edit.sender = original.sender
- WHERE
- relates_to_id = ?
- AND relation_type = ?
- AND edit.room_id = ?
- AND edit.type = 'm.room.message'
- ORDER by edit.origin_server_ts DESC, edit.event_id DESC
- LIMIT 1
- """
+ if isinstance(self.database_engine, PostgresEngine):
+ # The `DISTINCT ON` clause will pick the *first* row it encounters,
+ # so ordering by origin server ts + event ID desc will ensure we get
+ # the latest edit.
+ sql = """
+ SELECT DISTINCT ON (original.event_id) original.event_id, edit.event_id FROM events AS edit
+ INNER JOIN event_relations USING (event_id)
+ INNER JOIN events AS original ON
+ original.event_id = relates_to_id
+ AND edit.type = original.type
+ AND edit.sender = original.sender
+ AND edit.room_id = original.room_id
+ WHERE
+ %s
+ AND relation_type = ?
+ AND edit.type = 'm.room.message'
+ ORDER by original.event_id DESC, edit.origin_server_ts DESC, edit.event_id DESC
+ """
+ else:
+ # SQLite uses a simplified query which returns all edits for an
+ # original event. The results are then de-duplicated when turned into
+ # a dict. Due to the chosen ordering, the latest edit stomps on
+ # earlier edits.
+ sql = """
+ SELECT original.event_id, edit.event_id FROM events AS edit
+ INNER JOIN event_relations USING (event_id)
+ INNER JOIN events AS original ON
+ original.event_id = relates_to_id
+ AND edit.type = original.type
+ AND edit.sender = original.sender
+ AND edit.room_id = original.room_id
+ WHERE
+ %s
+ AND relation_type = ?
+ AND edit.type = 'm.room.message'
+ ORDER by edit.origin_server_ts, edit.event_id
+ """
- def _get_applicable_edit_txn(txn: LoggingTransaction) -> Optional[str]:
- txn.execute(sql, (event_id, RelationTypes.REPLACE, room_id))
- row = txn.fetchone()
- if row:
- return row[0]
- return None
+ def _get_applicable_edits_txn(txn: LoggingTransaction) -> Dict[str, str]:
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "relates_to_id", event_ids
+ )
+ args.append(RelationTypes.REPLACE)
- edit_id = await self.db_pool.runInteraction(
- "get_applicable_edit", _get_applicable_edit_txn
+ txn.execute(sql % (clause,), args)
+ return dict(cast(Iterable[Tuple[str, str]], txn.fetchall()))
+
+ edit_ids = await self.db_pool.runInteraction(
+ "get_applicable_edits", _get_applicable_edits_txn
)
- if not edit_id:
- return None
+ edits = await self.get_events(edit_ids.values()) # type: ignore[attr-defined]
- return await self.get_event(edit_id, allow_none=True) # type: ignore[attr-defined]
+ # Map to the original event IDs to the edit events.
+ #
+ # There might not be an edit event due to there being no edits or
+ # due to the event not being known, either case is treated the same.
+ return {
+ original_event_id: edits.get(edit_ids.get(original_event_id))
+ for original_event_id in event_ids
+ }
@cached()
async def get_thread_summary(
@@ -612,9 +657,6 @@ class RelationsWorkerStore(SQLBaseStore):
The bundled aggregations for an event, if bundled aggregations are
enabled and the event can have bundled aggregations.
"""
- # State events and redacted events do not get bundled aggregations.
- if event.is_state() or event.internal_metadata.is_redacted():
- return None
# Do not bundle aggregations for an event which represents an edit or an
# annotation. It does not make sense for them to have related events.
@@ -642,13 +684,6 @@ class RelationsWorkerStore(SQLBaseStore):
if references.chunk:
aggregations.references = references.to_dict()
- edit = None
- if event.type == EventTypes.Message:
- edit = await self.get_applicable_edit(event_id, room_id)
-
- if edit:
- aggregations.replace = edit
-
# If this event is the start of a thread, include a summary of the replies.
if self._msc3440_enabled:
thread_count, latest_thread_event = await self.get_thread_summary(
@@ -668,9 +703,7 @@ class RelationsWorkerStore(SQLBaseStore):
return aggregations
async def get_bundled_aggregations(
- self,
- events: Iterable[EventBase],
- user_id: str,
+ self, events: Iterable[EventBase], user_id: str
) -> Dict[str, BundledAggregations]:
"""Generate bundled aggregations for events.
@@ -683,13 +716,28 @@ class RelationsWorkerStore(SQLBaseStore):
events may have bundled aggregations in the results.
"""
- # TODO Parallelize.
- results = {}
+ # State events and redacted events do not get bundled aggregations.
+ events = [
+ event
+ for event in events
+ if not event.is_state() and not event.internal_metadata.is_redacted()
+ ]
+
+ # event ID -> bundled aggregation in non-serialized form.
+ results: Dict[str, BundledAggregations] = {}
+
+ # Fetch other relations per event.
for event in events:
event_result = await self._get_bundled_aggregation_for_event(event, user_id)
if event_result:
results[event.event_id] = event_result
+ # Fetch any edits.
+ event_ids = [event.event_id for event in events]
+ edits = await self._get_applicable_edits(event_ids)
+ for event_id, edit in edits.items():
+ results.setdefault(event_id, BundledAggregations()).replace = edit
+
return results
|