summary refs log tree commit diff
path: root/tests/rest/client
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest/client')
-rw-r--r--tests/rest/client/test_third_party_rules.py56
1 files changed, 50 insertions, 6 deletions
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index 38ac9be113..531f09c48b 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -12,25 +12,28 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import threading
-from typing import Dict
+from typing import TYPE_CHECKING, Dict, Optional, Tuple
 from unittest.mock import Mock
 
 from synapse.api.constants import EventTypes
+from synapse.api.errors import SynapseError
 from synapse.events import EventBase
 from synapse.events.third_party_rules import load_legacy_third_party_event_rules
-from synapse.module_api import ModuleApi
 from synapse.rest import admin
 from synapse.rest.client import login, room
-from synapse.types import Requester, StateMap
+from synapse.types import JsonDict, Requester, StateMap
 from synapse.util.frozenutils import unfreeze
 
 from tests import unittest
 
+if TYPE_CHECKING:
+    from synapse.module_api import ModuleApi
+
 thread_local = threading.local()
 
 
 class LegacyThirdPartyRulesTestModule:
-    def __init__(self, config: Dict, module_api: ModuleApi):
+    def __init__(self, config: Dict, module_api: "ModuleApi"):
         # keep a record of the "current" rules module, so that the test can patch
         # it if desired.
         thread_local.rules_module = self
@@ -50,7 +53,7 @@ class LegacyThirdPartyRulesTestModule:
 
 
 class LegacyDenyNewRooms(LegacyThirdPartyRulesTestModule):
-    def __init__(self, config: Dict, module_api: ModuleApi):
+    def __init__(self, config: Dict, module_api: "ModuleApi"):
         super().__init__(config, module_api)
 
     def on_create_room(
@@ -60,7 +63,7 @@ class LegacyDenyNewRooms(LegacyThirdPartyRulesTestModule):
 
 
 class LegacyChangeEvents(LegacyThirdPartyRulesTestModule):
-    def __init__(self, config: Dict, module_api: ModuleApi):
+    def __init__(self, config: Dict, module_api: "ModuleApi"):
         super().__init__(config, module_api)
 
     async def check_event_allowed(self, event: EventBase, state: StateMap[EventBase]):
@@ -136,6 +139,47 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
         )
         self.assertEquals(channel.result["code"], b"403", channel.result)
 
+    def test_third_party_rules_workaround_synapse_errors_pass_through(self):
+        """
+        Tests that the workaround introduced by https://github.com/matrix-org/synapse/pull/11042
+        is functional: that SynapseErrors are passed through from check_event_allowed
+        and bubble up to the web resource.
+
+        NEW MODULES SHOULD NOT MAKE USE OF THIS WORKAROUND!
+        This is a temporary workaround!
+        """
+
+        class NastyHackException(SynapseError):
+            def error_dict(self):
+                """
+                This overrides SynapseError's `error_dict` to nastily inject
+                JSON into the error response.
+                """
+                result = super().error_dict()
+                result["nasty"] = "very"
+                return result
+
+        # add a callback that will raise our hacky exception
+        async def check(ev, state) -> Tuple[bool, Optional[JsonDict]]:
+            raise NastyHackException(429, "message")
+
+        self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check]
+
+        # Make a request
+        channel = self.make_request(
+            "PUT",
+            "/_matrix/client/r0/rooms/%s/send/foo.bar.forbidden/2" % self.room_id,
+            {},
+            access_token=self.tok,
+        )
+        # Check the error code
+        self.assertEquals(channel.result["code"], b"429", channel.result)
+        # Check the JSON body has had the `nasty` key injected
+        self.assertEqual(
+            channel.json_body,
+            {"errcode": "M_UNKNOWN", "error": "message", "nasty": "very"},
+        )
+
     def test_cannot_modify_event(self):
         """cannot accidentally modify an event before it is persisted"""