summary refs log tree commit diff
path: root/tests/rest/client
diff options
context:
space:
mode:
authorEric Eastwood <erice@element.io>2022-02-04 15:27:01 -0600
committerEric Eastwood <erice@element.io>2022-02-04 15:27:01 -0600
commit47590bb19e87d626c94f9529eda6cda6c2359bec (patch)
tree5577c9470f41b4b85346178af584610ea95b75c6 /tests/rest/client
parentMerge branch 'develop' into madlittlemods/return-historical-events-in-order-f... (diff)
parentClarify that users' media are also preview images (#11862) (diff)
downloadsynapse-47590bb19e87d626c94f9529eda6cda6c2359bec.tar.xz
Merge branch 'develop' into madlittlemods/return-historical-events-in-order-from-backfill
Diffstat (limited to 'tests/rest/client')
-rw-r--r--tests/rest/client/test_profile.py156
-rw-r--r--tests/rest/client/test_register.py43
-rw-r--r--tests/rest/client/test_relations.py117
-rw-r--r--tests/rest/client/test_room_batch.py2
-rw-r--r--tests/rest/client/utils.py31
5 files changed, 304 insertions, 45 deletions
diff --git a/tests/rest/client/test_profile.py b/tests/rest/client/test_profile.py
index 2860579c2e..ead883ded8 100644
--- a/tests/rest/client/test_profile.py
+++ b/tests/rest/client/test_profile.py
@@ -13,8 +13,12 @@
 # limitations under the License.
 
 """Tests REST events for /profile paths."""
+from typing import Any, Dict
+
+from synapse.api.errors import Codes
 from synapse.rest import admin
 from synapse.rest.client import login, profile, room
+from synapse.types import UserID
 
 from tests import unittest
 
@@ -25,6 +29,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
         admin.register_servlets_for_client_rest_resource,
         login.register_servlets,
         profile.register_servlets,
+        room.register_servlets,
     ]
 
     def make_homeserver(self, reactor, clock):
@@ -150,6 +155,157 @@ class ProfileTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 200, channel.result)
         return channel.json_body.get("avatar_url")
 
