summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/11126.feature1
-rw-r--r--docs/modules/third_party_rules_callbacks.md21
-rw-r--r--synapse/events/third_party_rules.py31
-rw-r--r--synapse/handlers/federation_event.py2
-rw-r--r--synapse/handlers/message.py9
-rw-r--r--synapse/notifier.py17
-rw-r--r--synapse/replication/tcp/client.py3
-rw-r--r--tests/rest/client/test_third_party_rules.py93
8 files changed, 165 insertions, 12 deletions
diff --git a/changelog.d/11126.feature b/changelog.d/11126.feature
new file mode 100644
index 0000000000..c6078fe081
--- /dev/null
+++ b/changelog.d/11126.feature
@@ -0,0 +1 @@
+Add an `on_new_event` third-party rules callback to allow Synapse modules to act after an event has been sent into a room.
diff --git a/docs/modules/third_party_rules_callbacks.md b/docs/modules/third_party_rules_callbacks.md
index 034923da0f..a16e272f79 100644
--- a/docs/modules/third_party_rules_callbacks.md
+++ b/docs/modules/third_party_rules_callbacks.md
@@ -119,6 +119,27 @@ callback returns `True`, Synapse falls through to the next one. The value of the
 callback that does not return `True` will be used. If this happens, Synapse will not call
 any of the subsequent implementations of this callback.
 
