summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/5629.bugfix1
-rw-r--r--synapse/events/__init__.py11
-rw-r--r--synapse/events/utils.py16
-rw-r--r--synapse/rest/client/v2_alpha/relations.py75
-rw-r--r--tests/rest/client/v2_alpha/test_relations.py116
5 files changed, 180 insertions, 39 deletions
diff --git a/changelog.d/5629.bugfix b/changelog.d/5629.bugfix
new file mode 100644
index 0000000000..672eabad40
--- /dev/null
+++ b/changelog.d/5629.bugfix
@@ -0,0 +1 @@
+Forbid viewing relations on an event once it has been redacted.
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index d3de70e671..88ed6d764f 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -104,6 +104,17 @@ class _EventInternalMetadata(object):
         """
         return getattr(self, "proactively_send", True)
 
+    def is_redacted(self):
+        """Whether the event has been redacted.
+
+        This is used for efficiently checking whether an event has been
+        marked as redacted without needing to make another database call.
+
+        Returns:
+            bool
+        """
+        return getattr(self, "redacted", False)
+
 
 def _event_dict_property(key):
     # We want to be able to use hasattr with the event dict properties.
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 987de5cab7..9487a886f5 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -52,10 +52,15 @@ def prune_event(event):
 
     from . import event_type_from_format_version
 
-    return event_type_from_format_version(event.format_version)(
+    pruned_event = event_type_from_format_version(event.format_version)(
         pruned_event_dict, event.internal_metadata.get_dict()
     )
 
+    # Mark the event as redacted
+    pruned_event.internal_metadata.redacted = True
+
+    return pruned_event
+
 
 def prune_event_dict(event_dict):
     """Redacts the event_dict in the same way as `prune_event`, except it
@@ -360,9 +365,12 @@ class EventClientSerializer(object):
         event_id = event.event_id
         serialized_event = serialize_event(event, time_now, **kwargs)
 
-        # If MSC1849 is enabled then we need to look if thre are any relations
-        # we need to bundle in with the event
-        if self.experimental_msc1849_support_enabled and bundle_aggregations:
+        # If MSC1849 is enabled then we need to look if there are any relations
+        # we need to bundle in with the event.
+        # Do not bundle relations if the event has been redacted
+        if not event.internal_metadata.is_redacted() and (
+            self.experimental_msc1849_support_enabled and bundle_aggregations
+        ):
             annotations = yield self.store.get_aggregation_groups_for_event(event_id)
             references = yield self.store.get_relations_for_event(
                 event_id, RelationTypes.REFERENCE, direction="f"
diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py
index 7ce485b471..6e52f6d284 100644
--- a/synapse/rest/client/v2_alpha/relations.py
+++ b/synapse/rest/client/v2_alpha/relations.py
@@ -34,6 +34,7 @@ from synapse.http.servlet import (
 from synapse.rest.client.transactions import HttpTransactionCache
 from synapse.storage.relations import (
     AggregationPaginationToken,
+    PaginationChunk,
     RelationPaginationToken,
 )
 
@@ -153,23 +154,28 @@ class RelationPaginationServlet(RestServlet):
         from_token = parse_string(request, "from")
         to_token = parse_string(request, "to")
 
-        if from_token:
-            from_token = RelationPaginationToken.from_string(from_token)
-
-        if to_token:
-            to_token = RelationPaginationToken.from_string(to_token)
-
-        result = yield self.store.get_relations_for_event(
-            event_id=parent_id,
-            relation_type=relation_type,
-            event_type=event_type,
-            limit=limit,
-            from_token=from_token,
-            to_token=to_token,
-        )
+        if event.internal_metadata.is_redacted():
+            # If the event is redacted, return an empty list of relations
+            pagination_chunk = PaginationChunk(chunk=[])
+        else:
+            # Return the relations
+            if from_token:
+                from_token = RelationPaginationToken.from_string(from_token)
+
+            if to_token:
+                to_token = RelationPaginationToken.from_string(to_token)
+
+            pagination_chunk = yield self.store.get_relations_for_event(
+                event_id=parent_id,
+                relation_type=relation_type,
+                event_type=event_type,
+                limit=limit,
+                from_token=from_token,
+                to_token=to_token,
+            )
 
         events = yield self.store.get_events_as_list(
-            [c["event_id"] for c in result.chunk]
+            [c["event_id"] for c in pagination_chunk.chunk]
         )
 
         now = self.clock.time_msec()
@@ -186,7 +192,7 @@ class RelationPaginationServlet(RestServlet):
             events, now, bundle_aggregations=False
         )
 
-        return_value = result.to_dict()
+        return_value = pagination_chunk.to_dict()
         return_value["chunk"] = events
         return_value["original_event"] = original_event
 
@@ -234,7 +240,7 @@ class RelationAggregationPaginationServlet(RestServlet):
 
         # This checks that a) the event exists and b) the user is allowed to
         # view it.
-        yield self.event_handler.get_event(requester.user, room_id, parent_id)
+        event = yield self.event_handler.get_event(requester.user, room_id, parent_id)
 
         if relation_type not in (RelationTypes.ANNOTATION, None):
             raise SynapseError(400, "Relation type must be 'annotation'")
@@ -243,21 +249,26 @@ class RelationAggregationPaginationServlet(RestServlet):
         from_token = parse_string(request, "from")
         to_token = parse_string(request, "to")
 
-        if from_token:
-            from_token = AggregationPaginationToken.from_string(from_token)
-
-        if to_token:
-            to_token = AggregationPaginationToken.from_string(to_token)
-
-        res = yield self.store.get_aggregation_groups_for_event(
-            event_id=parent_id,
-            event_type=event_type,
-            limit=limit,
-            from_token=from_token,
-            to_token=to_token,
-        )
-
-        defer.returnValue((200, res.to_dict()))
+        if event.internal_metadata.is_redacted():
+            # If the event is redacted, return an empty list of relations
+            pagination_chunk = PaginationChunk(chunk=[])
+        else:
+            # Return the relations
+            if from_token:
+                from_token = AggregationPaginationToken.from_string(from_token)
+
+            if to_token:
+                to_token = AggregationPaginationToken.from_string(to_token)
+
+            pagination_chunk = yield self.store.get_aggregation_groups_for_event(
+                event_id=parent_id,
+                event_type=event_type,
+                limit=limit,
+                from_token=from_token,
+                to_token=to_token,
+            )
+
+        defer.returnValue((200, pagination_chunk.to_dict()))
 
 
 class RelationAggregationGroupPaginationServlet(RestServlet):
diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py
index 58c6951852..c7e5859970 100644
--- a/tests/rest/client/v2_alpha/test_relations.py
+++ b/tests/rest/client/v2_alpha/test_relations.py
@@ -93,7 +93,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
     def test_deny_double_react(self):
         """Test that we deny relations on membership events
         """
-        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
         self.assertEquals(200, channel.code, channel.json_body)
 
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
@@ -540,14 +540,122 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict
         )
 