+    @unittest.override_config({"max_avatar_size": 50})
+    def test_avatar_size_limit_global(self):
+        """Tests that the maximum size limit for avatars is enforced when updating a
+        global profile.
+        """
+        self._setup_local_files(
+            {
+                "small": {"size": 40},
+                "big": {"size": 60},
+            }
+        )
+
+        channel = self.make_request(
+            "PUT",
+            f"/profile/{self.owner}/avatar_url",
+            content={"avatar_url": "mxc://test/big"},
+            access_token=self.owner_tok,
+        )
+        self.assertEqual(channel.code, 403, channel.result)
+        self.assertEqual(
+            channel.json_body["errcode"], Codes.FORBIDDEN, channel.json_body
+        )
+
+        channel = self.make_request(
+            "PUT",
+            f"/profile/{self.owner}/avatar_url",
+            content={"avatar_url": "mxc://test/small"},
+            access_token=self.owner_tok,
+        )
+        self.assertEqual(channel.code, 200, channel.result)
+
+    @unittest.override_config({"max_avatar_size": 50})
+    def test_avatar_size_limit_per_room(self):
+        """Tests that the maximum size limit for avatars is enforced when updating a
+        per-room profile.
+        """
+        self._setup_local_files(
+            {
+                "small": {"size": 40},
+                "big": {"size": 60},
+            }
+        )
+
+        room_id = self.helper.create_room_as(tok=self.owner_tok)
+
+        channel = self.make_request(
+            "PUT",
+            f"/rooms/{room_id}/state/m.room.member/{self.owner}",
+            content={"membership": "join", "avatar_url": "mxc://test/big"},
+            access_token=self.owner_tok,
+        )
+        self.assertEqual(channel.code, 403, channel.result)
+        self.assertEqual(
+            channel.json_body["errcode"], Codes.FORBIDDEN, channel.json_body
+        )
+
+        channel = self.make_request(
+            "PUT",
+            f"/rooms/{room_id}/state/m.room.member/{self.owner}",
+            content={"membership": "join", "avatar_url": "mxc://test/small"},
+            access_token=self.owner_tok,
+        )
+        self.assertEqual(channel.code, 200, channel.result)
+
+    @unittest.override_config({"allowed_avatar_mimetypes": ["image/png"]})
+    def test_avatar_allowed_mime_type_global(self):
+        """Tests that the MIME type whitelist for avatars is enforced when updating a
+        global profile.
+        """
+        self._setup_local_files(
+            {
+                "good": {"mimetype": "image/png"},
+                "bad": {"mimetype": "application/octet-stream"},
+            }
+        )
+
+        channel = self.make_request(
+            "PUT",
+            f"/profile/{self.owner}/avatar_url",
+            content={"avatar_url": "mxc://test/bad"},
+            access_token=self.owner_tok,
+        )
+        self.assertEqual(channel.code, 403, channel.result)
+        self.assertEqual(
+            channel.json_body["errcode"], Codes.FORBIDDEN, channel.json_body
+        )
+
+        channel = self.make_request(
+            "PUT",
+            f"/profile/{self.owner}/avatar_url",
+            content={"avatar_url": "mxc://test/good"},
+            access_token=self.owner_tok,
+        )
+        self.assertEqual(channel.code, 200, channel.result)
+
+    @unittest.override_config({"allowed_avatar_mimetypes": ["image/png"]})
+    def test_avatar_allowed_mime_type_per_room(self):
+        """Tests that the MIME type whitelist for avatars is enforced when updating a
+        per-room profile.
+        """
+        self._setup_local_files(
+            {
+                "good": {"mimetype": "image/png"},
+                "bad": {"mimetype": "application/octet-stream"},
+            }
+        )
+
+        room_id = self.helper.create_room_as(tok=self.owner_tok)
+
+        channel = self.make_request(
+            "PUT",
+            f"/rooms/{room_id}/state/m.room.member/{self.owner}",
+            content={"membership": "join", "avatar_url": "mxc://test/bad"},
+            access_token=self.owner_tok,
+        )
+        self.assertEqual(channel.code, 403, channel.result)
+        self.assertEqual(
+            channel.json_body["errcode"], Codes.FORBIDDEN, channel.json_body
+        )
+
+        channel = self.make_request(
+            "PUT",
+            f"/rooms/{room_id}/state/m.room.member/{self.owner}",
+            content={"membership": "join", "avatar_url": "mxc://test/good"},
+            access_token=self.owner_tok,
+        )
+        self.assertEqual(channel.code, 200, channel.result)
+
+    def _setup_local_files(self, names_and_props: Dict[str, Dict[str, Any]]):
+        """Stores metadata about files in the database.
+
+        Args:
+            names_and_props: A dictionary with one entry per file, with the key being the
+                file's name, and the value being a dictionary of properties. Supported
+                properties are "mimetype" (for the file's type) and "size" (for the
+                file's size).
+        """
+        store = self.hs.get_datastore()
+
+        for name, props in names_and_props.items():
+            self.get_success(
+                store.store_local_media(
+                    media_id=name,
+                    media_type=props.get("mimetype", "image/png"),
+                    time_now_ms=self.clock.time_msec(),
+                    upload_name=None,
+                    media_length=props.get("size", 50),
+                    user_id=UserID.from_string("@rin:test"),
+                )
+            )
+
 
 class ProfilesRestrictedTestCase(unittest.HomeserverTestCase):
 
diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py
index 6e7c0f11df..0f1c47dcbb 100644
--- a/tests/rest/client/test_register.py
+++ b/tests/rest/client/test_register.py
@@ -726,6 +726,47 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
             {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"},
         )
 
