summary refs log tree commit diff
path: root/synapse/storage/databases/main/profile.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/profile.py')
-rw-r--r--synapse/storage/databases/main/profile.py312
1 files changed, 296 insertions, 16 deletions
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py

index 41cf08211f..30d8a58d96 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py
@@ -18,8 +18,13 @@ # [This file includes modifications made by New Vector Limited] # # -from typing import TYPE_CHECKING, Optional +import json +from typing import TYPE_CHECKING, Dict, Optional, Tuple, cast +from canonicaljson import encode_canonical_json + +from synapse.api.constants import ProfileFields +from synapse.api.errors import Codes, StoreError from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( DatabasePool, @@ -27,13 +32,17 @@ from synapse.storage.database import ( LoggingTransaction, ) from synapse.storage.databases.main.roommember import ProfileInfo -from synapse.storage.engines import PostgresEngine -from synapse.types import JsonDict, UserID +from synapse.storage.engines import PostgresEngine, Sqlite3Engine +from synapse.types import JsonDict, JsonValue, UserID if TYPE_CHECKING: from synapse.server import HomeServer +# The number of bytes that the serialized profile can have. +MAX_PROFILE_SIZE = 65536 + + class ProfileWorkerStore(SQLBaseStore): def __init__( self, @@ -201,6 +210,89 @@ class ProfileWorkerStore(SQLBaseStore): desc="get_profile_avatar_url", ) + async def get_profile_field(self, user_id: UserID, field_name: str) -> JsonValue: + """ + Get a custom profile field for a user. + + Args: + user_id: The user's ID. + field_name: The custom profile field name. + + Returns: + The string value if the field exists, otherwise raises 404. + """ + + def get_profile_field(txn: LoggingTransaction) -> JsonValue: + # This will error if field_name has double quotes in it, but that's not + # possible due to the grammar. + field_path = f'$."{field_name}"' + + if isinstance(self.database_engine, PostgresEngine): + sql = """ + SELECT JSONB_PATH_EXISTS(fields, ?), JSONB_EXTRACT_PATH(fields, ?) + FROM profiles + WHERE user_id = ? + """ + txn.execute( + sql, + (field_path, field_name, user_id.localpart), + ) + + # Test exists first since value being None is used for both + # missing and a null JSON value. + exists, value = cast(Tuple[bool, JsonValue], txn.fetchone()) + if not exists: + raise StoreError(404, "No row found") + return value + + else: + sql = """ + SELECT JSON_TYPE(fields, ?), JSON_EXTRACT(fields, ?) + FROM profiles + WHERE user_id = ? + """ + txn.execute( + sql, + (field_path, field_path, user_id.localpart), + ) + + # If value_type is None, then the value did not exist. + value_type, value = cast( + Tuple[Optional[str], JsonValue], txn.fetchone() + ) + if not value_type: + raise StoreError(404, "No row found") + # If value_type is object or array, then need to deserialize the JSON. + # Scalar values are properly returned directly. + if value_type in ("object", "array"): + assert isinstance(value, str) + return json.loads(value) + return value + + return await self.db_pool.runInteraction("get_profile_field", get_profile_field) + + async def get_profile_fields(self, user_id: UserID) -> Dict[str, str]: + """ + Get all custom profile fields for a user. + + Args: + user_id: The user's ID. + + Returns: + A dictionary of custom profile fields. + """ + result = await self.db_pool.simple_select_one_onecol( + table="profiles", + keyvalues={"full_user_id": user_id.to_string()}, + retcol="fields", + desc="get_profile_fields", + ) + # The SQLite driver doesn't automatically convert JSON to + # Python objects + if isinstance(self.database_engine, Sqlite3Engine) and result: + result = json.loads(result) + return result or {} + async def create_profile(self, user_id: UserID) -> None: """ Create a blank profile for a user. @@ -215,6 +307,71 @@ class ProfileWorkerStore(SQLBaseStore): desc="create_profile", ) + def _check_profile_size( + self, + txn: LoggingTransaction, + user_id: UserID, + new_field_name: str, + new_value: JsonValue, + ) -> None: + # For each entry there are 4 quotes (2 each for key and value), 1 colon, + # and 1 comma. + PER_VALUE_EXTRA = 6 + + # Add the size of the current custom profile fields, ignoring the entry + # which will be overwritten. + if isinstance(txn.database_engine, PostgresEngine): + size_sql = """ + SELECT + OCTET_LENGTH((fields - ?)::text), OCTET_LENGTH(displayname), OCTET_LENGTH(avatar_url) + FROM profiles + WHERE + user_id = ? + """ + txn.execute( + size_sql, + (new_field_name, user_id.localpart), + ) + else: + size_sql = """ + SELECT + LENGTH(json_remove(fields, ?)), LENGTH(displayname), LENGTH(avatar_url) + FROM profiles + WHERE + user_id = ? + """ + txn.execute( + size_sql, + # This will error if field_name has double quotes in it, but that's not + # possible due to the grammar. + (f'$."{new_field_name}"', user_id.localpart), + ) + row = cast(Tuple[Optional[int], Optional[int], Optional[int]], txn.fetchone()) + + # The values return null if the column is null. + total_bytes = ( + # Discount the opening and closing braces to avoid double counting, + # but add one for a comma. + # -2 + 1 = -1 + (row[0] - 1 if row[0] else 0) + + ( + row[1] + len("displayname") + PER_VALUE_EXTRA + if new_field_name != ProfileFields.DISPLAYNAME and row[1] + else 0 + ) + + ( + row[2] + len("avatar_url") + PER_VALUE_EXTRA + if new_field_name != ProfileFields.AVATAR_URL and row[2] + else 0 + ) + ) + + # Add the length of the field being added + the braces. + total_bytes += len(encode_canonical_json({new_field_name: new_value})) + + if total_bytes > MAX_PROFILE_SIZE: + raise StoreError(400, "Profile too large", Codes.PROFILE_TOO_LARGE) + async def set_profile_displayname( self, user_id: UserID, new_displayname: Optional[str] ) -> None: @@ -227,14 +384,25 @@ class ProfileWorkerStore(SQLBaseStore): name is removed. """ user_localpart = user_id.localpart - await self.db_pool.simple_upsert( - table="profiles", - keyvalues={"user_id": user_localpart}, - values={ - "displayname": new_displayname, - "full_user_id": user_id.to_string(), - }, - desc="set_profile_displayname", + + def set_profile_displayname(txn: LoggingTransaction) -> None: + if new_displayname is not None: + self._check_profile_size( + txn, user_id, ProfileFields.DISPLAYNAME, new_displayname + ) + + self.db_pool.simple_upsert_txn( + txn, + table="profiles", + keyvalues={"user_id": user_localpart}, + values={ + "displayname": new_displayname, + "full_user_id": user_id.to_string(), + }, + ) + + await self.db_pool.runInteraction( + "set_profile_displayname", set_profile_displayname ) async def set_profile_avatar_url( @@ -249,13 +417,125 @@ class ProfileWorkerStore(SQLBaseStore): removed. """ user_localpart = user_id.localpart - await self.db_pool.simple_upsert( - table="profiles", - keyvalues={"user_id": user_localpart}, - values={"avatar_url": new_avatar_url, "full_user_id": user_id.to_string()}, - desc="set_profile_avatar_url", + + def set_profile_avatar_url(txn: LoggingTransaction) -> None: + if new_avatar_url is not None: + self._check_profile_size( + txn, user_id, ProfileFields.AVATAR_URL, new_avatar_url + ) + + self.db_pool.simple_upsert_txn( + txn, + table="profiles", + keyvalues={"user_id": user_localpart}, + values={ + "avatar_url": new_avatar_url, + "full_user_id": user_id.to_string(), + }, + ) + + await self.db_pool.runInteraction( + "set_profile_avatar_url", set_profile_avatar_url ) + async def set_profile_field( + self, user_id: UserID, field_name: str, new_value: JsonValue + ) -> None: + """ + Set a custom profile field for a user. + + Args: + user_id: The user's ID. + field_name: The name of the custom profile field. + new_value: The value of the custom profile field. + """ + + # Encode to canonical JSON. + canonical_value = encode_canonical_json(new_value) + + def set_profile_field(txn: LoggingTransaction) -> None: + self._check_profile_size(txn, user_id, field_name, new_value) + + if isinstance(self.database_engine, PostgresEngine): + from psycopg2.extras import Json + + # Note that the || jsonb operator is not recursive, any duplicate + # keys will be taken from the second value. + sql = """ + INSERT INTO profiles (user_id, full_user_id, fields) VALUES (?, ?, JSON_BUILD_OBJECT(?, ?::jsonb)) + ON CONFLICT (user_id) + DO UPDATE SET full_user_id = EXCLUDED.full_user_id, fields = COALESCE(profiles.fields, '{}'::jsonb) || EXCLUDED.fields + """ + + txn.execute( + sql, + ( + user_id.localpart, + user_id.to_string(), + field_name, + # Pass as a JSON object since we have passing bytes disabled + # at the database driver. + Json(json.loads(canonical_value)), + ), + ) + else: + # You may be tempted to use json_patch instead of providing the parameters + # twice, but that recursively merges objects instead of replacing. + sql = """ + INSERT INTO profiles (user_id, full_user_id, fields) VALUES (?, ?, JSON_OBJECT(?, JSON(?))) + ON CONFLICT (user_id) + DO UPDATE SET full_user_id = EXCLUDED.full_user_id, fields = JSON_SET(COALESCE(profiles.fields, '{}'), ?, JSON(?)) + """ + # This will error if field_name has double quotes in it, but that's not + # possible due to the grammar. + json_field_name = f'$."{field_name}"' + + txn.execute( + sql, + ( + user_id.localpart, + user_id.to_string(), + json_field_name, + canonical_value, + json_field_name, + canonical_value, + ), + ) + + await self.db_pool.runInteraction("set_profile_field", set_profile_field) + + async def delete_profile_field(self, user_id: UserID, field_name: str) -> None: + """ + Remove a custom profile field for a user. + + Args: + user_id: The user's ID. + field_name: The name of the custom profile field. + """ + + def delete_profile_field(txn: LoggingTransaction) -> None: + if isinstance(self.database_engine, PostgresEngine): + sql = """ + UPDATE profiles SET fields = fields - ? + WHERE user_id = ? + """ + txn.execute( + sql, + (field_name, user_id.localpart), + ) + else: + sql = """ + UPDATE profiles SET fields = json_remove(fields, ?) + WHERE user_id = ? + """ + txn.execute( + sql, + # This will error if field_name has double quotes in it. + (f'$."{field_name}"', user_id.localpart), + ) + + await self.db_pool.runInteraction("delete_profile_field", delete_profile_field) + class ProfileStore(ProfileWorkerStore): pass