summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/federation/federation_client.py4
-rw-r--r--synapse/federation/federation_server.py19
-rw-r--r--synapse/federation/transport/client.py10
-rw-r--r--synapse/federation/transport/server.py14
-rw-r--r--synapse/handlers/message.py9
-rw-r--r--synapse/handlers/profile.py46
-rw-r--r--synapse/handlers/room_member.py10
-rw-r--r--synapse/push/mailer.py3
-rw-r--r--synapse/rest/__init__.py2
-rw-r--r--synapse/rest/client/v2_alpha/profiles_extended.py114
-rw-r--r--synapse/storage/__init__.py3
-rw-r--r--synapse/storage/profile.py170
-rw-r--r--synapse/storage/schema/delta/38/profile.py104
-rw-r--r--tests/handlers/test_profile.py11
-rw-r--r--tests/storage/test_profile.py19
15 files changed, 474 insertions, 64 deletions
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 94e76b1978..972c1fa91a 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -878,6 +878,10 @@ class FederationClient(FederationBase):
 
         defer.returnValue(signed_events)
 
+    def get_profile(self, user_id, persona=None, key=None):
+        destination = get_domain_from_id(user_id)
+        return self.transport_layer.get_profile(destination, user_id, persona, key)
+
     def event_from_pdu_json(self, pdu_json, outlier=False):
         event = FrozenEvent(
             pdu_json
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 3fa7b2315c..9287600cc8 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -442,6 +442,25 @@ class FederationServer(FederationBase):
             "events": [ev.get_pdu_json(time_now) for ev in missing_events],
         })
 
+    def on_profile_request(self, user_id, persona, key):
+        """Handle a /profile/ request. Persona and key parameters are optional.
+
+        Args:
+            user_id (str)
+            persona (str): Optional if `key` not also set. Returns only info from
+                the given persona.
+            key (str): Optional. Returns only the given `key`.
+        """
+        if not self.hs.is_mine_id(user_id):
+            raise SynapseError(400, "Not a local user")
+
+        if key is not None:
+            return self.store.get_profile_key(user_id, persona, key)
+        elif persona is not None:
+            return self.store.get_persona_profile(user_id, persona)
+        else:
+            return self.store.get_full_profile(user_id)
+
     @log_function
     def on_openid_userinfo(self, token):
         ts_now_ms = self._clock.time_msec()
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index db45c7826c..51ae3a56fb 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -396,3 +396,13 @@ class TransportLayerClient(object):
         )
 
         defer.returnValue(content)
+
+    def get_profile(self, destination, user_id, persona=None, key=None):
+        if key:
+            path = PREFIX + "/profile/%s/%s/%s" % (user_id, persona, key)
+        elif persona:
+            path = PREFIX + "/profile/%s/%s/" % (user_id, persona)
+        else:
+            path = PREFIX + "/profile/%s/" % (user_id,)
+
+        return self.client.get_json(destination, path)
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index fec337be64..103cfd4427 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -578,6 +578,19 @@ class FederationVersionServlet(BaseFederationServlet):
         }))
 
 