+    @override_config(
+        {
+            "inhibit_user_in_use_error": True,
+        }
+    )
+    def test_inhibit_user_in_use_error(self):
+        """Tests that the 'inhibit_user_in_use_error' configuration flag behaves
+        correctly.
+        """
+        username = "arthur"
+
+        # Manually register the user, so we know the test isn't passing because of a lack
+        # of clashing.
+        reg_handler = self.hs.get_registration_handler()
+        self.get_success(reg_handler.register_user(username))
+
+        # Check that /available correctly ignores the username provided despite the
+        # username being already registered.
+        channel = self.make_request("GET", "register/available?username=" + username)
+        self.assertEquals(200, channel.code, channel.result)
+
+        # Test that when starting a UIA registration flow the request doesn't fail because
+        # of a conflicting username
+        channel = self.make_request(
+            "POST",
+            "register",
+            {"username": username, "type": "m.login.password", "password": "foo"},
+        )
+        self.assertEqual(channel.code, 401)
+        self.assertIn("session", channel.json_body)
+
+        # Test that finishing the registration fails because of a conflicting username.
+        session = channel.json_body["session"]
+        channel = self.make_request(
+            "POST",
+            "register",
+            {"auth": {"session": session, "type": LoginType.DUMMY}},
+        )
+        self.assertEqual(channel.code, 400, channel.json_body)
+        self.assertEqual(channel.json_body["errcode"], Codes.USER_IN_USE)
+
 
 class AccountValidityTestCase(unittest.HomeserverTestCase):
 
@@ -1113,7 +1154,7 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
 
 class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
     servlets = [register.register_servlets]
-    url = "/_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity"
+    url = "/_matrix/client/v1/register/m.login.registration_token/validity"
 
     def default_config(self):
         config = super().default_config()
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index ff4e81d069..96ae7790bb 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -21,6 +21,7 @@ from unittest.mock import patch
 from synapse.api.constants import EventTypes, RelationTypes
 from synapse.rest import admin
 from synapse.rest.client import login, register, relations, room, sync
+from synapse.types import JsonDict
 
 from tests import unittest
 from tests.server import FakeChannel
@@ -454,7 +455,14 @@ class RelationsTestCase(unittest.HomeserverTestCase):
 
     @unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
     def test_bundled_aggregations(self):
-        """Test that annotations, references, and threads get correctly bundled."""
+        """
+        Test that annotations, references, and threads get correctly bundled.
+
+        Note that this doesn't test against /relations since only thread relations
+        get bundled via that API. See test_aggregation_get_event_for_thread.
+
+        See test_edit for a similar test for edits.
+        """
         # Setup by sending a variety of relations.
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
         self.assertEquals(200, channel.code, channel.json_body)
@@ -482,12 +490,13 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         self.assertEquals(200, channel.code, channel.json_body)
         thread_2 = channel.json_body["event_id"]
 
-        def assert_bundle(actual):
+        def assert_bundle(event_json: JsonDict) -> None:
             """Assert the expected values of the bundled aggregations."""
+            relations_dict = event_json["unsigned"].get("m.relations")
 
             # Ensure the fields are as expected.
             self.assertCountEqual(
-                actual.keys(),
+                relations_dict.keys(),
                 (
                     RelationTypes.ANNOTATION,
                     RelationTypes.REFERENCE,
@@ -503,17 +512,20 @@ class RelationsTestCase(unittest.HomeserverTestCase):
                         {"type": "m.reaction", "key": "b", "count": 1},
                     ]
                 },
-                actual[RelationTypes.ANNOTATION],
+                relations_dict[RelationTypes.ANNOTATION],
             )
 
             self.assertEquals(
                 {"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]},
-                actual[RelationTypes.REFERENCE],
+                relations_dict[RelationTypes.REFERENCE],
             )
 
             self.assertEquals(
                 2,
-                actual[RelationTypes.THREAD].get("count"),
+                relations_dict[RelationTypes.THREAD].get("count"),
+            )
+            self.assertTrue(
+                relations_dict[RelationTypes.THREAD].get("current_user_participated")
             )
             # The latest thread event has some fields that don't matter.
             self.assert_dict(
@@ -530,20 +542,9 @@ class RelationsTestCase(unittest.HomeserverTestCase):
                     "type": "m.room.test",
                     "user_id": self.user_id,
                 },
-                actual[RelationTypes.THREAD].get("latest_event"),
+                relations_dict[RelationTypes.THREAD].get("latest_event"),
             )
 
