summary refs log tree commit diff
path: root/synapse/storage/databases/main/relations.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/relations.py')
-rw-r--r--synapse/storage/databases/main/relations.py150
1 files changed, 99 insertions, 51 deletions
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 37468a5183..6180b17296 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -13,12 +13,22 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union, cast
+from typing import (
+    TYPE_CHECKING,
+    Collection,
+    Dict,
+    Iterable,
+    List,
+    Optional,
+    Tuple,
+    Union,
+    cast,
+)
 
 import attr
 from frozendict import frozendict
 
-from synapse.api.constants import EventTypes, RelationTypes
+from synapse.api.constants import RelationTypes
 from synapse.events import EventBase
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import (
@@ -28,13 +38,14 @@ from synapse.storage.database import (
     make_in_list_sql_clause,
 )
 from synapse.storage.databases.main.stream import generate_pagination_where_clause
+from synapse.storage.engines import PostgresEngine
 from synapse.storage.relations import (
     AggregationPaginationToken,
     PaginationChunk,
     RelationPaginationToken,
 )
 from synapse.types import JsonDict
-from synapse.util.caches.descriptors import cached
+from synapse.util.caches.descriptors import cached, cachedList
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -340,20 +351,24 @@ class RelationsWorkerStore(SQLBaseStore):
         )
 
     @cached()
-    async def get_applicable_edit(
-        self, event_id: str, room_id: str
-    ) -> Optional[EventBase]:
+    def get_applicable_edit(self, event_id: str) -> Optional[EventBase]:
+        raise NotImplementedError()
+
+    @cachedList(cached_method_name="get_applicable_edit", list_name="event_ids")
+    async def _get_applicable_edits(
+        self, event_ids: Collection[str]
+    ) -> Dict[str, Optional[EventBase]]:
         """Get the most recent edit (if any) that has happened for the given
-        event.
+        events.
 
         Correctly handles checking whether edits were allowed to happen.
 
         Args:
-            event_id: The original event ID
-            room_id: The original event's room ID
+            event_ids: The original event IDs
 
         Returns:
-            The most recent edit, if any.
+            A map of the most recent edit for each event. If there are no edits,
+            the event will map to None.
         """
 
         # We only allow edits for `m.room.message` events that have the same sender
@@ -362,37 +377,67 @@ class RelationsWorkerStore(SQLBaseStore):
 
         # Fetches latest edit that has the same type and sender as the
         # original, and is an `m.room.message`.
-        sql = """
-            SELECT edit.event_id FROM events AS edit
-            INNER JOIN event_relations USING (event_id)
-            INNER JOIN events AS original ON
-                original.event_id = relates_to_id
-                AND edit.type = original.type
-                AND edit.sender = original.sender
-            WHERE
-                relates_to_id = ?
-                AND relation_type = ?
-                AND edit.room_id = ?
-                AND edit.type = 'm.room.message'
-            ORDER by edit.origin_server_ts DESC, edit.event_id DESC
-            LIMIT 1
-        """
+        if isinstance(self.database_engine, PostgresEngine):
+            # The `DISTINCT ON` clause will pick the *first* row it encounters,
+            # so ordering by origin server ts + event ID desc will ensure we get
+            # the latest edit.
+            sql = """
+                SELECT DISTINCT ON (original.event_id) original.event_id, edit.event_id FROM events AS edit
+                INNER JOIN event_relations USING (event_id)
+                INNER JOIN events AS original ON
+                    original.event_id = relates_to_id
+                    AND edit.type = original.type
+                    AND edit.sender = original.sender
+                    AND edit.room_id = original.room_id
+                WHERE
+                    %s
+                    AND relation_type = ?
+                    AND edit.type = 'm.room.message'
+                ORDER by original.event_id DESC, edit.origin_server_ts DESC, edit.event_id DESC
+            """
+        else:
+            # SQLite uses a simplified query which returns all edits for an
+            # original event. The results are then de-duplicated when turned into
+            # a dict. Due to the chosen ordering, the latest edit stomps on
+            # earlier edits.
+            sql = """
+                SELECT original.event_id, edit.event_id FROM events AS edit
+                INNER JOIN event_relations USING (event_id)
+                INNER JOIN events AS original ON
+                    original.event_id = relates_to_id
+                    AND edit.type = original.type
+                    AND edit.sender = original.sender
+                    AND edit.room_id = original.room_id
+                WHERE
+                    %s
+                    AND relation_type = ?
+                    AND edit.type = 'm.room.message'
+                ORDER by edit.origin_server_ts, edit.event_id
+            """
 
