summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/14260.feature1
-rw-r--r--synapse/api/constants.py2
-rw-r--r--synapse/config/experimental.py3
-rw-r--r--synapse/handlers/message.py47
-rw-r--r--synapse/handlers/relations.py56
-rw-r--r--synapse/rest/client/room.py57
-rw-r--r--synapse/rest/client/versions.py2
-rw-r--r--synapse/storage/databases/main/relations.py36
-rw-r--r--tests/rest/client/test_redactions.py273
-rw-r--r--tests/rest/client/utils.py37
10 files changed, 486 insertions, 28 deletions
diff --git a/changelog.d/14260.feature b/changelog.d/14260.feature
new file mode 100644
index 0000000000..102dc7b3e0
--- /dev/null
+++ b/changelog.d/14260.feature
@@ -0,0 +1 @@
+Add experimental support for [MSC3912](https://github.com/matrix-org/matrix-spec-proposals/pull/3912): Relation-based redactions.
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 44c5ffc6a5..bc04a0755b 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -125,6 +125,8 @@ class EventTypes:
     MSC2716_BATCH: Final = "org.matrix.msc2716.batch"
     MSC2716_MARKER: Final = "org.matrix.msc2716.marker"
 
+    Reaction: Final = "m.reaction"
+
 
 class ToDeviceEventTypes:
     RoomKeyRequest: Final = "m.room_key_request"
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index d9bdd66d55..d4b71d1673 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -128,3 +128,6 @@ class ExperimentalConfig(Config):
         self.msc3886_endpoint: Optional[str] = experimental.get(
             "msc3886_endpoint", None
         )
+
+        # MSC3912: Relation-based redactions.
+        self.msc3912_enabled: bool = experimental.get("msc3912_enabled", False)
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 468900a07f..4cf593cfdc 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -877,6 +877,36 @@ class EventCreationHandler:
                 return prev_event
         return None
 
+    async def get_event_from_transaction(
+        self,
+        requester: Requester,
+        txn_id: str,
+        room_id: str,
+    ) -> Optional[EventBase]:
+        """For the given transaction ID and room ID, check if there is a matching event.
+        If so, fetch it and return it.
+
+        Args:
+            requester: The requester making the request in the context of which we want
+                to fetch the event.
+            txn_id: The transaction ID.
+            room_id: The room ID.
+
+        Returns:
+            An event if one could be found, None otherwise.
+        """
+        if requester.access_token_id:
+            existing_event_id = await self.store.get_event_id_from_transaction_id(
+                room_id,
+                requester.user.to_string(),
+                requester.access_token_id,
+                txn_id,
+            )
+            if existing_event_id:
+                return await self.store.get_event(existing_event_id)
+
+        return None
+
     async def create_and_send_nonmember_event(
         self,
         requester: Requester,
@@ -956,18 +986,17 @@ class EventCreationHandler:
         # extremities to pile up, which in turn leads to state resolution
         # taking longer.
         async with self.limiter.queue(event_dict["room_id"]):
-            if txn_id and requester.access_token_id:
-                existing_event_id = await self.store.get_event_id_from_transaction_id(
-                    event_dict["room_id"],
-                    requester.user.to_string(),
-                    requester.access_token_id,
-                    txn_id,
+            if txn_id:
+                event = await self.get_event_from_transaction(
+                    requester, txn_id, event_dict["room_id"]
                 )
-                if existing_event_id:
-                    event = await self.store.get_event(existing_event_id)
+                if event:
                     # we know it was persisted, so must have a stream ordering
                     assert event.internal_metadata.stream_ordering
-                    return event, event.internal_metadata.stream_ordering
+                    return (
+                        event,
+                        event.internal_metadata.stream_ordering,
+                    )
 
             event, context = await self.create_event(
                 requester,
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index 0a0c6d938e..8e71dda970 100644
--- a/synapse/handlers/relations.py
+++ b/synapse/handlers/relations.py
@@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tup
 
 import attr
 
-from synapse.api.constants import RelationTypes
+from synapse.api.constants import EventTypes, RelationTypes
 from synapse.api.errors import SynapseError
 from synapse.events import EventBase, relation_from_event
 from synapse.logging.opentracing import trace
@@ -75,6 +75,7 @@ class RelationsHandler:
         self._clock = hs.get_clock()
         self._event_handler = hs.get_event_handler()
         self._event_serializer = hs.get_event_client_serializer()
+        self._event_creation_handler = hs.get_event_creation_handler()
 
     async def get_relations(
         self,
@@ -205,6 +206,59 @@ class RelationsHandler:
 
         return related_events, next_token
 
+    async def redact_events_related_to(
+        self,
+        requester: Requester,
+        event_id: str,
+        initial_redaction_event: EventBase,
+        relation_types: List[str],
+    ) -> None:
+        """Redacts all events related to the given event ID with one of the given
+        relation types.
+
+        This method is expected to be called when redacting the event referred to by
+        the given event ID.
+
+        If an event cannot be redacted (e.g. because of insufficient permissions), log
+        the error and try to redact the next one.
+
+        Args:
+            requester: The requester to redact events on behalf of.
+            event_id: The event IDs to look and redact relations of.
+            initial_redaction_event: The redaction for the event referred to by
+                event_id.
+            relation_types: The types of relations to look for.
+
+        Raises:
+            ShadowBanError if the requester is shadow-banned
+        """
+        related_event_ids = (
+            await self._main_store.get_all_relations_for_event_with_types(
+                event_id, relation_types
+            )
+        )
+
+        for related_event_id in related_event_ids:
+            try:
+                await self._event_creation_handler.create_and_send_nonmember_event(
+                    requester,
+                    {
+                        "type": EventTypes.Redaction,
+                        "content": initial_redaction_event.content,
+                        "room_id": initial_redaction_event.room_id,
+                        "sender": requester.user.to_string(),
+                        "redacts": related_event_id,
+                    },
+                    ratelimit=False,
+                )
+            except SynapseError as e:
+                logger.warning(
+                    "Failed to redact event %s (related to event %s): %s",
+                    related_event_id,
+                    event_id,
+                    e.msg,
+                )
+
     async def get_annotations_for_event(
         self,
         event_id: str,
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index 01e5079963..91cb791139 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -52,6 +52,7 @@ from synapse.http.servlet import (
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.logging.opentracing import set_tag
+from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.rest.client._base import client_patterns
 from synapse.rest.client.transactions import HttpTransactionCache
 from synapse.storage.state import StateFilter
@@ -1029,6 +1030,8 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
         super().__init__(hs)
         self.event_creation_handler = hs.get_event_creation_handler()
         self.auth = hs.get_auth()
+        self._relation_handler = hs.get_relations_handler()
+        self._msc3912_enabled = hs.config.experimental.msc3912_enabled
 
     def register(self, http_server: HttpServer) -> None:
         PATTERNS = "/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)"
@@ -1045,20 +1048,46 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
         content = parse_json_object_from_request(request)
 
         try:
-            (
-                event,
-                _,
-            ) = await self.event_creation_handler.create_and_send_nonmember_event(
-                requester,
-                {
-                    "type": EventTypes.Redaction,
-                    "content": content,
-                    "room_id": room_id,
-                    "sender": requester.user.to_string(),
-                    "redacts": event_id,
-                },
-                txn_id=txn_id,
-            )
+            with_relations = None
+            if self._msc3912_enabled and "org.matrix.msc3912.with_relations" in content:
+                with_relations = content["org.matrix.msc3912.with_relations"]
+                del content["org.matrix.msc3912.with_relations"]
+
+            # Check if there's an existing event for this transaction now (even though
+            # create_and_send_nonmember_event also does it) because, if there's one,
+            # then we want to skip the call to redact_events_related_to.
+            event = None
+            if txn_id:
+                event = await self.event_creation_handler.get_event_from_transaction(
+                    requester, txn_id, room_id
+                )
+
+            if event is None:
+                (
+                    event,
+                    _,
+                ) = await self.event_creation_handler.create_and_send_nonmember_event(
+                    requester,
+                    {
+                        "type": EventTypes.Redaction,
+                        "content": content,
+                        "room_id": room_id,
+                        "sender": requester.user.to_string(),
+                        "redacts": event_id,
+                    },
+                    txn_id=txn_id,
+                )
+
+                if with_relations:
+                    run_as_background_process(
+                        "redact_related_events",
+                        self._relation_handler.redact_events_related_to,
+                        requester=requester,
+                        event_id=event_id,
+                        initial_redaction_event=event,
+                        relation_types=with_relations,
+                    )
+
             event_id = event.event_id
         except ShadowBanError:
             event_id = "$" + random_string(43)
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index 9b1b72c68a..180a11ef88 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -119,6 +119,8 @@ class VersionsRestServlet(RestServlet):
                     # Adds support for simple HTTP rendezvous as per MSC3886
                     "org.matrix.msc3886": self.config.experimental.msc3886_endpoint
                     is not None,
+                    # Adds support for relation-based redactions as per MSC3912.
+                    "org.matrix.msc3912": self.config.experimental.msc3912_enabled,
                 },
             },
         )
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index c022510e76..ca431002c8 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -295,6 +295,42 @@ class RelationsWorkerStore(SQLBaseStore):
             "get_recent_references_for_event", _get_recent_references_for_event_txn
         )
 
+    async def get_all_relations_for_event_with_types(
+        self,
+        event_id: str,
+        relation_types: List[str],
+    ) -> List[str]:
+        """Get the event IDs of all events that have a relation to the given event with
+        one of the given relation types.
+
+        Args:
+            event_id: The event for which to look for related events.
+            relation_types: The types of relations to look for.
+
+        Returns:
+            A list of the IDs of the events that relate to the given event with one of
+            the given relation types.
+        """
+
+        def get_all_relation_ids_for_event_with_types_txn(
+            txn: LoggingTransaction,
+        ) -> List[str]:
+            rows = self.db_pool.simple_select_many_txn(
+                txn=txn,
+                table="event_relations",
+                column="relation_type",
+                iterable=relation_types,
+                keyvalues={"relates_to_id": event_id},
+                retcols=["event_id"],
+            )
+
+            return [row["event_id"] for row in rows]
+
+        return await self.db_pool.runInteraction(
+            desc="get_all_relation_ids_for_event_with_types",
+            func=get_all_relation_ids_for_event_with_types_txn,
+        )
+
     async def event_includes_relation(self, event_id: str) -> bool:
         """Check if the given event relates to another event.
 
diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py
index be4c67d68e..5dfe44defb 100644
--- a/tests/rest/client/test_redactions.py
+++ b/tests/rest/client/test_redactions.py
@@ -11,17 +11,18 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import List
+from typing import List, Optional
 
 from twisted.test.proto_helpers import MemoryReactor
 
+from synapse.api.constants import EventTypes, RelationTypes
 from synapse.rest import admin
 from synapse.rest.client import login, room, sync
 from synapse.server import HomeServer
 from synapse.types import JsonDict
 from synapse.util import Clock
 
-from tests.unittest import HomeserverTestCase
+from tests.unittest import HomeserverTestCase, override_config
 
 
 class RedactionsTestCase(HomeserverTestCase):
@@ -67,7 +68,12 @@ class RedactionsTestCase(HomeserverTestCase):
         )
 
     def _redact_event(
-        self, access_token: str, room_id: str, event_id: str, expect_code: int = 200
+        self,
+        access_token: str,
+        room_id: str,
+        event_id: str,
+        expect_code: int = 200,
+        with_relations: Optional[List[str]] = None,
     ) -> JsonDict:
         """Helper function to send a redaction event.
 
@@ -75,7 +81,13 @@ class RedactionsTestCase(HomeserverTestCase):
         """
         path = "/_matrix/client/r0/rooms/%s/redact/%s" % (room_id, event_id)
 
-        channel = self.make_request("POST", path, content={}, access_token=access_token)
+        request_content = {}
+        if with_relations:
+            request_content["org.matrix.msc3912.with_relations"] = with_relations
+
+        channel = self.make_request(
+            "POST", path, request_content, access_token=access_token
+        )
         self.assertEqual(channel.code, expect_code)
         return channel.json_body
 
@@ -201,3 +213,256 @@ class RedactionsTestCase(HomeserverTestCase):
             # These should all succeed, even though this would be denied by
             # the standard message ratelimiter
             self._redact_event(self.mod_access_token, self.room_id, msg_id)
+
+    @override_config({"experimental_features": {"msc3912_enabled": True}})
+    def test_redact_relations(self) -> None:
+        """Tests that we can redact the relations of an event at the same time as the
+        event itself.
+        """
+        # Send a root event.
+        res = self.helper.send_event(
+            room_id=self.room_id,
+            type=EventTypes.Message,
+            content={"msgtype": "m.text", "body": "hello"},
+            tok=self.mod_access_token,
+        )
+        root_event_id = res["event_id"]
+
+        # Send an edit to this root event.
+        res = self.helper.send_event(
+            room_id=self.room_id,
+            type=EventTypes.Message,
+            content={
+                "body": " * hello world",
+                "m.new_content": {
+                    "body": "hello world",
+                    "msgtype": "m.text",
+                },
+                "m.relates_to": {
+                    "event_id": root_event_id,
+                    "rel_type": RelationTypes.REPLACE,
+                },
+                "msgtype": "m.text",
+            },
+            tok=self.mod_access_token,
+        )
+        edit_event_id = res["event_id"]
+
+        # Also send a threaded message whose root is the same as the edit's.
+        res = self.helper.send_event(
+            room_id=self.room_id,
+            type=EventTypes.Message,
+            content={
+                "msgtype": "m.text",
+                "body": "message 1",
+                "m.relates_to": {
+                    "event_id": root_event_id,
+                    "rel_type": RelationTypes.THREAD,
+                },
+            },
+            tok=self.mod_access_token,
+        )
+        threaded_event_id = res["event_id"]
+
+        # Also send a reaction, again with the same root.
+        res = self.helper.send_event(
+            room_id=self.room_id,
+            type=EventTypes.Reaction,
+            content={
+                "m.relates_to": {
+                    "rel_type": RelationTypes.ANNOTATION,
+                    "event_id": root_event_id,
+                    "key": "👍",
+                }
+            },
+            tok=self.mod_access_token,
+        )
+        reaction_event_id = res["event_id"]
+
+        # Redact the root event, specifying that we also want to delete events that
+        # relate to it with m.replace.
+        self._redact_event(
+            self.mod_access_token,
+            self.room_id,
+            root_event_id,
+            with_relations=[
+                RelationTypes.REPLACE,
+                RelationTypes.THREAD,
+            ],
+        )
+
+        # Check that the root event got redacted.
+        event_dict = self.helper.get_event(
+            self.room_id, root_event_id, self.mod_access_token
+        )
+        self.assertIn("redacted_because", event_dict, event_dict)
+
+        # Check that the edit got redacted.
+        event_dict = self.helper.get_event(
+            self.room_id, edit_event_id, self.mod_access_token
+        )
+        self.assertIn("redacted_because", event_dict, event_dict)
+
+        # Check that the threaded message got redacted.
+        event_dict = self.helper.get_event(
+            self.room_id, threaded_event_id, self.mod_access_token
+        )
+        self.assertIn("redacted_because", event_dict, event_dict)
+
+        # Check that the reaction did not get redacted.
+        event_dict = self.helper.get_event(
+            self.room_id, reaction_event_id, self.mod_access_token
+        )
+        self.assertNotIn("redacted_because", event_dict, event_dict)
+
+    @override_config({"experimental_features": {"msc3912_enabled": True}})
+    def test_redact_relations_no_perms(self) -> None:
+        """Tests that, when redacting a message along with its relations, if not all
+        the related messages can be redacted because of insufficient permissions, the
+        server still redacts all the ones that can be.
+        """
+        # Send a root event.
+        res = self.helper.send_event(
+            room_id=self.room_id,
+            type=EventTypes.Message,
+            content={
+                "msgtype": "m.text",
+                "body": "root",
+            },
+            tok=self.other_access_token,
+        )
+        root_event_id = res["event_id"]
+
+        # Send a first threaded message, this one from the moderator. We do this for the
+        # first message with the m.thread relation (and not the last one) to ensure
+        # that, when the server fails to redact it, it doesn't stop there, and it
+        # instead goes on to redact the other one.
+        res = self.helper.send_event(
+            room_id=self.room_id,
+            type=EventTypes.Message,
+            content={
+                "msgtype": "m.text",
+                "body": "message 1",
+                "m.relates_to": {
+                    "event_id": root_event_id,
+                    "rel_type": RelationTypes.THREAD,
+                },
+            },
+            tok=self.mod_access_token,
+        )
+        first_threaded_event_id = res["event_id"]
+
+        # Send a second threaded message, this time from the user who'll perform the
+        # redaction.
+        res = self.helper.send_event(
+            room_id=self.room_id,
+            type=EventTypes.Message,
+            content={
+                "msgtype": "m.text",
+                "body": "message 2",
+                "m.relates_to": {
+                    "event_id": root_event_id,
+                    "rel_type": RelationTypes.THREAD,
+                },
+            },
+            tok=self.other_access_token,
+        )
+        second_threaded_event_id = res["event_id"]
+
+        # Redact the thread's root, and request that all threaded messages are also
+        # redacted. Send that request from the non-mod user, so that the first threaded
+        # event cannot be redacted.
+        self._redact_event(
+            self.other_access_token,
+            self.room_id,
+            root_event_id,
+            with_relations=[RelationTypes.THREAD],
+        )
+
+        # Check that the thread root got redacted.
+        event_dict = self.helper.get_event(
+            self.room_id, root_event_id, self.other_access_token
+        )
+        self.assertIn("redacted_because", event_dict, event_dict)
+
+        # Check that the last message in the thread got redacted, despite failing to
+        # redact the one before it.
+        event_dict = self.helper.get_event(
+            self.room_id, second_threaded_event_id, self.other_access_token
+        )
+        self.assertIn("redacted_because", event_dict, event_dict)
+
+        # Check that the message that was sent into the tread by the mod user is not
+        # redacted.
+        event_dict = self.helper.get_event(
+            self.room_id, first_threaded_event_id, self.other_access_token
+        )
+        self.assertIn("body", event_dict["content"], event_dict)
+        self.assertEqual("message 1", event_dict["content"]["body"])
+
+    @override_config({"experimental_features": {"msc3912_enabled": True}})
+    def test_redact_relations_txn_id_reuse(self) -> None:
+        """Tests that redacting a message using a transaction ID, then reusing the same
+        transaction ID but providing an additional list of relations to redact, is
+        effectively a no-op.
+        """
+        # Send a root event.
+        res = self.helper.send_event(
+            room_id=self.room_id,
+            type=EventTypes.Message,
+            content={
+                "msgtype": "m.text",
+                "body": "root",
+            },
+            tok=self.mod_access_token,
+        )
+        root_event_id = res["event_id"]
+
+        # Send a first threaded message.
+        res = self.helper.send_event(
+            room_id=self.room_id,
+            type=EventTypes.Message,
+            content={
+                "msgtype": "m.text",
+                "body": "I'm in a thread!",
+                "m.relates_to": {
+                    "event_id": root_event_id,
+                    "rel_type": RelationTypes.THREAD,
+                },
+            },
+            tok=self.mod_access_token,
+        )
+        threaded_event_id = res["event_id"]
+
+        # Send a first redaction request which redacts only the root event.
+        channel = self.make_request(
+            method="PUT",
+            path=f"/rooms/{self.room_id}/redact/{root_event_id}/foo",
+            content={},
+            access_token=self.mod_access_token,
+        )
+        self.assertEqual(channel.code, 200)
+
+        # Send a second redaction request which redacts the root event as well as
+        # threaded messages.
+        channel = self.make_request(
+            method="PUT",
+            path=f"/rooms/{self.room_id}/redact/{root_event_id}/foo",
+            content={"org.matrix.msc3912.with_relations": [RelationTypes.THREAD]},
+            access_token=self.mod_access_token,
+        )
+        self.assertEqual(channel.code, 200)
+
+        # Check that the root event got redacted.
+        event_dict = self.helper.get_event(
+            self.room_id, root_event_id, self.mod_access_token
+        )
+        self.assertIn("redacted_because", event_dict)
+
+        # Check that the threaded message didn't get redacted (since that wasn't part of
+        # the original redaction).
+        event_dict = self.helper.get_event(
+            self.room_id, threaded_event_id, self.mod_access_token
+        )
+        self.assertIn("body", event_dict["content"], event_dict)
+        self.assertEqual("I'm in a thread!", event_dict["content"]["body"])
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index 706399fae5..8d6f2b6ff9 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -410,6 +410,43 @@ class RestHelper:
 
         return channel.json_body
 
+    def get_event(
+        self,
+        room_id: str,
+        event_id: str,
+        tok: Optional[str] = None,
+        expect_code: int = HTTPStatus.OK,
+    ) -> JsonDict:
+        """Request a specific event from the server.
+
+        Args:
+            room_id: the room in which the event was sent.
+            event_id: the event's ID.
+            tok: the token to request the event with.
+            expect_code: the expected HTTP status for the response.
+
+        Returns:
+            The event as a dict.
+        """
+        path = f"/_matrix/client/v3/rooms/{room_id}/event/{event_id}"
+        if tok:
+            path = path + f"?access_token={tok}"
+
+        channel = make_request(
+            self.hs.get_reactor(),
+            self.site,
+            "GET",
+            path,
+        )
+
+        assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
+            expect_code,
+            channel.code,
+            channel.result["body"],
+        )
+
+        return channel.json_body
+
     def _read_write_state(
         self,
         room_id: str,