summary refs log tree commit diff
path: root/tests/rest/client
diff options
context:
space:
mode:
authorDavid Robertson <davidr@element.io>2021-11-17 14:19:27 +0000
committerDavid Robertson <davidr@element.io>2021-11-17 14:19:27 +0000
commit077b74929f8f412395d1156e1b97eb16701059fa (patch)
tree25ecb8e93ec3ba275596bfb31efa45018645d47b /tests/rest/client
parentCorrect target of link to the modules page from the Password Auth Providers p... (diff)
parent1.47.0 (diff)
downloadsynapse-077b74929f8f412395d1156e1b97eb16701059fa.tar.xz
Merge remote-tracking branch 'origin/release-v1.47'
Diffstat (limited to 'tests/rest/client')
-rw-r--r--tests/rest/client/test_consent.py1
-rw-r--r--tests/rest/client/test_register.py1
-rw-r--r--tests/rest/client/test_sync.py30
-rw-r--r--tests/rest/client/test_third_party_rules.py109
-rw-r--r--tests/rest/client/utils.py29
5 files changed, 150 insertions, 20 deletions
diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py
index 84d092ca82..fcdc565814 100644
--- a/tests/rest/client/test_consent.py
+++ b/tests/rest/client/test_consent.py
@@ -35,7 +35,6 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
     def make_homeserver(self, reactor, clock):
 
         config = self.default_config()
-        config["public_baseurl"] = "aaaa"
         config["form_secret"] = "123abc"
 
         # Make some temporary templates...
diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py
index 66dcfc9f88..6e7c0f11df 100644
--- a/tests/rest/client/test_register.py
+++ b/tests/rest/client/test_register.py
@@ -891,7 +891,6 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
             "smtp_pass": None,
             "notif_from": "test@example.com",
         }
-        config["public_baseurl"] = "aaa"
 
         self.hs = self.setup_test_homeserver(config=config)
 
diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index 95be369d4b..c427686376 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -14,6 +14,8 @@
 # limitations under the License.
 import json
 
+from parameterized import parameterized
+
 import synapse.rest.admin
 from synapse.api.constants import (
     EventContentFields,
@@ -417,7 +419,30 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
         # Test that the first user can't see the other user's hidden read receipt
         self.assertEqual(self._get_read_receipt(), None)
 
-    def test_read_receipt_with_empty_body(self):
+    @parameterized.expand(
+        [
+            # Old Element version, expected to send an empty body
+            (
+                "agent1",
+                "Element/1.2.2 (Linux; U; Android 9; MatrixAndroidSDK_X 0.0.1)",
+                200,
+            ),
+            # Old SchildiChat version, expected to send an empty body
+            ("agent2", "SchildiChat/1.2.1 (Android 10)", 200),
+            # Expected 400: Denies empty body starting at version 1.3+
+            ("agent3", "Element/1.3.6 (Android 10)", 400),
+            ("agent4", "SchildiChat/1.3.6 (Android 11)", 400),
+            # Contains "Riot": Receipts with empty bodies expected
+            ("agent5", "Element (Riot.im) (Android 9)", 200),
+            # Expected 400: Does not contain "Android"
+            ("agent6", "Element/1.2.1", 400),
+            # Expected 400: Different format, missing "/" after Element; existing build that should allow empty bodies, but minimal ongoing usage
+            ("agent7", "Element dbg/1.1.8-dev (Android)", 400),
+        ]
+    )
+    def test_read_receipt_with_empty_body(
+        self, name, user_agent: str, expected_status_code: int
+    ):
         # Send a message as the first user
         res = self.helper.send(self.room_id, body="hello", tok=self.tok)
 
@@ -426,8 +451,9 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
             "POST",
             "/rooms/%s/receipt/m.read/%s" % (self.room_id, res["event_id"]),
             access_token=self.tok2,
+            custom_headers=[("User-Agent", user_agent)],
         )
-        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.code, expected_status_code)
 
     def _get_read_receipt(self):
         """Syncs and returns the read receipt."""
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index 531f09c48b..4e71b6ec12 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -15,7 +15,7 @@ import threading
 from typing import TYPE_CHECKING, Dict, Optional, Tuple
 from unittest.mock import Mock
 
-from synapse.api.constants import EventTypes
+from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import SynapseError
 from synapse.events import EventBase
 from synapse.events.third_party_rules import load_legacy_third_party_event_rules
@@ -25,6 +25,7 @@ from synapse.types import JsonDict, Requester, StateMap
 from synapse.util.frozenutils import unfreeze
 
 from tests import unittest
+from tests.test_utils import make_awaitable
 
 if TYPE_CHECKING:
     from synapse.module_api import ModuleApi
@@ -74,7 +75,7 @@ class LegacyChangeEvents(LegacyThirdPartyRulesTestModule):
         return d
 
 
