summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/8535.feature1
-rw-r--r--synapse/events/__init__.py6
-rw-r--r--synapse/events/third_party_rules.py19
-rw-r--r--synapse/handlers/federation.py66
-rw-r--r--synapse/handlers/message.py79
-rw-r--r--tests/rest/client/test_third_party_rules.py28
6 files changed, 118 insertions, 81 deletions
diff --git a/changelog.d/8535.feature b/changelog.d/8535.feature
new file mode 100644
index 0000000000..45342e66ad
--- /dev/null
+++ b/changelog.d/8535.feature
@@ -0,0 +1 @@
+Support modifying event content in `ThirdPartyRules` modules.
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index 7a51d0a22f..65df62107f 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -312,6 +312,12 @@ class EventBase(metaclass=abc.ABCMeta):
         """
         return [e for e, _ in self.auth_events]
 
+    def freeze(self):
+        """'Freeze' the event dict, so it cannot be modified by accident"""
+
+        # this will be a no-op if the event dict is already frozen.
+        self._dict = freeze(self._dict)
+
 
 class FrozenEvent(EventBase):
     format_version = EventFormatVersions.V1  # All events of this type are V1
diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
index 1535cc5339..77fbd3f68a 100644
--- a/synapse/events/third_party_rules.py
+++ b/synapse/events/third_party_rules.py
@@ -12,7 +12,8 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import Callable
+
+from typing import Callable, Union
 
 from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
@@ -44,15 +45,20 @@ class ThirdPartyEventRules:
 
     async def check_event_allowed(
         self, event: EventBase, context: EventContext
-    ) -> bool:
+    ) -> Union[bool, dict]:
         """Check if a provided event should be allowed in the given context.
 
+        The module can return:
+            * True: the event is allowed.
+            * False: the event is not allowed, and should be rejected with M_FORBIDDEN.
+            * a dict: replacement event data.
+
         Args:
             event: The event to be checked.
             context: The context of the event.
 
         Returns:
-            True if the event should be allowed, False if not.
+            The result from the ThirdPartyRules module, as above
         """
         if self.third_party_rules is None:
             return True
@@ -63,9 +69,10 @@ class ThirdPartyEventRules:
         events = await self.store.get_events(prev_state_ids.values())
         state_events = {(ev.type, ev.state_key): ev for ev in events.values()}
 
-        # The module can modify the event slightly if it wants, but caution should be
-        # exercised, and it's likely to go very wrong if applied to events received over
-        # federation.
+        # Ensure that the event is frozen, to make sure that the module is not tempted
+        # to try to modify it. Any attempt to modify it at this point will invalidate
+        # the hashes and signatures.
+        event.freeze()
 
         return await self.third_party_rules.check_event_allowed(event, state_events)
 
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 455acd7669..fde8f00531 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1507,18 +1507,9 @@ class FederationHandler(BaseHandler):
             event, context = await self.event_creation_handler.create_new_client_event(
                 builder=builder
             )
-        except AuthError as e:
+        except SynapseError as e:
             logger.warning("Failed to create join to %s because %s", room_id, e)
-            raise e
-
-        event_allowed = await self.third_party_event_rules.check_event_allowed(
-            event, context
-        )
-        if not event_allowed:
-            logger.info("Creation of join %s forbidden by third-party rules", event)
-            raise SynapseError(
-                403, "This event is not allowed in this context", Codes.FORBIDDEN
-            )
+            raise
 
         # The remote hasn't signed it yet, obviously. We'll do the full checks
         # when we get the event back in `on_send_join_request`
@@ -1567,15 +1558,6 @@ class FederationHandler(BaseHandler):
 
         context = await self._handle_new_event(origin, event)
 
-        event_allowed = await self.third_party_event_rules.check_event_allowed(
-            event, context
-        )
-        if not event_allowed:
-            logger.info("Sending of join %s forbidden by third-party rules", event)
-            raise SynapseError(
-                403, "This event is not allowed in this context", Codes.FORBIDDEN
-            )
-
         logger.debug(
             "on_send_join_request: After _handle_new_event: %s, sigs: %s",
             event.event_id,
@@ -1748,15 +1730,6 @@ class FederationHandler(BaseHandler):
             builder=builder
         )
 
-        event_allowed = await self.third_party_event_rules.check_event_allowed(
-            event, context
-        )
-        if not event_allowed:
-            logger.warning("Creation of leave %s forbidden by third-party rules", event)
-            raise SynapseError(
-                403, "This event is not allowed in this context", Codes.FORBIDDEN
-            )
-
         try:
             # The remote hasn't signed it yet, obviously. We'll do the full checks
             # when we get the event back in `on_send_leave_request`
@@ -1789,16 +1762,7 @@ class FederationHandler(BaseHandler):
 
         event.internal_metadata.outlier = False
 
-        context = await self._handle_new_event(origin, event)
-
-        event_allowed = await self.third_party_event_rules.check_event_allowed(
-            event, context
-        )
-        if not event_allowed:
-            logger.info("Sending of leave %s forbidden by third-party rules", event)
-            raise SynapseError(
-                403, "This event is not allowed in this context", Codes.FORBIDDEN
-            )
+        await self._handle_new_event(origin, event)
 
         logger.debug(
             "on_send_leave_request: After _handle_new_event: %s, sigs: %s",
@@ -2694,18 +2658,6 @@ class FederationHandler(BaseHandler):
                 builder=builder
             )
 
-            event_allowed = await self.third_party_event_rules.check_event_allowed(
-                event, context
-            )
-            if not event_allowed:
-                logger.info(
-                    "Creation of threepid invite %s forbidden by third-party rules",
-                    event,
-                )
-                raise SynapseError(
-                    403, "This event is not allowed in this context", Codes.FORBIDDEN
-                )
-
             event, context = await self.add_display_name_to_third_party_invite(
                 room_version, event_dict, event, context
             )
@@ -2756,18 +2708,6 @@ class FederationHandler(BaseHandler):
         event, context = await self.event_creation_handler.create_new_client_event(
             builder=builder
         )
-
-        event_allowed = await self.third_party_event_rules.check_event_allowed(
-            event, context
-        )
-        if not event_allowed:
-            logger.warning(
-                "Exchange of threepid invite %s forbidden by third-party rules", event
-            )
-            raise SynapseError(
-                403, "This event is not allowed in this context", Codes.FORBIDDEN
-            )
-
         event, context = await self.add_display_name_to_third_party_invite(
             room_version, event_dict, event, context
         )
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index f18f882596..7f00805a91 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -811,6 +811,23 @@ class EventCreationHandler:
         if requester:
             context.app_service = requester.app_service
 
+        third_party_result = await self.third_party_event_rules.check_event_allowed(
+            event, context
+        )
+        if not third_party_result:
+            logger.info(
+                "Event %s forbidden by third-party rules", event,
+            )
+            raise SynapseError(
+                403, "This event is not allowed in this context", Codes.FORBIDDEN
+            )
+        elif isinstance(third_party_result, dict):
+            # the third-party rules want to replace the event. We'll need to build a new
+            # event.
+            event, context = await self._rebuild_event_after_third_party_rules(
+                third_party_result, event
+            )
+
         self.validator.validate_new(event, self.config)
 
         # If this event is an annotation then we check that that the sender
@@ -897,14 +914,6 @@ class EventCreationHandler:
         else:
             room_version = await self.store.get_room_version_id(event.room_id)
 
-        event_allowed = await self.third_party_event_rules.check_event_allowed(
-            event, context
-        )
-        if not event_allowed:
-            raise SynapseError(
-                403, "This event is not allowed in this context", Codes.FORBIDDEN
-            )
-
         if event.internal_metadata.is_out_of_band_membership():
             # the only sort of out-of-band-membership events we expect to see here
             # are invite rejections we have generated ourselves.
@@ -1307,3 +1316,57 @@ class EventCreationHandler:
                 room_id,
             )
             del self._rooms_to_exclude_from_dummy_event_insertion[room_id]
+
+    async def _rebuild_event_after_third_party_rules(
+        self, third_party_result: dict, original_event: EventBase
+    ) -> Tuple[EventBase, EventContext]:
+        # the third_party_event_rules want to replace the event.
+        # we do some basic checks, and then return the replacement event and context.
+
+        # Construct a new EventBuilder and validate it, which helps with the
+        # rest of these checks.
+        try:
+            builder = self.event_builder_factory.for_room_version(
+                original_event.room_version, third_party_result
+            )
+            self.validator.validate_builder(builder)
+        except SynapseError as e:
+            raise Exception(
+                "Third party rules module created an invalid event: " + e.msg,
+            )
+
+        immutable_fields = [
+            # changing the room is going to break things: we've already checked that the
+            # room exists, and are holding a concurrency limiter token for that room.
+            # Also, we might need to use a different room version.
+            "room_id",
+            # changing the type or state key might work, but we'd need to check that the
+            # calling functions aren't making assumptions about them.
+            "type",
+            "state_key",
+        ]
+
+        for k in immutable_fields:
+            if getattr(builder, k, None) != original_event.get(k):
+                raise Exception(
+                    "Third party rules module created an invalid event: "
+                    "cannot change field " + k
+                )
+
+        # check that the new sender belongs to this HS
+        if not self.hs.is_mine_id(builder.sender):
+            raise Exception(
+                "Third party rules module created an invalid event: "
+                "invalid sender " + builder.sender
+            )
+
+        # copy over the original internal metadata
+        for k, v in original_event.internal_metadata.get_dict().items():
+            setattr(builder.internal_metadata, k, v)
+
+        event = await builder.build(prev_event_ids=original_event.prev_event_ids())
+
+        # we rebuild the event context, to be on the safe side. If nothing else,
+        # delta_ids might need an update.
+        context = await self.state.compute_event_context(event)
+        return event, context
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index b737625e33..0048bea54a 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -114,10 +114,10 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
         self.render(request)
         self.assertEquals(channel.result["code"], b"403", channel.result)
 
-    def test_modify_event(self):
-        """Tests that the module can successfully tweak an event before it is persisted.
-        """
-        # first patch the event checker so that it will modify the event
+    def test_cannot_modify_event(self):
+        """cannot accidentally modify an event before it is persisted"""
+
+        # first patch the event checker so that it will try to modify the event
         async def check(ev: EventBase, state):
             ev.content = {"x": "y"}
             return True
@@ -132,6 +132,26 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
             access_token=self.tok,
         )
         self.render(request)
+        self.assertEqual(channel.result["code"], b"500", channel.result)
+
+    def test_modify_event(self):
+        """The module can return a modified version of the event"""
+        # first patch the event checker so that it will modify the event
+        async def check(ev: EventBase, state):
+            d = ev.get_dict()
+            d["content"] = {"x": "y"}
+            return d
+
+        current_rules_module().check_event_allowed = check
+
+        # now send the event
+        request, channel = self.make_request(
+            "PUT",
+            "/_matrix/client/r0/rooms/%s/send/modifyme/1" % self.room_id,
+            {"x": "x"},
+            access_token=self.tok,
+        )
+        self.render(request)
         self.assertEqual(channel.result["code"], b"200", channel.result)
         event_id = channel.json_body["event_id"]