diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 53576ad52f..907af10995 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -20,7 +20,7 @@ import attr
from synapse.api.constants import RelationTypes
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import LoggingTransaction
+from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause
from synapse.storage.databases.main.stream import generate_pagination_where_clause
from synapse.storage.relations import (
AggregationPaginationToken,
@@ -334,6 +334,62 @@ class RelationsWorkerStore(SQLBaseStore):
return count, latest_event
+ async def events_have_relations(
+ self,
+ parent_ids: List[str],
+ relation_senders: Optional[List[str]],
+ relation_types: Optional[List[str]],
+ ) -> List[str]:
+ """Check which events have a relationship from the given senders of the
+ given types.
+
+ Args:
+ parent_ids: The events being annotated
+ relation_senders: The relation senders to check.
+ relation_types: The relation types to check.
+
+ Returns:
+ True if the event has at least one relationship from one of the given senders of the given type.
+ """
+ # If no restrictions are given then the event has the required relations.
+ if not relation_senders and not relation_types:
+ return parent_ids
+
+ sql = """
+ SELECT relates_to_id FROM event_relations
+ INNER JOIN events USING (event_id)
+ WHERE
+ %s;
+ """
+
+ def _get_if_event_has_relations(txn) -> List[str]:
+ clauses: List[str] = []
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "relates_to_id", parent_ids
+ )
+ clauses.append(clause)
+
+ if relation_senders:
+ clause, temp_args = make_in_list_sql_clause(
+ txn.database_engine, "sender", relation_senders
+ )
+ clauses.append(clause)
+ args.extend(temp_args)
+ if relation_types:
+ clause, temp_args = make_in_list_sql_clause(
+ txn.database_engine, "relation_type", relation_types
+ )
+ clauses.append(clause)
+ args.extend(temp_args)
+
+ txn.execute(sql % " AND ".join(clauses), args)
+
+ return [row[0] for row in txn]
+
+ return await self.db_pool.runInteraction(
+ "get_if_event_has_relations", _get_if_event_has_relations
+ )
+
async def has_user_annotated_event(
self, parent_id: str, event_type: str, aggregation_key: str, sender: str
) -> bool:
|