summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/cache.py1
-rw-r--r--synapse/storage/databases/main/events.py4
-rw-r--r--synapse/storage/databases/main/relations.py74
3 files changed, 71 insertions, 8 deletions
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index ddb7397714..a58668a380 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -259,6 +259,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
 
         if relates_to:
             self._attempt_to_invalidate_cache("get_relations_for_event", (relates_to,))
+            self._attempt_to_invalidate_cache("get_references_for_event", (relates_to,))
             self._attempt_to_invalidate_cache(
                 "get_aggregation_groups_for_event", (relates_to,)
             )
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index d68f127f9b..0f097a2927 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -2049,6 +2049,10 @@ class PersistEventsStore:
             self.store._invalidate_cache_and_stream(
                 txn, self.store.get_aggregation_groups_for_event, (redacted_relates_to,)
             )
+        if rel_type == RelationTypes.REFERENCE:
+            self.store._invalidate_cache_and_stream(
+                txn, self.store.get_references_for_event, (redacted_relates_to,)
+            )
         if rel_type == RelationTypes.REPLACE:
             self.store._invalidate_cache_and_stream(
                 txn, self.store.get_applicable_edit, (redacted_relates_to,)
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index f96a16956a..aea96e9d24 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -82,8 +82,6 @@ class _RelatedEvent:
     event_id: str
     # The sender of the related event.
     sender: str
-    topological_ordering: Optional[int]
-    stream_ordering: int
 
 
 class RelationsWorkerStore(SQLBaseStore):
@@ -246,13 +244,17 @@ class RelationsWorkerStore(SQLBaseStore):
             txn.execute(sql, where_args + [limit + 1])
 
             events = []
-            for event_id, relation_type, sender, topo_ordering, stream_ordering in txn:
+            topo_orderings: List[int] = []
+            stream_orderings: List[int] = []
+            for event_id, relation_type, sender, topo_ordering, stream_ordering in cast(
+                List[Tuple[str, str, str, int, int]], txn
+            ):
                 # Do not include edits for redacted events as they leak event
                 # content.
                 if not is_redacted or relation_type != RelationTypes.REPLACE:
-                    events.append(
-                        _RelatedEvent(event_id, sender, topo_ordering, stream_ordering)
-                    )
+                    events.append(_RelatedEvent(event_id, sender))
+                    topo_orderings.append(topo_ordering)
+                    stream_orderings.append(stream_ordering)
 
             # If there are more events, generate the next pagination key from the
             # last event returned.
@@ -261,9 +263,11 @@ class RelationsWorkerStore(SQLBaseStore):
                 # Instead of using the last row (which tells us there is more
                 # data), use the last row to be returned.
                 events = events[:limit]
+                topo_orderings = topo_orderings[:limit]
+                stream_orderings = stream_orderings[:limit]
 
-                topo = events[-1].topological_ordering
-                token = events[-1].stream_ordering
+                topo = topo_orderings[-1]
+                token = stream_orderings[-1]
                 if direction == "b":
                     # Tokens are positions between events.
                     # This token points *after* the last event in the chunk.
@@ -531,6 +535,60 @@ class RelationsWorkerStore(SQLBaseStore):
         )
 
     @cached()
+    async def get_references_for_event(self, event_id: str) -> List[JsonDict]:
+        raise NotImplementedError()
+
+    @cachedList(cached_method_name="get_references_for_event", list_name="event_ids")
+    async def get_references_for_events(
+        self, event_ids: Collection[str]
+    ) -> Mapping[str, Optional[List[_RelatedEvent]]]:
+        """Get a list of references to the given events.
+
+        Args:
+            event_ids: Fetch events that relate to these event IDs.
+
+        Returns:
+            A map of event IDs to a list of related event IDs (and their senders).
+        """
+
+        clause, args = make_in_list_sql_clause(
+            self.database_engine, "relates_to_id", event_ids
+        )
+        args.append(RelationTypes.REFERENCE)
+
+        sql = f"""
+            SELECT relates_to_id, ref.event_id, ref.sender
+            FROM events AS ref
+            INNER JOIN event_relations USING (event_id)
+            INNER JOIN events AS parent ON
+                parent.event_id = relates_to_id
+                AND parent.room_id = ref.room_id
+            WHERE
+                {clause}
+                AND relation_type = ?
+            ORDER BY ref.topological_ordering, ref.stream_ordering
+        """
+
+        def _get_references_for_events_txn(
+            txn: LoggingTransaction,
+        ) -> Mapping[str, List[_RelatedEvent]]:
+            txn.execute(sql, args)
+
+            result: Dict[str, List[_RelatedEvent]] = {}
+            for relates_to_id, event_id, sender in cast(
+                List[Tuple[str, str, str]], txn
+            ):
+                result.setdefault(relates_to_id, []).append(
+                    _RelatedEvent(event_id, sender)
+                )
+
+            return result
+
+        return await self.db_pool.runInteraction(
+            "_get_references_for_events_txn", _get_references_for_events_txn
+        )
+
+    @cached()
     def get_applicable_edit(self, event_id: str) -> Optional[EventBase]:
         raise NotImplementedError()