-        def _find_and_assert_event(events):
-            """
-            Find the parent event in a chunk of events and assert that it has the proper bundled aggregations.
-            """
-            for event in events:
-                if event["event_id"] == self.parent_id:
-                    break
-            else:
-                raise AssertionError(f"Event {self.parent_id} not found in chunk")
-            assert_bundle(event["unsigned"].get("m.relations"))
-
         # Request the event directly.
         channel = self.make_request(
             "GET",
@@ -551,7 +552,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             access_token=self.user_token,
         )
         self.assertEquals(200, channel.code, channel.json_body)
-        assert_bundle(channel.json_body["unsigned"].get("m.relations"))
+        assert_bundle(channel.json_body)
 
         # Request the room messages.
         channel = self.make_request(
@@ -560,7 +561,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             access_token=self.user_token,
         )
         self.assertEquals(200, channel.code, channel.json_body)
-        _find_and_assert_event(channel.json_body["chunk"])
+        assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"]))
 
         # Request the room context.
         channel = self.make_request(
@@ -569,17 +570,14 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             access_token=self.user_token,
         )
         self.assertEquals(200, channel.code, channel.json_body)
-        assert_bundle(channel.json_body["event"]["unsigned"].get("m.relations"))
+        assert_bundle(channel.json_body["event"])
 
         # Request sync.
-        # channel = self.make_request("GET", "/sync", access_token=self.user_token)
-        # self.assertEquals(200, channel.code, channel.json_body)
-        # room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
-        # self.assertTrue(room_timeline["limited"])
-        # _find_and_assert_event(room_timeline["events"])
-
-        # Note that /relations is tested separately in test_aggregation_get_event_for_thread
-        # since it needs different data configured.
+        channel = self.make_request("GET", "/sync", access_token=self.user_token)
+        self.assertEquals(200, channel.code, channel.json_body)
+        room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
+        self.assertTrue(room_timeline["limited"])
+        assert_bundle(self._find_event_in_chunk(room_timeline["events"]))
 
     def test_aggregation_get_event_for_annotation(self):
         """Test that annotations do not get bundled aggregations included
@@ -774,25 +772,58 @@ class RelationsTestCase(unittest.HomeserverTestCase):
 
         edit_event_id = channel.json_body["event_id"]
 
+        def assert_bundle(event_json: JsonDict) -> None:
+            """Assert the expected values of the bundled aggregations."""
+            relations_dict = event_json["unsigned"].get("m.relations")
+            self.assertIn(RelationTypes.REPLACE, relations_dict)
+
+            m_replace_dict = relations_dict[RelationTypes.REPLACE]
+            for key in ["event_id", "sender", "origin_server_ts"]:
+                self.assertIn(key, m_replace_dict)
+
+            self.assert_dict(
+                {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict
+            )
+
         channel = self.make_request(
             "GET",
-            "/rooms/%s/event/%s" % (self.room, self.parent_id),
+            f"/rooms/{self.room}/event/{self.parent_id}",
             access_token=self.user_token,
         )
         self.assertEquals(200, channel.code, channel.json_body)
-
         self.assertEquals(channel.json_body["content"], new_body)
+        assert_bundle(channel.json_body)
 
-        relations_dict = channel.json_body["unsigned"].get("m.relations")
-        self.assertIn(RelationTypes.REPLACE, relations_dict)
+        # Request the room messages.
+        channel = self.make_request(
+            "GET",
+            f"/rooms/{self.room}/messages?dir=b",
+            access_token=self.user_token,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"]))
 
-        m_replace_dict = relations_dict[RelationTypes.REPLACE]
-        for key in ["event_id", "sender", "origin_server_ts"]:
-            self.assertIn(key, m_replace_dict)
+        # Request the room context.
+        channel = self.make_request(
+            "GET",
+            f"/rooms/{self.room}/context/{self.parent_id}",
+            access_token=self.user_token,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        assert_bundle(channel.json_body["event"])
 
-        self.assert_dict(
-            {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict
+        # Request sync, but limit the timeline so it becomes limited (and includes
+        # bundled aggregations).
+        filter = urllib.parse.quote_plus(
+            '{"room": {"timeline": {"limit": 2}}}'.encode()
+        )
+        channel = self.make_request(
+            "GET", f"/sync?filter={filter}", access_token=self.user_token
         )
+        self.assertEquals(200, channel.code, channel.json_body)
+        room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
+        self.assertTrue(room_timeline["limited"])
+        assert_bundle(self._find_event_in_chunk(room_timeline["events"]))
 
     def test_multi_edit(self):
         """Test that multiple edits, including attempts by people who
