summary refs log tree commit diff
diff options
context:
space:
mode:
authorSean Quah <seanq@matrix.org>2023-04-15 02:02:01 +0100
committerSean Quah <seanq@matrix.org>2023-04-15 02:52:42 +0100
commite6c582095f33542a89e2b42a6b8506716ff64615 (patch)
tree6f9f9dbd7c80d200ea226c4e8b056c050c8fe603
parentDe-localpart `ProfileWorkerStore.get_profileinfo()` (diff)
downloadsynapse-e6c582095f33542a89e2b42a6b8506716ff64615.tar.xz
De-localpart `ProfileWorkerStore.get_profile_displayname()`
Signed-off-by: Sean Quah <seanq@matrix.org>
-rw-r--r--synapse/handlers/account_validity.py5
-rw-r--r--synapse/handlers/profile.py7
-rw-r--r--synapse/push/mailer.py6
-rw-r--r--synapse/storage/databases/main/profile.py27
-rw-r--r--tests/handlers/test_profile.py8
-rw-r--r--tests/module_api/test_api.py2
-rw-r--r--tests/storage/test_profile.py6
7 files changed, 36 insertions, 25 deletions
diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index 4aa4ebf7e4..e0efc93f2e 100644
--- a/synapse/handlers/account_validity.py
+++ b/synapse/handlers/account_validity.py
@@ -19,7 +19,6 @@ from typing import TYPE_CHECKING, List, Optional, Tuple
 
 from synapse.api.errors import AuthError, StoreError, SynapseError
 from synapse.metrics.background_process_metrics import wrap_as_background_process
-from synapse.types import UserID
 from synapse.util import stringutils
 from synapse.util.async_helpers import delay_cancellation
 
@@ -163,9 +162,7 @@ class AccountValidityHandler:
             return
 
         try:
-            user_display_name = await self.store.get_profile_displayname(
-                UserID.from_string(user_id).localpart
-            )
+            user_display_name = await self.store.get_profile_displayname(user_id)
             if user_display_name is None:
                 user_display_name = user_id
         except StoreError:
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 4fa5a8611f..1c5bdb15f1 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -100,7 +100,7 @@ class ProfileHandler:
         if self.hs.is_mine(target_user):
             try:
                 displayname = await self.store.get_profile_displayname(
-                    target_user.localpart
+                    target_user.to_string()
                 )
             except StoreError as e:
                 if e.code == 404:
@@ -364,7 +364,8 @@ class ProfileHandler:
                 Codes.FORBIDDEN,
             )
 
-        user = UserID.from_string(args["user_id"])
+        user_id = args["user_id"]
+        user = UserID.from_string(user_id)
         if not self.hs.is_mine(user):
             raise SynapseError(400, "User is not hosted on this homeserver")
 
@@ -374,7 +375,7 @@ class ProfileHandler:
         try:
             if just_field is None or just_field == "displayname":
                 response["displayname"] = await self.store.get_profile_displayname(
-                    user.localpart
+                    user_id
                 )
 
             if just_field is None or just_field == "avatar_url":
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index 491a09b71d..bf9cd4109c 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -37,7 +37,7 @@ from synapse.push.push_types import (
     TemplateVars,
 )
 from synapse.storage.databases.main.event_push_actions import EmailPushAction
-from synapse.types import StateMap, UserID
+from synapse.types import StateMap
 from synapse.types.state import StateFilter
 from synapse.util.async_helpers import concurrently_execute
 from synapse.visibility import filter_events_for_client
@@ -246,9 +246,7 @@ class Mailer:
         state_by_room = {}
 
         try:
-            user_display_name = await self.store.get_profile_displayname(
-                UserID.from_string(user_id).localpart
-            )
+            user_display_name = await self.store.get_profile_displayname(user_id)
             if user_display_name is None:
                 user_display_name = user_id
         except StoreError:
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index 23021a1f1f..a5e2ea9f04 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -57,13 +57,26 @@ class ProfileWorkerStore(SQLBaseStore):
             avatar_url=profile["avatar_url"], display_name=profile["displayname"]
         )
 
-    async def get_profile_displayname(self, user_localpart: str) -> Optional[str]:
-        return await self.db_pool.simple_select_one_onecol(
-            table="profiles",
-            keyvalues={"user_id": user_localpart},
-            retcol="displayname",
-            desc="get_profile_displayname",
-        )
+    async def get_profile_displayname(self, user_id: str) -> Optional[str]:
+        try:
+            return await self.db_pool.simple_select_one_onecol(
+                table="profiles",
+                keyvalues={"full_user_id": user_id},
+                retcol="displayname",
+                desc="get_profile_displayname",
+            )
+        except StoreError as e:
+            if e.code == 404:
+                # Fall back to the `user_id` column.
+                user_localpart = UserID.from_string(user_id).localpart
+                return await self.db_pool.simple_select_one_onecol(
+                    table="profiles",
+                    keyvalues={"user_id": user_localpart},
+                    retcol="displayname",
+                    desc="get_profile_displayname",
+                )
+            else:
+                raise
 
     async def get_profile_avatar_url(self, user_localpart: str) -> Optional[str]:
         return await self.db_pool.simple_select_one_onecol(
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 7c174782da..d8b2797859 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -84,7 +84,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
         self.assertEqual(
             (
                 self.get_success(
-                    self.store.get_profile_displayname(self.frank.localpart)
+                    self.store.get_profile_displayname(self.frank.to_string())
                 )
             ),
             "Frank Jr.",
@@ -100,7 +100,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
         self.assertEqual(
             (
                 self.get_success(
-                    self.store.get_profile_displayname(self.frank.localpart)
+                    self.store.get_profile_displayname(self.frank.to_string())
                 )
             ),
             "Frank",
@@ -114,7 +114,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
         )
 
         self.assertIsNone(
-            self.get_success(self.store.get_profile_displayname(self.frank.localpart))
+            self.get_success(self.store.get_profile_displayname(self.frank.to_string()))
         )
 
     def test_set_my_name_if_disabled(self) -> None:
@@ -128,7 +128,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
         self.assertEqual(
             (
                 self.get_success(
-                    self.store.get_profile_displayname(self.frank.localpart)
+                    self.store.get_profile_displayname(self.frank.to_string())
                 )
             ),
             "Frank",
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 758b4bc38b..23364b8f7e 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -103,7 +103,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase):
         self.assertEqual(email["added_at"], 0)
 
         # Check that the displayname was assigned
-        displayname = self.get_success(self.store.get_profile_displayname("bob"))
+        displayname = self.get_success(self.store.get_profile_displayname("@bob:test"))
         self.assertEqual(displayname, "Bobberino")
 
     def test_can_register_admin_user(self) -> None:
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index a019d06e09..702430f513 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -37,7 +37,7 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase):
             "Frank",
             (
                 self.get_success(
-                    self.store.get_profile_displayname(self.u_frank.localpart)
+                    self.store.get_profile_displayname(self.u_frank.to_string())
                 )
             ),
         )
@@ -48,7 +48,9 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase):
         )
 
         self.assertIsNone(
-            self.get_success(self.store.get_profile_displayname(self.u_frank.localpart))
+            self.get_success(
+                self.store.get_profile_displayname(self.u_frank.to_string())
+            )
         )
 
     def test_avatar_url(self) -> None: