diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index 41cf08211f..30d8a58d96 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -18,8 +18,13 @@
# [This file includes modifications made by New Vector Limited]
#
#
-from typing import TYPE_CHECKING, Optional
+import json
+from typing import TYPE_CHECKING, Dict, Optional, Tuple, cast
+from canonicaljson import encode_canonical_json
+
+from synapse.api.constants import ProfileFields
+from synapse.api.errors import Codes, StoreError
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
@@ -27,13 +32,17 @@ from synapse.storage.database import (
LoggingTransaction,
)
from synapse.storage.databases.main.roommember import ProfileInfo
-from synapse.storage.engines import PostgresEngine
-from synapse.types import JsonDict, UserID
+from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from synapse.types import JsonDict, JsonValue, UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
+# The number of bytes that the serialized profile can have.
+MAX_PROFILE_SIZE = 65536
+
+
class ProfileWorkerStore(SQLBaseStore):
def __init__(
self,
@@ -201,6 +210,89 @@ class ProfileWorkerStore(SQLBaseStore):
desc="get_profile_avatar_url",
)
+ async def get_profile_field(self, user_id: UserID, field_name: str) -> JsonValue:
+ """
+ Get a custom profile field for a user.
+
+ Args:
+ user_id: The user's ID.
+ field_name: The custom profile field name.
+
+ Returns:
+ The string value if the field exists, otherwise raises 404.
+ """
+
+ def get_profile_field(txn: LoggingTransaction) -> JsonValue:
+ # This will error if field_name has double quotes in it, but that's not
+ # possible due to the grammar.
+ field_path = f'$."{field_name}"'
+
+ if isinstance(self.database_engine, PostgresEngine):
+ sql = """
+ SELECT JSONB_PATH_EXISTS(fields, ?), JSONB_EXTRACT_PATH(fields, ?)
+ FROM profiles
+ WHERE user_id = ?
+ """
+ txn.execute(
+ sql,
+ (field_path, field_name, user_id.localpart),
+ )
+
+ # Test exists first since value being None is used for both
+ # missing and a null JSON value.
+ exists, value = cast(Tuple[bool, JsonValue], txn.fetchone())
+ if not exists:
+ raise StoreError(404, "No row found")
+ return value
+
+ else:
+ sql = """
+ SELECT JSON_TYPE(fields, ?), JSON_EXTRACT(fields, ?)
+ FROM profiles
+ WHERE user_id = ?
+ """
+ txn.execute(
+ sql,
+ (field_path, field_path, user_id.localpart),
+ )
+
+ # If value_type is None, then the value did not exist.
+ value_type, value = cast(
+ Tuple[Optional[str], JsonValue], txn.fetchone()
+ )
+ if not value_type:
+ raise StoreError(404, "No row found")
+ # If value_type is object or array, then need to deserialize the JSON.
+ # Scalar values are properly returned directly.
+ if value_type in ("object", "array"):
+ assert isinstance(value, str)
+ return json.loads(value)
+ return value
+
+ return await self.db_pool.runInteraction("get_profile_field", get_profile_field)
+
+ async def get_profile_fields(self, user_id: UserID) -> Dict[str, str]:
+ """
+ Get all custom profile fields for a user.
+
+ Args:
+ user_id: The user's ID.
+
+ Returns:
+ A dictionary of custom profile fields.
+ """
+ result = await self.db_pool.simple_select_one_onecol(
+ table="profiles",
+ keyvalues={"full_user_id": user_id.to_string()},
+ retcol="fields",
+ desc="get_profile_fields",
+ )
+ # The SQLite driver doesn't automatically convert JSON to
+ # Python objects
+ if isinstance(self.database_engine, Sqlite3Engine) and result:
+ result = json.loads(result)
+ return result or {}
+
async def create_profile(self, user_id: UserID) -> None:
"""
Create a blank profile for a user.
@@ -215,6 +307,71 @@ class ProfileWorkerStore(SQLBaseStore):
desc="create_profile",
)
+ def _check_profile_size(
+ self,
+ txn: LoggingTransaction,
+ user_id: UserID,
+ new_field_name: str,
+ new_value: JsonValue,
+ ) -> None:
+ # For each entry there are 4 quotes (2 each for key and value), 1 colon,
+ # and 1 comma.
+ PER_VALUE_EXTRA = 6
+
+ # Add the size of the current custom profile fields, ignoring the entry
+ # which will be overwritten.
+ if isinstance(txn.database_engine, PostgresEngine):
+ size_sql = """
+ SELECT
+ OCTET_LENGTH((fields - ?)::text), OCTET_LENGTH(displayname), OCTET_LENGTH(avatar_url)
+ FROM profiles
+ WHERE
+ user_id = ?
+ """
+ txn.execute(
+ size_sql,
+ (new_field_name, user_id.localpart),
+ )
+ else:
+ size_sql = """
+ SELECT
+ LENGTH(json_remove(fields, ?)), LENGTH(displayname), LENGTH(avatar_url)
+ FROM profiles
+ WHERE
+ user_id = ?
+ """
+ txn.execute(
+ size_sql,
+ # This will error if field_name has double quotes in it, but that's not
+ # possible due to the grammar.
+ (f'$."{new_field_name}"', user_id.localpart),
+ )
+ row = cast(Tuple[Optional[int], Optional[int], Optional[int]], txn.fetchone())
+
+ # The values return null if the column is null.
+ total_bytes = (
+ # Discount the opening and closing braces to avoid double counting,
+ # but add one for a comma.
+ # -2 + 1 = -1
+ (row[0] - 1 if row[0] else 0)
+ + (
+ row[1] + len("displayname") + PER_VALUE_EXTRA
+ if new_field_name != ProfileFields.DISPLAYNAME and row[1]
+ else 0
+ )
+ + (
+ row[2] + len("avatar_url") + PER_VALUE_EXTRA
+ if new_field_name != ProfileFields.AVATAR_URL and row[2]
+ else 0
+ )
+ )
+
+ # Add the length of the field being added + the braces.
+ total_bytes += len(encode_canonical_json({new_field_name: new_value}))
+
+ if total_bytes > MAX_PROFILE_SIZE:
+ raise StoreError(400, "Profile too large", Codes.PROFILE_TOO_LARGE)
+
async def set_profile_displayname(
self, user_id: UserID, new_displayname: Optional[str]
) -> None:
@@ -227,14 +384,25 @@ class ProfileWorkerStore(SQLBaseStore):
name is removed.
"""
user_localpart = user_id.localpart
- await self.db_pool.simple_upsert(
- table="profiles",
- keyvalues={"user_id": user_localpart},
- values={
- "displayname": new_displayname,
- "full_user_id": user_id.to_string(),
- },
- desc="set_profile_displayname",
+
+ def set_profile_displayname(txn: LoggingTransaction) -> None:
+ if new_displayname is not None:
+ self._check_profile_size(
+ txn, user_id, ProfileFields.DISPLAYNAME, new_displayname
+ )
+
+ self.db_pool.simple_upsert_txn(
+ txn,
+ table="profiles",
+ keyvalues={"user_id": user_localpart},
+ values={
+ "displayname": new_displayname,
+ "full_user_id": user_id.to_string(),
+ },
+ )
+
+ await self.db_pool.runInteraction(
+ "set_profile_displayname", set_profile_displayname
)
async def set_profile_avatar_url(
@@ -249,13 +417,125 @@ class ProfileWorkerStore(SQLBaseStore):
removed.
"""
user_localpart = user_id.localpart
- await self.db_pool.simple_upsert(
- table="profiles",
- keyvalues={"user_id": user_localpart},
- values={"avatar_url": new_avatar_url, "full_user_id": user_id.to_string()},
- desc="set_profile_avatar_url",
+
+ def set_profile_avatar_url(txn: LoggingTransaction) -> None:
+ if new_avatar_url is not None:
+ self._check_profile_size(
+ txn, user_id, ProfileFields.AVATAR_URL, new_avatar_url
+ )
+
+ self.db_pool.simple_upsert_txn(
+ txn,
+ table="profiles",
+ keyvalues={"user_id": user_localpart},
+ values={
+ "avatar_url": new_avatar_url,
+ "full_user_id": user_id.to_string(),
+ },
+ )
+
+ await self.db_pool.runInteraction(
+ "set_profile_avatar_url", set_profile_avatar_url
)
+ async def set_profile_field(
+ self, user_id: UserID, field_name: str, new_value: JsonValue
+ ) -> None:
+ """
+ Set a custom profile field for a user.
+
+ Args:
+ user_id: The user's ID.
+ field_name: The name of the custom profile field.
+ new_value: The value of the custom profile field.
+ """
+
+ # Encode to canonical JSON.
+ canonical_value = encode_canonical_json(new_value)
+
+ def set_profile_field(txn: LoggingTransaction) -> None:
+ self._check_profile_size(txn, user_id, field_name, new_value)
+
+ if isinstance(self.database_engine, PostgresEngine):
+ from psycopg2.extras import Json
+
+ # Note that the || jsonb operator is not recursive, any duplicate
+ # keys will be taken from the second value.
+ sql = """
+ INSERT INTO profiles (user_id, full_user_id, fields) VALUES (?, ?, JSON_BUILD_OBJECT(?, ?::jsonb))
+ ON CONFLICT (user_id)
+ DO UPDATE SET full_user_id = EXCLUDED.full_user_id, fields = COALESCE(profiles.fields, '{}'::jsonb) || EXCLUDED.fields
+ """
+
+ txn.execute(
+ sql,
+ (
+ user_id.localpart,
+ user_id.to_string(),
+ field_name,
+ # Pass as a JSON object since we have passing bytes disabled
+ # at the database driver.
+ Json(json.loads(canonical_value)),
+ ),
+ )
+ else:
+ # You may be tempted to use json_patch instead of providing the parameters
+ # twice, but that recursively merges objects instead of replacing.
+ sql = """
+ INSERT INTO profiles (user_id, full_user_id, fields) VALUES (?, ?, JSON_OBJECT(?, JSON(?)))
+ ON CONFLICT (user_id)
+ DO UPDATE SET full_user_id = EXCLUDED.full_user_id, fields = JSON_SET(COALESCE(profiles.fields, '{}'), ?, JSON(?))
+ """
+ # This will error if field_name has double quotes in it, but that's not
+ # possible due to the grammar.
+ json_field_name = f'$."{field_name}"'
+
+ txn.execute(
+ sql,
+ (
+ user_id.localpart,
+ user_id.to_string(),
+ json_field_name,
+ canonical_value,
+ json_field_name,
+ canonical_value,
+ ),
+ )
+
+ await self.db_pool.runInteraction("set_profile_field", set_profile_field)
+
+ async def delete_profile_field(self, user_id: UserID, field_name: str) -> None:
+ """
+ Remove a custom profile field for a user.
+
+ Args:
+ user_id: The user's ID.
+ field_name: The name of the custom profile field.
+ """
+
+ def delete_profile_field(txn: LoggingTransaction) -> None:
+ if isinstance(self.database_engine, PostgresEngine):
+ sql = """
+ UPDATE profiles SET fields = fields - ?
+ WHERE user_id = ?
+ """
+ txn.execute(
+ sql,
+ (field_name, user_id.localpart),
+ )
+ else:
+ sql = """
+ UPDATE profiles SET fields = json_remove(fields, ?)
+ WHERE user_id = ?
+ """
+ txn.execute(
+ sql,
+ # This will error if field_name has double quotes in it.
+ (f'$."{field_name}"', user_id.localpart),
+ )
+
+ await self.db_pool.runInteraction("delete_profile_field", delete_profile_field)
+
class ProfileStore(ProfileWorkerStore):
pass
|