summary refs log tree commit diff
path: root/synapse/storage/databases/main/registration.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/registration.py')
-rw-r--r--synapse/storage/databases/main/registration.py60
1 files changed, 38 insertions, 22 deletions
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 9e8643ae4d..b0ef7be155 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -855,13 +855,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
         Returns:
             Tuples of (auth_provider, external_id)
         """
-        res = await self.db_pool.simple_select_list(
-            table="user_external_ids",
-            keyvalues={"user_id": mxid},
-            retcols=("auth_provider", "external_id"),
-            desc="get_external_ids_by_user",
+        return cast(
+            List[Tuple[str, str]],
+            await self.db_pool.simple_select_list(
+                table="user_external_ids",
+                keyvalues={"user_id": mxid},
+                retcols=("auth_provider", "external_id"),
+                desc="get_external_ids_by_user",
+            ),
         )
-        return [(r["auth_provider"], r["external_id"]) for r in res]
 
     async def count_all_users(self) -> int:
         """Counts all users registered on the homeserver."""
@@ -997,13 +999,24 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
         )
 
     async def user_get_threepids(self, user_id: str) -> List[ThreepidResult]:
-        results = await self.db_pool.simple_select_list(
-            "user_threepids",
-            keyvalues={"user_id": user_id},
-            retcols=["medium", "address", "validated_at", "added_at"],
-            desc="user_get_threepids",
+        results = cast(
+            List[Tuple[str, str, int, int]],
+            await self.db_pool.simple_select_list(
+                "user_threepids",
+                keyvalues={"user_id": user_id},
+                retcols=["medium", "address", "validated_at", "added_at"],
+                desc="user_get_threepids",
+            ),
         )
-        return [ThreepidResult(**r) for r in results]
+        return [
+            ThreepidResult(
+                medium=r[0],
+                address=r[1],
+                validated_at=r[2],
+                added_at=r[3],
+            )
+            for r in results
+        ]
 
     async def user_delete_threepid(
         self, user_id: str, medium: str, address: str
@@ -1042,7 +1055,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             desc="add_user_bound_threepid",
         )
 
-    async def user_get_bound_threepids(self, user_id: str) -> List[Dict[str, Any]]:
+    async def user_get_bound_threepids(self, user_id: str) -> List[Tuple[str, str]]:
         """Get the threepids that a user has bound to an identity server through the homeserver
         The homeserver remembers where binds to an identity server occurred. Using this
         method can retrieve those threepids.
@@ -1051,15 +1064,18 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             user_id: The ID of the user to retrieve threepids for
 
         Returns:
-            List of dictionaries containing the following keys:
-                medium (str): The medium of the threepid (e.g "email")
-                address (str): The address of the threepid (e.g "bob@example.com")
-        """
-        return await self.db_pool.simple_select_list(
-            table="user_threepid_id_server",
-            keyvalues={"user_id": user_id},
-            retcols=["medium", "address"],
-            desc="user_get_bound_threepids",
+            List of tuples of two strings:
+                medium: The medium of the threepid (e.g "email")
+                address: The address of the threepid (e.g "bob@example.com")
+        """
+        return cast(
+            List[Tuple[str, str]],
+            await self.db_pool.simple_select_list(
+                table="user_threepid_id_server",
+                keyvalues={"user_id": user_id},
+                retcols=["medium", "address"],
+                desc="user_get_bound_threepids",
+            ),
         )
 
     async def remove_user_bound_threepid(