diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 9996f195a0..f9d51aed4d 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -116,6 +116,9 @@ class DataStore(RoomMemberStore, RoomStore,
self._public_room_id_gen = StreamIdGenerator(
db_conn, "public_room_list_stream", "stream_id"
)
+ self._profiles_id_gen = StreamIdGenerator(
+ db_conn, "profiles_extended", "stream_id"
+ )
self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
diff --git a/synapse/storage/profile.py b/synapse/storage/profile.py
index 26a40905ae..a432b1eaab 100644
--- a/synapse/storage/profile.py
+++ b/synapse/storage/profile.py
@@ -13,45 +13,159 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.internet import defer
+
from ._base import SQLBaseStore
+import ujson
+
class ProfileStore(SQLBaseStore):
def create_profile(self, user_localpart):
- return self._simple_insert(
- table="profiles",
- values={"user_id": user_localpart},
- desc="create_profile",
+ return defer.succeed(None)
+
+ @defer.inlineCallbacks
+ def get_profile_displayname(self, user_id):
+ profile = yield self.get_profile_key(
+ user_id, "default", "m.display_name"
)
- def get_profile_displayname(self, user_localpart):
- return self._simple_select_one_onecol(
- table="profiles",
- keyvalues={"user_id": user_localpart},
- retcol="displayname",
- desc="get_profile_displayname",
+ if profile:
+ try:
+ display_name = profile["rows"][0]["display_name"]
+ except (KeyError, IndexError):
+ display_name = None
+ else:
+ display_name = None
+
+ defer.returnValue(display_name)
+
+ def set_profile_displayname(self, user_id, new_displayname):
+ if new_displayname:
+ content = {"rows": [{
+ "display_name": new_displayname
+ }]}
+ else:
+ # TODO: Delete in this case
+ content = {}
+
+ return self.update_profile_key(
+ user_id, "default", "m.display_name", content
)
- def set_profile_displayname(self, user_localpart, new_displayname):
- return self._simple_update_one(
- table="profiles",
- keyvalues={"user_id": user_localpart},
- updatevalues={"displayname": new_displayname},
- desc="set_profile_displayname",
+ @defer.inlineCallbacks
+ def get_profile_avatar_url(self, user_id):
+ profile = yield self.get_profile_key(
+ user_id, "default", "m.avatar_url"
)
- def get_profile_avatar_url(self, user_localpart):
- return self._simple_select_one_onecol(
- table="profiles",
- keyvalues={"user_id": user_localpart},
- retcol="avatar_url",
- desc="get_profile_avatar_url",
+ if profile:
+ try:
+ avatar_url = profile["rows"][0]["avatar_url"]
+ except (KeyError, IndexError):
+ avatar_url = None
+ else:
+ avatar_url = None
+
+ defer.returnValue(avatar_url)
+
+ def set_profile_avatar_url(self, user_id, new_avatar_url):
+ if new_avatar_url:
+ content = {"rows": [{
+ "avatar_url": new_avatar_url
+ }]}
+ else:
+ # TODO: Delete in this case
+ content = {}
+
+ return self.update_profile_key(
+ user_id, "default", "m.avatar_url", content
)
- def set_profile_avatar_url(self, user_localpart, new_avatar_url):
- return self._simple_update_one(
- table="profiles",
- keyvalues={"user_id": user_localpart},
- updatevalues={"avatar_url": new_avatar_url},
- desc="set_profile_avatar_url",
+ @defer.inlineCallbacks
+ def get_full_profile(self, user_id):
+ rows = yield self._simple_select_list(
+ table="profiles_extended",
+ keyvalues={"user_id": user_id},
+ retcols=("persona", "key", "content",),
)
+
+ personas = {}
+ profile = {"personas": personas}
+ for row in rows:
+ content = ujson.loads(row["content"])
+ personas.setdefault(
+ row["persona"], {"rows": {}}
+ )["rows"][row["key"]] = content
+
+ defer.returnValue(profile)
+
+ @defer.inlineCallbacks
+ def get_persona_profile(self, user_id, persona):
+ rows = yield self._simple_select_list(
+ table="profiles_extended",
+ keyvalues={
+ "user_id": user_id,
+ "persona": persona,
+ },
+ retcols=("key", "content",),
+ )
+
+ persona = {"properties": {
+ row["key"]: ujson.loads(row["content"])
+ for row in rows
+ }}
+
+ defer.returnValue(persona)
+
+ @defer.inlineCallbacks
+ def get_profile_key(self, user_id, persona, key):
+ content_json = yield self._simple_select_one_onecol(
+ table="profiles_extended",
+ keyvalues={
+ "user_id": user_id,
+ "persona": persona,
+ "key": key,
+ },
+ retcol="content",
+ allow_none=True,
+ )
+
+ if content_json:
+ content = ujson.loads(content_json)
+ else:
+ content = None
+
+ defer.returnValue(content)
+
+ def update_profile_key(self, user_id, persona, key, content):
+ content_json = ujson.dumps(content)
+
+ def _update_profile_key_txn(txn, stream_id):
+ self._simple_delete_txn(
+ txn,
+ table="profiles_extended",
+ keyvalues={
+ "user_id": user_id,
+ "persona": persona,
+ "key": key,
+ }
+ )
+
+ self._simple_insert_txn(
+ txn,
+ table="profiles_extended",
+ values={
+ "stream_id": stream_id,
+ "user_id": user_id,
+ "persona": persona,
+ "key": key,
+ "content": content_json,
+ }
+ )
+
+ with self._profiles_id_gen.get_next() as stream_id:
+ return self.runInteraction(
+ "update_profile_key", _update_profile_key_txn,
+ stream_id,
+ )
diff --git a/synapse/storage/schema/delta/38/profile.py b/synapse/storage/schema/delta/38/profile.py
new file mode 100644
index 0000000000..bdd014ffba
--- /dev/null
+++ b/synapse/storage/schema/delta/38/profile.py
@@ -0,0 +1,104 @@
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.storage.prepare_database import get_statements
+from synapse.storage.engines import PostgresEngine
+
+import logging
+import ujson
+
+logger = logging.getLogger(__name__)
+
+CREATE_TABLE = """
+CREATE TABLE profiles_extended (
+ stream_id BIGINT NOT NULL,
+ user_id TEXT NOT NULL,
+ persona TEXT NOT NULL, -- Which persona this field is in, e.g. `default`
+ key TEXT NOT NULL, -- the key of this field, e.g. `m.display_name`
+ content TEXT NOT NULL -- JSON encoded content of the key
+);
+
+CREATE INDEX profiles_extended_tuple ON profiles_extended(
+ user_id, persona, key, stream_id
+);
+"""
+
+POSTGRES_UPDATE_DISPLAY_NAME = """
+INSERT INTO profiles_extended (stream_id, user_id, persona, key, content)
+SELECT
+ 1,
+ '@' || user_id || ':' || %s,
+ 'm.display_name',
+ '{"rows":["display_name":' || to_json(displayname) || '}]}'
+FROM profiles WHERE displayname IS NOT NULL
+"""
+
+POSTGRES_UPDATE_AVATAR_URL = """
+INSERT INTO profiles_extended (stream_id, user_id, persona, key, content)
+SELECT
+ 1,
+ '@' || user_id || ':' || %s,
+ 'm.avatar_url',
+ '{"rows":[{"avatar_url":' || to_json(avatar_url) || '}]}'
+FROM profiles WHERE avatar_url IS NOT NULL
+"""
+
+
+def run_create(cur, database_engine, *args, **kwargs):
+ for statement in get_statements(CREATE_TABLE.splitlines()):
+ cur.execute(statement)
+
+
+def run_upgrade(cur, database_engine, config, *args, **kwargs):
+ if isinstance(database_engine, PostgresEngine):
+ cur.execute(POSTGRES_UPDATE_DISPLAY_NAME, (config.server_name,))
+ cur.execute(POSTGRES_UPDATE_AVATAR_URL, (config.server_name,))
+ else:
+ cur.execute(
+ "SELECT user_id, displayname FROM profiles WHERE displayname IS NOT NULL"
+ )
+ displaynames = []
+ for user_id, displayname in cur.fetchall():
+ displaynames.append((
+ 1,
+ "@%s:%s" % (user_id, config.server_name),
+ "default",
+ "m.display_name",
+ ujson.dumps({"rows": [{"display_name": displayname}]}),
+ ))
+ cur.executemany(
+ "INSERT INTO profiles_extended"
+ " (stream_id, user_id, persona, key, content)"
+ " VALUES (?,?,?,?,?)",
+ displaynames
+ )
+
+ cur.execute(
+ "SELECT user_id, avatar_url FROM profiles WHERE avatar_url IS NOT NULL"
+ )
+ avatar_urls = []
+ for user_id, avatar_url in cur.fetchall():
+ avatar_urls.append((
+ 1,
+ "@%s:%s" % (user_id, config.server_name),
+ "default",
+ "m.avatar_url",
+ ujson.dumps({"rows": [{"avatar_url": avatar_url}]}),
+ ))
+ cur.executemany(
+ "INSERT INTO profiles_extended"
+ " (stream_id, user_id, persona, key, content)"
+ " VALUES (?,?,?,?,?)",
+ avatar_urls
+ )
|