summary refs log tree commit diff
diff options
context:
space:
mode:
authorSean Quah <seanq@matrix.org>2023-04-15 02:15:02 +0100
committerSean Quah <seanq@matrix.org>2023-04-15 02:52:42 +0100
commit96bb319d14f8c16a1f7b712ccc672dfbfe51f59c (patch)
treec8f96a7690dc1bb079ebb13de6419ff3f160b53b
parentDe-localpart `ProfileWorkerStore.get_profile_displayname()` (diff)
downloadsynapse-96bb319d14f8c16a1f7b712ccc672dfbfe51f59c.tar.xz
De-localpart `ProfileWorkerStore.get_profile_avatar_url()`
Signed-off-by: Sean Quah <seanq@matrix.org>
-rw-r--r--synapse/handlers/profile.py4
-rw-r--r--synapse/storage/databases/main/profile.py27
-rw-r--r--tests/handlers/test_profile.py24
-rw-r--r--tests/storage/test_profile.py6
4 files changed, 46 insertions, 15 deletions
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 1c5bdb15f1..aa90e38f5c 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -197,7 +197,7 @@ class ProfileHandler:
         if self.hs.is_mine(target_user):
             try:
                 avatar_url = await self.store.get_profile_avatar_url(
-                    target_user.localpart
+                    target_user.to_string()
                 )
             except StoreError as e:
                 if e.code == 404:
@@ -380,7 +380,7 @@ class ProfileHandler:
 
             if just_field is None or just_field == "avatar_url":
                 response["avatar_url"] = await self.store.get_profile_avatar_url(
-                    user.localpart
+                    user_id
                 )
         except StoreError as e:
             if e.code == 404:
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index a5e2ea9f04..12f984d433 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -78,13 +78,26 @@ class ProfileWorkerStore(SQLBaseStore):
             else:
                 raise
 
-    async def get_profile_avatar_url(self, user_localpart: str) -> Optional[str]:
-        return await self.db_pool.simple_select_one_onecol(
-            table="profiles",
-            keyvalues={"user_id": user_localpart},
-            retcol="avatar_url",
-            desc="get_profile_avatar_url",
-        )
+    async def get_profile_avatar_url(self, user_id: str) -> Optional[str]:
+        try:
+            return await self.db_pool.simple_select_one_onecol(
+                table="profiles",
+                keyvalues={"full_user_id": user_id},
+                retcol="avatar_url",
+                desc="get_profile_avatar_url",
+            )
+        except StoreError as e:
+            if e.code == 404:
+                # Fall back to the `user_id` column.
+                user_localpart = UserID.from_string(user_id).localpart
+                return await self.db_pool.simple_select_one_onecol(
+                    table="profiles",
+                    keyvalues={"user_id": user_localpart},
+                    retcol="avatar_url",
+                    desc="get_profile_avatar_url",
+                )
+            else:
+                raise
 
     async def create_profile(self, user_localpart: str) -> None:
         await self.db_pool.simple_insert(
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index d8b2797859..2cf3fd2119 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -201,7 +201,11 @@ class ProfileTestCase(unittest.HomeserverTestCase):
         )
 
         self.assertEqual(
-            (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
+            (
+                self.get_success(
+                    self.store.get_profile_avatar_url(self.frank.to_string())
+                )
+            ),
             "http://my.server/pic.gif",
         )
 
@@ -215,7 +219,11 @@ class ProfileTestCase(unittest.HomeserverTestCase):
         )
 
         self.assertEqual(
-            (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
+            (
+                self.get_success(
+                    self.store.get_profile_avatar_url(self.frank.to_string())
+                )
+            ),
             "http://my.server/me.png",
         )
 
@@ -229,7 +237,11 @@ class ProfileTestCase(unittest.HomeserverTestCase):
         )
 
         self.assertIsNone(
-            (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
+            (
+                self.get_success(
+                    self.store.get_profile_avatar_url(self.frank.to_string())
+                )
+            ),
         )
 
     def test_set_my_avatar_if_disabled(self) -> None:
@@ -243,7 +255,11 @@ class ProfileTestCase(unittest.HomeserverTestCase):
         )
 
         self.assertEqual(
-            (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
+            (
+                self.get_success(
+                    self.store.get_profile_avatar_url(self.frank.to_string())
+                )
+            ),
             "http://my.server/me.png",
         )
 
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index 702430f513..136352f838 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -66,7 +66,7 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase):
             "http://my.site/here",
             (
                 self.get_success(
-                    self.store.get_profile_avatar_url(self.u_frank.localpart)
+                    self.store.get_profile_avatar_url(self.u_frank.to_string())
                 )
             ),
         )
@@ -77,5 +77,7 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase):
         )
 
         self.assertIsNone(
-            self.get_success(self.store.get_profile_avatar_url(self.u_frank.localpart))
+            self.get_success(
+                self.store.get_profile_avatar_url(self.u_frank.to_string())
+            )
         )