@@ -1099,6 +1130,16 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         self.assertEquals(200, channel.code, channel.json_body)
         self.assertEquals(channel.json_body["chunk"], [])
 
+    def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict:
+        """
+        Find the parent event in a chunk of events and assert that it has the proper bundled aggregations.
+        """
+        for event in events:
+            if event["event_id"] == self.parent_id:
+                return event
+
+        raise AssertionError(f"Event {self.parent_id} not found in chunk")
+
     def _send_relation(
         self,
         relation_type: str,
diff --git a/tests/rest/client/test_room_batch.py b/tests/rest/client/test_room_batch.py
index 721454c187..e9f8704035 100644
--- a/tests/rest/client/test_room_batch.py
+++ b/tests/rest/client/test_room_batch.py
@@ -89,7 +89,7 @@ class RoomBatchTestCase(unittest.HomeserverTestCase):
         self.clock = clock
         self.storage = hs.get_storage()
 
-        self.virtual_user_id = self.register_appservice_user(
+        self.virtual_user_id, _ = self.register_appservice_user(
             "as_user_potato", self.appservice.token
         )
 
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index 8424383580..1c0cb0cf4f 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -31,6 +31,7 @@ from typing import (
     overload,
 )
 from unittest.mock import patch
+from urllib.parse import urlencode
 
 import attr
 from typing_extensions import Literal
@@ -147,12 +148,20 @@ class RestHelper:
             expect_code=expect_code,
         )
 
-    def join(self, room=None, user=None, expect_code=200, tok=None):
+    def join(
+        self,
+        room: str,
+        user: Optional[str] = None,
+        expect_code: int = 200,
+        tok: Optional[str] = None,
+        appservice_user_id: Optional[str] = None,
+    ) -> None:
         self.change_membership(
             room=room,
             src=user,
             targ=user,
             tok=tok,
+            appservice_user_id=appservice_user_id,
             membership=Membership.JOIN,
             expect_code=expect_code,
         )
@@ -209,11 +218,12 @@ class RestHelper:
     def change_membership(
         self,
         room: str,
-        src: str,
-        targ: str,
+        src: Optional[str],
+        targ: Optional[str],
         membership: str,
         extra_data: Optional[dict] = None,
         tok: Optional[str] = None,
+        appservice_user_id: Optional[str] = None,
         expect_code: int = 200,
         expect_errcode: Optional[str] = None,
     ) -> None:
@@ -227,15 +237,26 @@ class RestHelper:
             membership: The type of membership event
             extra_data: Extra information to include in the content of the event
             tok: The user access token to use
+            appservice_user_id: The `user_id` URL parameter to pass.
+                This allows driving an application service user
+                using an application service access token in `tok`.
             expect_code: The expected HTTP response code
             expect_errcode: The expected Matrix error code
         """
         temp_id = self.auth_user_id
         self.auth_user_id = src
 
-        path = "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % (room, targ)
+        path = f"/_matrix/client/r0/rooms/{room}/state/m.room.member/{targ}"
+        url_params: Dict[str, str] = {}
+
         if tok:
-            path = path + "?access_token=%s" % tok
+            url_params["access_token"] = tok
+
+        if appservice_user_id:
+            url_params["user_id"] = appservice_user_id
+
+        if url_params:
+            path += "?" + urlencode(url_params)
 
         data = {"membership": membership}
         data.update(extra_data or {})