summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2016-10-31 15:00:19 +0000
committerErik Johnston <erik@matrix.org>2016-10-31 15:52:36 +0000
commitbabfa01cc7645337299ee4d4ec6fb377b48f89ab (patch)
treee6819df43606e82ed680029d6d63a27993f22e24
parentMigrate old profile data (diff)
downloadsynapse-babfa01cc7645337299ee4d4ec6fb377b48f89ab.tar.xz
Use new tables
-rw-r--r--synapse/handlers/message.py9
-rw-r--r--synapse/handlers/profile.py44
-rw-r--r--synapse/handlers/room_member.py10
-rw-r--r--synapse/push/mailer.py3
-rw-r--r--synapse/storage/profile.py80
-rw-r--r--tests/handlers/test_profile.py11
-rw-r--r--tests/storage/test_profile.py19
7 files changed, 87 insertions, 89 deletions
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index abfa8c65a4..5a20a847ee 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -202,8 +202,13 @@ class MessageHandler(BaseHandler):
                 content = builder.content
 
                 try:
-                    content["displayname"] = yield profile.get_displayname(target)
-                    content["avatar_url"] = yield profile.get_avatar_url(target)
+                    display_name = yield profile.get_displayname(target)
+                    if display_name:
+                        content["displayname"] = display_name
+
+                    avatar_url = yield profile.get_avatar_url(target)
+                    if avatar_url:
+                        content["avatar_url"] = avatar_url
                 except Exception as e:
                     logger.info(
                         "Failed to get profile information for %r: %s",
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index c1f6d88fa2..d0adf4e934 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -39,11 +39,11 @@ class ProfileHandler(BaseHandler):
     @defer.inlineCallbacks
     def get_displayname(self, target_user):
         if self.hs.is_mine(target_user):
-            displayname = yield self.store.get_profile_displayname(
-                target_user.localpart
+            display_name = yield self.store.get_profile_displayname(
+                target_user.to_string(),
             )
 
-            defer.returnValue(displayname)
+            defer.returnValue(display_name)
         else:
             try:
                 result = yield self.federation.make_query(
@@ -78,19 +78,7 @@ class ProfileHandler(BaseHandler):
             new_displayname = None
 
         yield self.store.set_profile_displayname(
-            target_user.localpart, new_displayname
-        )
-
-        if new_displayname:
-            content = {"rows": [{
-                "display_name": new_displayname
-            }]}
-        else:
-            # TODO: Delete in this case
-            content = {}
-
-        yield self.store.update_profile_key(
-            target_user.to_string(), "default", "m.display_name", content
+            target_user.to_string(), new_displayname
         )
 
         yield self._update_join_states(requester)
@@ -99,7 +87,7 @@ class ProfileHandler(BaseHandler):
     def get_avatar_url(self, target_user):
         if self.hs.is_mine(target_user):
             avatar_url = yield self.store.get_profile_avatar_url(
-                target_user.localpart
+                target_user.to_string(),
             )
 
             defer.returnValue(avatar_url)
@@ -133,19 +121,7 @@ class ProfileHandler(BaseHandler):
             raise AuthError(400, "Cannot set another user's avatar_url")
 
         yield self.store.set_profile_avatar_url(
-            target_user.localpart, new_avatar_url
-        )
-
-        if new_avatar_url:
-            content = {"rows": [{
-                "url": new_avatar_url
-            }]}
-        else:
-            # TODO: Delete in this case
-            content = {}
-
-        yield self.store.update_profile_key(
-            target_user.to_string(), "default", "m.avatar_url", content
+            target_user.to_string(), new_avatar_url
         )
 
         yield self._update_join_states(requester)
@@ -161,13 +137,13 @@ class ProfileHandler(BaseHandler):
         response = {}
 
         if just_field is None or just_field == "displayname":
-            response["displayname"] = yield self.store.get_profile_displayname(
-                user.localpart
+            response["displayname"] = yield self.get_displayname(
+                user
             )
 
         if just_field is None or just_field == "avatar_url":
-            response["avatar_url"] = yield self.store.get_profile_avatar_url(
-                user.localpart
+            response["avatar_url"] = yield self.get_avatar_url(
+                user
             )
 
         defer.returnValue(response)
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index ba49075a20..73dcf326a6 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -245,8 +245,14 @@ class RoomMemberHandler(BaseHandler):
                 content["membership"] = Membership.JOIN
 
                 profile = self.hs.get_handlers().profile_handler
-                content["displayname"] = yield profile.get_displayname(target)
-                content["avatar_url"] = yield profile.get_avatar_url(target)
+
+                display_name = yield profile.get_displayname(target)
+                if display_name:
+                    content["displayname"] = display_name
+
+                avatar_url = yield profile.get_avatar_url(target)
+                if avatar_url:
+                    content["avatar_url"] = avatar_url
 
                 if requester.is_guest:
                     content["kind"] = "guest"
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index 53551632b6..d72f98dee7 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -25,7 +25,6 @@ from synapse.util.async import concurrently_execute
 from synapse.push.presentable_names import (
     calculate_room_name, name_from_member_event, descriptor_from_member_events
 )
-from synapse.types import UserID
 from synapse.api.errors import StoreError
 from synapse.api.constants import EventTypes
 from synapse.visibility import filter_events_for_client
@@ -130,7 +129,7 @@ class Mailer(object):
 
         try:
             user_display_name = yield self.store.get_profile_displayname(
-                UserID.from_string(user_id).localpart
+                user_id
             )
             if user_display_name is None:
                 user_display_name = user_id
diff --git a/synapse/storage/profile.py b/synapse/storage/profile.py
index f0e281a483..94415c9ead 100644
--- a/synapse/storage/profile.py
+++ b/synapse/storage/profile.py
@@ -22,42 +22,64 @@ import ujson
 
 class ProfileStore(SQLBaseStore):
     def create_profile(self, user_localpart):
-        return self._simple_insert(
-            table="profiles",
-            values={"user_id": user_localpart},
-            desc="create_profile",
-        )
+        return defer.succeed(None)
 
-    def get_profile_displayname(self, user_localpart):
-        return self._simple_select_one_onecol(
-            table="profiles",
-            keyvalues={"user_id": user_localpart},
-            retcol="displayname",
-            desc="get_profile_displayname",
+    @defer.inlineCallbacks
+    def get_profile_displayname(self, user_id):
+        profile = yield self.get_profile_key(
+            user_id, "default", "m.display_name"
         )
 
-    def set_profile_displayname(self, user_localpart, new_displayname):
-        return self._simple_update_one(
-            table="profiles",
-            keyvalues={"user_id": user_localpart},
-            updatevalues={"displayname": new_displayname},
-            desc="set_profile_displayname",
+        if profile:
+            try:
+                display_name = profile["rows"][0]["display_name"]
+            except (KeyError, IndexError):
+                display_name = None
+        else:
+            display_name = None
+
+        defer.returnValue(display_name)
+
+    def set_profile_displayname(self, user_id, new_displayname):
+        if new_displayname:
+            content = {"rows": [{
+                "display_name": new_displayname
+            }]}
+        else:
+            # TODO: Delete in this case
+            content = {}
+
+        return self.update_profile_key(
+            user_id, "default", "m.display_name", content
         )
 
-    def get_profile_avatar_url(self, user_localpart):
-        return self._simple_select_one_onecol(
-            table="profiles",
-            keyvalues={"user_id": user_localpart},
-            retcol="avatar_url",
-            desc="get_profile_avatar_url",
+    @defer.inlineCallbacks
+    def get_profile_avatar_url(self, user_id):
+        profile = yield self.get_profile_key(
+            user_id, "default", "m.avatar_url"
         )
 
-    def set_profile_avatar_url(self, user_localpart, new_avatar_url):
-        return self._simple_update_one(
-            table="profiles",
-            keyvalues={"user_id": user_localpart},
-            updatevalues={"avatar_url": new_avatar_url},
-            desc="set_profile_avatar_url",
+        if profile:
+            try:
+                avatar_url = profile["rows"][0]["avatar_url"]
+            except (KeyError, IndexError):
+                avatar_url = None
+        else:
+            avatar_url = None
+
+        defer.returnValue(avatar_url)
+
+    def set_profile_avatar_url(self, user_id, new_avatar_url):
+        if new_avatar_url:
+            content = {"rows": [{
+                "avatar_url": new_avatar_url
+            }]}
+        else:
+            # TODO: Delete in this case
+            content = {}
+
+        return self.update_profile_key(
+            user_id, "default", "m.avatar_url", content
         )
 
     @defer.inlineCallbacks
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index f1f664275f..3fd3ed323f 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -76,7 +76,7 @@ class ProfileTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_get_my_name(self):
         yield self.store.set_profile_displayname(
-            self.frank.localpart, "Frank"
+            self.frank.to_string(), "Frank"
         )
 
         displayname = yield self.handler.get_displayname(self.frank)
@@ -92,7 +92,7 @@ class ProfileTestCase(unittest.TestCase):
         )
 
         self.assertEquals(
-            (yield self.store.get_profile_displayname(self.frank.localpart)),
+            (yield self.store.get_profile_displayname(self.frank.to_string())),
             "Frank Jr."
         )
 
@@ -123,8 +123,7 @@ class ProfileTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_incoming_fed_query(self):
-        yield self.store.create_profile("caroline")
-        yield self.store.set_profile_displayname("caroline", "Caroline")
+        yield self.store.set_profile_displayname("@caroline:test", "Caroline")
 
         response = yield self.query_handlers["profile"](
             {"user_id": "@caroline:test", "field": "displayname"}
@@ -135,7 +134,7 @@ class ProfileTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_get_my_avatar(self):
         yield self.store.set_profile_avatar_url(
-            self.frank.localpart, "http://my.server/me.png"
+            self.frank.to_string(), "http://my.server/me.png"
         )
 
         avatar_url = yield self.handler.get_avatar_url(self.frank)
@@ -150,6 +149,6 @@ class ProfileTestCase(unittest.TestCase):
         )
 
         self.assertEquals(
-            (yield self.store.get_profile_avatar_url(self.frank.localpart)),
+            (yield self.store.get_profile_avatar_url(self.frank.to_string())),
             "http://my.server/pic.gif"
         )
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index 24118bbc86..04603a9e9a 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -17,7 +17,6 @@
 from tests import unittest
 from twisted.internet import defer
 
-from synapse.storage.profile import ProfileStore
 from synapse.types import UserID
 
 from tests.utils import setup_test_homeserver
@@ -29,36 +28,28 @@ class ProfileStoreTestCase(unittest.TestCase):
     def setUp(self):
         hs = yield setup_test_homeserver()
 
-        self.store = ProfileStore(hs)
+        self.store = hs.get_datastore()
 
         self.u_frank = UserID.from_string("@frank:test")
 
     @defer.inlineCallbacks
     def test_displayname(self):
-        yield self.store.create_profile(
-            self.u_frank.localpart
-        )
-
         yield self.store.set_profile_displayname(
-            self.u_frank.localpart, "Frank"
+            self.u_frank.to_string(), "Frank"
         )
 
         self.assertEquals(
             "Frank",
-            (yield self.store.get_profile_displayname(self.u_frank.localpart))
+            (yield self.store.get_profile_displayname(self.u_frank.to_string()))
         )
 
     @defer.inlineCallbacks
     def test_avatar_url(self):
-        yield self.store.create_profile(
-            self.u_frank.localpart
-        )
-
         yield self.store.set_profile_avatar_url(
-            self.u_frank.localpart, "http://my.site/here"
+            self.u_frank.to_string(), "http://my.site/here"
         )
 
         self.assertEquals(
             "http://my.site/here",
-            (yield self.store.get_profile_avatar_url(self.u_frank.localpart))
+            (yield self.store.get_profile_avatar_url(self.u_frank.to_string()))
         )