diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 94e76b1978..972c1fa91a 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -878,6 +878,10 @@ class FederationClient(FederationBase):
defer.returnValue(signed_events)
+ def get_profile(self, user_id, persona=None, key=None):
+ destination = get_domain_from_id(user_id)
+ return self.transport_layer.get_profile(destination, user_id, persona, key)
+
def event_from_pdu_json(self, pdu_json, outlier=False):
event = FrozenEvent(
pdu_json
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 3fa7b2315c..9287600cc8 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -442,6 +442,25 @@ class FederationServer(FederationBase):
"events": [ev.get_pdu_json(time_now) for ev in missing_events],
})
+ def on_profile_request(self, user_id, persona, key):
+ """Handle a /profile/ request. Persona and key parameters are optional.
+
+ Args:
+ user_id (str)
+ persona (str): Optional if `key` not also set. Returns only info from
+ the given persona.
+ key (str): Optional. Returns only the given `key`.
+ """
+ if not self.hs.is_mine_id(user_id):
+ raise SynapseError(400, "Not a local user")
+
+ if key is not None:
+ return self.store.get_profile_key(user_id, persona, key)
+ elif persona is not None:
+ return self.store.get_persona_profile(user_id, persona)
+ else:
+ return self.store.get_full_profile(user_id)
+
@log_function
def on_openid_userinfo(self, token):
ts_now_ms = self._clock.time_msec()
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index db45c7826c..51ae3a56fb 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -396,3 +396,13 @@ class TransportLayerClient(object):
)
defer.returnValue(content)
+
+ def get_profile(self, destination, user_id, persona=None, key=None):
+ if key:
+ path = PREFIX + "/profile/%s/%s/%s" % (user_id, persona, key)
+ elif persona:
+ path = PREFIX + "/profile/%s/%s/" % (user_id, persona)
+ else:
+ path = PREFIX + "/profile/%s/" % (user_id,)
+
+ return self.client.get_json(destination, path)
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index fec337be64..103cfd4427 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -578,6 +578,19 @@ class FederationVersionServlet(BaseFederationServlet):
}))
+class FederationProfileServlet(BaseFederationServlet):
+ # This matches all three of:
+ # - /profile/@foo:bar/
+ # - /profile/@foo:bar/default/
+ # - /profile/@foo:bar/default/m.displayname
+ PATH = "/profile/(?P<user_id>[^/]+)/((?P<persona>[^/]+)/(?P<key>[^/]+)?)?$"
+
+ @defer.inlineCallbacks
+ def on_GET(self, origin, content, query, user_id, persona, key):
+ profile = yield self.handler.on_profile_request(user_id, persona, key)
+ defer.returnValue((200, profile))
+
+
SERVLET_CLASSES = (
FederationSendServlet,
FederationPullServlet,
@@ -602,6 +615,7 @@ SERVLET_CLASSES = (
OpenIdUserInfo,
PublicRoomList,
FederationVersionServlet,
+ FederationProfileServlet,
)
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index abfa8c65a4..5a20a847ee 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -202,8 +202,13 @@ class MessageHandler(BaseHandler):
content = builder.content
try:
- content["displayname"] = yield profile.get_displayname(target)
- content["avatar_url"] = yield profile.get_avatar_url(target)
+ display_name = yield profile.get_displayname(target)
+ if display_name:
+ content["displayname"] = display_name
+
+ avatar_url = yield profile.get_avatar_url(target)
+ if avatar_url:
+ content["avatar_url"] = avatar_url
except Exception as e:
logger.info(
"Failed to get profile information for %r: %s",
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 87f74dfb8e..d0adf4e934 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -39,11 +39,11 @@ class ProfileHandler(BaseHandler):
@defer.inlineCallbacks
def get_displayname(self, target_user):
if self.hs.is_mine(target_user):
- displayname = yield self.store.get_profile_displayname(
- target_user.localpart
+ display_name = yield self.store.get_profile_displayname(
+ target_user.to_string(),
)
- defer.returnValue(displayname)
+ defer.returnValue(display_name)
else:
try:
result = yield self.federation.make_query(
@@ -78,7 +78,7 @@ class ProfileHandler(BaseHandler):
new_displayname = None
yield self.store.set_profile_displayname(
- target_user.localpart, new_displayname
+ target_user.to_string(), new_displayname
)
yield self._update_join_states(requester)
@@ -87,7 +87,7 @@ class ProfileHandler(BaseHandler):
def get_avatar_url(self, target_user):
if self.hs.is_mine(target_user):
avatar_url = yield self.store.get_profile_avatar_url(
- target_user.localpart
+ target_user.to_string(),
)
defer.returnValue(avatar_url)
@@ -121,7 +121,7 @@ class ProfileHandler(BaseHandler):
raise AuthError(400, "Cannot set another user's avatar_url")
yield self.store.set_profile_avatar_url(
- target_user.localpart, new_avatar_url
+ target_user.to_string(), new_avatar_url
)
yield self._update_join_states(requester)
@@ -137,13 +137,13 @@ class ProfileHandler(BaseHandler):
response = {}
if just_field is None or just_field == "displayname":
- response["displayname"] = yield self.store.get_profile_displayname(
- user.localpart
+ response["displayname"] = yield self.get_displayname(
+ user
)
if just_field is None or just_field == "avatar_url":
- response["avatar_url"] = yield self.store.get_profile_avatar_url(
- user.localpart
+ response["avatar_url"] = yield self.get_avatar_url(
+ user
)
defer.returnValue(response)
@@ -180,3 +180,29 @@ class ProfileHandler(BaseHandler):
"Failed to update join event for room %s - %s",
j.room_id, str(e.message)
)
+
+ def get_full_profile_for_user(self, user_id):
+ if self.hs.is_mine_id(user_id):
+ return self.store.get_full_profile(user_id)
+ else:
+ return self.federation.get_profile(user_id)
+
+ def get_persona_profile_for_user(self, user_id, persona):
+ if self.hs.is_mine_id(user_id):
+ return self.store.get_persona_profile(user_id, persona)
+ else:
+ return self.federation.get_profile(user_id, persona)
+
+ def get_profile_key_for_user(self, user_id, persona, key):
+ if self.hs.is_mine_id(user_id):
+ return self.store.get_profile_key(user_id, persona, key)
+ else:
+ return self.federation.get_profile(user_id, persona, key)
+
+ def update_profile_key(self, user_id, persona, key, content):
+ if self.hs.is_mine_id(user_id):
+ return self.store.update_profile_key(
+ user_id, persona, key, content
+ )
+ else:
+ raise AuthError("Cannot set a remote profile")
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index ba49075a20..73dcf326a6 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -245,8 +245,14 @@ class RoomMemberHandler(BaseHandler):
content["membership"] = Membership.JOIN
profile = self.hs.get_handlers().profile_handler
- content["displayname"] = yield profile.get_displayname(target)
- content["avatar_url"] = yield profile.get_avatar_url(target)
+
+ display_name = yield profile.get_displayname(target)
+ if display_name:
+ content["displayname"] = display_name
+
+ avatar_url = yield profile.get_avatar_url(target)
+ if avatar_url:
+ content["avatar_url"] = avatar_url
if requester.is_guest:
content["kind"] = "guest"
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index 53551632b6..d72f98dee7 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -25,7 +25,6 @@ from synapse.util.async import concurrently_execute
from synapse.push.presentable_names import (
calculate_room_name, name_from_member_event, descriptor_from_member_events
)
-from synapse.types import UserID
from synapse.api.errors import StoreError
from synapse.api.constants import EventTypes
from synapse.visibility import filter_events_for_client
@@ -130,7 +129,7 @@ class Mailer(object):
try:
user_display_name = yield self.store.get_profile_displayname(
- UserID.from_string(user_id).localpart
+ user_id
)
if user_display_name is None:
user_display_name = user_id
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index f9f5a3e077..d6fc666195 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -50,6 +50,7 @@ from synapse.rest.client.v2_alpha import (
devices,
thirdparty,
sendtodevice,
+ profiles_extended,
)
from synapse.http.server import JsonResource
@@ -98,3 +99,4 @@ class ClientRestResource(JsonResource):
devices.register_servlets(hs, client_resource)
thirdparty.register_servlets(hs, client_resource)
sendtodevice.register_servlets(hs, client_resource)
+ profiles_extended.register_servlets(hs, client_resource)
diff --git a/synapse/rest/client/v2_alpha/profiles_extended.py b/synapse/rest/client/v2_alpha/profiles_extended.py
new file mode 100644
index 0000000000..a0520d33e0
--- /dev/null
+++ b/synapse/rest/client/v2_alpha/profiles_extended.py
@@ -0,0 +1,114 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from ._base import client_v2_patterns
+
+from synapse.api.errors import NotFoundError
+from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from twisted.internet import defer
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class FullProfileServlet(RestServlet):
+ PATTERNS = client_v2_patterns(
+ "/profile_extended/(?P<user_id>[^/]+)/$"
+ )
+
+ EXPIRES_MS = 3600 * 1000
+
+ def __init__(self, hs):
+ super(FullProfileServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.profile_handler = hs.get_handlers().profile_handler
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, user_id):
+ yield self.auth.get_user_by_req(request)
+
+ profile = yield self.profile_handler.get_full_profile_for_user(user_id)
+
+ defer.returnValue((200, profile))
+
+
+class ProfilePersonaServlet(RestServlet):
+ PATTERNS = client_v2_patterns(
+ "/profile_extended/(?P<user_id>[^/]+)/(?P<persona>[^/]+)/$"
+ )
+
+ EXPIRES_MS = 3600 * 1000
+
+ def __init__(self, hs):
+ super(ProfilePersonaServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.profile_handler = hs.get_handlers().profile_handler
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, user_id, persona):
+ yield self.auth.get_user_by_req(request)
+
+ profile = yield self.profile_handler.get_persona_profile_for_user(
+ user_id, persona
+ )
+
+ if profile:
+ defer.returnValue((200, profile))
+ else:
+ raise NotFoundError()
+
+
+class ProfileTupleServlet(RestServlet):
+ PATTERNS = client_v2_patterns(
+ "/profile_extended/(?P<user_id>[^/]+)/(?P<persona>[^/]+)/(?P<key>[^/]+)$"
+ )
+
+ EXPIRES_MS = 3600 * 1000
+
+ def __init__(self, hs):
+ super(ProfileTupleServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.profile_handler = hs.get_handlers().profile_handler
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, user_id, persona, key):
+ yield self.auth.get_user_by_req(request)
+
+ profile = yield self.profile_handler.get_profile_key_for_user(
+ user_id, persona, key
+ )
+
+ if profile is not None:
+ defer.returnValue((200, profile))
+ else:
+ raise NotFoundError()
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, user_id, persona, key):
+ yield self.auth.get_user_by_req(request)
+
+ content = parse_json_object_from_request(request)
+
+ yield self.profile_handler.update_profile_key(user_id, persona, key, content)
+
+ defer.returnValue((200, {}))
+
+
+def register_servlets(hs, http_server):
+ FullProfileServlet(hs).register(http_server)
+ ProfileTupleServlet(hs).register(http_server)
+ ProfilePersonaServlet(hs).register(http_server)
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 9996f195a0..f9d51aed4d 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -116,6 +116,9 @@ class DataStore(RoomMemberStore, RoomStore,
self._public_room_id_gen = StreamIdGenerator(
db_conn, "public_room_list_stream", "stream_id"
)
+ self._profiles_id_gen = StreamIdGenerator(
+ db_conn, "profiles_extended", "stream_id"
+ )
self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
diff --git a/synapse/storage/profile.py b/synapse/storage/profile.py
index 26a40905ae..a432b1eaab 100644
--- a/synapse/storage/profile.py
+++ b/synapse/storage/profile.py
@@ -13,45 +13,159 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.internet import defer
+
from ._base import SQLBaseStore
+import ujson
+
class ProfileStore(SQLBaseStore):
def create_profile(self, user_localpart):
- return self._simple_insert(
- table="profiles",
- values={"user_id": user_localpart},
- desc="create_profile",
+ return defer.succeed(None)
+
+ @defer.inlineCallbacks
+ def get_profile_displayname(self, user_id):
+ profile = yield self.get_profile_key(
+ user_id, "default", "m.display_name"
)
- def get_profile_displayname(self, user_localpart):
- return self._simple_select_one_onecol(
- table="profiles",
- keyvalues={"user_id": user_localpart},
- retcol="displayname",
- desc="get_profile_displayname",
+ if profile:
+ try:
+ display_name = profile["rows"][0]["display_name"]
+ except (KeyError, IndexError):
+ display_name = None
+ else:
+ display_name = None
+
+ defer.returnValue(display_name)
+
+ def set_profile_displayname(self, user_id, new_displayname):
+ if new_displayname:
+ content = {"rows": [{
+ "display_name": new_displayname
+ }]}
+ else:
+ # TODO: Delete in this case
+ content = {}
+
+ return self.update_profile_key(
+ user_id, "default", "m.display_name", content
)
- def set_profile_displayname(self, user_localpart, new_displayname):
- return self._simple_update_one(
- table="profiles",
- keyvalues={"user_id": user_localpart},
- updatevalues={"displayname": new_displayname},
- desc="set_profile_displayname",
+ @defer.inlineCallbacks
+ def get_profile_avatar_url(self, user_id):
+ profile = yield self.get_profile_key(
+ user_id, "default", "m.avatar_url"
)
- def get_profile_avatar_url(self, user_localpart):
- return self._simple_select_one_onecol(
- table="profiles",
- keyvalues={"user_id": user_localpart},
- retcol="avatar_url",
- desc="get_profile_avatar_url",
+ if profile:
+ try:
+ avatar_url = profile["rows"][0]["avatar_url"]
+ except (KeyError, IndexError):
+ avatar_url = None
+ else:
+ avatar_url = None
+
+ defer.returnValue(avatar_url)
+
+ def set_profile_avatar_url(self, user_id, new_avatar_url):
+ if new_avatar_url:
+ content = {"rows": [{
+ "avatar_url": new_avatar_url
+ }]}
+ else:
+ # TODO: Delete in this case
+ content = {}
+
+ return self.update_profile_key(
+ user_id, "default", "m.avatar_url", content
)
- def set_profile_avatar_url(self, user_localpart, new_avatar_url):
- return self._simple_update_one(
- table="profiles",
- keyvalues={"user_id": user_localpart},
- updatevalues={"avatar_url": new_avatar_url},
- desc="set_profile_avatar_url",
+ @defer.inlineCallbacks
+ def get_full_profile(self, user_id):
+ rows = yield self._simple_select_list(
+ table="profiles_extended",
+ keyvalues={"user_id": user_id},
+ retcols=("persona", "key", "content",),
)
+
+ personas = {}
+ profile = {"personas": personas}
+ for row in rows:
+ content = ujson.loads(row["content"])
+ personas.setdefault(
+ row["persona"], {"rows": {}}
+ )["rows"][row["key"]] = content
+
+ defer.returnValue(profile)
+
+ @defer.inlineCallbacks
+ def get_persona_profile(self, user_id, persona):
+ rows = yield self._simple_select_list(
+ table="profiles_extended",
+ keyvalues={
+ "user_id": user_id,
+ "persona": persona,
+ },
+ retcols=("key", "content",),
+ )
+
+ persona = {"properties": {
+ row["key"]: ujson.loads(row["content"])
+ for row in rows
+ }}
+
+ defer.returnValue(persona)
+
+ @defer.inlineCallbacks
+ def get_profile_key(self, user_id, persona, key):
+ content_json = yield self._simple_select_one_onecol(
+ table="profiles_extended",
+ keyvalues={
+ "user_id": user_id,
+ "persona": persona,
+ "key": key,
+ },
+ retcol="content",
+ allow_none=True,
+ )
+
+ if content_json:
+ content = ujson.loads(content_json)
+ else:
+ content = None
+
+ defer.returnValue(content)
+
+ def update_profile_key(self, user_id, persona, key, content):
+ content_json = ujson.dumps(content)
+
+ def _update_profile_key_txn(txn, stream_id):
+ self._simple_delete_txn(
+ txn,
+ table="profiles_extended",
+ keyvalues={
+ "user_id": user_id,
+ "persona": persona,
+ "key": key,
+ }
+ )
+
+ self._simple_insert_txn(
+ txn,
+ table="profiles_extended",
+ values={
+ "stream_id": stream_id,
+ "user_id": user_id,
+ "persona": persona,
+ "key": key,
+ "content": content_json,
+ }
+ )
+
+ with self._profiles_id_gen.get_next() as stream_id:
+ return self.runInteraction(
+ "update_profile_key", _update_profile_key_txn,
+ stream_id,
+ )
diff --git a/synapse/storage/schema/delta/38/profile.py b/synapse/storage/schema/delta/38/profile.py
new file mode 100644
index 0000000000..bdd014ffba
--- /dev/null
+++ b/synapse/storage/schema/delta/38/profile.py
@@ -0,0 +1,104 @@
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.storage.prepare_database import get_statements
+from synapse.storage.engines import PostgresEngine
+
+import logging
+import ujson
+
+logger = logging.getLogger(__name__)
+
+CREATE_TABLE = """
+CREATE TABLE profiles_extended (
+ stream_id BIGINT NOT NULL,
+ user_id TEXT NOT NULL,
+ persona TEXT NOT NULL, -- Which persona this field is in, e.g. `default`
+ key TEXT NOT NULL, -- the key of this field, e.g. `m.display_name`
+ content TEXT NOT NULL -- JSON encoded content of the key
+);
+
+CREATE INDEX profiles_extended_tuple ON profiles_extended(
+ user_id, persona, key, stream_id
+);
+"""
+
+POSTGRES_UPDATE_DISPLAY_NAME = """
+INSERT INTO profiles_extended (stream_id, user_id, persona, key, content)
+SELECT
+ 1,
+ '@' || user_id || ':' || %s,
+ 'm.display_name',
+ '{"rows":["display_name":' || to_json(displayname) || '}]}'
+FROM profiles WHERE displayname IS NOT NULL
+"""
+
+POSTGRES_UPDATE_AVATAR_URL = """
+INSERT INTO profiles_extended (stream_id, user_id, persona, key, content)
+SELECT
+ 1,
+ '@' || user_id || ':' || %s,
+ 'm.avatar_url',
+ '{"rows":[{"avatar_url":' || to_json(avatar_url) || '}]}'
+FROM profiles WHERE avatar_url IS NOT NULL
+"""
+
+
+def run_create(cur, database_engine, *args, **kwargs):
+ for statement in get_statements(CREATE_TABLE.splitlines()):
+ cur.execute(statement)
+
+
+def run_upgrade(cur, database_engine, config, *args, **kwargs):
+ if isinstance(database_engine, PostgresEngine):
+ cur.execute(POSTGRES_UPDATE_DISPLAY_NAME, (config.server_name,))
+ cur.execute(POSTGRES_UPDATE_AVATAR_URL, (config.server_name,))
+ else:
+ cur.execute(
+ "SELECT user_id, displayname FROM profiles WHERE displayname IS NOT NULL"
+ )
+ displaynames = []
+ for user_id, displayname in cur.fetchall():
+ displaynames.append((
+ 1,
+ "@%s:%s" % (user_id, config.server_name),
+ "default",
+ "m.display_name",
+ ujson.dumps({"rows": [{"display_name": displayname}]}),
+ ))
+ cur.executemany(
+ "INSERT INTO profiles_extended"
+ " (stream_id, user_id, persona, key, content)"
+ " VALUES (?,?,?,?,?)",
+ displaynames
+ )
+
+ cur.execute(
+ "SELECT user_id, avatar_url FROM profiles WHERE avatar_url IS NOT NULL"
+ )
+ avatar_urls = []
+ for user_id, avatar_url in cur.fetchall():
+ avatar_urls.append((
+ 1,
+ "@%s:%s" % (user_id, config.server_name),
+ "default",
+ "m.avatar_url",
+ ujson.dumps({"rows": [{"avatar_url": avatar_url}]}),
+ ))
+ cur.executemany(
+ "INSERT INTO profiles_extended"
+ " (stream_id, user_id, persona, key, content)"
+ " VALUES (?,?,?,?,?)",
+ avatar_urls
+ )
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index f1f664275f..3fd3ed323f 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -76,7 +76,7 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_my_name(self):
yield self.store.set_profile_displayname(
- self.frank.localpart, "Frank"
+ self.frank.to_string(), "Frank"
)
displayname = yield self.handler.get_displayname(self.frank)
@@ -92,7 +92,7 @@ class ProfileTestCase(unittest.TestCase):
)
self.assertEquals(
- (yield self.store.get_profile_displayname(self.frank.localpart)),
+ (yield self.store.get_profile_displayname(self.frank.to_string())),
"Frank Jr."
)
@@ -123,8 +123,7 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_incoming_fed_query(self):
- yield self.store.create_profile("caroline")
- yield self.store.set_profile_displayname("caroline", "Caroline")
+ yield self.store.set_profile_displayname("@caroline:test", "Caroline")
response = yield self.query_handlers["profile"](
{"user_id": "@caroline:test", "field": "displayname"}
@@ -135,7 +134,7 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_my_avatar(self):
yield self.store.set_profile_avatar_url(
- self.frank.localpart, "http://my.server/me.png"
+ self.frank.to_string(), "http://my.server/me.png"
)
avatar_url = yield self.handler.get_avatar_url(self.frank)
@@ -150,6 +149,6 @@ class ProfileTestCase(unittest.TestCase):
)
self.assertEquals(
- (yield self.store.get_profile_avatar_url(self.frank.localpart)),
+ (yield self.store.get_profile_avatar_url(self.frank.to_string())),
"http://my.server/pic.gif"
)
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index 24118bbc86..04603a9e9a 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -17,7 +17,6 @@
from tests import unittest
from twisted.internet import defer
-from synapse.storage.profile import ProfileStore
from synapse.types import UserID
from tests.utils import setup_test_homeserver
@@ -29,36 +28,28 @@ class ProfileStoreTestCase(unittest.TestCase):
def setUp(self):
hs = yield setup_test_homeserver()
- self.store = ProfileStore(hs)
+ self.store = hs.get_datastore()
self.u_frank = UserID.from_string("@frank:test")
@defer.inlineCallbacks
def test_displayname(self):
- yield self.store.create_profile(
- self.u_frank.localpart
- )
-
yield self.store.set_profile_displayname(
- self.u_frank.localpart, "Frank"
+ self.u_frank.to_string(), "Frank"
)
self.assertEquals(
"Frank",
- (yield self.store.get_profile_displayname(self.u_frank.localpart))
+ (yield self.store.get_profile_displayname(self.u_frank.to_string()))
)
@defer.inlineCallbacks
def test_avatar_url(self):
- yield self.store.create_profile(
- self.u_frank.localpart
- )
-
yield self.store.set_profile_avatar_url(
- self.u_frank.localpart, "http://my.site/here"
+ self.u_frank.to_string(), "http://my.site/here"
)
self.assertEquals(
"http://my.site/here",
- (yield self.store.get_profile_avatar_url(self.u_frank.localpart))
+ (yield self.store.get_profile_avatar_url(self.u_frank.to_string()))
)
|