+### `on_new_event`
+
+_First introduced in Synapse v1.47.0_
+
+```python
+async def on_new_event(
+    event: "synapse.events.EventBase",
+    state_events: "synapse.types.StateMap",
+) -> None:
+```
+
+Called after sending an event into a room. The module is passed the event, as well
+as the state of the room _after_ the event. This means that if the event is a state event,
+it will be included in this state.
+
+Note that this callback is called when the event has already been processed and stored
+into the room, which means this callback cannot be used to deny persisting the event. To
+deny an incoming event, see [`check_event_for_spam`](spam_checker_callbacks.md#check_event_for_spam) instead.
+
+If multiple modules implement this callback, Synapse runs them all in order.
+
 ## Example
 
 The example below is a module that implements the third-party rules callback
diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
index 2a6dabdab6..8816ef4b76 100644
--- a/synapse/events/third_party_rules.py
+++ b/synapse/events/third_party_rules.py
@@ -36,6 +36,7 @@ CHECK_THREEPID_CAN_BE_INVITED_CALLBACK = Callable[
 CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK = Callable[
     [str, StateMap[EventBase], str], Awaitable[bool]
 ]
+ON_NEW_EVENT_CALLBACK = Callable[[EventBase, StateMap[EventBase]], Awaitable]
 
 
 def load_legacy_third_party_event_rules(hs: "HomeServer") -> None:
@@ -152,6 +153,7 @@ class ThirdPartyEventRules:
         self._check_visibility_can_be_modified_callbacks: List[
             CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK
         ] = []
+        self._on_new_event_callbacks: List[ON_NEW_EVENT_CALLBACK] = []
 
     def register_third_party_rules_callbacks(
         self,
@@ -163,6 +165,7 @@ class ThirdPartyEventRules:
         check_visibility_can_be_modified: Optional[
             CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK
         ] = None,
+        on_new_event: Optional[ON_NEW_EVENT_CALLBACK] = None,
     ) -> None:
         """Register callbacks from modules for each hook."""
         if check_event_allowed is not None:
@@ -181,6 +184,9 @@ class ThirdPartyEventRules:
                 check_visibility_can_be_modified,
             )
 
+        if on_new_event is not None:
+            self._on_new_event_callbacks.append(on_new_event)
+
     async def check_event_allowed(
         self, event: EventBase, context: EventContext
     ) -> Tuple[bool, Optional[dict]]:
@@ -321,6 +327,31 @@ class ThirdPartyEventRules:
 
         return True
 
+    async def on_new_event(self, event_id: str) -> None:
+        """Let modules act on events after they've been sent (e.g. auto-accepting
+        invites, etc.)
+
+        Args:
+            event_id: The ID of the event.
+
+        Raises:
+            ModuleFailureError if a callback raised any exception.
+        """
+        # Bail out early without hitting the store if we don't have any callbacks
+        if len(self._on_new_event_callbacks) == 0:
+            return
+
+        event = await self.store.get_event(event_id)
+        state_events = await self._get_state_map_for_room(event.room_id)
+
+        for callback in self._on_new_event_callbacks:
+            try:
+                await callback(event, state_events)
+            except Exception as e:
+                logger.exception(
+                    "Failed to run module API callback %s: %s", callback, e
+                )
+
     async def _get_state_map_for_room(self, room_id: str) -> StateMap[EventBase]:
         """Given a room ID, return the state events of that room.
 
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 9584d5bd46..bd1fa08cef 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -1916,7 +1916,7 @@ class FederationEventHandler:
         event_pos = PersistedEventPosition(
             self._instance_name, event.internal_metadata.stream_ordering
         )
-        self._notifier.on_new_room_event(
+        await self._notifier.on_new_room_event(
             event, event_pos, max_stream_token, extra_users=extra_users
         )
 
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 2e024b551f..4a0fccfcc6 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -1537,13 +1537,16 @@ class EventCreationHandler:
             # If there's an expiry timestamp on the event, schedule its expiry.
             self._message_handler.maybe_schedule_expiry(event)
 
-        def _notify() -> None:
+        async def _notify() -> None:
             try:
-                self.notifier.on_new_room_event(
+                await self.notifier.on_new_room_event(
                     event, event_pos, max_stream_token, extra_users=extra_users
                 )
             except Exception:
-                logger.exception("Error notifying about new room event")
+                logger.exception(
+                    "Error notifying about new room event %s",
+                    event.event_id,
+                )
 
         run_in_background(_notify)
 
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 1acd899fab..1882fffd2a 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -220,6 +220,8 @@ class Notifier:
         # down.
         self.remote_server_up_callbacks: List[Callable[[str], None]] = []
 
+        self._third_party_rules = hs.get_third_party_event_rules()
+
         self.clock = hs.get_clock()
         self.appservice_handler = hs.get_application_service_handler()
         self._pusher_pool = hs.get_pusherpool()
@@ -267,7 +269,7 @@ class Notifier:
         """
         self.replication_callbacks.append(cb)
 
-    def on_new_room_event(
+    async def on_new_room_event(
         self,
         event: EventBase,
         event_pos: PersistedEventPosition,
@@ -275,9 +277,10 @@ class Notifier:
         extra_users: Optional[Collection[UserID]] = None,
     ):
         """Unwraps event and calls `on_new_room_event_args`."""
-        self.on_new_room_event_args(
+        await self.on_new_room_event_args(
             event_pos=event_pos,
             room_id=event.room_id,
+            event_id=event.event_id,
             event_type=event.type,
             state_key=event.get("state_key"),
             membership=event.content.get("membership"),
@@ -285,9 +288,10 @@ class Notifier:
             extra_users=extra_users or [],
         )
 
-    def on_new_room_event_args(
+    async def on_new_room_event_args(
         self,
         room_id: str,
+        event_id: str,
         event_type: str,
         state_key: Optional[str],
         membership: Optional[str],
@@ -302,7 +306,10 @@ class Notifier:
         listening to the room, and any listeners for the users in the
         `extra_users` param.
 
-        The events can be peristed out of order. The notifier will wait
+        This also notifies modules listening on new events via the
+        `on_new_event` callback.
+
+        The events can be persisted out of order. The notifier will wait
         until all previous events have been persisted before notifying
         the client streams.
         """
@@ -318,6 +325,8 @@ class Notifier:
         )
         self._notify_pending_new_room_events(max_room_stream_token)
 
+        await self._third_party_rules.on_new_event(event_id)
+
         self.notify_replication()
 
     def _notify_pending_new_room_events(self, max_room_stream_token: RoomStreamToken):
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 961c17762e..e29ae1e375 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -207,11 +207,12 @@ class ReplicationDataHandler:
 
                 max_token = self.store.get_room_max_token()
                 event_pos = PersistedEventPosition(instance_name, token)
-                self.notifier.on_new_room_event_args(
+                await self.notifier.on_new_room_event_args(
                     event_pos=event_pos,
                     max_room_stream_token=max_token,
                     extra_users=extra_users,
                     room_id=row.data.room_id,
+                    event_id=row.data.event_id,
                     event_type=row.data.type,
                     state_key=row.data.state_key,
                     membership=row.data.membership,
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index 531f09c48b..1c42c46630 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -15,7 +15,7 @@ import threading
 from typing import TYPE_CHECKING, Dict, Optional, Tuple
 from unittest.mock import Mock
 
-from synapse.api.constants import EventTypes
+from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import SynapseError
 from synapse.events import EventBase
 from synapse.events.third_party_rules import load_legacy_third_party_event_rules
@@ -25,6 +25,7 @@ from synapse.types import JsonDict, Requester, StateMap
 from synapse.util.frozenutils import unfreeze
 
 from tests import unittest
+from tests.test_utils import make_awaitable
 
 if TYPE_CHECKING:
     from synapse.module_api import ModuleApi
@@ -74,7 +75,7 @@ class LegacyChangeEvents(LegacyThirdPartyRulesTestModule):
         return d
 
 
-class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
+class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
     servlets = [
         admin.register_servlets,
         login.register_servlets,
@@ -86,11 +87,29 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
 
         load_legacy_third_party_event_rules(hs)
 
+        # We're not going to be properly signing events as our remote homeserver is fake,
+        # therefore disable event signature checks.
+        # Note that these checks are not relevant to this test case.
+
+        # Have this homeserver auto-approve all event signature checking.
+        async def approve_all_signature_checking(_, pdu):
+            return pdu
+
+        hs.get_federation_server()._check_sigs_and_hash = approve_all_signature_checking
+
+        # Have this homeserver skip event auth checks. This is necessary due to
+        # event auth checks ensuring that events were signed by the sender's homeserver.
+        async def _check_event_auth(origin, event, context, *args, **kwargs):
+            return context
+
+        hs.get_federation_event_handler()._check_event_auth = _check_event_auth
+
         return hs
 
     def prepare(self, reactor, clock, homeserver):
-        # Create a user and room to play with during the tests
+        # Create some users and a room to play with during the tests
         self.user_id = self.register_user("kermit", "monkey")
+        self.invitee = self.register_user("invitee", "hackme")
         self.tok = self.login("kermit", "monkey")
 
         # Some tests might prevent room creation on purpose.
@@ -424,6 +443,74 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
             self.assertEqual(channel.code, 200)
             self.assertEqual(channel.json_body["i"], i)
 
+    def test_on_new_event(self):
+        """Test that the on_new_event callback is called on new events"""
+        on_new_event = Mock(make_awaitable(None))
+        self.hs.get_third_party_event_rules()._on_new_event_callbacks.append(
+            on_new_event
+        )
+
+        # Send a message event to the room and check that the callback is called.
+        self.helper.send(room_id=self.room_id, tok=self.tok)
+        self.assertEqual(on_new_event.call_count, 1)
+
+        # Check that the callback is also called on membership updates.
+        self.helper.invite(
+            room=self.room_id,
+            src=self.user_id,
+            targ=self.invitee,
+            tok=self.tok,
+        )
+
+        self.assertEqual(on_new_event.call_count, 2)
+
+        args, _ = on_new_event.call_args
+
+        self.assertEqual(args[0].membership, Membership.INVITE)
+        self.assertEqual(args[0].state_key, self.invitee)
+
+        # Check that the invitee's membership is correct in the state that's passed down
+        # to the callback.
+        self.assertEqual(
+            args[1][(EventTypes.Member, self.invitee)].membership,
+            Membership.INVITE,
+        )
+
+        # Send an event over federation and check that the callback is also called.
+        self._send_event_over_federation()
+        self.assertEqual(on_new_event.call_count, 3)
+
+    def _send_event_over_federation(self) -> None:
+        """Send a dummy event over federation and check that the request succeeds."""
+        body = {
+            "origin": self.hs.config.server.server_name,
+            "origin_server_ts": self.clock.time_msec(),
+            "pdus": [
+                {
+                    "sender": self.user_id,
+                    "type": EventTypes.Message,
+                    "state_key": "",
+                    "content": {"body": "hello world", "msgtype": "m.text"},
+                    "room_id": self.room_id,
+                    "depth": 0,
+                    "origin_server_ts": self.clock.time_msec(),
+                    "prev_events": [],
+                    "auth_events": [],
+                    "signatures": {},
+                    "unsigned": {},
+                }
+            ],
+        }
+
+        channel = self.make_request(
+            method="PUT",
+            path="/_matrix/federation/v1/send/1",
+            content=body,
+            federation_auth_origin=self.hs.config.server.server_name.encode("utf8"),
+        )
+
+        self.assertEqual(channel.code, 200, channel.result)
+
     def _update_power_levels(self, event_default: int = 0):
         """Updates the room's power levels.