+    def test_relations_redaction_redacts_edits(self):
+        """Test that edits of an event are redacted when the original event
+        is redacted.
+        """
+        # Send a new event
+        res = self.helper.send(self.room, body="Heyo!", tok=self.user_token)
+        original_event_id = res["event_id"]
+
+        # Add a relation
+        channel = self._send_relation(
+            RelationTypes.REPLACE,
+            "m.room.message",
+            parent_id=original_event_id,
+            content={
+                "msgtype": "m.text",
+                "body": "Wibble",
+                "m.new_content": {"msgtype": "m.text", "body": "First edit"},
+            },
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+
+        # Check the relation is returned
+        request, channel = self.make_request(
+            "GET",
+            "/_matrix/client/unstable/rooms/%s/relations/%s/m.replace/m.room.message"
+            % (self.room, original_event_id),
+            access_token=self.user_token,
+        )
+        self.render(request)
+        self.assertEquals(200, channel.code, channel.json_body)
+
+        self.assertIn("chunk", channel.json_body)
+        self.assertEquals(len(channel.json_body["chunk"]), 1)
+
+        # Redact the original event
+        request, channel = self.make_request(
+            "PUT",
+            "/rooms/%s/redact/%s/%s"
+            % (self.room, original_event_id, "test_relations_redaction_redacts_edits"),
+            access_token=self.user_token,
+            content="{}",
+        )
+        self.render(request)
+        self.assertEquals(200, channel.code, channel.json_body)
+
+        # Try to check for remaining m.replace relations
+        request, channel = self.make_request(
+            "GET",
+            "/_matrix/client/unstable/rooms/%s/relations/%s/m.replace/m.room.message"
+            % (self.room, original_event_id),
+            access_token=self.user_token,
+        )
+        self.render(request)
+        self.assertEquals(200, channel.code, channel.json_body)
+
+        # Check that no relations are returned
+        self.assertIn("chunk", channel.json_body)
+        self.assertEquals(channel.json_body["chunk"], [])
+
+    def test_aggregations_redaction_prevents_access_to_aggregations(self):
+        """Test that annotations of an event are redacted when the original event
+        is redacted.
+        """
+        # Send a new event
+        res = self.helper.send(self.room, body="Hello!", tok=self.user_token)
+        original_event_id = res["event_id"]
+
+        # Add a relation
+        channel = self._send_relation(
+            RelationTypes.ANNOTATION, "m.reaction", key="👍", parent_id=original_event_id
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+
+        # Redact the original
+        request, channel = self.make_request(
+            "PUT",
+            "/rooms/%s/redact/%s/%s"
+            % (
+                self.room,
+                original_event_id,
+                "test_aggregations_redaction_prevents_access_to_aggregations",
+            ),
+            access_token=self.user_token,
+            content="{}",
+        )
+        self.render(request)
+        self.assertEquals(200, channel.code, channel.json_body)
+
+        # Check that aggregations returns zero
+        request, channel = self.make_request(
+            "GET",
+            "/_matrix/client/unstable/rooms/%s/aggregations/%s/m.annotation/m.reaction"
+            % (self.room, original_event_id),
+            access_token=self.user_token,
+        )
+        self.render(request)
+        self.assertEquals(200, channel.code, channel.json_body)
+
+        self.assertIn("chunk", channel.json_body)
+        self.assertEquals(channel.json_body["chunk"], [])
+
     def _send_relation(
-        self, relation_type, event_type, key=None, content={}, access_token=None
+        self,
+        relation_type,
+        event_type,
+        key=None,
+        content={},
+        access_token=None,
+        parent_id=None,
     ):
         """Helper function to send a relation pointing at `self.parent_id`
 
         Args:
             relation_type (str): One of `RelationTypes`
             event_type (str): The type of the event to create
+            parent_id (str): The event_id this relation relates to. If None, then self.parent_id
             key (str|None): The aggregation key used for m.annotation relation
                 type.
             content(dict|None): The content of the created event.
@@ -564,10 +672,12 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         if key:
             query = "?key=" + six.moves.urllib.parse.quote_plus(key.encode("utf-8"))
 
+        original_id = parent_id if parent_id else self.parent_id
+
         request, channel = self.make_request(
             "POST",
             "/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s"
-            % (self.room, self.parent_id, relation_type, event_type, query),
+            % (self.room, original_id, relation_type, event_type, query),
             json.dumps(content).encode("utf-8"),
             access_token=access_token,
         )