diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index d4c2a6ab7a..22dd4cf5fd 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -1001,13 +1001,52 @@ class EventCreationHandler:
)
self.validator.validate_new(event, self.config)
+ await self._validate_event_relation(event)
+ logger.debug("Created event %s", event.event_id)
+
+ return event, context
+
+ async def _validate_event_relation(self, event: EventBase) -> None:
+ """
+ Ensure the relation data on a new event is not bogus.
+
+ Args:
+ event: The event being created.
+
+ Raises:
+ SynapseError if the event is invalid.
+ """
+
+ relation = event.content.get("m.relates_to")
+ if not relation:
+ return
+
+ relation_type = relation.get("rel_type")
+ if not relation_type:
+ return
+
+ # Ensure the parent is real.
+ relates_to = relation.get("event_id")
+ if not relates_to:
+ return
+
+ parent_event = await self.store.get_event(relates_to, allow_none=True)
+ if parent_event:
+ # And in the same room.
+ if parent_event.room_id != event.room_id:
+ raise SynapseError(400, "Relations must be in the same room")
+
+ else:
+ # There must be some reason that the client knows the event exists,
+ # see if there are existing relations. If so, assume everything is fine.
+ if not await self.store.event_is_target_of_relation(relates_to):
+ # Otherwise, the client can't know about the parent event!
+ raise SynapseError(400, "Can't send relation to unknown event")
# If this event is an annotation then we check that that the sender
# can't annotate the same way twice (e.g. stops users from liking an
# event multiple times).
- relation = event.content.get("m.relates_to", {})
- if relation.get("rel_type") == RelationTypes.ANNOTATION:
- relates_to = relation["event_id"]
+ if relation_type == RelationTypes.ANNOTATION:
aggregation_key = relation["key"]
already_exists = await self.store.has_user_annotated_event(
@@ -1016,9 +1055,12 @@ class EventCreationHandler:
if already_exists:
raise SynapseError(400, "Can't send same reaction twice")
- logger.debug("Created event %s", event.event_id)
-
- return event, context
+ # Don't attempt to start a thread if the parent event is a relation.
+ elif relation_type == RelationTypes.THREAD:
+ if await self.store.event_includes_relation(relates_to):
+ raise SynapseError(
+ 400, "Cannot start threads from an event with a relation"
+ )
@measure_func("handle_new_client_event")
async def handle_new_client_event(
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 907af10995..0a43acda07 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -132,6 +132,69 @@ class RelationsWorkerStore(SQLBaseStore):
"get_recent_references_for_event", _get_recent_references_for_event_txn
)
+ async def event_includes_relation(self, event_id: str) -> bool:
+ """Check if the given event relates to another event.
+
+ An event has a relation if it has a valid m.relates_to with a rel_type
+ and event_id in the content:
+
+ {
+ "content": {
+ "m.relates_to": {
+ "rel_type": "m.replace",
+ "event_id": "$other_event_id"
+ }
+ }
+ }
+
+ Args:
+ event_id: The event to check.
+
+ Returns:
+ True if the event includes a valid relation.
+ """
+
+ result = await self.db_pool.simple_select_one_onecol(
+ table="event_relations",
+ keyvalues={"event_id": event_id},
+ retcol="event_id",
+ allow_none=True,
+ desc="event_includes_relation",
+ )
+ return result is not None
+
+ async def event_is_target_of_relation(self, parent_id: str) -> bool:
+ """Check if the given event is the target of another event's relation.
+
+ An event is the target of an event relation if it has a valid
+ m.relates_to with a rel_type and event_id pointing to parent_id in the
+ content:
+
+ {
+ "content": {
+ "m.relates_to": {
+ "rel_type": "m.replace",
+ "event_id": "$parent_id"
+ }
+ }
+ }
+
+ Args:
+ parent_id: The event to check.
+
+ Returns:
+ True if the event is the target of another event's relation.
+ """
+
+ result = await self.db_pool.simple_select_one_onecol(
+ table="event_relations",
+ keyvalues={"relates_to_id": parent_id},
+ retcol="event_id",
+ allow_none=True,
+ desc="event_is_target_of_relation",
+ )
+ return result is not None
+
@cached(tree=True)
async def get_aggregation_groups_for_event(
self,
@@ -362,7 +425,7 @@ class RelationsWorkerStore(SQLBaseStore):
%s;
"""
- def _get_if_event_has_relations(txn) -> List[str]:
+ def _get_if_events_have_relations(txn) -> List[str]:
clauses: List[str] = []
clause, args = make_in_list_sql_clause(
txn.database_engine, "relates_to_id", parent_ids
@@ -387,7 +450,7 @@ class RelationsWorkerStore(SQLBaseStore):
return [row[0] for row in txn]
return await self.db_pool.runInteraction(
- "get_if_event_has_relations", _get_if_event_has_relations
+ "get_if_events_have_relations", _get_if_events_have_relations
)
async def has_user_annotated_event(
|