diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 6ad1a0cf7f..14670c2881 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -29,7 +29,7 @@ from synapse.storage.databases.main.stats import StatsStore
from synapse.storage.types import Connection, Cursor
from synapse.storage.util.id_generators import IdGenerator
from synapse.storage.util.sequence import build_sequence_generator
-from synapse.types import UserID
+from synapse.types import UserID, UserInfo
from synapse.util.caches.descriptors import cached
if TYPE_CHECKING:
@@ -146,6 +146,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
@cached()
async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
+ """Deprecated: use get_userinfo_by_id instead"""
return await self.db_pool.simple_select_one(
table="users",
keyvalues={"name": user_id},
@@ -166,6 +167,33 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
desc="get_user_by_id",
)
+ async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]:
+ """Get a UserInfo object for a user by user ID.
+
+ Note! Currently uses the cache of `get_user_by_id`. Once that deprecated method is removed,
+ this method should be cached.
+
+ Args:
+ user_id: The user to fetch user info for.
+ Returns:
+ `UserInfo` object if user found, otherwise `None`.
+ """
+ user_data = await self.get_user_by_id(user_id)
+ if not user_data:
+ return None
+ return UserInfo(
+ appservice_id=user_data["appservice_id"],
+ consent_server_notice_sent=user_data["consent_server_notice_sent"],
+ consent_version=user_data["consent_version"],
+ creation_ts=user_data["creation_ts"],
+ is_admin=bool(user_data["admin"]),
+ is_deactivated=bool(user_data["deactivated"]),
+ is_guest=bool(user_data["is_guest"]),
+ is_shadow_banned=bool(user_data["shadow_banned"]),
+ user_id=UserID.from_string(user_data["name"]),
+ user_type=user_data["user_type"],
+ )
+
async def is_trial_user(self, user_id: str) -> bool:
"""Checks if user is in the "trial" period, i.e. within the first
N days of registration defined by `mau_trial_days` config
|