diff options
-rw-r--r-- | synapse/federation/federation_client.py | 4 | ||||
-rw-r--r-- | synapse/federation/federation_server.py | 19 | ||||
-rw-r--r-- | synapse/federation/transport/client.py | 10 | ||||
-rw-r--r-- | synapse/federation/transport/server.py | 14 | ||||
-rw-r--r-- | synapse/handlers/message.py | 9 | ||||
-rw-r--r-- | synapse/handlers/profile.py | 46 | ||||
-rw-r--r-- | synapse/handlers/room_member.py | 10 | ||||
-rw-r--r-- | synapse/push/mailer.py | 3 | ||||
-rw-r--r-- | synapse/rest/__init__.py | 2 | ||||
-rw-r--r-- | synapse/rest/client/v2_alpha/profiles_extended.py | 114 | ||||
-rw-r--r-- | synapse/storage/__init__.py | 3 | ||||
-rw-r--r-- | synapse/storage/profile.py | 170 | ||||
-rw-r--r-- | synapse/storage/schema/delta/38/profile.py | 104 | ||||
-rw-r--r-- | tests/handlers/test_profile.py | 11 | ||||
-rw-r--r-- | tests/storage/test_profile.py | 19 |
15 files changed, 474 insertions, 64 deletions
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 94e76b1978..972c1fa91a 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -878,6 +878,10 @@ class FederationClient(FederationBase): defer.returnValue(signed_events) + def get_profile(self, user_id, persona=None, key=None): + destination = get_domain_from_id(user_id) + return self.transport_layer.get_profile(destination, user_id, persona, key) + def event_from_pdu_json(self, pdu_json, outlier=False): event = FrozenEvent( pdu_json diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 3fa7b2315c..9287600cc8 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -442,6 +442,25 @@ class FederationServer(FederationBase): "events": [ev.get_pdu_json(time_now) for ev in missing_events], }) + def on_profile_request(self, user_id, persona, key): + """Handle a /profile/ request. Persona and key parameters are optional. + + Args: + user_id (str) + persona (str): Optional if `key` not also set. Returns only info from + the given persona. + key (str): Optional. Returns only the given `key`. + """ + if not self.hs.is_mine_id(user_id): + raise SynapseError(400, "Not a local user") + + if key is not None: + return self.store.get_profile_key(user_id, persona, key) + elif persona is not None: + return self.store.get_persona_profile(user_id, persona) + else: + return self.store.get_full_profile(user_id) + @log_function def on_openid_userinfo(self, token): ts_now_ms = self._clock.time_msec() diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index db45c7826c..51ae3a56fb 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -396,3 +396,13 @@ class TransportLayerClient(object): ) defer.returnValue(content) + + def get_profile(self, destination, user_id, persona=None, key=None): + if key: + path = PREFIX + "/profile/%s/%s/%s" % (user_id, persona, key) + elif persona: + path = PREFIX + "/profile/%s/%s/" % (user_id, persona) + else: + path = PREFIX + "/profile/%s/" % (user_id,) + + return self.client.get_json(destination, path) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index fec337be64..103cfd4427 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -578,6 +578,19 @@ class FederationVersionServlet(BaseFederationServlet): })) +class FederationProfileServlet(BaseFederationServlet): + # This matches all three of: + # - /profile/@foo:bar/ + # - /profile/@foo:bar/default/ + # - /profile/@foo:bar/default/m.displayname + PATH = "/profile/(?P<user_id>[^/]+)/((?P<persona>[^/]+)/(?P<key>[^/]+)?)?$" + + @defer.inlineCallbacks + def on_GET(self, origin, content, query, user_id, persona, key): + profile = yield self.handler.on_profile_request(user_id, persona, key) + defer.returnValue((200, profile)) + + SERVLET_CLASSES = ( FederationSendServlet, FederationPullServlet, @@ -602,6 +615,7 @@ SERVLET_CLASSES = ( OpenIdUserInfo, PublicRoomList, FederationVersionServlet, + FederationProfileServlet, ) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index abfa8c65a4..5a20a847ee 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -202,8 +202,13 @@ class MessageHandler(BaseHandler): content = builder.content try: - content["displayname"] = yield profile.get_displayname(target) - content["avatar_url"] = yield profile.get_avatar_url(target) + display_name = yield profile.get_displayname(target) + if display_name: + content["displayname"] = display_name + + avatar_url = yield profile.get_avatar_url(target) + if avatar_url: + content["avatar_url"] = avatar_url except Exception as e: logger.info( "Failed to get profile information for %r: %s", diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 87f74dfb8e..d0adf4e934 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -39,11 +39,11 @@ class ProfileHandler(BaseHandler): @defer.inlineCallbacks def get_displayname(self, target_user): if self.hs.is_mine(target_user): - displayname = yield self.store.get_profile_displayname( - target_user.localpart + display_name = yield self.store.get_profile_displayname( + target_user.to_string(), ) - defer.returnValue(displayname) + defer.returnValue(display_name) else: try: result = yield self.federation.make_query( @@ -78,7 +78,7 @@ class ProfileHandler(BaseHandler): new_displayname = None yield self.store.set_profile_displayname( - target_user.localpart, new_displayname + target_user.to_string(), new_displayname ) yield self._update_join_states(requester) @@ -87,7 +87,7 @@ class ProfileHandler(BaseHandler): def get_avatar_url(self, target_user): if self.hs.is_mine(target_user): avatar_url = yield self.store.get_profile_avatar_url( - target_user.localpart + target_user.to_string(), ) defer.returnValue(avatar_url) @@ -121,7 +121,7 @@ class ProfileHandler(BaseHandler): raise AuthError(400, "Cannot set another user's avatar_url") yield self.store.set_profile_avatar_url( - target_user.localpart, new_avatar_url + target_user.to_string(), new_avatar_url ) yield self._update_join_states(requester) @@ -137,13 +137,13 @@ class ProfileHandler(BaseHandler): response = {} if just_field is None or just_field == "displayname": - response["displayname"] = yield self.store.get_profile_displayname( - user.localpart + response["displayname"] = yield self.get_displayname( + user ) if just_field is None or just_field == "avatar_url": - response["avatar_url"] = yield self.store.get_profile_avatar_url( - user.localpart + response["avatar_url"] = yield self.get_avatar_url( + user ) defer.returnValue(response) @@ -180,3 +180,29 @@ class ProfileHandler(BaseHandler): "Failed to update join event for room %s - %s", j.room_id, str(e.message) ) + + def get_full_profile_for_user(self, user_id): + if self.hs.is_mine_id(user_id): + return self.store.get_full_profile(user_id) + else: + return self.federation.get_profile(user_id) + + def get_persona_profile_for_user(self, user_id, persona): + if self.hs.is_mine_id(user_id): + return self.store.get_persona_profile(user_id, persona) + else: + return self.federation.get_profile(user_id, persona) + + def get_profile_key_for_user(self, user_id, persona, key): + if self.hs.is_mine_id(user_id): + return self.store.get_profile_key(user_id, persona, key) + else: + return self.federation.get_profile(user_id, persona, key) + + def update_profile_key(self, user_id, persona, key, content): + if self.hs.is_mine_id(user_id): + return self.store.update_profile_key( + user_id, persona, key, content + ) + else: + raise AuthError("Cannot set a remote profile") diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index ba49075a20..73dcf326a6 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -245,8 +245,14 @@ class RoomMemberHandler(BaseHandler): content["membership"] = Membership.JOIN profile = self.hs.get_handlers().profile_handler - content["displayname"] = yield profile.get_displayname(target) - content["avatar_url"] = yield profile.get_avatar_url(target) + + display_name = yield profile.get_displayname(target) + if display_name: + content["displayname"] = display_name + + avatar_url = yield profile.get_avatar_url(target) + if avatar_url: + content["avatar_url"] = avatar_url if requester.is_guest: content["kind"] = "guest" diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 53551632b6..d72f98dee7 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -25,7 +25,6 @@ from synapse.util.async import concurrently_execute from synapse.push.presentable_names import ( calculate_room_name, name_from_member_event, descriptor_from_member_events ) -from synapse.types import UserID from synapse.api.errors import StoreError from synapse.api.constants import EventTypes from synapse.visibility import filter_events_for_client @@ -130,7 +129,7 @@ class Mailer(object): try: user_display_name = yield self.store.get_profile_displayname( - UserID.from_string(user_id).localpart + user_id ) if user_display_name is None: user_display_name = user_id diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index f9f5a3e077..d6fc666195 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -50,6 +50,7 @@ from synapse.rest.client.v2_alpha import ( devices, thirdparty, sendtodevice, + profiles_extended, ) from synapse.http.server import JsonResource @@ -98,3 +99,4 @@ class ClientRestResource(JsonResource): devices.register_servlets(hs, client_resource) thirdparty.register_servlets(hs, client_resource) sendtodevice.register_servlets(hs, client_resource) + profiles_extended.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/v2_alpha/profiles_extended.py b/synapse/rest/client/v2_alpha/profiles_extended.py new file mode 100644 index 0000000000..a0520d33e0 --- /dev/null +++ b/synapse/rest/client/v2_alpha/profiles_extended.py @@ -0,0 +1,114 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ._base import client_v2_patterns + +from synapse.api.errors import NotFoundError +from synapse.http.servlet import RestServlet, parse_json_object_from_request +from twisted.internet import defer + +import logging + +logger = logging.getLogger(__name__) + + +class FullProfileServlet(RestServlet): + PATTERNS = client_v2_patterns( + "/profile_extended/(?P<user_id>[^/]+)/$" + ) + + EXPIRES_MS = 3600 * 1000 + + def __init__(self, hs): + super(FullProfileServlet, self).__init__() + self.auth = hs.get_auth() + self.profile_handler = hs.get_handlers().profile_handler + + @defer.inlineCallbacks + def on_GET(self, request, user_id): + yield self.auth.get_user_by_req(request) + + profile = yield self.profile_handler.get_full_profile_for_user(user_id) + + defer.returnValue((200, profile)) + + +class ProfilePersonaServlet(RestServlet): + PATTERNS = client_v2_patterns( + "/profile_extended/(?P<user_id>[^/]+)/(?P<persona>[^/]+)/$" + ) + + EXPIRES_MS = 3600 * 1000 + + def __init__(self, hs): + super(ProfilePersonaServlet, self).__init__() + self.auth = hs.get_auth() + self.profile_handler = hs.get_handlers().profile_handler + + @defer.inlineCallbacks + def on_GET(self, request, user_id, persona): + yield self.auth.get_user_by_req(request) + + profile = yield self.profile_handler.get_persona_profile_for_user( + user_id, persona + ) + + if profile: + defer.returnValue((200, profile)) + else: + raise NotFoundError() + + +class ProfileTupleServlet(RestServlet): + PATTERNS = client_v2_patterns( + "/profile_extended/(?P<user_id>[^/]+)/(?P<persona>[^/]+)/(?P<key>[^/]+)$" + ) + + EXPIRES_MS = 3600 * 1000 + + def __init__(self, hs): + super(ProfileTupleServlet, self).__init__() + self.auth = hs.get_auth() + self.profile_handler = hs.get_handlers().profile_handler + + @defer.inlineCallbacks + def on_GET(self, request, user_id, persona, key): + yield self.auth.get_user_by_req(request) + + profile = yield self.profile_handler.get_profile_key_for_user( + user_id, persona, key + ) + + if profile is not None: + defer.returnValue((200, profile)) + else: + raise NotFoundError() + + @defer.inlineCallbacks + def on_PUT(self, request, user_id, persona, key): + yield self.auth.get_user_by_req(request) + + content = parse_json_object_from_request(request) + + yield self.profile_handler.update_profile_key(user_id, persona, key, content) + + defer.returnValue((200, {})) + + +def register_servlets(hs, http_server): + FullProfileServlet(hs).register(http_server) + ProfileTupleServlet(hs).register(http_server) + ProfilePersonaServlet(hs).register(http_server) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 9996f195a0..f9d51aed4d 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -116,6 +116,9 @@ class DataStore(RoomMemberStore, RoomStore, self._public_room_id_gen = StreamIdGenerator( db_conn, "public_room_list_stream", "stream_id" ) + self._profiles_id_gen = StreamIdGenerator( + db_conn, "profiles_extended", "stream_id" + ) self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id") self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id") diff --git a/synapse/storage/profile.py b/synapse/storage/profile.py index 26a40905ae..a432b1eaab 100644 --- a/synapse/storage/profile.py +++ b/synapse/storage/profile.py @@ -13,45 +13,159 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.internet import defer + from ._base import SQLBaseStore +import ujson + class ProfileStore(SQLBaseStore): def create_profile(self, user_localpart): - return self._simple_insert( - table="profiles", - values={"user_id": user_localpart}, - desc="create_profile", + return defer.succeed(None) + + @defer.inlineCallbacks + def get_profile_displayname(self, user_id): + profile = yield self.get_profile_key( + user_id, "default", "m.display_name" ) - def get_profile_displayname(self, user_localpart): - return self._simple_select_one_onecol( - table="profiles", - keyvalues={"user_id": user_localpart}, - retcol="displayname", - desc="get_profile_displayname", + if profile: + try: + display_name = profile["rows"][0]["display_name"] + except (KeyError, IndexError): + display_name = None + else: + display_name = None + + defer.returnValue(display_name) + + def set_profile_displayname(self, user_id, new_displayname): + if new_displayname: + content = {"rows": [{ + "display_name": new_displayname + }]} + else: + # TODO: Delete in this case + content = {} + + return self.update_profile_key( + user_id, "default", "m.display_name", content ) - def set_profile_displayname(self, user_localpart, new_displayname): - return self._simple_update_one( - table="profiles", - keyvalues={"user_id": user_localpart}, - updatevalues={"displayname": new_displayname}, - desc="set_profile_displayname", + @defer.inlineCallbacks + def get_profile_avatar_url(self, user_id): + profile = yield self.get_profile_key( + user_id, "default", "m.avatar_url" ) - def get_profile_avatar_url(self, user_localpart): - return self._simple_select_one_onecol( - table="profiles", - keyvalues={"user_id": user_localpart}, - retcol="avatar_url", - desc="get_profile_avatar_url", + if profile: + try: + avatar_url = profile["rows"][0]["avatar_url"] + except (KeyError, IndexError): + avatar_url = None + else: + avatar_url = None + + defer.returnValue(avatar_url) + + def set_profile_avatar_url(self, user_id, new_avatar_url): + if new_avatar_url: + content = {"rows": [{ + "avatar_url": new_avatar_url + }]} + else: + # TODO: Delete in this case + content = {} + + return self.update_profile_key( + user_id, "default", "m.avatar_url", content ) - def set_profile_avatar_url(self, user_localpart, new_avatar_url): - return self._simple_update_one( - table="profiles", - keyvalues={"user_id": user_localpart}, - updatevalues={"avatar_url": new_avatar_url}, - desc="set_profile_avatar_url", + @defer.inlineCallbacks + def get_full_profile(self, user_id): + rows = yield self._simple_select_list( + table="profiles_extended", + keyvalues={"user_id": user_id}, + retcols=("persona", "key", "content",), ) + + personas = {} + profile = {"personas": personas} + for row in rows: + content = ujson.loads(row["content"]) + personas.setdefault( + row["persona"], {"rows": {}} + )["rows"][row["key"]] = content + + defer.returnValue(profile) + + @defer.inlineCallbacks + def get_persona_profile(self, user_id, persona): + rows = yield self._simple_select_list( + table="profiles_extended", + keyvalues={ + "user_id": user_id, + "persona": persona, + }, + retcols=("key", "content",), + ) + + persona = {"properties": { + row["key"]: ujson.loads(row["content"]) + for row in rows + }} + + defer.returnValue(persona) + + @defer.inlineCallbacks + def get_profile_key(self, user_id, persona, key): + content_json = yield self._simple_select_one_onecol( + table="profiles_extended", + keyvalues={ + "user_id": user_id, + "persona": persona, + "key": key, + }, + retcol="content", + allow_none=True, + ) + + if content_json: + content = ujson.loads(content_json) + else: + content = None + + defer.returnValue(content) + + def update_profile_key(self, user_id, persona, key, content): + content_json = ujson.dumps(content) + + def _update_profile_key_txn(txn, stream_id): + self._simple_delete_txn( + txn, + table="profiles_extended", + keyvalues={ + "user_id": user_id, + "persona": persona, + "key": key, + } + ) + + self._simple_insert_txn( + txn, + table="profiles_extended", + values={ + "stream_id": stream_id, + "user_id": user_id, + "persona": persona, + "key": key, + "content": content_json, + } + ) + + with self._profiles_id_gen.get_next() as stream_id: + return self.runInteraction( + "update_profile_key", _update_profile_key_txn, + stream_id, + ) diff --git a/synapse/storage/schema/delta/38/profile.py b/synapse/storage/schema/delta/38/profile.py new file mode 100644 index 0000000000..bdd014ffba --- /dev/null +++ b/synapse/storage/schema/delta/38/profile.py @@ -0,0 +1,104 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.storage.prepare_database import get_statements +from synapse.storage.engines import PostgresEngine + +import logging +import ujson + +logger = logging.getLogger(__name__) + +CREATE_TABLE = """ +CREATE TABLE profiles_extended ( + stream_id BIGINT NOT NULL, + user_id TEXT NOT NULL, + persona TEXT NOT NULL, -- Which persona this field is in, e.g. `default` + key TEXT NOT NULL, -- the key of this field, e.g. `m.display_name` + content TEXT NOT NULL -- JSON encoded content of the key +); + +CREATE INDEX profiles_extended_tuple ON profiles_extended( + user_id, persona, key, stream_id +); +""" + +POSTGRES_UPDATE_DISPLAY_NAME = """ +INSERT INTO profiles_extended (stream_id, user_id, persona, key, content) +SELECT + 1, + '@' || user_id || ':' || %s, + 'm.display_name', + '{"rows":["display_name":' || to_json(displayname) || '}]}' +FROM profiles WHERE displayname IS NOT NULL +""" + +POSTGRES_UPDATE_AVATAR_URL = """ +INSERT INTO profiles_extended (stream_id, user_id, persona, key, content) +SELECT + 1, + '@' || user_id || ':' || %s, + 'm.avatar_url', + '{"rows":[{"avatar_url":' || to_json(avatar_url) || '}]}' +FROM profiles WHERE avatar_url IS NOT NULL +""" + + +def run_create(cur, database_engine, *args, **kwargs): + for statement in get_statements(CREATE_TABLE.splitlines()): + cur.execute(statement) + + +def run_upgrade(cur, database_engine, config, *args, **kwargs): + if isinstance(database_engine, PostgresEngine): + cur.execute(POSTGRES_UPDATE_DISPLAY_NAME, (config.server_name,)) + cur.execute(POSTGRES_UPDATE_AVATAR_URL, (config.server_name,)) + else: + cur.execute( + "SELECT user_id, displayname FROM profiles WHERE displayname IS NOT NULL" + ) + displaynames = [] + for user_id, displayname in cur.fetchall(): + displaynames.append(( + 1, + "@%s:%s" % (user_id, config.server_name), + "default", + "m.display_name", + ujson.dumps({"rows": [{"display_name": displayname}]}), + )) + cur.executemany( + "INSERT INTO profiles_extended" + " (stream_id, user_id, persona, key, content)" + " VALUES (?,?,?,?,?)", + displaynames + ) + + cur.execute( + "SELECT user_id, avatar_url FROM profiles WHERE avatar_url IS NOT NULL" + ) + avatar_urls = [] + for user_id, avatar_url in cur.fetchall(): + avatar_urls.append(( + 1, + "@%s:%s" % (user_id, config.server_name), + "default", + "m.avatar_url", + ujson.dumps({"rows": [{"avatar_url": avatar_url}]}), + )) + cur.executemany( + "INSERT INTO profiles_extended" + " (stream_id, user_id, persona, key, content)" + " VALUES (?,?,?,?,?)", + avatar_urls + ) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index f1f664275f..3fd3ed323f 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -76,7 +76,7 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_my_name(self): yield self.store.set_profile_displayname( - self.frank.localpart, "Frank" + self.frank.to_string(), "Frank" ) displayname = yield self.handler.get_displayname(self.frank) @@ -92,7 +92,7 @@ class ProfileTestCase(unittest.TestCase): ) self.assertEquals( - (yield self.store.get_profile_displayname(self.frank.localpart)), + (yield self.store.get_profile_displayname(self.frank.to_string())), "Frank Jr." ) @@ -123,8 +123,7 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_incoming_fed_query(self): - yield self.store.create_profile("caroline") - yield self.store.set_profile_displayname("caroline", "Caroline") + yield self.store.set_profile_displayname("@caroline:test", "Caroline") response = yield self.query_handlers["profile"]( {"user_id": "@caroline:test", "field": "displayname"} @@ -135,7 +134,7 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_my_avatar(self): yield self.store.set_profile_avatar_url( - self.frank.localpart, "http://my.server/me.png" + self.frank.to_string(), "http://my.server/me.png" ) avatar_url = yield self.handler.get_avatar_url(self.frank) @@ -150,6 +149,6 @@ class ProfileTestCase(unittest.TestCase): ) self.assertEquals( - (yield self.store.get_profile_avatar_url(self.frank.localpart)), + (yield self.store.get_profile_avatar_url(self.frank.to_string())), "http://my.server/pic.gif" ) diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index 24118bbc86..04603a9e9a 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -17,7 +17,6 @@ from tests import unittest from twisted.internet import defer -from synapse.storage.profile import ProfileStore from synapse.types import UserID from tests.utils import setup_test_homeserver @@ -29,36 +28,28 @@ class ProfileStoreTestCase(unittest.TestCase): def setUp(self): hs = yield setup_test_homeserver() - self.store = ProfileStore(hs) + self.store = hs.get_datastore() self.u_frank = UserID.from_string("@frank:test") @defer.inlineCallbacks def test_displayname(self): - yield self.store.create_profile( - self.u_frank.localpart - ) - yield self.store.set_profile_displayname( - self.u_frank.localpart, "Frank" + self.u_frank.to_string(), "Frank" ) self.assertEquals( "Frank", - (yield self.store.get_profile_displayname(self.u_frank.localpart)) + (yield self.store.get_profile_displayname(self.u_frank.to_string())) ) @defer.inlineCallbacks def test_avatar_url(self): - yield self.store.create_profile( - self.u_frank.localpart - ) - yield self.store.set_profile_avatar_url( - self.u_frank.localpart, "http://my.site/here" + self.u_frank.to_string(), "http://my.site/here" ) self.assertEquals( "http://my.site/here", - (yield self.store.get_profile_avatar_url(self.u_frank.localpart)) + (yield self.store.get_profile_avatar_url(self.u_frank.to_string())) ) |