-class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
+class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
     servlets = [
         admin.register_servlets,
         login.register_servlets,
@@ -86,11 +87,29 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
 
         load_legacy_third_party_event_rules(hs)
 
+        # We're not going to be properly signing events as our remote homeserver is fake,
+        # therefore disable event signature checks.
+        # Note that these checks are not relevant to this test case.
+
+        # Have this homeserver auto-approve all event signature checking.
+        async def approve_all_signature_checking(_, pdu):
+            return pdu
+
+        hs.get_federation_server()._check_sigs_and_hash = approve_all_signature_checking
+
+        # Have this homeserver skip event auth checks. This is necessary due to
+        # event auth checks ensuring that events were signed by the sender's homeserver.
+        async def _check_event_auth(origin, event, context, *args, **kwargs):
+            return context
+
+        hs.get_federation_event_handler()._check_event_auth = _check_event_auth
+
         return hs
 
     def prepare(self, reactor, clock, homeserver):
-        # Create a user and room to play with during the tests
+        # Create some users and a room to play with during the tests
         self.user_id = self.register_user("kermit", "monkey")
+        self.invitee = self.register_user("invitee", "hackme")
         self.tok = self.login("kermit", "monkey")
 
         # Some tests might prevent room creation on purpose.
@@ -197,19 +216,9 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
             {"x": "x"},
             access_token=self.tok,
         )
-        # check_event_allowed has some error handling, so it shouldn't 500 just because a
-        # module did something bad.
-        self.assertEqual(channel.code, 200, channel.result)
-        event_id = channel.json_body["event_id"]
-
-        channel = self.make_request(
-            "GET",
-            "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
-            access_token=self.tok,
-        )
-        self.assertEqual(channel.code, 200, channel.result)
-        ev = channel.json_body
-        self.assertEqual(ev["content"]["x"], "x")
+        # Because check_event_allowed raises an exception, it leads to a
+        # 500 Internal Server Error
+        self.assertEqual(channel.code, 500, channel.result)
 
     def test_modify_event(self):
         """The module can return a modified version of the event"""
@@ -424,6 +433,74 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
             self.assertEqual(channel.code, 200)
             self.assertEqual(channel.json_body["i"], i)
 
+    def test_on_new_event(self):
+        """Test that the on_new_event callback is called on new events"""
+        on_new_event = Mock(make_awaitable(None))
+        self.hs.get_third_party_event_rules()._on_new_event_callbacks.append(
+            on_new_event
+        )
+
+        # Send a message event to the room and check that the callback is called.
+        self.helper.send(room_id=self.room_id, tok=self.tok)
+        self.assertEqual(on_new_event.call_count, 1)
+
+        # Check that the callback is also called on membership updates.
+        self.helper.invite(
+            room=self.room_id,
+            src=self.user_id,
+            targ=self.invitee,
+            tok=self.tok,
+        )
+
+        self.assertEqual(on_new_event.call_count, 2)
+
+        args, _ = on_new_event.call_args
+
+        self.assertEqual(args[0].membership, Membership.INVITE)
+        self.assertEqual(args[0].state_key, self.invitee)
+
+        # Check that the invitee's membership is correct in the state that's passed down
+        # to the callback.
+        self.assertEqual(
+            args[1][(EventTypes.Member, self.invitee)].membership,
+            Membership.INVITE,
+        )
+
+        # Send an event over federation and check that the callback is also called.
+        self._send_event_over_federation()
+        self.assertEqual(on_new_event.call_count, 3)
+
+    def _send_event_over_federation(self) -> None:
+        """Send a dummy event over federation and check that the request succeeds."""
+        body = {
+            "origin": self.hs.config.server.server_name,
+            "origin_server_ts": self.clock.time_msec(),
+            "pdus": [
+                {
+                    "sender": self.user_id,
+                    "type": EventTypes.Message,
+                    "state_key": "",
+                    "content": {"body": "hello world", "msgtype": "m.text"},
+                    "room_id": self.room_id,
+                    "depth": 0,
+                    "origin_server_ts": self.clock.time_msec(),
+                    "prev_events": [],
+                    "auth_events": [],
+                    "signatures": {},
+                    "unsigned": {},
+                }
+            ],
+        }
+
+        channel = self.make_request(
+            method="PUT",
+            path="/_matrix/federation/v1/send/1",
+            content=body,
+            federation_auth_origin=self.hs.config.server.server_name.encode("utf8"),
+        )
+
+        self.assertEqual(channel.code, 200, channel.result)
+
     def _update_power_levels(self, event_default: int = 0):
         """Updates the room's power levels.
 
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index 71fa87ce92..ec0979850b 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -120,6 +120,35 @@ class RestHelper:
             expect_code=expect_code,
         )
 
+    def knock(self, room=None, user=None, reason=None, expect_code=200, tok=None):
+        temp_id = self.auth_user_id
+        self.auth_user_id = user
+        path = "/knock/%s" % room
+        if tok:
+            path = path + "?access_token=%s" % tok
+
+        data = {}
+        if reason:
+            data["reason"] = reason
+
+        channel = make_request(
+            self.hs.get_reactor(),
+            self.site,
+            "POST",
+            path,
+            json.dumps(data).encode("utf8"),
+        )
+
+        assert (
+            int(channel.result["code"]) == expect_code
+        ), "Expected: %d, got: %d, resp: %r" % (
+            expect_code,
+            int(channel.result["code"]),
+            channel.result["body"],
+        )
+
+        self.auth_user_id = temp_id
+
     def leave(self, room=None, user=None, expect_code=200, tok=None):
         self.change_membership(
             room=room,