diff --git a/changelog.d/5629.bugfix b/changelog.d/5629.bugfix
new file mode 100644
index 0000000000..672eabad40
--- /dev/null
+++ b/changelog.d/5629.bugfix
@@ -0,0 +1 @@
+Forbid viewing relations on an event once it has been redacted.
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index d3de70e671..88ed6d764f 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -104,6 +104,17 @@ class _EventInternalMetadata(object):
"""
return getattr(self, "proactively_send", True)
+ def is_redacted(self):
+ """Whether the event has been redacted.
+
+ This is used for efficiently checking whether an event has been
+ marked as redacted without needing to make another database call.
+
+ Returns:
+ bool
+ """
+ return getattr(self, "redacted", False)
+
def _event_dict_property(key):
# We want to be able to use hasattr with the event dict properties.
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 987de5cab7..9487a886f5 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -52,10 +52,15 @@ def prune_event(event):
from . import event_type_from_format_version
- return event_type_from_format_version(event.format_version)(
+ pruned_event = event_type_from_format_version(event.format_version)(
pruned_event_dict, event.internal_metadata.get_dict()
)
+ # Mark the event as redacted
+ pruned_event.internal_metadata.redacted = True
+
+ return pruned_event
+
def prune_event_dict(event_dict):
"""Redacts the event_dict in the same way as `prune_event`, except it
@@ -360,9 +365,12 @@ class EventClientSerializer(object):
event_id = event.event_id
serialized_event = serialize_event(event, time_now, **kwargs)
- # If MSC1849 is enabled then we need to look if thre are any relations
- # we need to bundle in with the event
- if self.experimental_msc1849_support_enabled and bundle_aggregations:
+ # If MSC1849 is enabled then we need to look if there are any relations
+ # we need to bundle in with the event.
+ # Do not bundle relations if the event has been redacted
+ if not event.internal_metadata.is_redacted() and (
+ self.experimental_msc1849_support_enabled and bundle_aggregations
+ ):
annotations = yield self.store.get_aggregation_groups_for_event(event_id)
references = yield self.store.get_relations_for_event(
event_id, RelationTypes.REFERENCE, direction="f"
diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py
index 7ce485b471..6e52f6d284 100644
--- a/synapse/rest/client/v2_alpha/relations.py
+++ b/synapse/rest/client/v2_alpha/relations.py
@@ -34,6 +34,7 @@ from synapse.http.servlet import (
from synapse.rest.client.transactions import HttpTransactionCache
from synapse.storage.relations import (
AggregationPaginationToken,
+ PaginationChunk,
RelationPaginationToken,
)
@@ -153,23 +154,28 @@ class RelationPaginationServlet(RestServlet):
from_token = parse_string(request, "from")
to_token = parse_string(request, "to")
- if from_token:
- from_token = RelationPaginationToken.from_string(from_token)
-
- if to_token:
- to_token = RelationPaginationToken.from_string(to_token)
-
- result = yield self.store.get_relations_for_event(
- event_id=parent_id,
- relation_type=relation_type,
- event_type=event_type,
- limit=limit,
- from_token=from_token,
- to_token=to_token,
- )
+ if event.internal_metadata.is_redacted():
+ # If the event is redacted, return an empty list of relations
+ pagination_chunk = PaginationChunk(chunk=[])
+ else:
+ # Return the relations
+ if from_token:
+ from_token = RelationPaginationToken.from_string(from_token)
+
+ if to_token:
+ to_token = RelationPaginationToken.from_string(to_token)
+
+ pagination_chunk = yield self.store.get_relations_for_event(
+ event_id=parent_id,
+ relation_type=relation_type,
+ event_type=event_type,
+ limit=limit,
+ from_token=from_token,
+ to_token=to_token,
+ )
events = yield self.store.get_events_as_list(
- [c["event_id"] for c in result.chunk]
+ [c["event_id"] for c in pagination_chunk.chunk]
)
now = self.clock.time_msec()
@@ -186,7 +192,7 @@ class RelationPaginationServlet(RestServlet):
events, now, bundle_aggregations=False
)
- return_value = result.to_dict()
+ return_value = pagination_chunk.to_dict()
return_value["chunk"] = events
return_value["original_event"] = original_event
@@ -234,7 +240,7 @@ class RelationAggregationPaginationServlet(RestServlet):
# This checks that a) the event exists and b) the user is allowed to
# view it.
- yield self.event_handler.get_event(requester.user, room_id, parent_id)
+ event = yield self.event_handler.get_event(requester.user, room_id, parent_id)
if relation_type not in (RelationTypes.ANNOTATION, None):
raise SynapseError(400, "Relation type must be 'annotation'")
@@ -243,21 +249,26 @@ class RelationAggregationPaginationServlet(RestServlet):
from_token = parse_string(request, "from")
to_token = parse_string(request, "to")
- if from_token:
- from_token = AggregationPaginationToken.from_string(from_token)
-
- if to_token:
- to_token = AggregationPaginationToken.from_string(to_token)
-
- res = yield self.store.get_aggregation_groups_for_event(
- event_id=parent_id,
- event_type=event_type,
- limit=limit,
- from_token=from_token,
- to_token=to_token,
- )
-
- defer.returnValue((200, res.to_dict()))
+ if event.internal_metadata.is_redacted():
+ # If the event is redacted, return an empty list of relations
+ pagination_chunk = PaginationChunk(chunk=[])
+ else:
+ # Return the relations
+ if from_token:
+ from_token = AggregationPaginationToken.from_string(from_token)
+
+ if to_token:
+ to_token = AggregationPaginationToken.from_string(to_token)
+
+ pagination_chunk = yield self.store.get_aggregation_groups_for_event(
+ event_id=parent_id,
+ event_type=event_type,
+ limit=limit,
+ from_token=from_token,
+ to_token=to_token,
+ )
+
+ defer.returnValue((200, pagination_chunk.to_dict()))
class RelationAggregationGroupPaginationServlet(RestServlet):
diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py
index 58c6951852..c7e5859970 100644
--- a/tests/rest/client/v2_alpha/test_relations.py
+++ b/tests/rest/client/v2_alpha/test_relations.py
@@ -93,7 +93,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
def test_deny_double_react(self):
"""Test that we deny relations on membership events
"""
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
self.assertEquals(200, channel.code, channel.json_body)
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
@@ -540,14 +540,122 @@ class RelationsTestCase(unittest.HomeserverTestCase):
{"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict
)
+ def test_relations_redaction_redacts_edits(self):
+ """Test that edits of an event are redacted when the original event
+ is redacted.
+ """
+ # Send a new event
+ res = self.helper.send(self.room, body="Heyo!", tok=self.user_token)
+ original_event_id = res["event_id"]
+
+ # Add a relation
+ channel = self._send_relation(
+ RelationTypes.REPLACE,
+ "m.room.message",
+ parent_id=original_event_id,
+ content={
+ "msgtype": "m.text",
+ "body": "Wibble",
+ "m.new_content": {"msgtype": "m.text", "body": "First edit"},
+ },
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ # Check the relation is returned
+ request, channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/rooms/%s/relations/%s/m.replace/m.room.message"
+ % (self.room, original_event_id),
+ access_token=self.user_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ self.assertIn("chunk", channel.json_body)
+ self.assertEquals(len(channel.json_body["chunk"]), 1)
+
+ # Redact the original event
+ request, channel = self.make_request(
+ "PUT",
+ "/rooms/%s/redact/%s/%s"
+ % (self.room, original_event_id, "test_relations_redaction_redacts_edits"),
+ access_token=self.user_token,
+ content="{}",
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ # Try to check for remaining m.replace relations
+ request, channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/rooms/%s/relations/%s/m.replace/m.room.message"
+ % (self.room, original_event_id),
+ access_token=self.user_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ # Check that no relations are returned
+ self.assertIn("chunk", channel.json_body)
+ self.assertEquals(channel.json_body["chunk"], [])
+
+ def test_aggregations_redaction_prevents_access_to_aggregations(self):
+ """Test that annotations of an event are redacted when the original event
+ is redacted.
+ """
+ # Send a new event
+ res = self.helper.send(self.room, body="Hello!", tok=self.user_token)
+ original_event_id = res["event_id"]
+
+ # Add a relation
+ channel = self._send_relation(
+ RelationTypes.ANNOTATION, "m.reaction", key="👍", parent_id=original_event_id
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ # Redact the original
+ request, channel = self.make_request(
+ "PUT",
+ "/rooms/%s/redact/%s/%s"
+ % (
+ self.room,
+ original_event_id,
+ "test_aggregations_redaction_prevents_access_to_aggregations",
+ ),
+ access_token=self.user_token,
+ content="{}",
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ # Check that aggregations returns zero
+ request, channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/rooms/%s/aggregations/%s/m.annotation/m.reaction"
+ % (self.room, original_event_id),
+ access_token=self.user_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ self.assertIn("chunk", channel.json_body)
+ self.assertEquals(channel.json_body["chunk"], [])
+
def _send_relation(
- self, relation_type, event_type, key=None, content={}, access_token=None
+ self,
+ relation_type,
+ event_type,
+ key=None,
+ content={},
+ access_token=None,
+ parent_id=None,
):
"""Helper function to send a relation pointing at `self.parent_id`
Args:
relation_type (str): One of `RelationTypes`
event_type (str): The type of the event to create
+ parent_id (str): The event_id this relation relates to. If None, then self.parent_id
key (str|None): The aggregation key used for m.annotation relation
type.
content(dict|None): The content of the created event.
@@ -564,10 +672,12 @@ class RelationsTestCase(unittest.HomeserverTestCase):
if key:
query = "?key=" + six.moves.urllib.parse.quote_plus(key.encode("utf-8"))
+ original_id = parent_id if parent_id else self.parent_id
+
request, channel = self.make_request(
"POST",
"/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s"
- % (self.room, self.parent_id, relation_type, event_type, query),
+ % (self.room, original_id, relation_type, event_type, query),
json.dumps(content).encode("utf-8"),
access_token=access_token,
)
|