+class FederationProfileServlet(BaseFederationServlet):
+    # This matches all three of:
+    #  - /profile/@foo:bar/
+    #  - /profile/@foo:bar/default/
+    #  - /profile/@foo:bar/default/m.displayname
+    PATH = "/profile/(?P<user_id>[^/]+)/((?P<persona>[^/]+)/(?P<key>[^/]+)?)?$"
+
+    @defer.inlineCallbacks
+    def on_GET(self, origin, content, query, user_id, persona, key):
+        profile = yield self.handler.on_profile_request(user_id, persona, key)
+        defer.returnValue((200, profile))
+
+
 SERVLET_CLASSES = (
     FederationSendServlet,
     FederationPullServlet,
@@ -602,6 +615,7 @@ SERVLET_CLASSES = (
     OpenIdUserInfo,
     PublicRoomList,
     FederationVersionServlet,
+    FederationProfileServlet,
 )
 
 
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index abfa8c65a4..5a20a847ee 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -202,8 +202,13 @@ class MessageHandler(BaseHandler):
                 content = builder.content
 
                 try:
-                    content["displayname"] = yield profile.get_displayname(target)
-                    content["avatar_url"] = yield profile.get_avatar_url(target)
+                    display_name = yield profile.get_displayname(target)
+                    if display_name:
+                        content["displayname"] = display_name
+
+                    avatar_url = yield profile.get_avatar_url(target)
+                    if avatar_url:
+                        content["avatar_url"] = avatar_url
                 except Exception as e:
                     logger.info(
                         "Failed to get profile information for %r: %s",
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 87f74dfb8e..d0adf4e934 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -39,11 +39,11 @@ class ProfileHandler(BaseHandler):
     @defer.inlineCallbacks
     def get_displayname(self, target_user):
         if self.hs.is_mine(target_user):
-            displayname = yield self.store.get_profile_displayname(
-                target_user.localpart
+            display_name = yield self.store.get_profile_displayname(
+                target_user.to_string(),
             )
 
-            defer.returnValue(displayname)
+            defer.returnValue(display_name)
         else:
             try:
                 result = yield self.federation.make_query(
@@ -78,7 +78,7 @@ class ProfileHandler(BaseHandler):
             new_displayname = None
 
         yield self.store.set_profile_displayname(
-            target_user.localpart, new_displayname
+            target_user.to_string(), new_displayname
         )
 
         yield self._update_join_states(requester)
@@ -87,7 +87,7 @@ class ProfileHandler(BaseHandler):
     def get_avatar_url(self, target_user):
         if self.hs.is_mine(target_user):
             avatar_url = yield self.store.get_profile_avatar_url(
-                target_user.localpart
+                target_user.to_string(),
             )
 
             defer.returnValue(avatar_url)
@@ -121,7 +121,7 @@ class ProfileHandler(BaseHandler):
             raise AuthError(400, "Cannot set another user's avatar_url")
 
         yield self.store.set_profile_avatar_url(
-            target_user.localpart, new_avatar_url
+            target_user.to_string(), new_avatar_url
         )
 
         yield self._update_join_states(requester)
@@ -137,13 +137,13 @@ class ProfileHandler(BaseHandler):
         response = {}
 
         if just_field is None or just_field == "displayname":
-            response["displayname"] = yield self.store.get_profile_displayname(
-                user.localpart
+            response["displayname"] = yield self.get_displayname(
+                user
             )
 
         if just_field is None or just_field == "avatar_url":
-            response["avatar_url"] = yield self.store.get_profile_avatar_url(
-                user.localpart
+            response["avatar_url"] = yield self.get_avatar_url(
+                user
             )
 
         defer.returnValue(response)
@@ -180,3 +180,29 @@ class ProfileHandler(BaseHandler):
                     "Failed to update join event for room %s - %s",
                     j.room_id, str(e.message)
                 )
+
+    def get_full_profile_for_user(self, user_id):
+        if self.hs.is_mine_id(user_id):
+            return self.store.get_full_profile(user_id)
+        else:
+            return self.federation.get_profile(user_id)
+
+    def get_persona_profile_for_user(self, user_id, persona):
+        if self.hs.is_mine_id(user_id):
+            return self.store.get_persona_profile(user_id, persona)
+        else:
+            return self.federation.get_profile(user_id, persona)
+
+    def get_profile_key_for_user(self, user_id, persona, key):
+        if self.hs.is_mine_id(user_id):
+            return self.store.get_profile_key(user_id, persona, key)
+        else:
+            return self.federation.get_profile(user_id, persona, key)
+
+    def update_profile_key(self, user_id, persona, key, content):
+        if self.hs.is_mine_id(user_id):
+            return self.store.update_profile_key(
+                user_id, persona, key, content
+            )
+        else:
+            raise AuthError("Cannot set a remote profile")
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index ba49075a20..73dcf326a6 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -245,8 +245,14 @@ class RoomMemberHandler(BaseHandler):
                 content["membership"] = Membership.JOIN
 
                 profile = self.hs.get_handlers().profile_handler
-                content["displayname"] = yield profile.get_displayname(target)
-                content["avatar_url"] = yield profile.get_avatar_url(target)
+
+                display_name = yield profile.get_displayname(target)
+                if display_name:
+                    content["displayname"] = display_name
+
+                avatar_url = yield profile.get_avatar_url(target)
+                if avatar_url:
+                    content["avatar_url"] = avatar_url
 
                 if requester.is_guest:
                     content["kind"] = "guest"
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index 53551632b6..d72f98dee7 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -25,7 +25,6 @@ from synapse.util.async import concurrently_execute
 from synapse.push.presentable_names import (
     calculate_room_name, name_from_member_event, descriptor_from_member_events
 )
-from synapse.types import UserID
 from synapse.api.errors import StoreError
 from synapse.api.constants import EventTypes
 from synapse.visibility import filter_events_for_client
@@ -130,7 +129,7 @@ class Mailer(object):
 
         try:
             user_display_name = yield self.store.get_profile_displayname(
-                UserID.from_string(user_id).localpart
+                user_id
             )
             if user_display_name is None:
                 user_display_name = user_id
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index f9f5a3e077..d6fc666195 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -50,6 +50,7 @@ from synapse.rest.client.v2_alpha import (
     devices,
     thirdparty,
     sendtodevice,
+    profiles_extended,
 )
 
 from synapse.http.server import JsonResource
@@ -98,3 +99,4 @@ class ClientRestResource(JsonResource):
         devices.register_servlets(hs, client_resource)
         thirdparty.register_servlets(hs, client_resource)
         sendtodevice.register_servlets(hs, client_resource)
+        profiles_extended.register_servlets(hs, client_resource)
diff --git a/synapse/rest/client/v2_alpha/profiles_extended.py b/synapse/rest/client/v2_alpha/profiles_extended.py
new file mode 100644
index 0000000000..a0520d33e0
--- /dev/null
+++ b/synapse/rest/client/v2_alpha/profiles_extended.py
@@ -0,0 +1,114 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from ._base import client_v2_patterns
+
+from synapse.api.errors import NotFoundError
+from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from twisted.internet import defer
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class FullProfileServlet(RestServlet):
+    PATTERNS = client_v2_patterns(
+        "/profile_extended/(?P<user_id>[^/]+)/$"
+    )
+
+    EXPIRES_MS = 3600 * 1000
+
+    def __init__(self, hs):
+        super(FullProfileServlet, self).__init__()
+        self.auth = hs.get_auth()
+        self.profile_handler = hs.get_handlers().profile_handler
+
+    @defer.inlineCallbacks
+    def on_GET(self, request, user_id):
+        yield self.auth.get_user_by_req(request)
+
+        profile = yield self.profile_handler.get_full_profile_for_user(user_id)
+
+        defer.returnValue((200, profile))
+
+
+class ProfilePersonaServlet(RestServlet):
+    PATTERNS = client_v2_patterns(
+        "/profile_extended/(?P<user_id>[^/]+)/(?P<persona>[^/]+)/$"
+    )
+
+    EXPIRES_MS = 3600 * 1000
+
+    def __init__(self, hs):
+        super(ProfilePersonaServlet, self).__init__()
+        self.auth = hs.get_auth()
+        self.profile_handler = hs.get_handlers().profile_handler
+
+    @defer.inlineCallbacks
+    def on_GET(self, request, user_id, persona):
+        yield self.auth.get_user_by_req(request)
+
+        profile = yield self.profile_handler.get_persona_profile_for_user(
+            user_id, persona
+        )
+
+        if profile:
+            defer.returnValue((200, profile))
+        else:
+            raise NotFoundError()
+
+
+class ProfileTupleServlet(RestServlet):
+    PATTERNS = client_v2_patterns(
+        "/profile_extended/(?P<user_id>[^/]+)/(?P<persona>[^/]+)/(?P<key>[^/]+)$"
+    )
+
+    EXPIRES_MS = 3600 * 1000
+
+    def __init__(self, hs):
+        super(ProfileTupleServlet, self).__init__()
+        self.auth = hs.get_auth()
+        self.profile_handler = hs.get_handlers().profile_handler
+
+    @defer.inlineCallbacks
+    def on_GET(self, request, user_id, persona, key):
+        yield self.auth.get_user_by_req(request)
+
+        profile = yield self.profile_handler.get_profile_key_for_user(
+            user_id, persona, key
+        )
+
+        if profile is not None:
+            defer.returnValue((200, profile))
+        else:
+            raise NotFoundError()
+
+    @defer.inlineCallbacks
+    def on_PUT(self, request, user_id, persona, key):
+        yield self.auth.get_user_by_req(request)
+
+        content = parse_json_object_from_request(request)
+
+        yield self.profile_handler.update_profile_key(user_id, persona, key, content)
+
+        defer.returnValue((200, {}))
+
+
+def register_servlets(hs, http_server):
+    FullProfileServlet(hs).register(http_server)
+    ProfileTupleServlet(hs).register(http_server)
+    ProfilePersonaServlet(hs).register(http_server)
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 9996f195a0..f9d51aed4d 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -116,6 +116,9 @@ class DataStore(RoomMemberStore, RoomStore,
         self._public_room_id_gen = StreamIdGenerator(
             db_conn, "public_room_list_stream", "stream_id"
         )
+        self._profiles_id_gen = StreamIdGenerator(
+            db_conn, "profiles_extended", "stream_id"
+        )
 
         self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
         self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
diff --git a/synapse/storage/profile.py b/synapse/storage/profile.py
index 26a40905ae..a432b1eaab 100644
--- a/synapse/storage/profile.py
+++ b/synapse/storage/profile.py
@@ -13,45 +13,159 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from twisted.internet import defer
+
 from ._base import SQLBaseStore
 
+import ujson
+
 
 class ProfileStore(SQLBaseStore):
     def create_profile(self, user_localpart):
-        return self._simple_insert(
-            table="profiles",
-            values={"user_id": user_localpart},
-            desc="create_profile",
+        return defer.succeed(None)
+
+    @defer.inlineCallbacks
+    def get_profile_displayname(self, user_id):
+        profile = yield self.get_profile_key(
+            user_id, "default", "m.display_name"
         )
 
-    def get_profile_displayname(self, user_localpart):
-        return self._simple_select_one_onecol(
-            table="profiles",
-            keyvalues={"user_id": user_localpart},
-            retcol="displayname",
-            desc="get_profile_displayname",
+        if profile:
+            try:
+                display_name = profile["rows"][0]["display_name"]
+            except (KeyError, IndexError):
+                display_name = None
+        else:
+            display_name = None
+
+        defer.returnValue(display_name)
+
+    def set_profile_displayname(self, user_id, new_displayname):
+        if new_displayname:
+            content = {"rows": [{
+                "display_name": new_displayname
+            }]}
+        else:
+            # TODO: Delete in this case
+            content = {}
+
+        return self.update_profile_key(
+            user_id, "default", "m.display_name", content
         )
 
-    def set_profile_displayname(self, user_localpart, new_displayname):
-        return self._simple_update_one(
-            table="profiles",
-            keyvalues={"user_id": user_localpart},
-            updatevalues={"displayname": new_displayname},
-            desc="set_profile_displayname",
+    @defer.inlineCallbacks
+    def get_profile_avatar_url(self, user_id):
+        profile = yield self.get_profile_key(
+            user_id, "default", "m.avatar_url"
         )
 
-    def get_profile_avatar_url(self, user_localpart):
-        return self._simple_select_one_onecol(
-            table="profiles",
-            keyvalues={"user_id": user_localpart},
-            retcol="avatar_url",
-            desc="get_profile_avatar_url",
+        if profile:
+            try:
+                avatar_url = profile["rows"][0]["avatar_url"]
+            except (KeyError, IndexError):
+                avatar_url = None
+        else:
+            avatar_url = None
+
+        defer.returnValue(avatar_url)
+
+    def set_profile_avatar_url(self, user_id, new_avatar_url):
+        if new_avatar_url:
+            content = {"rows": [{
+                "avatar_url": new_avatar_url
+            }]}
+        else:
+            # TODO: Delete in this case
+            content = {}
+
+        return self.update_profile_key(
+            user_id, "default", "m.avatar_url", content
         )
 
-    def set_profile_avatar_url(self, user_localpart, new_avatar_url):
-        return self._simple_update_one(
-            table="profiles",
-            keyvalues={"user_id": user_localpart},
-            updatevalues={"avatar_url": new_avatar_url},
-            desc="set_profile_avatar_url",
+    @defer.inlineCallbacks
+    def get_full_profile(self, user_id):
+        rows = yield self._simple_select_list(
+            table="profiles_extended",
+            keyvalues={"user_id": user_id},
+            retcols=("persona", "key", "content",),
         )
+
+        personas = {}
+        profile = {"personas": personas}
+        for row in rows:
+            content = ujson.loads(row["content"])
+            personas.setdefault(
+                row["persona"], {"rows": {}}
+            )["rows"][row["key"]] = content
+
+        defer.returnValue(profile)
+
+    @defer.inlineCallbacks
+    def get_persona_profile(self, user_id, persona):
+        rows = yield self._simple_select_list(
+            table="profiles_extended",
+            keyvalues={
+                "user_id": user_id,
+                "persona": persona,
+            },
+            retcols=("key", "content",),
+        )
+
+        persona = {"properties": {
+            row["key"]: ujson.loads(row["content"])
+            for row in rows
+        }}
+
+        defer.returnValue(persona)
+
+    @defer.inlineCallbacks
+    def get_profile_key(self, user_id, persona, key):
+        content_json = yield self._simple_select_one_onecol(
+            table="profiles_extended",
+            keyvalues={
+                "user_id": user_id,
+                "persona": persona,
+                "key": key,
+            },
+            retcol="content",
+            allow_none=True,
+        )
+
+        if content_json:
+            content = ujson.loads(content_json)
+        else:
+            content = None
+
+        defer.returnValue(content)
+
+    def update_profile_key(self, user_id, persona, key, content):
+        content_json = ujson.dumps(content)
+
+        def _update_profile_key_txn(txn, stream_id):
+            self._simple_delete_txn(
+                txn,
+                table="profiles_extended",
+                keyvalues={
+                    "user_id": user_id,
+                    "persona": persona,
+                    "key": key,
+                }
+            )
+
+            self._simple_insert_txn(
+                txn,
+                table="profiles_extended",
+                values={
+                    "stream_id": stream_id,
+                    "user_id": user_id,
+                    "persona": persona,
+                    "key": key,
+                    "content": content_json,
+                }
+            )
+
+        with self._profiles_id_gen.get_next() as stream_id:
+            return self.runInteraction(
+                "update_profile_key", _update_profile_key_txn,
+                stream_id,
+            )
diff --git a/synapse/storage/schema/delta/38/profile.py b/synapse/storage/schema/delta/38/profile.py
new file mode 100644
index 0000000000..bdd014ffba
--- /dev/null
+++ b/synapse/storage/schema/delta/38/profile.py
@@ -0,0 +1,104 @@
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.storage.prepare_database import get_statements
+from synapse.storage.engines import PostgresEngine
+
+import logging
+import ujson
+
+logger = logging.getLogger(__name__)
+
+CREATE_TABLE = """
+CREATE TABLE profiles_extended (
+    stream_id BIGINT NOT NULL,
+    user_id TEXT NOT NULL,
+    persona TEXT NOT NULL,  -- Which persona this field is in, e.g. `default`
+    key TEXT NOT NULL,  -- the key of this field, e.g. `m.display_name`
+    content TEXT NOT NULL  -- JSON encoded content of the key
+);
+
+CREATE INDEX profiles_extended_tuple ON profiles_extended(
+    user_id, persona, key, stream_id
+);
+"""
+
+POSTGRES_UPDATE_DISPLAY_NAME = """
+INSERT INTO profiles_extended (stream_id, user_id, persona, key, content)
+SELECT
+    1,
+    '@' || user_id || ':' || %s,
+    'm.display_name',
+    '{"rows":["display_name":' || to_json(displayname) || '}]}'
+FROM profiles WHERE displayname IS NOT NULL
+"""
+
+POSTGRES_UPDATE_AVATAR_URL = """
+INSERT INTO profiles_extended (stream_id, user_id, persona, key, content)
+SELECT
+    1,
+    '@' || user_id || ':' || %s,
+    'm.avatar_url',
+    '{"rows":[{"avatar_url":' || to_json(avatar_url) || '}]}'
+FROM profiles WHERE avatar_url IS NOT NULL
+"""
+
+
+def run_create(cur, database_engine, *args, **kwargs):
+    for statement in get_statements(CREATE_TABLE.splitlines()):
+        cur.execute(statement)
+
+
+def run_upgrade(cur, database_engine, config, *args, **kwargs):
+    if isinstance(database_engine, PostgresEngine):
+        cur.execute(POSTGRES_UPDATE_DISPLAY_NAME, (config.server_name,))
+        cur.execute(POSTGRES_UPDATE_AVATAR_URL, (config.server_name,))
+    else:
+        cur.execute(
+            "SELECT user_id, displayname FROM profiles WHERE displayname IS NOT NULL"
+        )
+        displaynames = []
+        for user_id, displayname in cur.fetchall():
+            displaynames.append((
+                1,
+                "@%s:%s" % (user_id, config.server_name),
+                "default",
+                "m.display_name",
+                ujson.dumps({"rows": [{"display_name": displayname}]}),
+            ))
+        cur.executemany(
+            "INSERT INTO profiles_extended"
+            " (stream_id, user_id, persona, key, content)"
+            " VALUES (?,?,?,?,?)",
+            displaynames
+        )
+
+        cur.execute(
+            "SELECT user_id, avatar_url FROM profiles WHERE avatar_url IS NOT NULL"
+        )
+        avatar_urls = []
+        for user_id, avatar_url in cur.fetchall():
+            avatar_urls.append((
+                1,
+                "@%s:%s" % (user_id, config.server_name),
+                "default",
+                "m.avatar_url",
+                ujson.dumps({"rows": [{"avatar_url": avatar_url}]}),
+            ))
+        cur.executemany(
+            "INSERT INTO profiles_extended"
+            " (stream_id, user_id, persona, key, content)"
+            " VALUES (?,?,?,?,?)",
+            avatar_urls
+        )
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index f1f664275f..3fd3ed323f 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -76,7 +76,7 @@ class ProfileTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_get_my_name(self):
         yield self.store.set_profile_displayname(
-            self.frank.localpart, "Frank"
+            self.frank.to_string(), "Frank"
         )
 
         displayname = yield self.handler.get_displayname(self.frank)
@@ -92,7 +92,7 @@ class ProfileTestCase(unittest.TestCase):
         )
 
         self.assertEquals(
-            (yield self.store.get_profile_displayname(self.frank.localpart)),
+            (yield self.store.get_profile_displayname(self.frank.to_string())),
             "Frank Jr."
         )
 
@@ -123,8 +123,7 @@ class ProfileTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_incoming_fed_query(self):
-        yield self.store.create_profile("caroline")
-        yield self.store.set_profile_displayname("caroline", "Caroline")
+        yield self.store.set_profile_displayname("@caroline:test", "Caroline")
 
         response = yield self.query_handlers["profile"](
             {"user_id": "@caroline:test", "field": "displayname"}
@@ -135,7 +134,7 @@ class ProfileTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_get_my_avatar(self):
         yield self.store.set_profile_avatar_url(
-            self.frank.localpart, "http://my.server/me.png"
+            self.frank.to_string(), "http://my.server/me.png"
         )
 
         avatar_url = yield self.handler.get_avatar_url(self.frank)
@@ -150,6 +149,6 @@ class ProfileTestCase(unittest.TestCase):
         )
 
         self.assertEquals(
-            (yield self.store.get_profile_avatar_url(self.frank.localpart)),
+            (yield self.store.get_profile_avatar_url(self.frank.to_string())),
             "http://my.server/pic.gif"
         )
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index 24118bbc86..04603a9e9a 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -17,7 +17,6 @@
 from tests import unittest
 from twisted.internet import defer
 
-from synapse.storage.profile import ProfileStore
 from synapse.types import UserID
 
 from tests.utils import setup_test_homeserver
@@ -29,36 +28,28 @@ class ProfileStoreTestCase(unittest.TestCase):
     def setUp(self):
         hs = yield setup_test_homeserver()
 
-        self.store = ProfileStore(hs)
+        self.store = hs.get_datastore()
 
         self.u_frank = UserID.from_string("@frank:test")
 
     @defer.inlineCallbacks
     def test_displayname(self):
-        yield self.store.create_profile(
-            self.u_frank.localpart
-        )
-
         yield self.store.set_profile_displayname(
-            self.u_frank.localpart, "Frank"
+            self.u_frank.to_string(), "Frank"
         )
 
         self.assertEquals(
             "Frank",
-            (yield self.store.get_profile_displayname(self.u_frank.localpart))
+            (yield self.store.get_profile_displayname(self.u_frank.to_string()))
         )
 
     @defer.inlineCallbacks
     def test_avatar_url(self):
-        yield self.store.create_profile(
-            self.u_frank.localpart
-        )
-
         yield self.store.set_profile_avatar_url(
-            self.u_frank.localpart, "http://my.site/here"
+            self.u_frank.to_string(), "http://my.site/here"
         )
 
         self.assertEquals(
             "http://my.site/here",
-            (yield self.store.get_profile_avatar_url(self.u_frank.localpart))
+            (yield self.store.get_profile_avatar_url(self.u_frank.to_string()))
         )