diff --git a/changelog.d/17488.feature b/changelog.d/17488.feature
new file mode 100644
index 0000000000..15cccf3ac2
--- /dev/null
+++ b/changelog.d/17488.feature
@@ -0,0 +1 @@
+Implement [MSC4133](https://github.com/matrix-org/matrix-spec-proposals/pull/4133) for custom profile fields.
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 21989b6e0e..5dd6e84289 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -132,6 +132,10 @@ class Codes(str, Enum):
# connection.
UNKNOWN_POS = "M_UNKNOWN_POS"
+ # Part of MSC4133
+ PROFILE_TOO_LARGE = "M_PROFILE_TOO_LARGE"
+ KEY_TOO_LARGE = "M_KEY_TOO_LARGE"
+
class CodeMessageException(RuntimeError):
"""An exception with integer code, a message string attributes and optional headers.
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index 90d19849ff..94a25c7ee8 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -436,6 +436,9 @@ class ExperimentalConfig(Config):
("experimental", "msc4108_delegation_endpoint"),
)
+ # MSC4133: Custom profile fields
+ self.msc4133_enabled: bool = experimental.get("msc4133_enabled", False)
+
# MSC4210: Remove legacy mentions
self.msc4210_enabled: bool = experimental.get("msc4210_enabled", False)
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 22eedcb54f..cdc388b4ab 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -32,7 +32,7 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.storage.databases.main.media_repository import LocalMedia, RemoteMedia
-from synapse.types import JsonDict, Requester, UserID, create_requester
+from synapse.types import JsonDict, JsonValue, Requester, UserID, create_requester
from synapse.util.caches.descriptors import cached
from synapse.util.stringutils import parse_and_validate_mxc_uri
@@ -43,6 +43,8 @@ logger = logging.getLogger(__name__)
MAX_DISPLAYNAME_LEN = 256
MAX_AVATAR_URL_LEN = 1000
+# Field name length is specced at 255 bytes.
+MAX_CUSTOM_FIELD_LEN = 255
class ProfileHandler:
@@ -90,7 +92,15 @@ class ProfileHandler:
if self.hs.is_mine(target_user):
profileinfo = await self.store.get_profileinfo(target_user)
- if profileinfo.display_name is None and profileinfo.avatar_url is None:
+ extra_fields = {}
+ if self.hs.config.experimental.msc4133_enabled:
+ extra_fields = await self.store.get_profile_fields(target_user)
+
+ if (
+ profileinfo.display_name is None
+ and profileinfo.avatar_url is None
+ and not extra_fields
+ ):
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
# Do not include display name or avatar if unset.
@@ -99,6 +109,9 @@ class ProfileHandler:
ret[ProfileFields.DISPLAYNAME] = profileinfo.display_name
if profileinfo.avatar_url is not None:
ret[ProfileFields.AVATAR_URL] = profileinfo.avatar_url
+ if extra_fields:
+ ret.update(extra_fields)
+
return ret
else:
try:
@@ -403,6 +416,110 @@ class ProfileHandler:
return True
+ async def get_profile_field(
+ self, target_user: UserID, field_name: str
+ ) -> JsonValue:
+ """
+ Fetch a user's profile from the database for local users and over federation
+ for remote users.
+
+ Args:
+ target_user: The user ID to fetch the profile for.
+ field_name: The field to fetch the profile for.
+
+ Returns:
+ The value for the profile field or None if the field does not exist.
+ """
+ if self.hs.is_mine(target_user):
+ try:
+ field_value = await self.store.get_profile_field(
+ target_user, field_name
+ )
+ except StoreError as e:
+ if e.code == 404:
+ raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
+ raise
+
+ return field_value
+ else:
+ try:
+ result = await self.federation.make_query(
+ destination=target_user.domain,
+ query_type="profile",
+ args={"user_id": target_user.to_string(), "field": field_name},
+ ignore_backoff=True,
+ )
+ except RequestSendFailed as e:
+ raise SynapseError(502, "Failed to fetch profile") from e
+ except HttpResponseException as e:
+ raise e.to_synapse_error()
+
+ return result.get(field_name)
+
+ async def set_profile_field(
+ self,
+ target_user: UserID,
+ requester: Requester,
+ field_name: str,
+ new_value: JsonValue,
+ by_admin: bool = False,
+ deactivation: bool = False,
+ ) -> None:
+ """Set a new profile field for a user.
+
+ Args:
+ target_user: the user whose profile is to be changed.
+ requester: The user attempting to make this change.
+ field_name: The name of the profile field to update.
+ new_value: The new field value for this user.
+ by_admin: Whether this change was made by an administrator.
+ deactivation: Whether this change was made while deactivating the user.
+ """
+ if not self.hs.is_mine(target_user):
+ raise SynapseError(400, "User is not hosted on this homeserver")
+
+ if not by_admin and target_user != requester.user:
+ raise AuthError(403, "Cannot set another user's profile")
+
+ await self.store.set_profile_field(target_user, field_name, new_value)
+
+ # Custom fields do not propagate into the user directory *or* rooms.
+ profile = await self.store.get_profileinfo(target_user)
+ await self._third_party_rules.on_profile_update(
+ target_user.to_string(), profile, by_admin, deactivation
+ )
+
+ async def delete_profile_field(
+ self,
+ target_user: UserID,
+ requester: Requester,
+ field_name: str,
+ by_admin: bool = False,
+ deactivation: bool = False,
+ ) -> None:
+ """Delete a field from a user's profile.
+
+ Args:
+ target_user: the user whose profile is to be changed.
+ requester: The user attempting to make this change.
+ field_name: The name of the profile field to remove.
+ by_admin: Whether this change was made by an administrator.
+ deactivation: Whether this change was made while deactivating the user.
+ """
+ if not self.hs.is_mine(target_user):
+ raise SynapseError(400, "User is not hosted on this homeserver")
+
+ if not by_admin and target_user != requester.user:
+ raise AuthError(400, "Cannot set another user's profile")
+
+ await self.store.delete_profile_field(target_user, field_name)
+
+ # Custom fields do not propagate into the user directory *or* rooms.
+ profile = await self.store.get_profileinfo(target_user)
+ await self._third_party_rules.on_profile_update(
+ target_user.to_string(), profile, by_admin, deactivation
+ )
+
async def on_profile_query(self, args: JsonDict) -> JsonDict:
"""Handles federation profile query requests."""
@@ -419,13 +536,24 @@ class ProfileHandler:
just_field = args.get("field", None)
- response = {}
+ response: JsonDict = {}
try:
- if just_field is None or just_field == "displayname":
+ if just_field is None or just_field == ProfileFields.DISPLAYNAME:
response["displayname"] = await self.store.get_profile_displayname(user)
- if just_field is None or just_field == "avatar_url":
+ if just_field is None or just_field == ProfileFields.AVATAR_URL:
response["avatar_url"] = await self.store.get_profile_avatar_url(user)
+
+ if self.hs.config.experimental.msc4133_enabled:
+ if just_field is None:
+ response.update(await self.store.get_profile_fields(user))
+ elif just_field not in (
+ ProfileFields.DISPLAYNAME,
+ ProfileFields.AVATAR_URL,
+ ):
+ response[just_field] = await self.store.get_profile_field(
+ user, just_field
+ )
except StoreError as e:
if e.code == 404:
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
diff --git a/synapse/rest/client/capabilities.py b/synapse/rest/client/capabilities.py
index 63b8a9364a..ebd5a33ea5 100644
--- a/synapse/rest/client/capabilities.py
+++ b/synapse/rest/client/capabilities.py
@@ -92,6 +92,23 @@ class CapabilitiesRestServlet(RestServlet):
"enabled": self.config.experimental.msc3664_enabled,
}
+ if self.config.experimental.msc4133_enabled:
+ response["capabilities"]["uk.tcpip.msc4133.profile_fields"] = {
+ "enabled": True,
+ }
+
+ # Ensure this is consistent with the legacy m.set_displayname and
+ # m.set_avatar_url.
+ disallowed = []
+ if not self.config.registration.enable_set_displayname:
+ disallowed.append("displayname")
+ if not self.config.registration.enable_set_avatar_url:
+ disallowed.append("avatar_url")
+ if disallowed:
+ response["capabilities"]["uk.tcpip.msc4133.profile_fields"][
+ "disallowed"
+ ] = disallowed
+
return HTTPStatus.OK, response
diff --git a/synapse/rest/client/profile.py b/synapse/rest/client/profile.py
index ef59582865..8326d8017c 100644
--- a/synapse/rest/client/profile.py
+++ b/synapse/rest/client/profile.py
@@ -21,10 +21,13 @@
"""This module contains REST servlets to do with profile: /profile/<paths>"""
+import re
from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple
+from synapse.api.constants import ProfileFields
from synapse.api.errors import Codes, SynapseError
+from synapse.handlers.profile import MAX_CUSTOM_FIELD_LEN
from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
@@ -33,7 +36,8 @@ from synapse.http.servlet import (
)
from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
-from synapse.types import JsonDict, UserID
+from synapse.types import JsonDict, JsonValue, UserID
+from synapse.util.stringutils import is_namedspaced_grammar
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -91,6 +95,11 @@ class ProfileDisplaynameRestServlet(RestServlet):
async def on_PUT(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
+ if not UserID.is_valid(user_id):
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST, "Invalid user id", Codes.INVALID_PARAM
+ )
+
requester = await self.auth.get_user_by_req(request, allow_guest=True)
user = UserID.from_string(user_id)
is_admin = await self.auth.is_server_admin(requester)
@@ -101,9 +110,7 @@ class ProfileDisplaynameRestServlet(RestServlet):
new_name = content["displayname"]
except Exception:
raise SynapseError(
- code=400,
- msg="Unable to parse name",
- errcode=Codes.BAD_JSON,
+ 400, "Missing key 'displayname'", errcode=Codes.MISSING_PARAM
)
propagate = _read_propagate(self.hs, request)
@@ -166,6 +173,11 @@ class ProfileAvatarURLRestServlet(RestServlet):
async def on_PUT(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
+ if not UserID.is_valid(user_id):
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST, "Invalid user id", Codes.INVALID_PARAM
+ )
+
requester = await self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
is_admin = await self.auth.is_server_admin(requester)
@@ -232,7 +244,180 @@ class ProfileRestServlet(RestServlet):
return 200, ret
+class UnstableProfileFieldRestServlet(RestServlet):
+ PATTERNS = [
+ re.compile(
+ r"^/_matrix/client/unstable/uk\.tcpip\.msc4133/profile/(?P<user_id>[^/]*)/(?P<field_name>[^/]*)"
+ )
+ ]
+ CATEGORY = "Event sending requests"
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__()
+ self.hs = hs
+ self.profile_handler = hs.get_profile_handler()
+ self.auth = hs.get_auth()
+
+ async def on_GET(
+ self, request: SynapseRequest, user_id: str, field_name: str
+ ) -> Tuple[int, JsonDict]:
+ requester_user = None
+
+ if self.hs.config.server.require_auth_for_profile_requests:
+ requester = await self.auth.get_user_by_req(request)
+ requester_user = requester.user
+
+ if not UserID.is_valid(user_id):
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST, "Invalid user id", Codes.INVALID_PARAM
+ )
+
+ if not field_name:
+ raise SynapseError(400, "Field name too short", errcode=Codes.INVALID_PARAM)
+
+ if len(field_name.encode("utf-8")) > MAX_CUSTOM_FIELD_LEN:
+ raise SynapseError(400, "Field name too long", errcode=Codes.KEY_TOO_LARGE)
+ if not is_namedspaced_grammar(field_name):
+ raise SynapseError(
+ 400,
+ "Field name does not follow Common Namespaced Identifier Grammar",
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ user = UserID.from_string(user_id)
+ await self.profile_handler.check_profile_query_allowed(user, requester_user)
+
+ if field_name == ProfileFields.DISPLAYNAME:
+ field_value: JsonValue = await self.profile_handler.get_displayname(user)
+ elif field_name == ProfileFields.AVATAR_URL:
+ field_value = await self.profile_handler.get_avatar_url(user)
+ else:
+ field_value = await self.profile_handler.get_profile_field(user, field_name)
+
+ return 200, {field_name: field_value}
+
+ async def on_PUT(
+ self, request: SynapseRequest, user_id: str, field_name: str
+ ) -> Tuple[int, JsonDict]:
+ if not UserID.is_valid(user_id):
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST, "Invalid user id", Codes.INVALID_PARAM
+ )
+
+ requester = await self.auth.get_user_by_req(request)
+ user = UserID.from_string(user_id)
+ is_admin = await self.auth.is_server_admin(requester)
+
+ if not field_name:
+ raise SynapseError(400, "Field name too short", errcode=Codes.INVALID_PARAM)
+
+ if len(field_name.encode("utf-8")) > MAX_CUSTOM_FIELD_LEN:
+ raise SynapseError(400, "Field name too long", errcode=Codes.KEY_TOO_LARGE)
+ if not is_namedspaced_grammar(field_name):
+ raise SynapseError(
+ 400,
+ "Field name does not follow Common Namespaced Identifier Grammar",
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ content = parse_json_object_from_request(request)
+ try:
+ new_value = content[field_name]
+ except KeyError:
+ raise SynapseError(
+ 400, f"Missing key '{field_name}'", errcode=Codes.MISSING_PARAM
+ )
+
+ propagate = _read_propagate(self.hs, request)
+
+ requester_suspended = (
+ await self.hs.get_datastores().main.get_user_suspended_status(
+ requester.user.to_string()
+ )
+ )
+
+ if requester_suspended:
+ raise SynapseError(
+ 403,
+ "Updating profile while account is suspended is not allowed.",
+ Codes.USER_ACCOUNT_SUSPENDED,
+ )
+
+ if field_name == ProfileFields.DISPLAYNAME:
+ await self.profile_handler.set_displayname(
+ user, requester, new_value, is_admin, propagate=propagate
+ )
+ elif field_name == ProfileFields.AVATAR_URL:
+ await self.profile_handler.set_avatar_url(
+ user, requester, new_value, is_admin, propagate=propagate
+ )
+ else:
+ await self.profile_handler.set_profile_field(
+ user, requester, field_name, new_value, is_admin
+ )
+
+ return 200, {}
+
+ async def on_DELETE(
+ self, request: SynapseRequest, user_id: str, field_name: str
+ ) -> Tuple[int, JsonDict]:
+ if not UserID.is_valid(user_id):
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST, "Invalid user id", Codes.INVALID_PARAM
+ )
+
+ requester = await self.auth.get_user_by_req(request)
+ user = UserID.from_string(user_id)
+ is_admin = await self.auth.is_server_admin(requester)
+
+ if not field_name:
+ raise SynapseError(400, "Field name too short", errcode=Codes.INVALID_PARAM)
+
+ if len(field_name.encode("utf-8")) > MAX_CUSTOM_FIELD_LEN:
+ raise SynapseError(400, "Field name too long", errcode=Codes.KEY_TOO_LARGE)
+ if not is_namedspaced_grammar(field_name):
+ raise SynapseError(
+ 400,
+ "Field name does not follow Common Namespaced Identifier Grammar",
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ propagate = _read_propagate(self.hs, request)
+
+ requester_suspended = (
+ await self.hs.get_datastores().main.get_user_suspended_status(
+ requester.user.to_string()
+ )
+ )
+
+ if requester_suspended:
+ raise SynapseError(
+ 403,
+ "Updating profile while account is suspended is not allowed.",
+ Codes.USER_ACCOUNT_SUSPENDED,
+ )
+
+ if field_name == ProfileFields.DISPLAYNAME:
+ await self.profile_handler.set_displayname(
+ user, requester, "", is_admin, propagate=propagate
+ )
+ elif field_name == ProfileFields.AVATAR_URL:
+ await self.profile_handler.set_avatar_url(
+ user, requester, "", is_admin, propagate=propagate
+ )
+ else:
+ await self.profile_handler.delete_profile_field(
+ user, requester, field_name, is_admin
+ )
+
+ return 200, {}
+
+
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
+ # The specific displayname / avatar URL / custom field endpoints *must* appear
+ # before their corresponding generic profile endpoint.
ProfileDisplaynameRestServlet(hs).register(http_server)
ProfileAvatarURLRestServlet(hs).register(http_server)
ProfileRestServlet(hs).register(http_server)
+ if hs.config.experimental.msc4133_enabled:
+ UnstableProfileFieldRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index a1d089ebac..266a0b835b 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -172,6 +172,8 @@ class VersionsRestServlet(RestServlet):
"org.matrix.msc4140": bool(self.config.server.max_event_delay_ms),
# Simplified sliding sync
"org.matrix.simplified_msc3575": msc3575_enabled,
+ # Arbitrary key-value profile fields.
+ "uk.tcpip.msc4133": self.config.experimental.msc4133_enabled,
},
},
)
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index 41cf08211f..30d8a58d96 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -18,8 +18,13 @@
# [This file includes modifications made by New Vector Limited]
#
#
-from typing import TYPE_CHECKING, Optional
+import json
+from typing import TYPE_CHECKING, Dict, Optional, Tuple, cast
+from canonicaljson import encode_canonical_json
+
+from synapse.api.constants import ProfileFields
+from synapse.api.errors import Codes, StoreError
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
@@ -27,13 +32,17 @@ from synapse.storage.database import (
LoggingTransaction,
)
from synapse.storage.databases.main.roommember import ProfileInfo
-from synapse.storage.engines import PostgresEngine
-from synapse.types import JsonDict, UserID
+from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from synapse.types import JsonDict, JsonValue, UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
+# The number of bytes that the serialized profile can have.
+MAX_PROFILE_SIZE = 65536
+
+
class ProfileWorkerStore(SQLBaseStore):
def __init__(
self,
@@ -201,6 +210,89 @@ class ProfileWorkerStore(SQLBaseStore):
desc="get_profile_avatar_url",
)
+ async def get_profile_field(self, user_id: UserID, field_name: str) -> JsonValue:
+ """
+ Get a custom profile field for a user.
+
+ Args:
+ user_id: The user's ID.
+ field_name: The custom profile field name.
+
+ Returns:
+ The string value if the field exists, otherwise raises 404.
+ """
+
+ def get_profile_field(txn: LoggingTransaction) -> JsonValue:
+ # This will error if field_name has double quotes in it, but that's not
+ # possible due to the grammar.
+ field_path = f'$."{field_name}"'
+
+ if isinstance(self.database_engine, PostgresEngine):
+ sql = """
+ SELECT JSONB_PATH_EXISTS(fields, ?), JSONB_EXTRACT_PATH(fields, ?)
+ FROM profiles
+ WHERE user_id = ?
+ """
+ txn.execute(
+ sql,
+ (field_path, field_name, user_id.localpart),
+ )
+
+ # Test exists first since value being None is used for both
+ # missing and a null JSON value.
+ exists, value = cast(Tuple[bool, JsonValue], txn.fetchone())
+ if not exists:
+ raise StoreError(404, "No row found")
+ return value
+
+ else:
+ sql = """
+ SELECT JSON_TYPE(fields, ?), JSON_EXTRACT(fields, ?)
+ FROM profiles
+ WHERE user_id = ?
+ """
+ txn.execute(
+ sql,
+ (field_path, field_path, user_id.localpart),
+ )
+
+ # If value_type is None, then the value did not exist.
+ value_type, value = cast(
+ Tuple[Optional[str], JsonValue], txn.fetchone()
+ )
+ if not value_type:
+ raise StoreError(404, "No row found")
+ # If value_type is object or array, then need to deserialize the JSON.
+ # Scalar values are properly returned directly.
+ if value_type in ("object", "array"):
+ assert isinstance(value, str)
+ return json.loads(value)
+ return value
+
+ return await self.db_pool.runInteraction("get_profile_field", get_profile_field)
+
+ async def get_profile_fields(self, user_id: UserID) -> Dict[str, str]:
+ """
+ Get all custom profile fields for a user.
+
+ Args:
+ user_id: The user's ID.
+
+ Returns:
+ A dictionary of custom profile fields.
+ """
+ result = await self.db_pool.simple_select_one_onecol(
+ table="profiles",
+ keyvalues={"full_user_id": user_id.to_string()},
+ retcol="fields",
+ desc="get_profile_fields",
+ )
+ # The SQLite driver doesn't automatically convert JSON to
+ # Python objects
+ if isinstance(self.database_engine, Sqlite3Engine) and result:
+ result = json.loads(result)
+ return result or {}
+
async def create_profile(self, user_id: UserID) -> None:
"""
Create a blank profile for a user.
@@ -215,6 +307,71 @@ class ProfileWorkerStore(SQLBaseStore):
desc="create_profile",
)
+ def _check_profile_size(
+ self,
+ txn: LoggingTransaction,
+ user_id: UserID,
+ new_field_name: str,
+ new_value: JsonValue,
+ ) -> None:
+ # For each entry there are 4 quotes (2 each for key and value), 1 colon,
+ # and 1 comma.
+ PER_VALUE_EXTRA = 6
+
+ # Add the size of the current custom profile fields, ignoring the entry
+ # which will be overwritten.
+ if isinstance(txn.database_engine, PostgresEngine):
+ size_sql = """
+ SELECT
+ OCTET_LENGTH((fields - ?)::text), OCTET_LENGTH(displayname), OCTET_LENGTH(avatar_url)
+ FROM profiles
+ WHERE
+ user_id = ?
+ """
+ txn.execute(
+ size_sql,
+ (new_field_name, user_id.localpart),
+ )
+ else:
+ size_sql = """
+ SELECT
+ LENGTH(json_remove(fields, ?)), LENGTH(displayname), LENGTH(avatar_url)
+ FROM profiles
+ WHERE
+ user_id = ?
+ """
+ txn.execute(
+ size_sql,
+ # This will error if field_name has double quotes in it, but that's not
+ # possible due to the grammar.
+ (f'$."{new_field_name}"', user_id.localpart),
+ )
+ row = cast(Tuple[Optional[int], Optional[int], Optional[int]], txn.fetchone())
+
+ # The values return null if the column is null.
+ total_bytes = (
+ # Discount the opening and closing braces to avoid double counting,
+ # but add one for a comma.
+ # -2 + 1 = -1
+ (row[0] - 1 if row[0] else 0)
+ + (
+ row[1] + len("displayname") + PER_VALUE_EXTRA
+ if new_field_name != ProfileFields.DISPLAYNAME and row[1]
+ else 0
+ )
+ + (
+ row[2] + len("avatar_url") + PER_VALUE_EXTRA
+ if new_field_name != ProfileFields.AVATAR_URL and row[2]
+ else 0
+ )
+ )
+
+ # Add the length of the field being added + the braces.
+ total_bytes += len(encode_canonical_json({new_field_name: new_value}))
+
+ if total_bytes > MAX_PROFILE_SIZE:
+ raise StoreError(400, "Profile too large", Codes.PROFILE_TOO_LARGE)
+
async def set_profile_displayname(
self, user_id: UserID, new_displayname: Optional[str]
) -> None:
@@ -227,14 +384,25 @@ class ProfileWorkerStore(SQLBaseStore):
name is removed.
"""
user_localpart = user_id.localpart
- await self.db_pool.simple_upsert(
- table="profiles",
- keyvalues={"user_id": user_localpart},
- values={
- "displayname": new_displayname,
- "full_user_id": user_id.to_string(),
- },
- desc="set_profile_displayname",
+
+ def set_profile_displayname(txn: LoggingTransaction) -> None:
+ if new_displayname is not None:
+ self._check_profile_size(
+ txn, user_id, ProfileFields.DISPLAYNAME, new_displayname
+ )
+
+ self.db_pool.simple_upsert_txn(
+ txn,
+ table="profiles",
+ keyvalues={"user_id": user_localpart},
+ values={
+ "displayname": new_displayname,
+ "full_user_id": user_id.to_string(),
+ },
+ )
+
+ await self.db_pool.runInteraction(
+ "set_profile_displayname", set_profile_displayname
)
async def set_profile_avatar_url(
@@ -249,13 +417,125 @@ class ProfileWorkerStore(SQLBaseStore):
removed.
"""
user_localpart = user_id.localpart
- await self.db_pool.simple_upsert(
- table="profiles",
- keyvalues={"user_id": user_localpart},
- values={"avatar_url": new_avatar_url, "full_user_id": user_id.to_string()},
- desc="set_profile_avatar_url",
+
+ def set_profile_avatar_url(txn: LoggingTransaction) -> None:
+ if new_avatar_url is not None:
+ self._check_profile_size(
+ txn, user_id, ProfileFields.AVATAR_URL, new_avatar_url
+ )
+
+ self.db_pool.simple_upsert_txn(
+ txn,
+ table="profiles",
+ keyvalues={"user_id": user_localpart},
+ values={
+ "avatar_url": new_avatar_url,
+ "full_user_id": user_id.to_string(),
+ },
+ )
+
+ await self.db_pool.runInteraction(
+ "set_profile_avatar_url", set_profile_avatar_url
)
+ async def set_profile_field(
+ self, user_id: UserID, field_name: str, new_value: JsonValue
+ ) -> None:
+ """
+ Set a custom profile field for a user.
+
+ Args:
+ user_id: The user's ID.
+ field_name: The name of the custom profile field.
+ new_value: The value of the custom profile field.
+ """
+
+ # Encode to canonical JSON.
+ canonical_value = encode_canonical_json(new_value)
+
+ def set_profile_field(txn: LoggingTransaction) -> None:
+ self._check_profile_size(txn, user_id, field_name, new_value)
+
+ if isinstance(self.database_engine, PostgresEngine):
+ from psycopg2.extras import Json
+
+ # Note that the || jsonb operator is not recursive, any duplicate
+ # keys will be taken from the second value.
+ sql = """
+ INSERT INTO profiles (user_id, full_user_id, fields) VALUES (?, ?, JSON_BUILD_OBJECT(?, ?::jsonb))
+ ON CONFLICT (user_id)
+ DO UPDATE SET full_user_id = EXCLUDED.full_user_id, fields = COALESCE(profiles.fields, '{}'::jsonb) || EXCLUDED.fields
+ """
+
+ txn.execute(
+ sql,
+ (
+ user_id.localpart,
+ user_id.to_string(),
+ field_name,
+ # Pass as a JSON object since we have passing bytes disabled
+ # at the database driver.
+ Json(json.loads(canonical_value)),
+ ),
+ )
+ else:
+ # You may be tempted to use json_patch instead of providing the parameters
+ # twice, but that recursively merges objects instead of replacing.
+ sql = """
+ INSERT INTO profiles (user_id, full_user_id, fields) VALUES (?, ?, JSON_OBJECT(?, JSON(?)))
+ ON CONFLICT (user_id)
+ DO UPDATE SET full_user_id = EXCLUDED.full_user_id, fields = JSON_SET(COALESCE(profiles.fields, '{}'), ?, JSON(?))
+ """
+ # This will error if field_name has double quotes in it, but that's not
+ # possible due to the grammar.
+ json_field_name = f'$."{field_name}"'
+
+ txn.execute(
+ sql,
+ (
+ user_id.localpart,
+ user_id.to_string(),
+ json_field_name,
+ canonical_value,
+ json_field_name,
+ canonical_value,
+ ),
+ )
+
+ await self.db_pool.runInteraction("set_profile_field", set_profile_field)
+
+ async def delete_profile_field(self, user_id: UserID, field_name: str) -> None:
+ """
+ Remove a custom profile field for a user.
+
+ Args:
+ user_id: The user's ID.
+ field_name: The name of the custom profile field.
+ """
+
+ def delete_profile_field(txn: LoggingTransaction) -> None:
+ if isinstance(self.database_engine, PostgresEngine):
+ sql = """
+ UPDATE profiles SET fields = fields - ?
+ WHERE user_id = ?
+ """
+ txn.execute(
+ sql,
+ (field_name, user_id.localpart),
+ )
+ else:
+ sql = """
+ UPDATE profiles SET fields = json_remove(fields, ?)
+ WHERE user_id = ?
+ """
+ txn.execute(
+ sql,
+ # This will error if field_name has double quotes in it.
+ (f'$."{field_name}"', user_id.localpart),
+ )
+
+ await self.db_pool.runInteraction("delete_profile_field", delete_profile_field)
+
class ProfileStore(ProfileWorkerStore):
pass
diff --git a/synapse/storage/schema/main/delta/88/01_custom_profile_fields.sql b/synapse/storage/schema/main/delta/88/01_custom_profile_fields.sql
new file mode 100644
index 0000000000..63cbd7ffa9
--- /dev/null
+++ b/synapse/storage/schema/main/delta/88/01_custom_profile_fields.sql
@@ -0,0 +1,15 @@
+--
+-- This file is licensed under the Affero General Public License (AGPL) version 3.
+--
+-- Copyright (C) 2024 Patrick Cloke
+--
+-- This program is free software: you can redistribute it and/or modify
+-- it under the terms of the GNU Affero General Public License as
+-- published by the Free Software Foundation, either version 3 of the
+-- License, or (at your option) any later version.
+--
+-- See the GNU Affero General Public License for more details:
+-- <https://www.gnu.org/licenses/agpl-3.0.html>.
+
+-- Custom profile fields.
+ALTER TABLE profiles ADD COLUMN fields JSONB;
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index 13ff54b669..32b5bc00c9 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -43,6 +43,14 @@ CLIENT_SECRET_REGEX = re.compile(r"^[0-9a-zA-Z\.=_\-]+$")
#
MXC_REGEX = re.compile("^mxc://([^/]+)/([^/#?]+)$")
+# https://spec.matrix.org/v1.13/appendices/#common-namespaced-identifier-grammar
+#
+# At least one character, less than or equal to 255 characters. Must start with
+# a-z, the rest is a-z, 0-9, -, _, or ..
+#
+# This doesn't check anything about validity of namespaces.
+NAMESPACED_GRAMMAR = re.compile(r"^[a-z][a-z0-9_.-]{0,254}$")
+
def random_string(length: int) -> str:
"""Generate a cryptographically secure string of random letters.
@@ -68,6 +76,10 @@ def is_ascii(s: bytes) -> bool:
return True
+def is_namedspaced_grammar(s: str) -> bool:
+ return bool(NAMESPACED_GRAMMAR.match(s))
+
+
def assert_valid_client_secret(client_secret: str) -> None:
"""Validate that a given string matches the client_secret defined by the spec"""
if (
diff --git a/tests/rest/client/test_capabilities.py b/tests/rest/client/test_capabilities.py
index bbe8ab1a7c..8af00221c2 100644
--- a/tests/rest/client/test_capabilities.py
+++ b/tests/rest/client/test_capabilities.py
@@ -142,6 +142,50 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, HTTPStatus.OK)
self.assertFalse(capabilities["m.set_avatar_url"]["enabled"])
+ @override_config(
+ {
+ "enable_set_displayname": False,
+ "experimental_features": {"msc4133_enabled": True},
+ }
+ )
+ def test_get_set_displayname_capabilities_displayname_disabled_msc4133(
+ self,
+ ) -> None:
+ """Test if set displayname is disabled that the server responds it."""
+ access_token = self.login(self.localpart, self.password)
+
+ channel = self.make_request("GET", self.url, access_token=access_token)
+ capabilities = channel.json_body["capabilities"]
+
+ self.assertEqual(channel.code, HTTPStatus.OK)
+ self.assertFalse(capabilities["m.set_displayname"]["enabled"])
+ self.assertTrue(capabilities["uk.tcpip.msc4133.profile_fields"]["enabled"])
+ self.assertEqual(
+ capabilities["uk.tcpip.msc4133.profile_fields"]["disallowed"],
+ ["displayname"],
+ )
+
+ @override_config(
+ {
+ "enable_set_avatar_url": False,
+ "experimental_features": {"msc4133_enabled": True},
+ }
+ )
+ def test_get_set_avatar_url_capabilities_avatar_url_disabled_msc4133(self) -> None:
+ """Test if set avatar_url is disabled that the server responds it."""
+ access_token = self.login(self.localpart, self.password)
+
+ channel = self.make_request("GET", self.url, access_token=access_token)
+ capabilities = channel.json_body["capabilities"]
+
+ self.assertEqual(channel.code, HTTPStatus.OK)
+ self.assertFalse(capabilities["m.set_avatar_url"]["enabled"])
+ self.assertTrue(capabilities["uk.tcpip.msc4133.profile_fields"]["enabled"])
+ self.assertEqual(
+ capabilities["uk.tcpip.msc4133.profile_fields"]["disallowed"],
+ ["avatar_url"],
+ )
+
@override_config({"enable_3pid_changes": False})
def test_get_change_3pid_capabilities_3pid_disabled(self) -> None:
"""Test if change 3pid is disabled that the server responds it."""
diff --git a/tests/rest/client/test_profile.py b/tests/rest/client/test_profile.py
index a92713d220..708402b792 100644
--- a/tests/rest/client/test_profile.py
+++ b/tests/rest/client/test_profile.py
@@ -25,16 +25,20 @@ import urllib.parse
from http import HTTPStatus
from typing import Any, Dict, Optional
+from canonicaljson import encode_canonical_json
+
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import Codes
from synapse.rest import admin
from synapse.rest.client import login, profile, room
from synapse.server import HomeServer
+from synapse.storage.databases.main.profile import MAX_PROFILE_SIZE
from synapse.types import UserID
from synapse.util import Clock
from tests import unittest
+from tests.utils import USE_POSTGRES_FOR_TESTS
class ProfileTestCase(unittest.HomeserverTestCase):
@@ -480,6 +484,298 @@ class ProfileTestCase(unittest.HomeserverTestCase):
# The client requested ?propagate=true, so it should have happened.
self.assertEqual(channel.json_body.get(prop), "http://my.server/pic.gif")
+ @unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
+ def test_get_missing_custom_field(self) -> None:
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
+ )
+ self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, channel.result)
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ @unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
+ def test_get_missing_custom_field_invalid_field_name(self) -> None:
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/[custom_field]",
+ )
+ self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ @unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
+ def test_get_custom_field_rejects_bad_username(self) -> None:
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{urllib.parse.quote('@alice:')}/custom_field",
+ )
+ self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ @unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
+ def test_set_custom_field(self) -> None:
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
+ content={"custom_field": "test"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
+ )
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+ self.assertEqual(channel.json_body, {"custom_field": "test"})
+
+ # Overwriting the field should work.
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
+ content={"custom_field": "new_Value"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
+ )
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+ self.assertEqual(channel.json_body, {"custom_field": "new_Value"})
+
+ # Deleting the field should work.
+ channel = self.make_request(
+ "DELETE",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
+ content={},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
+ )
+ self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, channel.result)
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ @unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
+ def test_non_string(self) -> None:
+ """Non-string fields are supported for custom fields."""
+ fields = {
+ "bool_field": True,
+ "array_field": ["test"],
+ "object_field": {"test": "test"},
+ "numeric_field": 1,
+ "null_field": None,
+ }
+
+ for key, value in fields.items():
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}",
+ content={key: value},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/v3/profile/{self.owner}",
+ )
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+ self.assertEqual(channel.json_body, {"displayname": "owner", **fields})
+
+ # Check getting individual fields works.
+ for key, value in fields.items():
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}",
+ )
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+ self.assertEqual(channel.json_body, {key: value})
+
+ @unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
+ def test_set_custom_field_noauth(self) -> None:
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
+ content={"custom_field": "test"},
+ )
+ self.assertEqual(channel.code, 401, channel.result)
+ self.assertEqual(channel.json_body["errcode"], Codes.MISSING_TOKEN)
+
+ @unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
+ def test_set_custom_field_size(self) -> None:
+ """
+ Attempts to set a custom field name that is too long should get a 400 error.
+ """
+ # Key is missing.
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/",
+ content={"": "test"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ # Single key is too large.
+ key = "c" * 500
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}",
+ content={key: "test"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.json_body["errcode"], Codes.KEY_TOO_LARGE)
+
+ channel = self.make_request(
+ "DELETE",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}",
+ content={key: "test"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.json_body["errcode"], Codes.KEY_TOO_LARGE)
+
+ # Key doesn't match body.
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/custom_field",
+ content={"diff_key": "test"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.json_body["errcode"], Codes.MISSING_PARAM)
+
+ @unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
+ def test_set_custom_field_profile_too_long(self) -> None:
+ """
+ Attempts to set a custom field that would push the overall profile too large.
+ """
+ # Get right to the boundary:
+ # len("displayname") + len("owner") + 5 = 21 for the displayname
+ # 1 + 65498 + 5 for key "a" = 65504
+ # 2 braces, 1 comma
+ # 3 + 21 + 65498 = 65522 < 65536.
+ key = "a"
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}",
+ content={key: "a" * 65498},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ # Get the entire profile.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/v3/profile/{self.owner}",
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+ canonical_json = encode_canonical_json(channel.json_body)
+ # 6 is the minimum bytes to store a value: 4 quotes, 1 colon, 1 comma, an empty key.
+ # Be one below that so we can prove we're at the boundary.
+ self.assertEqual(len(canonical_json), MAX_PROFILE_SIZE - 8)
+
+ # Postgres stores JSONB with whitespace, while SQLite doesn't.
+ if USE_POSTGRES_FOR_TESTS:
+ ADDITIONAL_CHARS = 0
+ else:
+ ADDITIONAL_CHARS = 1
+
+ # The next one should fail, note the value has a (JSON) length of 2.
+ key = "b"
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}",
+ content={key: "1" + "a" * ADDITIONAL_CHARS},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.json_body["errcode"], Codes.PROFILE_TOO_LARGE)
+
+ # Setting an avatar or (longer) display name should not work.
+ channel = self.make_request(
+ "PUT",
+ f"/profile/{self.owner}/displayname",
+ content={"displayname": "owner12345678" + "a" * ADDITIONAL_CHARS},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.json_body["errcode"], Codes.PROFILE_TOO_LARGE)
+
+ channel = self.make_request(
+ "PUT",
+ f"/profile/{self.owner}/avatar_url",
+ content={"avatar_url": "mxc://foo/bar"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.json_body["errcode"], Codes.PROFILE_TOO_LARGE)
+
+ # Removing a single byte should work.
+ key = "b"
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}",
+ content={key: "" + "a" * ADDITIONAL_CHARS},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ # Finally, setting a field that already exists to a value that is <= in length should work.
+ key = "a"
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/{key}",
+ content={key: ""},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ @unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
+ def test_set_custom_field_displayname(self) -> None:
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/displayname",
+ content={"displayname": "test"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ displayname = self._get_displayname()
+ self.assertEqual(displayname, "test")
+
+ @unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
+ def test_set_custom_field_avatar_url(self) -> None:
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.owner}/avatar_url",
+ content={"avatar_url": "mxc://test/good"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ avatar_url = self._get_avatar_url()
+ self.assertEqual(avatar_url, "mxc://test/good")
+
+ @unittest.override_config({"experimental_features": {"msc4133_enabled": True}})
+ def test_set_custom_field_other(self) -> None:
+ """Setting someone else's profile field should fail"""
+ channel = self.make_request(
+ "PUT",
+ f"/_matrix/client/unstable/uk.tcpip.msc4133/profile/{self.other}/custom_field",
+ content={"custom_field": "test"},
+ access_token=self.owner_tok,
+ )
+ self.assertEqual(channel.code, 403, channel.result)
+ self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN)
+
def _setup_local_files(self, names_and_props: Dict[str, Dict[str, Any]]) -> None:
"""Stores metadata about files in the database.
diff --git a/tests/util/test_stringutils.py b/tests/util/test_stringutils.py
index 646fd2163e..34c2395ecf 100644
--- a/tests/util/test_stringutils.py
+++ b/tests/util/test_stringutils.py
@@ -20,7 +20,11 @@
#
from synapse.api.errors import SynapseError
-from synapse.util.stringutils import assert_valid_client_secret, base62_encode
+from synapse.util.stringutils import (
+ assert_valid_client_secret,
+ base62_encode,
+ is_namedspaced_grammar,
+)
from .. import unittest
@@ -58,3 +62,25 @@ class StringUtilsTestCase(unittest.TestCase):
self.assertEqual("10", base62_encode(62))
self.assertEqual("1c", base62_encode(100))
self.assertEqual("001c", base62_encode(100, minwidth=4))
+
+ def test_namespaced_identifier(self) -> None:
+ self.assertTrue(is_namedspaced_grammar("test"))
+ self.assertTrue(is_namedspaced_grammar("m.test"))
+ self.assertTrue(is_namedspaced_grammar("org.matrix.test"))
+ self.assertTrue(is_namedspaced_grammar("org.matrix.msc1234"))
+ self.assertTrue(is_namedspaced_grammar("test"))
+ self.assertTrue(is_namedspaced_grammar("t-e_s.t"))
+
+ # Must start with letter.
+ self.assertFalse(is_namedspaced_grammar("1test"))
+ self.assertFalse(is_namedspaced_grammar("-test"))
+ self.assertFalse(is_namedspaced_grammar("_test"))
+ self.assertFalse(is_namedspaced_grammar(".test"))
+
+ # Must contain only a-z, 0-9, -, _, ..
+ self.assertFalse(is_namedspaced_grammar("test/"))
+ self.assertFalse(is_namedspaced_grammar('test"'))
+ self.assertFalse(is_namedspaced_grammar("testö"))
+
+ # Must be < 255 characters.
+ self.assertFalse(is_namedspaced_grammar("t" * 256))
|