diff --git a/changelog.d/11516.bugfix b/changelog.d/11516.bugfix
new file mode 100644
index 0000000000..22bba93671
--- /dev/null
+++ b/changelog.d/11516.bugfix
@@ -0,0 +1 @@
+Fix a long-standing bug where relations from other rooms could be included in the bundled aggregations of an event.
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 84ef69df67..3da432c1df 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -454,23 +454,26 @@ class EventClientSerializer:
return
event_id = event.event_id
+ room_id = event.room_id
# The bundled aggregations to include.
aggregations = {}
- annotations = await self.store.get_aggregation_groups_for_event(event_id)
+ annotations = await self.store.get_aggregation_groups_for_event(
+ event_id, room_id
+ )
if annotations.chunk:
aggregations[RelationTypes.ANNOTATION] = annotations.to_dict()
references = await self.store.get_relations_for_event(
- event_id, RelationTypes.REFERENCE, direction="f"
+ event_id, room_id, RelationTypes.REFERENCE, direction="f"
)
if references.chunk:
aggregations[RelationTypes.REFERENCE] = references.to_dict()
edit = None
if event.type == EventTypes.Message:
- edit = await self.store.get_applicable_edit(event_id)
+ edit = await self.store.get_applicable_edit(event_id, room_id)
if edit:
# If there is an edit replace the content, preserving existing
@@ -503,7 +506,7 @@ class EventClientSerializer:
(
thread_count,
latest_thread_event,
- ) = await self.store.get_thread_summary(event_id)
+ ) = await self.store.get_thread_summary(event_id, room_id)
if latest_thread_event:
aggregations[RelationTypes.THREAD] = {
# Don't bundle aggregations as this could recurse forever.
diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py
index fc4e6921c5..ffa37ef06c 100644
--- a/synapse/rest/client/relations.py
+++ b/synapse/rest/client/relations.py
@@ -212,6 +212,7 @@ class RelationPaginationServlet(RestServlet):
pagination_chunk = await self.store.get_relations_for_event(
event_id=parent_id,
+ room_id=room_id,
relation_type=relation_type,
event_type=event_type,
limit=limit,
@@ -317,6 +318,7 @@ class RelationAggregationPaginationServlet(RestServlet):
pagination_chunk = await self.store.get_aggregation_groups_for_event(
event_id=parent_id,
+ room_id=room_id,
event_type=event_type,
limit=limit,
from_token=from_token,
@@ -383,7 +385,9 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
# This checks that a) the event exists and b) the user is allowed to
# view it.
- await self.event_handler.get_event(requester.user, room_id, parent_id)
+ event = await self.event_handler.get_event(requester.user, room_id, parent_id)
+ if event is None:
+ raise SynapseError(404, "Unknown parent event.")
if relation_type != RelationTypes.ANNOTATION:
raise SynapseError(400, "Relation type must be 'annotation'")
@@ -402,6 +406,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
result = await self.store.get_relations_for_event(
event_id=parent_id,
+ room_id=room_id,
relation_type=relation_type,
event_type=event_type,
aggregation_key=key,
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 4e528612ea..f1f4ce5e07 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1780,10 +1780,14 @@ class PersistEventsStore:
)
if rel_type == RelationTypes.REPLACE:
- txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,))
+ txn.call_after(
+ self.store.get_applicable_edit.invalidate, (parent_id, event.room_id)
+ )
if rel_type == RelationTypes.THREAD:
- txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,))
+ txn.call_after(
+ self.store.get_thread_summary.invalidate, (parent_id, event.room_id)
+ )
def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):
"""Handles keeping track of insertion events and edges/connections.
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 0a43acda07..3368a8b084 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -37,6 +37,7 @@ class RelationsWorkerStore(SQLBaseStore):
async def get_relations_for_event(
self,
event_id: str,
+ room_id: str,
relation_type: Optional[str] = None,
event_type: Optional[str] = None,
aggregation_key: Optional[str] = None,
@@ -49,6 +50,7 @@ class RelationsWorkerStore(SQLBaseStore):
Args:
event_id: Fetch events that relate to this event ID.
+ room_id: The room the event belongs to.
relation_type: Only fetch events with this relation type, if given.
event_type: Only fetch events with this event type, if given.
aggregation_key: Only fetch events with this aggregation key, if given.
@@ -63,8 +65,8 @@ class RelationsWorkerStore(SQLBaseStore):
the form `{"event_id": "..."}`.
"""
- where_clause = ["relates_to_id = ?"]
- where_args: List[Union[str, int]] = [event_id]
+ where_clause = ["relates_to_id = ?", "room_id = ?"]
+ where_args: List[Union[str, int]] = [event_id, room_id]
if relation_type is not None:
where_clause.append("relation_type = ?")
@@ -199,6 +201,7 @@ class RelationsWorkerStore(SQLBaseStore):
async def get_aggregation_groups_for_event(
self,
event_id: str,
+ room_id: str,
event_type: Optional[str] = None,
limit: int = 5,
direction: str = "b",
@@ -213,6 +216,7 @@ class RelationsWorkerStore(SQLBaseStore):
Args:
event_id: Fetch events that relate to this event ID.
+ room_id: The room the event belongs to.
event_type: Only fetch events with this event type, if given.
limit: Only fetch the `limit` groups.
direction: Whether to fetch the highest count first (`"b"`) or
@@ -225,8 +229,12 @@ class RelationsWorkerStore(SQLBaseStore):
`type`, `key` and `count` fields.
"""
- where_clause = ["relates_to_id = ?", "relation_type = ?"]
- where_args: List[Union[str, int]] = [event_id, RelationTypes.ANNOTATION]
+ where_clause = ["relates_to_id = ?", "room_id = ?", "relation_type = ?"]
+ where_args: List[Union[str, int]] = [
+ event_id,
+ room_id,
+ RelationTypes.ANNOTATION,
+ ]
if event_type:
where_clause.append("type = ?")
@@ -288,7 +296,9 @@ class RelationsWorkerStore(SQLBaseStore):
)
@cached()
- async def get_applicable_edit(self, event_id: str) -> Optional[EventBase]:
+ async def get_applicable_edit(
+ self, event_id: str, room_id: str
+ ) -> Optional[EventBase]:
"""Get the most recent edit (if any) that has happened for the given
event.
@@ -296,6 +306,7 @@ class RelationsWorkerStore(SQLBaseStore):
Args:
event_id: The original event ID
+ room_id: The original event's room ID
Returns:
The most recent edit, if any.
@@ -317,13 +328,14 @@ class RelationsWorkerStore(SQLBaseStore):
WHERE
relates_to_id = ?
AND relation_type = ?
+ AND edit.room_id = ?
AND edit.type = 'm.room.message'
ORDER by edit.origin_server_ts DESC, edit.event_id DESC
LIMIT 1
"""
def _get_applicable_edit_txn(txn: LoggingTransaction) -> Optional[str]:
- txn.execute(sql, (event_id, RelationTypes.REPLACE))
+ txn.execute(sql, (event_id, RelationTypes.REPLACE, room_id))
row = txn.fetchone()
if row:
return row[0]
@@ -340,13 +352,14 @@ class RelationsWorkerStore(SQLBaseStore):
@cached()
async def get_thread_summary(
- self, event_id: str
+ self, event_id: str, room_id: str
) -> Tuple[int, Optional[EventBase]]:
"""Get the number of threaded replies, the senders of those replies, and
the latest reply (if any) for the given event.
Args:
- event_id: The original event ID
+ event_id: Summarize the thread related to this event ID.
+ room_id: The room the event belongs to.
Returns:
The number of items in the thread and the most recent response, if any.
@@ -363,12 +376,13 @@ class RelationsWorkerStore(SQLBaseStore):
INNER JOIN events USING (event_id)
WHERE
relates_to_id = ?
+ AND room_id = ?
AND relation_type = ?
ORDER BY topological_ordering DESC, stream_ordering DESC
LIMIT 1
"""
- txn.execute(sql, (event_id, RelationTypes.THREAD))
+ txn.execute(sql, (event_id, room_id, RelationTypes.THREAD))
row = txn.fetchone()
if row is None:
return 0, None
@@ -378,11 +392,13 @@ class RelationsWorkerStore(SQLBaseStore):
sql = """
SELECT COALESCE(COUNT(event_id), 0)
FROM event_relations
+ INNER JOIN events USING (event_id)
WHERE
relates_to_id = ?
+ AND room_id = ?
AND relation_type = ?
"""
- txn.execute(sql, (event_id, RelationTypes.THREAD))
+ txn.execute(sql, (event_id, room_id, RelationTypes.THREAD))
count = txn.fetchone()[0] # type: ignore[index]
return count, latest_event_id
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index 397c12c2a6..55f4f0b1d0 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -16,6 +16,7 @@
import itertools
import urllib.parse
from typing import Dict, List, Optional, Tuple
+from unittest.mock import patch
from synapse.api.constants import EventTypes, RelationTypes
from synapse.rest import admin
@@ -23,6 +24,8 @@ from synapse.rest.client import login, register, relations, room, sync
from tests import unittest
from tests.server import FakeChannel
+from tests.test_utils import make_awaitable
+from tests.test_utils.event_injection import inject_event
class RelationsTestCase(unittest.HomeserverTestCase):
@@ -651,6 +654,118 @@ class RelationsTestCase(unittest.HomeserverTestCase):
},
)
+ @unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
+ def test_ignore_invalid_room(self):
+ """Test that we ignore invalid relations over federation."""
+ # Create another room and send a message in it.
+ room2 = self.helper.create_room_as(self.user_id, tok=self.user_token)
+ res = self.helper.send(room2, body="Hi!", tok=self.user_token)
+ parent_id = res["event_id"]
+
+ # Disable the validation to pretend this came over federation.
+ with patch(
+ "synapse.handlers.message.EventCreationHandler._validate_event_relation",
+ new=lambda self, event: make_awaitable(None),
+ ):
+ # Generate a various relations from a different room.
+ self.get_success(
+ inject_event(
+ self.hs,
+ room_id=self.room,
+ type="m.reaction",
+ sender=self.user_id,
+ content={
+ "m.relates_to": {
+ "rel_type": RelationTypes.ANNOTATION,
+ "event_id": parent_id,
+ "key": "A",
+ }
+ },
+ )
+ )
+
+ self.get_success(
+ inject_event(
+ self.hs,
+ room_id=self.room,
+ type="m.room.message",
+ sender=self.user_id,
+ content={
+ "body": "foo",
+ "msgtype": "m.text",
+ "m.relates_to": {
+ "rel_type": RelationTypes.REFERENCE,
+ "event_id": parent_id,
+ },
+ },
+ )
+ )
+
+ self.get_success(
+ inject_event(
+ self.hs,
+ room_id=self.room,
+ type="m.room.message",
+ sender=self.user_id,
+ content={
+ "body": "foo",
+ "msgtype": "m.text",
+ "m.relates_to": {
+ "rel_type": RelationTypes.THREAD,
+ "event_id": parent_id,
+ },
+ },
+ )
+ )
+
+ self.get_success(
+ inject_event(
+ self.hs,
+ room_id=self.room,
+ type="m.room.message",
+ sender=self.user_id,
+ content={
+ "body": "foo",
+ "msgtype": "m.text",
+ "new_content": {
+ "body": "new content",
+ "msgtype": "m.text",
+ },
+ "m.relates_to": {
+ "rel_type": RelationTypes.REPLACE,
+ "event_id": parent_id,
+ },
+ },
+ )
+ )
+
+ # They should be ignored when fetching relations.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{room2}/relations/{parent_id}",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ self.assertEqual(channel.json_body["chunk"], [])
+
+ # And when fetching aggregations.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{room2}/aggregations/{parent_id}",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ self.assertEqual(channel.json_body["chunk"], [])
+
+ # And for bundled aggregations.
+ channel = self.make_request(
+ "GET",
+ f"/rooms/{room2}/event/{parent_id}",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ self.assertNotIn("m.relations", channel.json_body["unsigned"])
+
def test_edit(self):
"""Test that a simple edit works."""
|