summary refs log tree commit diff
diff options
context:
space:
mode:
authorJason Robinson <jasonr@element.io>2021-08-04 13:40:25 +0300
committerGitHub <noreply@github.com>2021-08-04 10:40:25 +0000
commitc2000ab35b76288a625f598d2382d4e3f29f65f6 (patch)
treeb0f5bdd7e9fe5fb2002d0cf166ce73b2c16d336f
parentAdd warnings to ip_range_blacklist usage with proxies (#10129) (diff)
downloadsynapse-c2000ab35b76288a625f598d2382d4e3f29f65f6.tar.xz
Add `get_userinfo_by_id` method to `ModuleApi` (#9581)
Makes it easier to fetch user details in for example spam checker modules, without needing to use api._store or figure out database interactions.

Signed-off-by: Jason Robinson <jasonr@matrix.org>
-rw-r--r--changelog.d/9581.feature1
-rw-r--r--synapse/module_api/__init__.py12
-rw-r--r--synapse/storage/databases/main/registration.py30
-rw-r--r--synapse/types.py29
-rw-r--r--tests/module_api/test_api.py10
5 files changed, 80 insertions, 2 deletions
diff --git a/changelog.d/9581.feature b/changelog.d/9581.feature
new file mode 100644
index 0000000000..fa1949cd4b
--- /dev/null
+++ b/changelog.d/9581.feature
@@ -0,0 +1 @@
+Add `get_userinfo_by_id` method to ModuleApi.
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 473812b8e2..1cc13fc97b 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -45,7 +45,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.storage.databases.main.roommember import ProfileInfo
 from synapse.storage.state import StateFilter
-from synapse.types import JsonDict, Requester, UserID, create_requester
+from synapse.types import JsonDict, Requester, UserID, UserInfo, create_requester
 from synapse.util import Clock
 from synapse.util.caches.descriptors import cached
 
@@ -174,6 +174,16 @@ class ModuleApi:
         """The application name configured in the homeserver's configuration."""
         return self._hs.config.email.email_app_name
 
+    async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]:
+        """Get user info by user_id
+
+        Args:
+            user_id: Fully qualified user id.
+        Returns:
+            UserInfo object if a user was found, otherwise None
+        """
+        return await self._store.get_userinfo_by_id(user_id)
+
     async def get_user_by_req(
         self,
         req: SynapseRequest,
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 6ad1a0cf7f..14670c2881 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -29,7 +29,7 @@ from synapse.storage.databases.main.stats import StatsStore
 from synapse.storage.types import Connection, Cursor
 from synapse.storage.util.id_generators import IdGenerator
 from synapse.storage.util.sequence import build_sequence_generator
-from synapse.types import UserID
+from synapse.types import UserID, UserInfo
 from synapse.util.caches.descriptors import cached
 
 if TYPE_CHECKING:
@@ -146,6 +146,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
 
     @cached()
     async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
+        """Deprecated: use get_userinfo_by_id instead"""
         return await self.db_pool.simple_select_one(
             table="users",
             keyvalues={"name": user_id},
@@ -166,6 +167,33 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             desc="get_user_by_id",
         )
 
+    async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]:
+        """Get a UserInfo object for a user by user ID.
+
+        Note! Currently uses the cache of `get_user_by_id`. Once that deprecated method is removed,
+        this method should be cached.
+
+        Args:
+             user_id: The user to fetch user info for.
+        Returns:
+            `UserInfo` object if user found, otherwise `None`.
+        """
+        user_data = await self.get_user_by_id(user_id)
+        if not user_data:
+            return None
+        return UserInfo(
+            appservice_id=user_data["appservice_id"],
+            consent_server_notice_sent=user_data["consent_server_notice_sent"],
+            consent_version=user_data["consent_version"],
+            creation_ts=user_data["creation_ts"],
+            is_admin=bool(user_data["admin"]),
+            is_deactivated=bool(user_data["deactivated"]),
+            is_guest=bool(user_data["is_guest"]),
+            is_shadow_banned=bool(user_data["shadow_banned"]),
+            user_id=UserID.from_string(user_data["name"]),
+            user_type=user_data["user_type"],
+        )
+
     async def is_trial_user(self, user_id: str) -> bool:
         """Checks if user is in the "trial" period, i.e. within the first
         N days of registration defined by `mau_trial_days` config
diff --git a/synapse/types.py b/synapse/types.py
index 429bb013d2..80fa903c4b 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -751,3 +751,32 @@ def get_verify_key_from_cross_signing_key(key_info):
     # and return that one key
     for key_id, key_data in keys.items():
         return (key_id, decode_verify_key_bytes(key_id, decode_base64(key_data)))
+
+
+@attr.s(auto_attribs=True, frozen=True, slots=True)
+class UserInfo:
+    """Holds information about a user. Result of get_userinfo_by_id.
+
+    Attributes:
+        user_id:  ID of the user.
+        appservice_id:  Application service ID that created this user.
+        consent_server_notice_sent:  Version of policy documents the user has been sent.
+        consent_version:  Version of policy documents the user has consented to.
+        creation_ts:  Creation timestamp of the user.
+        is_admin:  True if the user is an admin.
+        is_deactivated:  True if the user has been deactivated.
+        is_guest:  True if the user is a guest user.
+        is_shadow_banned:  True if the user has been shadow-banned.
+        user_type:  User type (None for normal user, 'support' and 'bot' other options).
+    """
+
+    user_id: UserID
+    appservice_id: Optional[int]
+    consent_server_notice_sent: Optional[str]
+    consent_version: Optional[str]
+    user_type: Optional[str]
+    creation_ts: int
+    is_admin: bool
+    is_deactivated: bool
+    is_guest: bool
+    is_shadow_banned: bool
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 81d9e2f484..0b817cc701 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -79,6 +79,16 @@ class ModuleApiTestCase(HomeserverTestCase):
         displayname = self.get_success(self.store.get_profile_displayname("bob"))
         self.assertEqual(displayname, "Bobberino")
 
+    def test_get_userinfo_by_id(self):
+        user_id = self.register_user("alice", "1234")
+        found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id))
+        self.assertEqual(found_user.user_id.to_string(), user_id)
+        self.assertIdentical(found_user.is_admin, False)
+
+    def test_get_userinfo_by_id__no_user_found(self):
+        found_user = self.get_success(self.module_api.get_userinfo_by_id("@alice:test"))
+        self.assertIsNone(found_user)
+
     def test_sending_events_into_room(self):
         """Tests that a module can send events into a room"""
         # Mock out create_and_send_nonmember_event to check whether events are being sent