-        def _get_applicable_edit_txn(txn: LoggingTransaction) -> Optional[str]:
-            txn.execute(sql, (event_id, RelationTypes.REPLACE, room_id))
-            row = txn.fetchone()
-            if row:
-                return row[0]
-            return None
+        def _get_applicable_edits_txn(txn: LoggingTransaction) -> Dict[str, str]:
+            clause, args = make_in_list_sql_clause(
+                txn.database_engine, "relates_to_id", event_ids
+            )
+            args.append(RelationTypes.REPLACE)
 
-        edit_id = await self.db_pool.runInteraction(
-            "get_applicable_edit", _get_applicable_edit_txn
+            txn.execute(sql % (clause,), args)
+            return dict(cast(Iterable[Tuple[str, str]], txn.fetchall()))
+
+        edit_ids = await self.db_pool.runInteraction(
+            "get_applicable_edits", _get_applicable_edits_txn
         )
 
-        if not edit_id:
-            return None
+        edits = await self.get_events(edit_ids.values())  # type: ignore[attr-defined]
 
-        return await self.get_event(edit_id, allow_none=True)  # type: ignore[attr-defined]
+        # Map to the original event IDs to the edit events.
+        #
+        # There might not be an edit event due to there being no edits or
+        # due to the event not being known, either case is treated the same.
+        return {
+            original_event_id: edits.get(edit_ids.get(original_event_id))
+            for original_event_id in event_ids
+        }
 
     @cached()
     async def get_thread_summary(
@@ -612,9 +657,6 @@ class RelationsWorkerStore(SQLBaseStore):
             The bundled aggregations for an event, if bundled aggregations are
             enabled and the event can have bundled aggregations.
         """
-        # State events and redacted events do not get bundled aggregations.
-        if event.is_state() or event.internal_metadata.is_redacted():
-            return None
 
         # Do not bundle aggregations for an event which represents an edit or an
         # annotation. It does not make sense for them to have related events.
@@ -642,13 +684,6 @@ class RelationsWorkerStore(SQLBaseStore):
         if references.chunk:
             aggregations.references = references.to_dict()
 
-        edit = None
-        if event.type == EventTypes.Message:
-            edit = await self.get_applicable_edit(event_id, room_id)
-
-        if edit:
-            aggregations.replace = edit
-
         # If this event is the start of a thread, include a summary of the replies.
         if self._msc3440_enabled:
             thread_count, latest_thread_event = await self.get_thread_summary(
@@ -668,9 +703,7 @@ class RelationsWorkerStore(SQLBaseStore):
         return aggregations
 
     async def get_bundled_aggregations(
-        self,
-        events: Iterable[EventBase],
-        user_id: str,
+        self, events: Iterable[EventBase], user_id: str
     ) -> Dict[str, BundledAggregations]:
         """Generate bundled aggregations for events.
 
@@ -683,13 +716,28 @@ class RelationsWorkerStore(SQLBaseStore):
             events may have bundled aggregations in the results.
         """
 
-        # TODO Parallelize.
-        results = {}
+        # State events and redacted events do not get bundled aggregations.
+        events = [
+            event
+            for event in events
+            if not event.is_state() and not event.internal_metadata.is_redacted()
+        ]
+
+        # event ID -> bundled aggregation in non-serialized form.
+        results: Dict[str, BundledAggregations] = {}
+
+        # Fetch other relations per event.
         for event in events:
             event_result = await self._get_bundled_aggregation_for_event(event, user_id)
             if event_result:
                 results[event.event_id] = event_result
 
+        # Fetch any edits.
+        event_ids = [event.event_id for event in events]
+        edits = await self._get_applicable_edits(event_ids)
+        for event_id, edit in edits.items():
+            results.setdefault(event_id, BundledAggregations()).replace = edit
+
         return results