summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/__init__.py3
-rw-r--r--synapse/storage/profile.py170
-rw-r--r--synapse/storage/schema/delta/38/profile.py104
3 files changed, 249 insertions, 28 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py

index 9996f195a0..f9d51aed4d 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py
@@ -116,6 +116,9 @@ class DataStore(RoomMemberStore, RoomStore, self._public_room_id_gen = StreamIdGenerator( db_conn, "public_room_list_stream", "stream_id" ) + self._profiles_id_gen = StreamIdGenerator( + db_conn, "profiles_extended", "stream_id" + ) self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id") self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id") diff --git a/synapse/storage/profile.py b/synapse/storage/profile.py
index 26a40905ae..a432b1eaab 100644 --- a/synapse/storage/profile.py +++ b/synapse/storage/profile.py
@@ -13,45 +13,159 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.internet import defer + from ._base import SQLBaseStore +import ujson + class ProfileStore(SQLBaseStore): def create_profile(self, user_localpart): - return self._simple_insert( - table="profiles", - values={"user_id": user_localpart}, - desc="create_profile", + return defer.succeed(None) + + @defer.inlineCallbacks + def get_profile_displayname(self, user_id): + profile = yield self.get_profile_key( + user_id, "default", "m.display_name" ) - def get_profile_displayname(self, user_localpart): - return self._simple_select_one_onecol( - table="profiles", - keyvalues={"user_id": user_localpart}, - retcol="displayname", - desc="get_profile_displayname", + if profile: + try: + display_name = profile["rows"][0]["display_name"] + except (KeyError, IndexError): + display_name = None + else: + display_name = None + + defer.returnValue(display_name) + + def set_profile_displayname(self, user_id, new_displayname): + if new_displayname: + content = {"rows": [{ + "display_name": new_displayname + }]} + else: + # TODO: Delete in this case + content = {} + + return self.update_profile_key( + user_id, "default", "m.display_name", content ) - def set_profile_displayname(self, user_localpart, new_displayname): - return self._simple_update_one( - table="profiles", - keyvalues={"user_id": user_localpart}, - updatevalues={"displayname": new_displayname}, - desc="set_profile_displayname", + @defer.inlineCallbacks + def get_profile_avatar_url(self, user_id): + profile = yield self.get_profile_key( + user_id, "default", "m.avatar_url" ) - def get_profile_avatar_url(self, user_localpart): - return self._simple_select_one_onecol( - table="profiles", - keyvalues={"user_id": user_localpart}, - retcol="avatar_url", - desc="get_profile_avatar_url", + if profile: + try: + avatar_url = profile["rows"][0]["avatar_url"] + except (KeyError, IndexError): + avatar_url = None + else: + avatar_url = None + + defer.returnValue(avatar_url) + + def set_profile_avatar_url(self, user_id, new_avatar_url): + if new_avatar_url: + content = {"rows": [{ + "avatar_url": new_avatar_url + }]} + else: + # TODO: Delete in this case + content = {} + + return self.update_profile_key( + user_id, "default", "m.avatar_url", content ) - def set_profile_avatar_url(self, user_localpart, new_avatar_url): - return self._simple_update_one( - table="profiles", - keyvalues={"user_id": user_localpart}, - updatevalues={"avatar_url": new_avatar_url}, - desc="set_profile_avatar_url", + @defer.inlineCallbacks + def get_full_profile(self, user_id): + rows = yield self._simple_select_list( + table="profiles_extended", + keyvalues={"user_id": user_id}, + retcols=("persona", "key", "content",), ) + + personas = {} + profile = {"personas": personas} + for row in rows: + content = ujson.loads(row["content"]) + personas.setdefault( + row["persona"], {"rows": {}} + )["rows"][row["key"]] = content + + defer.returnValue(profile) + + @defer.inlineCallbacks + def get_persona_profile(self, user_id, persona): + rows = yield self._simple_select_list( + table="profiles_extended", + keyvalues={ + "user_id": user_id, + "persona": persona, + }, + retcols=("key", "content",), + ) + + persona = {"properties": { + row["key"]: ujson.loads(row["content"]) + for row in rows + }} + + defer.returnValue(persona) + + @defer.inlineCallbacks + def get_profile_key(self, user_id, persona, key): + content_json = yield self._simple_select_one_onecol( + table="profiles_extended", + keyvalues={ + "user_id": user_id, + "persona": persona, + "key": key, + }, + retcol="content", + allow_none=True, + ) + + if content_json: + content = ujson.loads(content_json) + else: + content = None + + defer.returnValue(content) + + def update_profile_key(self, user_id, persona, key, content): + content_json = ujson.dumps(content) + + def _update_profile_key_txn(txn, stream_id): + self._simple_delete_txn( + txn, + table="profiles_extended", + keyvalues={ + "user_id": user_id, + "persona": persona, + "key": key, + } + ) + + self._simple_insert_txn( + txn, + table="profiles_extended", + values={ + "stream_id": stream_id, + "user_id": user_id, + "persona": persona, + "key": key, + "content": content_json, + } + ) + + with self._profiles_id_gen.get_next() as stream_id: + return self.runInteraction( + "update_profile_key", _update_profile_key_txn, + stream_id, + ) diff --git a/synapse/storage/schema/delta/38/profile.py b/synapse/storage/schema/delta/38/profile.py new file mode 100644
index 0000000000..bdd014ffba --- /dev/null +++ b/synapse/storage/schema/delta/38/profile.py
@@ -0,0 +1,104 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.storage.prepare_database import get_statements +from synapse.storage.engines import PostgresEngine + +import logging +import ujson + +logger = logging.getLogger(__name__) + +CREATE_TABLE = """ +CREATE TABLE profiles_extended ( + stream_id BIGINT NOT NULL, + user_id TEXT NOT NULL, + persona TEXT NOT NULL, -- Which persona this field is in, e.g. `default` + key TEXT NOT NULL, -- the key of this field, e.g. `m.display_name` + content TEXT NOT NULL -- JSON encoded content of the key +); + +CREATE INDEX profiles_extended_tuple ON profiles_extended( + user_id, persona, key, stream_id +); +""" + +POSTGRES_UPDATE_DISPLAY_NAME = """ +INSERT INTO profiles_extended (stream_id, user_id, persona, key, content) +SELECT + 1, + '@' || user_id || ':' || %s, + 'm.display_name', + '{"rows":["display_name":' || to_json(displayname) || '}]}' +FROM profiles WHERE displayname IS NOT NULL +""" + +POSTGRES_UPDATE_AVATAR_URL = """ +INSERT INTO profiles_extended (stream_id, user_id, persona, key, content) +SELECT + 1, + '@' || user_id || ':' || %s, + 'm.avatar_url', + '{"rows":[{"avatar_url":' || to_json(avatar_url) || '}]}' +FROM profiles WHERE avatar_url IS NOT NULL +""" + + +def run_create(cur, database_engine, *args, **kwargs): + for statement in get_statements(CREATE_TABLE.splitlines()): + cur.execute(statement) + + +def run_upgrade(cur, database_engine, config, *args, **kwargs): + if isinstance(database_engine, PostgresEngine): + cur.execute(POSTGRES_UPDATE_DISPLAY_NAME, (config.server_name,)) + cur.execute(POSTGRES_UPDATE_AVATAR_URL, (config.server_name,)) + else: + cur.execute( + "SELECT user_id, displayname FROM profiles WHERE displayname IS NOT NULL" + ) + displaynames = [] + for user_id, displayname in cur.fetchall(): + displaynames.append(( + 1, + "@%s:%s" % (user_id, config.server_name), + "default", + "m.display_name", + ujson.dumps({"rows": [{"display_name": displayname}]}), + )) + cur.executemany( + "INSERT INTO profiles_extended" + " (stream_id, user_id, persona, key, content)" + " VALUES (?,?,?,?,?)", + displaynames + ) + + cur.execute( + "SELECT user_id, avatar_url FROM profiles WHERE avatar_url IS NOT NULL" + ) + avatar_urls = [] + for user_id, avatar_url in cur.fetchall(): + avatar_urls.append(( + 1, + "@%s:%s" % (user_id, config.server_name), + "default", + "m.avatar_url", + ujson.dumps({"rows": [{"avatar_url": avatar_url}]}), + )) + cur.executemany( + "INSERT INTO profiles_extended" + " (stream_id, user_id, persona, key, content)" + " VALUES (?,?,?,?,?)", + avatar_urls + )