summary refs log tree commit diff
diff options
context:
space:
mode:
authorBrendan Abolivier <babolivier@matrix.org>2021-07-20 12:39:46 +0200
committerGitHub <noreply@github.com>2021-07-20 12:39:46 +0200
commita743bf46949e851c9a10d8e01a138659f3af2484 (patch)
treefc6fe51e777b197c83d3aef49011fe13d35b7130
parentFix exception when failing to get remote room list (#10414) (diff)
downloadsynapse-a743bf46949e851c9a10d8e01a138659f3af2484.tar.xz
Port the ThirdPartyEventRules module interface to the new generic interface (#10386)
Port the third-party event rules interface to the generic module interface introduced in v1.37.0
-rw-r--r--changelog.d/10386.removal1
-rw-r--r--docs/modules.md62
-rw-r--r--docs/sample_config.yaml13
-rw-r--r--docs/upgrade.md13
-rw-r--r--synapse/app/_base.py2
-rw-r--r--synapse/config/third_party_event_rules.py15
-rw-r--r--synapse/events/third_party_rules.py245
-rw-r--r--synapse/handlers/federation.py4
-rw-r--r--synapse/handlers/message.py8
-rw-r--r--synapse/handlers/room.py10
-rw-r--r--synapse/module_api/__init__.py6
-rw-r--r--tests/rest/client/test_third_party_rules.py132
12 files changed, 403 insertions, 108 deletions
diff --git a/changelog.d/10386.removal b/changelog.d/10386.removal
new file mode 100644
index 0000000000..800a6143d7
--- /dev/null
+++ b/changelog.d/10386.removal
@@ -0,0 +1 @@
+The third-party event rules module interface is deprecated in favour of the generic module interface introduced in Synapse v1.37.0. See the [upgrade notes](https://matrix-org.github.io/synapse/latest/upgrade.html#upgrading-to-v1390) for more information.
diff --git a/docs/modules.md b/docs/modules.md
index c4cb7018f7..9a430390a4 100644
--- a/docs/modules.md
+++ b/docs/modules.md
@@ -186,7 +186,7 @@ The arguments passed to this callback are:
 ```python
 async def check_media_file_for_spam(
     file_wrapper: "synapse.rest.media.v1.media_storage.ReadableFileWrapper",
-    file_info: "synapse.rest.media.v1._base.FileInfo"
+    file_info: "synapse.rest.media.v1._base.FileInfo",
 ) -> bool
 ```
 
@@ -223,6 +223,66 @@ Called after successfully registering a user, in case the module needs to perfor
 operations to keep track of them. (e.g. add them to a database table). The user is
 represented by their Matrix user ID.
 
+#### Third party rules callbacks
+
+Third party rules callbacks allow module developers to add extra checks to verify the
+validity of incoming events. Third party event rules callbacks can be registered using
+the module API's `register_third_party_rules_callbacks` method.
+
+The available third party rules callbacks are:
+
+```python
+async def check_event_allowed(
+    event: "synapse.events.EventBase",
+    state_events: "synapse.types.StateMap",
+) -> Tuple[bool, Optional[dict]]
+```
+
+**<span style="color:red">
+This callback is very experimental and can and will break without notice. Module developers
+are encouraged to implement `check_event_for_spam` from the spam checker category instead.
+</span>**
+
+Called when processing any incoming event, with the event and a `StateMap`
+representing the current state of the room the event is being sent into. A `StateMap` is
+a dictionary that maps tuples containing an event type and a state key to the
+corresponding state event. For example retrieving the room's `m.room.create` event from
+the `state_events` argument would look like this: `state_events.get(("m.room.create", ""))`.
+The module must return a boolean indicating whether the event can be allowed.
+
+Note that this callback function processes incoming events coming via federation
+traffic (on top of client traffic). This means denying an event might cause the local
+copy of the room's history to diverge from that of remote servers. This may cause
+federation issues in the room. It is strongly recommended to only deny events using this
+callback function if the sender is a local user, or in a private federation in which all
+servers are using the same module, with the same configuration.
+
+If the boolean returned by the module is `True`, it may also tell Synapse to replace the
+event with new data by returning the new event's data as a dictionary. In order to do
+that, it is recommended the module calls `event.get_dict()` to get the current event as a
+dictionary, and modify the returned dictionary accordingly.
+
+Note that replacing the event only works for events sent by local users, not for events
+received over federation.
+
+```python
+async def on_create_room(
+    requester: "synapse.types.Requester",
+    request_content: dict,
+    is_requester_admin: bool,
+) -> None
+```
+
+Called when processing a room creation request, with the `Requester` object for the user
+performing the request, a dictionary representing the room creation request's JSON body
+(see [the spec](https://matrix.org/docs/spec/client_server/latest#post-matrix-client-r0-createroom)
+for a list of possible parameters), and a boolean indicating whether the user performing
+the request is a server admin.
+
+Modules can modify the `request_content` (by e.g. adding events to its `initial_state`),
+or deny the room's creation by raising a `module_api.errors.SynapseError`.
+
+
 ### Porting an existing module that uses the old interface
 
 In order to port a module that uses Synapse's old module interface, its author needs to:
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index f4845a5841..853c2f6899 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -2654,19 +2654,6 @@ stats:
 #    action: allow
 
 
-# Server admins can define a Python module that implements extra rules for
-# allowing or denying incoming events. In order to work, this module needs to
-# override the methods defined in synapse/events/third_party_rules.py.
-#
-# This feature is designed to be used in closed federations only, where each
-# participating server enforces the same rules.
-#
-#third_party_event_rules:
-#  module: "my_custom_project.SuperRulesSet"
-#  config:
-#    example_option: 'things'
-
-
 ## Opentracing ##
 
 # These settings enable opentracing, which implements distributed tracing.
diff --git a/docs/upgrade.md b/docs/upgrade.md
index db0450f563..c8f4a2c171 100644
--- a/docs/upgrade.md
+++ b/docs/upgrade.md
@@ -86,6 +86,19 @@ process, for example:
     ```
 
 
+# Upgrading to v1.39.0
+
+## Deprecation of the current third-party rules module interface
+
+The current third-party rules module interface is deprecated in favour of the new generic
+modules system introduced in Synapse v1.37.0. Authors of third-party rules modules can refer
+to [this documentation](modules.md#porting-an-existing-module-that-uses-the-old-interface)
+to update their modules. Synapse administrators can refer to [this documentation](modules.md#using-modules)
+to update their configuration once the modules they are using have been updated.
+
+We plan to remove support for the current third-party rules interface in September 2021.
+
+
 # Upgrading to v1.38.0
 
 ## Re-indexing of `events` table on Postgres databases
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index b30571fe49..50a02f51f5 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -38,6 +38,7 @@ from synapse.app.phone_stats_home import start_phone_stats_home
 from synapse.config.homeserver import HomeServerConfig
 from synapse.crypto import context_factory
 from synapse.events.spamcheck import load_legacy_spam_checkers
+from synapse.events.third_party_rules import load_legacy_third_party_event_rules
 from synapse.logging.context import PreserveLoggingContext
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.metrics.jemalloc import setup_jemalloc_stats
@@ -368,6 +369,7 @@ async def start(hs: "HomeServer"):
         module(config=config, api=module_api)
 
     load_legacy_spam_checkers(hs)
+    load_legacy_third_party_event_rules(hs)
 
     # If we've configured an expiry time for caches, start the background job now.
     setup_expire_lru_cache_entries(hs)
diff --git a/synapse/config/third_party_event_rules.py b/synapse/config/third_party_event_rules.py
index f502ff539e..a3fae02420 100644
--- a/synapse/config/third_party_event_rules.py
+++ b/synapse/config/third_party_event_rules.py
@@ -28,18 +28,3 @@ class ThirdPartyRulesConfig(Config):
             self.third_party_event_rules = load_module(
                 provider, ("third_party_event_rules",)
             )
-
-    def generate_config_section(self, **kwargs):
-        return """\
-        # Server admins can define a Python module that implements extra rules for
-        # allowing or denying incoming events. In order to work, this module needs to
-        # override the methods defined in synapse/events/third_party_rules.py.
-        #
-        # This feature is designed to be used in closed federations only, where each
-        # participating server enforces the same rules.
-        #
-        #third_party_event_rules:
-        #  module: "my_custom_project.SuperRulesSet"
-        #  config:
-        #    example_option: 'things'
-        """
diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
index f7944fd834..7a6eb3e516 100644
--- a/synapse/events/third_party_rules.py
+++ b/synapse/events/third_party_rules.py
@@ -11,16 +11,124 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import logging
+from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional, Tuple
 
-from typing import TYPE_CHECKING, Union
-
+from synapse.api.errors import SynapseError
 from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
 from synapse.types import Requester, StateMap
+from synapse.util.async_helpers import maybe_awaitable
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
 
+logger = logging.getLogger(__name__)
+
+
+CHECK_EVENT_ALLOWED_CALLBACK = Callable[
+    [EventBase, StateMap[EventBase]], Awaitable[Tuple[bool, Optional[dict]]]
+]
+ON_CREATE_ROOM_CALLBACK = Callable[[Requester, dict, bool], Awaitable]
+CHECK_THREEPID_CAN_BE_INVITED_CALLBACK = Callable[
+    [str, str, StateMap[EventBase]], Awaitable[bool]
+]
+CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK = Callable[
+    [str, StateMap[EventBase], str], Awaitable[bool]
+]
+
+
+def load_legacy_third_party_event_rules(hs: "HomeServer"):
+    """Wrapper that loads a third party event rules module configured using the old
+    configuration, and registers the hooks they implement.
+    """
+    if hs.config.third_party_event_rules is None:
+        return
+
+    module, config = hs.config.third_party_event_rules
+
+    api = hs.get_module_api()
+    third_party_rules = module(config=config, module_api=api)
+
+    # The known hooks. If a module implements a method which name appears in this set,
+    # we'll want to register it.
+    third_party_event_rules_methods = {
+        "check_event_allowed",
+        "on_create_room",
+        "check_threepid_can_be_invited",
+        "check_visibility_can_be_modified",
+    }
+
+    def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]:
+        # f might be None if the callback isn't implemented by the module. In this
+        # case we don't want to register a callback at all so we return None.
+        if f is None:
+            return None
+
+        # We return a separate wrapper for these methods because, in order to wrap them
+        # correctly, we need to await its result. Therefore it doesn't make a lot of
+        # sense to make it go through the run() wrapper.
+        if f.__name__ == "check_event_allowed":
+
+            # We need to wrap check_event_allowed because its old form would return either
+            # a boolean or a dict, but now we want to return the dict separately from the
+            # boolean.
+            async def wrap_check_event_allowed(
+                event: EventBase,
+                state_events: StateMap[EventBase],
+            ) -> Tuple[bool, Optional[dict]]:
+                # We've already made sure f is not None above, but mypy doesn't do well
+                # across function boundaries so we need to tell it f is definitely not
+                # None.
+                assert f is not None
+
+                res = await f(event, state_events)
+                if isinstance(res, dict):
+                    return True, res
+                else:
+                    return res, None
+
+            return wrap_check_event_allowed
+
+        if f.__name__ == "on_create_room":
+
+            # We need to wrap on_create_room because its old form would return a boolean
+            # if the room creation is denied, but now we just want it to raise an
+            # exception.
+            async def wrap_on_create_room(
+                requester: Requester, config: dict, is_requester_admin: bool
+            ) -> None:
+                # We've already made sure f is not None above, but mypy doesn't do well
+                # across function boundaries so we need to tell it f is definitely not
+                # None.
+                assert f is not None
+
+                res = await f(requester, config, is_requester_admin)
+                if res is False:
+                    raise SynapseError(
+                        403,
+                        "Room creation forbidden with these parameters",
+                    )
+
+            return wrap_on_create_room
+
+        def run(*args, **kwargs):
+            # mypy doesn't do well across function boundaries so we need to tell it
+            # f is definitely not None.
+            assert f is not None
+
+            return maybe_awaitable(f(*args, **kwargs))
+
+        return run
+
+    # Register the hooks through the module API.
+    hooks = {
+        hook: async_wrapper(getattr(third_party_rules, hook, None))
+        for hook in third_party_event_rules_methods
+    }
+
+    api.register_third_party_rules_callbacks(**hooks)
+
 
 class ThirdPartyEventRules:
     """Allows server admins to provide a Python module implementing an extra
@@ -35,36 +143,65 @@ class ThirdPartyEventRules:
 
         self.store = hs.get_datastore()
 
-        module = None
-        config = None
-        if hs.config.third_party_event_rules:
-            module, config = hs.config.third_party_event_rules
+        self._check_event_allowed_callbacks: List[CHECK_EVENT_ALLOWED_CALLBACK] = []
+        self._on_create_room_callbacks: List[ON_CREATE_ROOM_CALLBACK] = []
+        self._check_threepid_can_be_invited_callbacks: List[
+            CHECK_THREEPID_CAN_BE_INVITED_CALLBACK
+        ] = []
+        self._check_visibility_can_be_modified_callbacks: List[
+            CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK
+        ] = []
+
+    def register_third_party_rules_callbacks(
+        self,
+        check_event_allowed: Optional[CHECK_EVENT_ALLOWED_CALLBACK] = None,
+        on_create_room: Optional[ON_CREATE_ROOM_CALLBACK] = None,
+        check_threepid_can_be_invited: Optional[
+            CHECK_THREEPID_CAN_BE_INVITED_CALLBACK
+        ] = None,
+        check_visibility_can_be_modified: Optional[
+            CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK
+        ] = None,
+    ):
+        """Register callbacks from modules for each hook."""
+        if check_event_allowed is not None:
+            self._check_event_allowed_callbacks.append(check_event_allowed)
+
+        if on_create_room is not None:
+            self._on_create_room_callbacks.append(on_create_room)
+
+        if check_threepid_can_be_invited is not None:
+            self._check_threepid_can_be_invited_callbacks.append(
+                check_threepid_can_be_invited,
+            )
 
-        if module is not None:
-            self.third_party_rules = module(
-                config=config,
-                module_api=hs.get_module_api(),
+        if check_visibility_can_be_modified is not None:
+            self._check_visibility_can_be_modified_callbacks.append(
+                check_visibility_can_be_modified,
             )
 
     async def check_event_allowed(
         self, event: EventBase, context: EventContext
-    ) -> Union[bool, dict]:
+    ) -> Tuple[bool, Optional[dict]]:
         """Check if a provided event should be allowed in the given context.
 
         The module can return:
             * True: the event is allowed.
             * False: the event is not allowed, and should be rejected with M_FORBIDDEN.
-            * a dict: replacement event data.
+
+        If the event is allowed, the module can also return a dictionary to use as a
+        replacement for the event.
 
         Args:
             event: The event to be checked.
             context: The context of the event.
 
         Returns:
-            The result from the ThirdPartyRules module, as above
+            The result from the ThirdPartyRules module, as above.
         """
-        if self.third_party_rules is None:
-            return True
+        # Bail out early without hitting the store if we don't have any callbacks to run.
+        if len(self._check_event_allowed_callbacks) == 0:
+            return True, None
 
         prev_state_ids = await context.get_prev_state_ids()
 
@@ -77,29 +214,46 @@ class ThirdPartyEventRules:
         # the hashes and signatures.
         event.freeze()
 
-        return await self.third_party_rules.check_event_allowed(event, state_events)
+        for callback in self._check_event_allowed_callbacks:
+            try:
+                res, replacement_data = await callback(event, state_events)
+            except Exception as e:
+                logger.warning("Failed to run module API callback %s: %s", callback, e)
+                continue
+
+            # Return if the event shouldn't be allowed or if the module came up with a
+            # replacement dict for the event.
+            if res is False:
+                return res, None
+            elif isinstance(replacement_data, dict):
+                return True, replacement_data
+
+        return True, None
 
     async def on_create_room(
         self, requester: Requester, config: dict, is_requester_admin: bool
-    ) -> bool:
-        """Intercept requests to create room to allow, deny or update the
-        request config.
+    ) -> None:
+        """Intercept requests to create room to maybe deny it (via an exception) or
+        update the request config.
 
         Args:
             requester
             config: The creation config from the client.
             is_requester_admin: If the requester is an admin
-
-        Returns:
-            Whether room creation is allowed or denied.
         """
-
-        if self.third_party_rules is None:
-            return True
-
-        return await self.third_party_rules.on_create_room(
-            requester, config, is_requester_admin
-        )
+        for callback in self._on_create_room_callbacks:
+            try:
+                await callback(requester, config, is_requester_admin)
+            except Exception as e:
+                # Don't silence the errors raised by this callback since we expect it to
+                # raise an exception to deny the creation of the room; instead make sure
+                # it's a SynapseError we can send to clients.
+                if not isinstance(e, SynapseError):
+                    e = SynapseError(
+                        403, "Room creation forbidden with these parameters"
+                    )
+
+                raise e
 
     async def check_threepid_can_be_invited(
         self, medium: str, address: str, room_id: str
@@ -114,15 +268,20 @@ class ThirdPartyEventRules:
         Returns:
             True if the 3PID can be invited, False if not.
         """
-
-        if self.third_party_rules is None:
+        # Bail out early without hitting the store if we don't have any callbacks to run.
+        if len(self._check_threepid_can_be_invited_callbacks) == 0:
             return True
 
         state_events = await self._get_state_map_for_room(room_id)
 
-        return await self.third_party_rules.check_threepid_can_be_invited(
-            medium, address, state_events
-        )
+        for callback in self._check_threepid_can_be_invited_callbacks:
+            try:
+                if await callback(medium, address, state_events) is False:
+                    return False
+            except Exception as e:
+                logger.warning("Failed to run module API callback %s: %s", callback, e)
+
+        return True
 
     async def check_visibility_can_be_modified(
         self, room_id: str, new_visibility: str
@@ -137,18 +296,20 @@ class ThirdPartyEventRules:
         Returns:
             True if the room's visibility can be modified, False if not.
         """
-        if self.third_party_rules is None:
-            return True
-
-        check_func = getattr(
-            self.third_party_rules, "check_visibility_can_be_modified", None
-        )
-        if not check_func or not callable(check_func):
+        # Bail out early without hitting the store if we don't have any callback
+        if len(self._check_visibility_can_be_modified_callbacks) == 0:
             return True
 
         state_events = await self._get_state_map_for_room(room_id)
 
-        return await check_func(room_id, state_events, new_visibility)
+        for callback in self._check_visibility_can_be_modified_callbacks:
+            try:
+                if await callback(room_id, state_events, new_visibility) is False:
+                    return False
+            except Exception as e:
+                logger.warning("Failed to run module API callback %s: %s", callback, e)
+
+        return True
 
     async def _get_state_map_for_room(self, room_id: str) -> StateMap[EventBase]:
         """Given a room ID, return the state events of that room.
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index cf389be3e4..5728719909 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1934,7 +1934,7 @@ class FederationHandler(BaseHandler):
             builder=builder
         )
 
-        event_allowed = await self.third_party_event_rules.check_event_allowed(
+        event_allowed, _ = await self.third_party_event_rules.check_event_allowed(
             event, context
         )
         if not event_allowed:
@@ -2026,7 +2026,7 @@ class FederationHandler(BaseHandler):
         # for knock events, we run the third-party event rules. It's not entirely clear
         # why we don't do this for other sorts of membership events.
         if event.membership == Membership.KNOCK:
-            event_allowed = await self.third_party_event_rules.check_event_allowed(
+            event_allowed, _ = await self.third_party_event_rules.check_event_allowed(
                 event, context
             )
             if not event_allowed:
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index c7fe4ff89e..8a0024ce84 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -949,10 +949,10 @@ class EventCreationHandler:
         if requester:
             context.app_service = requester.app_service
 
-        third_party_result = await self.third_party_event_rules.check_event_allowed(
+        res, new_content = await self.third_party_event_rules.check_event_allowed(
             event, context
         )
-        if not third_party_result:
+        if res is False:
             logger.info(
                 "Event %s forbidden by third-party rules",
                 event,
@@ -960,11 +960,11 @@ class EventCreationHandler:
             raise SynapseError(
                 403, "This event is not allowed in this context", Codes.FORBIDDEN
             )
-        elif isinstance(third_party_result, dict):
+        elif new_content is not None:
             # the third-party rules want to replace the event. We'll need to build a new
             # event.
             event, context = await self._rebuild_event_after_third_party_rules(
-                third_party_result, event
+                new_content, event
             )
 
         self.validator.validate_new(event, self.config)
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 64656fda22..370561e549 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -618,15 +618,11 @@ class RoomCreationHandler(BaseHandler):
         else:
             is_requester_admin = await self.auth.is_server_admin(requester.user)
 
-        # Check whether the third party rules allows/changes the room create
-        # request.
-        event_allowed = await self.third_party_event_rules.on_create_room(
+        # Let the third party rules modify the room creation config if needed, or abort
+        # the room creation entirely with an exception.
+        await self.third_party_event_rules.on_create_room(
             requester, config, is_requester_admin=is_requester_admin
         )
-        if not event_allowed:
-            raise SynapseError(
-                403, "You are not permitted to create rooms", Codes.FORBIDDEN
-            )
 
         if not is_requester_admin and not await self.spam_checker.user_may_create_room(
             user_id
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 5df9349134..1259fc2d90 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -110,6 +110,7 @@ class ModuleApi:
 
         self._spam_checker = hs.get_spam_checker()
         self._account_validity_handler = hs.get_account_validity_handler()
+        self._third_party_event_rules = hs.get_third_party_event_rules()
 
     #################################################################################
     # The following methods should only be called during the module's initialisation.
@@ -124,6 +125,11 @@ class ModuleApi:
         """Registers callbacks for account validity capabilities."""
         return self._account_validity_handler.register_account_validity_callbacks
 
+    @property
+    def register_third_party_rules_callbacks(self):
+        """Registers callbacks for third party event rules capabilities."""
+        return self._third_party_event_rules.register_third_party_rules_callbacks
+
     def register_web_resource(self, path: str, resource: IResource):
         """Registers a web resource to be served at the given path.
 
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index c5e1c5458b..28dd47a28b 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -16,17 +16,19 @@ from typing import Dict
 from unittest.mock import Mock
 
 from synapse.events import EventBase
+from synapse.events.third_party_rules import load_legacy_third_party_event_rules
 from synapse.module_api import ModuleApi
 from synapse.rest import admin
 from synapse.rest.client.v1 import login, room
 from synapse.types import Requester, StateMap
+from synapse.util.frozenutils import unfreeze
 
 from tests import unittest
 
 thread_local = threading.local()
 
 
-class ThirdPartyRulesTestModule:
+class LegacyThirdPartyRulesTestModule:
     def __init__(self, config: Dict, module_api: ModuleApi):
         # keep a record of the "current" rules module, so that the test can patch
         # it if desired.
@@ -46,8 +48,26 @@ class ThirdPartyRulesTestModule:
         return config
 
 
-def current_rules_module() -> ThirdPartyRulesTestModule:
-    return thread_local.rules_module
+class LegacyDenyNewRooms(LegacyThirdPartyRulesTestModule):
+    def __init__(self, config: Dict, module_api: ModuleApi):
+        super().__init__(config, module_api)
+
+    def on_create_room(
+        self, requester: Requester, config: dict, is_requester_admin: bool
+    ):
+        return False
+
+
+class LegacyChangeEvents(LegacyThirdPartyRulesTestModule):
+    def __init__(self, config: Dict, module_api: ModuleApi):
+        super().__init__(config, module_api)
+
+    async def check_event_allowed(self, event: EventBase, state: StateMap[EventBase]):
+        d = event.get_dict()
+        content = unfreeze(event.content)
+        content["foo"] = "bar"
+        d["content"] = content
+        return d
 
 
 class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
@@ -57,20 +77,23 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
         room.register_servlets,
     ]
 
-    def default_config(self):
-        config = super().default_config()
-        config["third_party_event_rules"] = {
-            "module": __name__ + ".ThirdPartyRulesTestModule",
-            "config": {},
-        }
-        return config
+    def make_homeserver(self, reactor, clock):
+        hs = self.setup_test_homeserver()
+
+        load_legacy_third_party_event_rules(hs)
+
+        return hs
 
     def prepare(self, reactor, clock, homeserver):
         # Create a user and room to play with during the tests
         self.user_id = self.register_user("kermit", "monkey")
         self.tok = self.login("kermit", "monkey")
 
-        self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
+        # Some tests might prevent room creation on purpose.
+        try:
+            self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
+        except Exception:
+            pass
 
     def test_third_party_rules(self):
         """Tests that a forbidden event is forbidden from being sent, but an allowed one
@@ -79,10 +102,12 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
         # patch the rules module with a Mock which will return False for some event
         # types
         async def check(ev, state):
-            return ev.type != "foo.bar.forbidden"
+            return ev.type != "foo.bar.forbidden", None
 
         callback = Mock(spec=[], side_effect=check)
-        current_rules_module().check_event_allowed = callback
+        self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [
+            callback
+        ]
 
         channel = self.make_request(
             "PUT",
@@ -116,9 +141,9 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
         # first patch the event checker so that it will try to modify the event
         async def check(ev: EventBase, state):
             ev.content = {"x": "y"}
-            return True
+            return True, None
 
-        current_rules_module().check_event_allowed = check
+        self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check]
 
         # now send the event
         channel = self.make_request(
@@ -127,7 +152,19 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
             {"x": "x"},
             access_token=self.tok,
         )
-        self.assertEqual(channel.result["code"], b"500", channel.result)
+        # check_event_allowed has some error handling, so it shouldn't 500 just because a
+        # module did something bad.
+        self.assertEqual(channel.code, 200, channel.result)
+        event_id = channel.json_body["event_id"]
+
+        channel = self.make_request(
+            "GET",
+            "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
+            access_token=self.tok,
+        )
+        self.assertEqual(channel.code, 200, channel.result)
+        ev = channel.json_body
+        self.assertEqual(ev["content"]["x"], "x")
 
     def test_modify_event(self):
         """The module can return a modified version of the event"""
@@ -135,9 +172,9 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
         async def check(ev: EventBase, state):
             d = ev.get_dict()
             d["content"] = {"x": "y"}
-            return d
+            return True, d
 
-        current_rules_module().check_event_allowed = check
+        self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check]
 
         # now send the event
         channel = self.make_request(
@@ -168,9 +205,9 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
                 "msgtype": "m.text",
                 "body": d["content"]["body"].upper(),
             }
-            return d
+            return True, d
 
-        current_rules_module().check_event_allowed = check
+        self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check]
 
         # Send an event, then edit it.
         channel = self.make_request(
@@ -222,7 +259,7 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
         self.assertEqual(ev["content"]["body"], "EDITED BODY")
 
     def test_send_event(self):
-        """Tests that the module can send an event into a room via the module api"""
+        """Tests that a module can send an event into a room via the module api"""
         content = {
             "msgtype": "m.text",
             "body": "Hello!",
@@ -234,12 +271,59 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
             "sender": self.user_id,
         }
         event: EventBase = self.get_success(
-            current_rules_module().module_api.create_and_send_event_into_room(
-                event_dict
-            )
+            self.hs.get_module_api().create_and_send_event_into_room(event_dict)
         )
 
         self.assertEquals(event.sender, self.user_id)
         self.assertEquals(event.room_id, self.room_id)
         self.assertEquals(event.type, "m.room.message")
         self.assertEquals(event.content, content)
+
+    @unittest.override_config(
+        {
+            "third_party_event_rules": {
+                "module": __name__ + ".LegacyChangeEvents",
+                "config": {},
+            }
+        }
+    )
+    def test_legacy_check_event_allowed(self):
+        """Tests that the wrapper for legacy check_event_allowed callbacks works
+        correctly.
+        """
+        channel = self.make_request(
+            "PUT",
+            "/_matrix/client/r0/rooms/%s/send/m.room.message/1" % self.room_id,
+            {
+                "msgtype": "m.text",
+                "body": "Original body",
+            },
+            access_token=self.tok,
+        )
+        self.assertEqual(channel.result["code"], b"200", channel.result)
+
+        event_id = channel.json_body["event_id"]
+
+        channel = self.make_request(
+            "GET",
+            "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
+            access_token=self.tok,
+        )
+        self.assertEqual(channel.result["code"], b"200", channel.result)
+
+        self.assertIn("foo", channel.json_body["content"].keys())
+        self.assertEqual(channel.json_body["content"]["foo"], "bar")
+
+    @unittest.override_config(
+        {
+            "third_party_event_rules": {
+                "module": __name__ + ".LegacyDenyNewRooms",
+                "config": {},
+            }
+        }
+    )
+    def test_legacy_on_create_room(self):
+        """Tests that the wrapper for legacy on_create_room callbacks works
+        correctly.
+        """
+        self.helper.create_room_as(self.user_id, tok=self.tok, expect_code=403)