diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 907af10995..0a43acda07 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -132,6 +132,69 @@ class RelationsWorkerStore(SQLBaseStore):
"get_recent_references_for_event", _get_recent_references_for_event_txn
)
+ async def event_includes_relation(self, event_id: str) -> bool:
+ """Check if the given event relates to another event.
+
+ An event has a relation if it has a valid m.relates_to with a rel_type
+ and event_id in the content:
+
+ {
+ "content": {
+ "m.relates_to": {
+ "rel_type": "m.replace",
+ "event_id": "$other_event_id"
+ }
+ }
+ }
+
+ Args:
+ event_id: The event to check.
+
+ Returns:
+ True if the event includes a valid relation.
+ """
+
+ result = await self.db_pool.simple_select_one_onecol(
+ table="event_relations",
+ keyvalues={"event_id": event_id},
+ retcol="event_id",
+ allow_none=True,
+ desc="event_includes_relation",
+ )
+ return result is not None
+
+ async def event_is_target_of_relation(self, parent_id: str) -> bool:
+ """Check if the given event is the target of another event's relation.
+
+ An event is the target of an event relation if it has a valid
+ m.relates_to with a rel_type and event_id pointing to parent_id in the
+ content:
+
+ {
+ "content": {
+ "m.relates_to": {
+ "rel_type": "m.replace",
+ "event_id": "$parent_id"
+ }
+ }
+ }
+
+ Args:
+ parent_id: The event to check.
+
+ Returns:
+ True if the event is the target of another event's relation.
+ """
+
+ result = await self.db_pool.simple_select_one_onecol(
+ table="event_relations",
+ keyvalues={"relates_to_id": parent_id},
+ retcol="event_id",
+ allow_none=True,
+ desc="event_is_target_of_relation",
+ )
+ return result is not None
+
@cached(tree=True)
async def get_aggregation_groups_for_event(
self,
@@ -362,7 +425,7 @@ class RelationsWorkerStore(SQLBaseStore):
%s;
"""
- def _get_if_event_has_relations(txn) -> List[str]:
+ def _get_if_events_have_relations(txn) -> List[str]:
clauses: List[str] = []
clause, args = make_in_list_sql_clause(
txn.database_engine, "relates_to_id", parent_ids
@@ -387,7 +450,7 @@ class RelationsWorkerStore(SQLBaseStore):
return [row[0] for row in txn]
return await self.db_pool.runInteraction(
- "get_if_event_has_relations", _get_if_event_has_relations
+ "get_if_events_have_relations", _get_if_events_have_relations
)
async def has_user_annotated_event(
|