diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index f96a16956a..aea96e9d24 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -82,8 +82,6 @@ class _RelatedEvent:
event_id: str
# The sender of the related event.
sender: str
- topological_ordering: Optional[int]
- stream_ordering: int
class RelationsWorkerStore(SQLBaseStore):
@@ -246,13 +244,17 @@ class RelationsWorkerStore(SQLBaseStore):
txn.execute(sql, where_args + [limit + 1])
events = []
- for event_id, relation_type, sender, topo_ordering, stream_ordering in txn:
+ topo_orderings: List[int] = []
+ stream_orderings: List[int] = []
+ for event_id, relation_type, sender, topo_ordering, stream_ordering in cast(
+ List[Tuple[str, str, str, int, int]], txn
+ ):
# Do not include edits for redacted events as they leak event
# content.
if not is_redacted or relation_type != RelationTypes.REPLACE:
- events.append(
- _RelatedEvent(event_id, sender, topo_ordering, stream_ordering)
- )
+ events.append(_RelatedEvent(event_id, sender))
+ topo_orderings.append(topo_ordering)
+ stream_orderings.append(stream_ordering)
# If there are more events, generate the next pagination key from the
# last event returned.
@@ -261,9 +263,11 @@ class RelationsWorkerStore(SQLBaseStore):
# Instead of using the last row (which tells us there is more
# data), use the last row to be returned.
events = events[:limit]
+ topo_orderings = topo_orderings[:limit]
+ stream_orderings = stream_orderings[:limit]
- topo = events[-1].topological_ordering
- token = events[-1].stream_ordering
+ topo = topo_orderings[-1]
+ token = stream_orderings[-1]
if direction == "b":
# Tokens are positions between events.
# This token points *after* the last event in the chunk.
@@ -531,6 +535,60 @@ class RelationsWorkerStore(SQLBaseStore):
)
@cached()
+ async def get_references_for_event(self, event_id: str) -> List[JsonDict]:
+ raise NotImplementedError()
+
+ @cachedList(cached_method_name="get_references_for_event", list_name="event_ids")
+ async def get_references_for_events(
+ self, event_ids: Collection[str]
+ ) -> Mapping[str, Optional[List[_RelatedEvent]]]:
+ """Get a list of references to the given events.
+
+ Args:
+ event_ids: Fetch events that relate to these event IDs.
+
+ Returns:
+ A map of event IDs to a list of related event IDs (and their senders).
+ """
+
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "relates_to_id", event_ids
+ )
+ args.append(RelationTypes.REFERENCE)
+
+ sql = f"""
+ SELECT relates_to_id, ref.event_id, ref.sender
+ FROM events AS ref
+ INNER JOIN event_relations USING (event_id)
+ INNER JOIN events AS parent ON
+ parent.event_id = relates_to_id
+ AND parent.room_id = ref.room_id
+ WHERE
+ {clause}
+ AND relation_type = ?
+ ORDER BY ref.topological_ordering, ref.stream_ordering
+ """
+
+ def _get_references_for_events_txn(
+ txn: LoggingTransaction,
+ ) -> Mapping[str, List[_RelatedEvent]]:
+ txn.execute(sql, args)
+
+ result: Dict[str, List[_RelatedEvent]] = {}
+ for relates_to_id, event_id, sender in cast(
+ List[Tuple[str, str, str]], txn
+ ):
+ result.setdefault(relates_to_id, []).append(
+ _RelatedEvent(event_id, sender)
+ )
+
+ return result
+
+ return await self.db_pool.runInteraction(
+ "_get_references_for_events_txn", _get_references_for_events_txn
+ )
+
+ @cached()
def get_applicable_edit(self, event_id: str) -> Optional[EventBase]:
raise NotImplementedError()
|