diff --git a/tests/rest/client/test_profile.py b/tests/rest/client/test_profile.py
index a92713d220..708402b792 100644
--- a/tests/rest/client/test_profile.py
+++ b/tests/rest/client/test_profile.py
@@ -25,16 +25,20 @@ import urllib.parse
from http import HTTPStatus
from typing import Any, Dict, Optional
+from canonicaljson import encode_canonical_json
+
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import Codes
from synapse.rest import admin
from synapse.rest.client import login, profile, room
from synapse.server import HomeServer
+from synapse.storage.databases.main.profile import MAX_PROFILE_SIZE
from synapse.types import UserID
from synapse.util import Clock
from tests import unittest
+from tests.utils import USE_POSTGRES_FOR_TESTS
class ProfileTestCase(unittest.HomeserverTestCase):
@@ -480,6 +484,298 @@ class ProfileTestCase(unittest.HomeserverTestCase):
# The client requested ?propagate=true, so it should have happened.
self.assertEqual(channel.json_body.get(prop), "http://my.server/pic.gif")
+ @unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
+ def test_get_missing_custom_field(self) -> None:
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
+ )
+ self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, channel.result)
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ @unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
+ def test_get_missing_custom_field_invalid_field_name(self) -> None:
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/[custom_field]",
+ )
+ self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ @unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
+ def test_get_custom_field_rejects_bad_username(self) -> None:
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{urllib.parse.quote('@alice:')}/custom_field",
+ )
+ self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ @unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
+ def test_set_custom_field(self) -> None:
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
+ content={"custom_field": "test"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
+ )
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+ self.assertEqual(channel.json_body, {"custom_field": "test"})
+
+ # Overwriting the field should work.
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
+ content={"custom_field": "new_Value"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
+ )
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+ self.assertEqual(channel.json_body, {"custom_field": "new_Value"})
+
+ # Deleting the field should work.
+ channel = self.make_request(
+ "DELETE",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
+ content={},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
+ )
+ self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, channel.result)
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ @unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
+ def test_non_string(self) -> None:
+ """Non-string fields are supported for custom fields."""
+ fields = {
+ "bool_field": True,
+ "array_field": ["test"],
+ "object_field": {"test": "test"},
+ "numeric_field": 1,
+ "null_field": None,
+ }
+
+ for key, value in fields.items():
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}",
+ content={key: value},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/v3/profile/{self.owner}",
+ )
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+ self.assertEqual(channel.json_body, {"displayname": "owner", **fields})
+
+ # Check getting individual fields works.
+ for key, value in fields.items():
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}",
+ )
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+ self.assertEqual(channel.json_body, {key: value})
+
+ @unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
+ def test_set_custom_field_noauth(self) -> None:
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
+ content={"custom_field": "test"},
+ )
+ self.assertEqual(channel.code, 401, channel.result)
+ self.assertEqual(channel.json_body["errcode"], Codes.MISSING_TOKEN)
+
+ @unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
+ def test_set_custom_field_size(self) -> None:
+ """
+ Attempts to set a custom field name that is too long should get a 400 error.
+ """
+ # Key is missing.
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/",
+ content={"": "test"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ # Single key is too large.
+ key = "c" * 500
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}",
+ content={key: "test"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.json_body["errcode"], Codes.KEY_TOO_LARGE)
+
+ channel = self.make_request(
+ "DELETE",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}",
+ content={key: "test"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.json_body["errcode"], Codes.KEY_TOO_LARGE)
+
+ # Key doesn't match body.
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
+ content={"diff_key": "test"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.json_body["errcode"], Codes.MISSING_PARAM)
+
+ @unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
+ def test_set_custom_field_profile_too_long(self) -> None:
+ """
+ Attempts to set a custom field that would push the overall profile too large.
+ """
+ # Get right to the boundary:
+ # len("displayname") + len("owner") + 5 = 21 for the displayname
+ # 1 + 65498 + 5 for key "a" = 65504
+ # 2 braces, 1 comma
+ # 3 + 21 + 65498 = 65522 < 65536.
+ key = "a"
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}",
+ content={key: "a" * 65498},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ # Get the entire profile.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/v3/profile/{self.owner}",
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+ canonical_json = encode_canonical_json(channel.json_body)
+ # 6 is the minimum bytes to store a value: 4 quotes, 1 colon, 1 comma, an empty key.
+ # Be one below that so we can prove we're at the boundary.
+ self.assertEqual(len(canonical_json), MAX_PROFILE_SIZE - 8)
+
+ # Postgres stores JSONB with whitespace, while SQLite doesn't.
+ if USE_POSTGRES_FOR_TESTS:
+ ADDITIONAL_CHARS = 0
+ else:
+ ADDITIONAL_CHARS = 1
+
+ # The next one should fail, note the value has a (JSON) length of 2.
+ key = "b"
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}",
+ content={key: "1" + "a" * ADDITIONAL_CHARS},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.json_body["errcode"], Codes.PROFILE_TOO_LARGE)
+
+ # Setting an avatar or (longer) display name should not work.
+ channel = self.make_request(
+ "PUT",
+ f"/profile/{self.owner}/displayname",
+ content={"displayname": "owner12345678" + "a" * ADDITIONAL_CHARS},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.json_body["errcode"], Codes.PROFILE_TOO_LARGE)
+
+ channel = self.make_request(
+ "PUT",
+ f"/profile/{self.owner}/avatar_url",
+ content={"avatar_url": "mxc://foo/bar"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.json_body["errcode"], Codes.PROFILE_TOO_LARGE)
+
+ # Removing a single byte should work.
+ key = "b"
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}",
+ content={key: "" + "a" * ADDITIONAL_CHARS},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ # Finally, setting a field that already exists to a value that is <= in length should work.
+ key = "a"
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}",
+ content={key: ""},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ @unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
+ def test_set_custom_field_displayname(self) -> None:
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/displayname",
+ content={"displayname": "test"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ displayname = self._get_displayname()
+ self.assertEqual(displayname, "test")
+
+ @unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
+ def test_set_custom_field_avatar_url(self) -> None:
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/avatar_url",
+ content={"avatar_url": "mxc://test/good"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ avatar_url = self._get_avatar_url()
+ self.assertEqual(avatar_url, "mxc://test/good")
+
+ @unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
+ def test_set_custom_field_other(self) -> None:
+ """Setting someone else's profile field should fail"""
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.other}/custom_field",
+ content={"custom_field": "test"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 403, channel.result)
+ self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN)
+
def _setup_local_files(self, names_and_props: Dict[str, Dict[str, Any]]) -> None:
"""Stores metadata about files in the database.
|