summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/14508.feature1
-rw-r--r--synapse/handlers/relations.py128
-rw-r--r--synapse/storage/databases/main/cache.py1
-rw-r--r--synapse/storage/databases/main/events.py4
-rw-r--r--synapse/storage/databases/main/relations.py74
-rw-r--r--tests/rest/client/test_relations.py4
6 files changed, 133 insertions, 79 deletions
diff --git a/changelog.d/14508.feature b/changelog.d/14508.feature
new file mode 100644
index 0000000000..4fca7282f7
--- /dev/null
+++ b/changelog.d/14508.feature
@@ -0,0 +1 @@
+Reduce database load of [Client-Server endpoints](https://spec.matrix.org/v1.4/client-server-api/#aggregations) which return bundled aggregations.
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index ca94239f61..8414be5879 100644
--- a/synapse/handlers/relations.py
+++ b/synapse/handlers/relations.py
@@ -13,16 +13,7 @@
 # limitations under the License.
 import enum
 import logging
-from typing import (
-    TYPE_CHECKING,
-    Collection,
-    Dict,
-    FrozenSet,
-    Iterable,
-    List,
-    Optional,
-    Tuple,
-)
+from typing import TYPE_CHECKING, Collection, Dict, FrozenSet, Iterable, List, Optional
 
 import attr
 
@@ -32,7 +23,7 @@ from synapse.events import EventBase, relation_from_event
 from synapse.logging.opentracing import trace
 from synapse.storage.databases.main.relations import ThreadsNextBatch, _RelatedEvent
 from synapse.streams.config import PaginationConfig
-from synapse.types import JsonDict, Requester, StreamToken, UserID
+from synapse.types import JsonDict, Requester, UserID
 from synapse.visibility import filter_events_for_client
 
 if TYPE_CHECKING:
@@ -181,40 +172,6 @@ class RelationsHandler:
 
         return return_value
 
-    async def get_relations_for_event(
-        self,
-        event_id: str,
-        event: EventBase,
-        room_id: str,
-        relation_type: str,
-        ignored_users: FrozenSet[str] = frozenset(),
-    ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]:
-        """Get a list of events which relate to an event, ordered by topological ordering.
-
-        Args:
-            event_id: Fetch events that relate to this event ID.
-            event: The matching EventBase to event_id.
-            room_id: The room the event belongs to.
-            relation_type: The type of relation.
-            ignored_users: The users ignored by the requesting user.
-
-        Returns:
-            List of event IDs that match relations requested. The rows are of
-            the form `{"event_id": "..."}`.
-        """
-
-        # Call the underlying storage method, which is cached.
-        related_events, next_token = await self._main_store.get_relations_for_event(
-            event_id, event, room_id, relation_type, direction="f"
-        )
-
-        # Filter out ignored users and convert to the expected format.
-        related_events = [
-            event for event in related_events if event.sender not in ignored_users
-        ]
-
-        return related_events, next_token
-
     async def redact_events_related_to(
         self,
         requester: Requester,
@@ -329,6 +286,46 @@ class RelationsHandler:
 
         return filtered_results
 
+    async def get_references_for_events(
+        self, event_ids: Collection[str], ignored_users: FrozenSet[str] = frozenset()
+    ) -> Dict[str, List[_RelatedEvent]]:
+        """Get a list of references to the given events.
+
+        Args:
+            event_ids: Fetch events that relate to this event ID.
+            ignored_users: The users ignored by the requesting user.
+
+        Returns:
+            A map of event IDs to a list related events.
+        """
+
+        related_events = await self._main_store.get_references_for_events(event_ids)
+
+        # Avoid additional logic if there are no ignored users.
+        if not ignored_users:
+            return {
+                event_id: results
+                for event_id, results in related_events.items()
+                if results
+            }
+
+        # Filter out ignored users.
+        results = {}
+        for event_id, events in related_events.items():
+            # If no references, skip.
+            if not events:
+                continue
+
+            # Filter ignored users out.
+            events = [event for event in events if event.sender not in ignored_users]
+            # If there are no events left, skip this event.
+            if not events:
+                continue
+
+            results[event_id] = events
+
+        return results
+
     async def _get_threads_for_events(
         self,
         events_by_id: Dict[str, EventBase],
@@ -412,14 +409,18 @@ class RelationsHandler:
                 if event is None:
                     continue
 
-                potential_events, _ = await self.get_relations_for_event(
-                    event_id,
-                    event,
-                    room_id,
-                    RelationTypes.THREAD,
-                    ignored_users,
+                # Attempt to find another event to use as the latest event.
+                potential_events, _ = await self._main_store.get_relations_for_event(
+                    event_id, event, room_id, RelationTypes.THREAD, direction="f"
                 )
 
+                # Filter out ignored users.
+                potential_events = [
+                    event
+                    for event in potential_events
+                    if event.sender not in ignored_users
+                ]
+
                 # If all found events are from ignored users, do not include
                 # a summary of the thread.
                 if not potential_events:
@@ -534,27 +535,16 @@ class RelationsHandler:
                     "chunk": annotations
                 }
 
-        # Fetch other relations per event.
-        for event in events_by_id.values():
-            # Fetch any references to bundle with this event.
-            references, next_token = await self.get_relations_for_event(
-                event.event_id,
-                event,
-                event.room_id,
-                RelationTypes.REFERENCE,
-                ignored_users=ignored_users,
-            )
+        # Fetch any references to bundle with this event.
+        references_by_event_id = await self.get_references_for_events(
+            events_by_id.keys(), ignored_users=ignored_users
+        )
+        for event_id, references in references_by_event_id.items():
             if references:
-                aggregations = results.setdefault(event.event_id, BundledAggregations())
-                aggregations.references = {
+                results.setdefault(event_id, BundledAggregations()).references = {
                     "chunk": [{"event_id": ev.event_id} for ev in references]
                 }
 
-                if next_token:
-                    aggregations.references["next_batch"] = await next_token.to_string(
-                        self._main_store
-                    )
-
         # Fetch any edits (but not for redacted events).
         #
         # Note that there is no use in limiting edits by ignored users since the
@@ -600,7 +590,7 @@ class RelationsHandler:
             room_id, requester, allow_departed_users=True
         )
 
-        # Note that ignored users are not passed into get_relations_for_event
+        # Note that ignored users are not passed into get_threads
         # below. Ignored users are handled in filter_events_for_client (and by
         # not passing them in here we should get a better cache hit rate).
         thread_roots, next_batch = await self._main_store.get_threads(
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index ddb7397714..a58668a380 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -259,6 +259,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
 
         if relates_to:
             self._attempt_to_invalidate_cache("get_relations_for_event", (relates_to,))
+            self._attempt_to_invalidate_cache("get_references_for_event", (relates_to,))
             self._attempt_to_invalidate_cache(
                 "get_aggregation_groups_for_event", (relates_to,)
             )
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index d68f127f9b..0f097a2927 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -2049,6 +2049,10 @@ class PersistEventsStore:
             self.store._invalidate_cache_and_stream(
                 txn, self.store.get_aggregation_groups_for_event, (redacted_relates_to,)
             )
+        if rel_type == RelationTypes.REFERENCE:
+            self.store._invalidate_cache_and_stream(
+                txn, self.store.get_references_for_event, (redacted_relates_to,)
+            )
         if rel_type == RelationTypes.REPLACE:
             self.store._invalidate_cache_and_stream(
                 txn, self.store.get_applicable_edit, (redacted_relates_to,)
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index f96a16956a..aea96e9d24 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -82,8 +82,6 @@ class _RelatedEvent:
     event_id: str
     # The sender of the related event.
     sender: str
-    topological_ordering: Optional[int]
-    stream_ordering: int
 
 
 class RelationsWorkerStore(SQLBaseStore):
@@ -246,13 +244,17 @@ class RelationsWorkerStore(SQLBaseStore):
             txn.execute(sql, where_args + [limit + 1])
 
             events = []
-            for event_id, relation_type, sender, topo_ordering, stream_ordering in txn:
+            topo_orderings: List[int] = []
+            stream_orderings: List[int] = []
+            for event_id, relation_type, sender, topo_ordering, stream_ordering in cast(
+                List[Tuple[str, str, str, int, int]], txn
+            ):
                 # Do not include edits for redacted events as they leak event
                 # content.
                 if not is_redacted or relation_type != RelationTypes.REPLACE:
-                    events.append(
-                        _RelatedEvent(event_id, sender, topo_ordering, stream_ordering)
-                    )
+                    events.append(_RelatedEvent(event_id, sender))
+                    topo_orderings.append(topo_ordering)
+                    stream_orderings.append(stream_ordering)
 
             # If there are more events, generate the next pagination key from the
             # last event returned.
@@ -261,9 +263,11 @@ class RelationsWorkerStore(SQLBaseStore):
                 # Instead of using the last row (which tells us there is more
                 # data), use the last row to be returned.
                 events = events[:limit]
+                topo_orderings = topo_orderings[:limit]
+                stream_orderings = stream_orderings[:limit]
 
-                topo = events[-1].topological_ordering
-                token = events[-1].stream_ordering
+                topo = topo_orderings[-1]
+                token = stream_orderings[-1]
                 if direction == "b":
                     # Tokens are positions between events.
                     # This token points *after* the last event in the chunk.
@@ -531,6 +535,60 @@ class RelationsWorkerStore(SQLBaseStore):
         )
 
     @cached()
+    async def get_references_for_event(self, event_id: str) -> List[JsonDict]:
+        raise NotImplementedError()
+
+    @cachedList(cached_method_name="get_references_for_event", list_name="event_ids")
+    async def get_references_for_events(
+        self, event_ids: Collection[str]
+    ) -> Mapping[str, Optional[List[_RelatedEvent]]]:
+        """Get a list of references to the given events.
+
+        Args:
+            event_ids: Fetch events that relate to these event IDs.
+
+        Returns:
+            A map of event IDs to a list of related event IDs (and their senders).
+        """
+
+        clause, args = make_in_list_sql_clause(
+            self.database_engine, "relates_to_id", event_ids
+        )
+        args.append(RelationTypes.REFERENCE)
+
+        sql = f"""
+            SELECT relates_to_id, ref.event_id, ref.sender
+            FROM events AS ref
+            INNER JOIN event_relations USING (event_id)
+            INNER JOIN events AS parent ON
+                parent.event_id = relates_to_id
+                AND parent.room_id = ref.room_id
+            WHERE
+                {clause}
+                AND relation_type = ?
+            ORDER BY ref.topological_ordering, ref.stream_ordering
+        """
+
+        def _get_references_for_events_txn(
+            txn: LoggingTransaction,
+        ) -> Mapping[str, List[_RelatedEvent]]:
+            txn.execute(sql, args)
+
+            result: Dict[str, List[_RelatedEvent]] = {}
+            for relates_to_id, event_id, sender in cast(
+                List[Tuple[str, str, str]], txn
+            ):
+                result.setdefault(relates_to_id, []).append(
+                    _RelatedEvent(event_id, sender)
+                )
+
+            return result
+
+        return await self.db_pool.runInteraction(
+            "_get_references_for_events_txn", _get_references_for_events_txn
+        )
+
+    @cached()
     def get_applicable_edit(self, event_id: str) -> Optional[EventBase]:
         raise NotImplementedError()
 
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index 2d2b683548..b86f341ff5 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -1108,7 +1108,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
 
         # The "user" sent the root event and is making queries for the bundled
         # aggregations: they have participated.
-        self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 8)
+        self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 7)
         # The "user2" sent replies in the thread and is making queries for the
         # bundled aggregations: they have participated.
         #
@@ -1170,7 +1170,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
                 bundled_aggregations["latest_event"].get("unsigned"),
             )
 
-        self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 8)
+        self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 7)
 
     def test_nested_thread(self) -> None:
         """