diff --git a/changelog.d/16468.misc b/changelog.d/16468.misc
new file mode 100644
index 0000000000..93ceaeafc9
--- /dev/null
+++ b/changelog.d/16468.misc
@@ -0,0 +1 @@
+Improve type hints.
diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index f1a7a05df6..6c2a49a3b9 100644
--- a/synapse/handlers/account_validity.py
+++ b/synapse/handlers/account_validity.py
@@ -212,8 +212,8 @@ class AccountValidityHandler:
addresses = []
for threepid in threepids:
- if threepid["medium"] == "email":
- addresses.append(threepid["address"])
+ if threepid.medium == "email":
+ addresses.append(threepid.address)
return addresses
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index 97fd1fd427..2c2baeac67 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -16,6 +16,8 @@ import abc
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Sequence, Set
+import attr
+
from synapse.api.constants import Direction, Membership
from synapse.events import EventBase
from synapse.types import JsonMapping, RoomStreamToken, StateMap, UserID, UserInfo
@@ -93,7 +95,7 @@ class AdminHandler:
]
user_info_dict["displayname"] = profile.display_name
user_info_dict["avatar_url"] = profile.avatar_url
- user_info_dict["threepids"] = threepids
+ user_info_dict["threepids"] = [attr.asdict(t) for t in threepids]
user_info_dict["external_ids"] = external_ids
user_info_dict["erased"] = await self._store.is_user_erased(user.to_string())
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index 67adeae6a7..6a8f8f2fd1 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -117,9 +117,9 @@ class DeactivateAccountHandler:
# Remove any local threepid associations for this account.
local_threepids = await self.store.user_get_threepids(user_id)
- for threepid in local_threepids:
+ for local_threepid in local_threepids:
await self._auth_handler.delete_local_threepid(
- user_id, threepid["medium"], threepid["address"]
+ user_id, local_threepid.medium, local_threepid.address
)
# delete any devices belonging to the user, which will also
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 65e2aca456..0786d20635 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -678,7 +678,7 @@ class ModuleApi:
"msisdn" for phone numbers, and an "address" key which value is the
threepid's address.
"""
- return await self._store.user_get_threepids(user_id)
+ return [attr.asdict(t) for t in await self._store.user_get_threepids(user_id)]
def check_user_exists(self, user_id: str) -> "defer.Deferred[Optional[str]]":
"""Check if user exists.
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index cd995e8dbb..7fe16130e7 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -329,9 +329,8 @@ class UserRestServletV2(RestServlet):
if threepids is not None:
# get changed threepids (added and removed)
- # convert List[Dict[str, Any]] into Set[Tuple[str, str]]
cur_threepids = {
- (threepid["medium"], threepid["address"])
+ (threepid.medium, threepid.address)
for threepid in await self.store.user_get_threepids(user_id)
}
add_threepids = new_threepids - cur_threepids
diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py
index e74a87af4d..641390cb30 100644
--- a/synapse/rest/client/account.py
+++ b/synapse/rest/client/account.py
@@ -24,6 +24,8 @@ if TYPE_CHECKING or HAS_PYDANTIC_V2:
from pydantic.v1 import StrictBool, StrictStr, constr
else:
from pydantic import StrictBool, StrictStr, constr
+
+import attr
from typing_extensions import Literal
from twisted.web.server import Request
@@ -595,7 +597,7 @@ class ThreepidRestServlet(RestServlet):
threepids = await self.datastore.user_get_threepids(requester.user.to_string())
- return 200, {"threepids": threepids}
+ return 200, {"threepids": [attr.asdict(t) for t in threepids]}
# NOTE(dmr): I have chosen not to use Pydantic to parse this request's body, because
# the endpoint is deprecated. (If you really want to, you could do this by reusing
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 64a2c31a5d..9e8643ae4d 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -143,6 +143,14 @@ class LoginTokenLookupResult:
"""The session ID advertised by the SSO Identity Provider."""
+@attr.s(frozen=True, slots=True, auto_attribs=True)
+class ThreepidResult:
+ medium: str
+ address: str
+ validated_at: int
+ added_at: int
+
+
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
def __init__(
self,
@@ -988,13 +996,14 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
{"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
)
- async def user_get_threepids(self, user_id: str) -> List[Dict[str, Any]]:
- return await self.db_pool.simple_select_list(
+ async def user_get_threepids(self, user_id: str) -> List[ThreepidResult]:
+ results = await self.db_pool.simple_select_list(
"user_threepids",
- {"user_id": user_id},
- ["medium", "address", "validated_at", "added_at"],
- "user_get_threepids",
+ keyvalues={"user_id": user_id},
+ retcols=["medium", "address", "validated_at", "added_at"],
+ desc="user_get_threepids",
)
+ return [ThreepidResult(**r) for r in results]
async def user_delete_threepid(
self, user_id: str, medium: str, address: str
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 172fc3a736..1dabf52156 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -94,12 +94,12 @@ class ModuleApiTestCase(BaseModuleApiTestCase):
self.assertEqual(len(emails), 1)
email = emails[0]
- self.assertEqual(email["medium"], "email")
- self.assertEqual(email["address"], "bob@bobinator.bob")
+ self.assertEqual(email.medium, "email")
+ self.assertEqual(email.address, "bob@bobinator.bob")
# Should these be 0?
- self.assertEqual(email["validated_at"], 0)
- self.assertEqual(email["added_at"], 0)
+ self.assertEqual(email.validated_at, 0)
+ self.assertEqual(email.added_at, 0)
# Check that the displayname was assigned
displayname = self.get_success(
|