summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erikj@matrix.org>2023-09-14 12:46:30 +0100
committerGitHub <noreply@github.com>2023-09-14 12:46:30 +0100
commit954921736b88de25c775c519a206449e46b3bf07 (patch)
treed4e428b26c39659d9da44996d9d3db3bcf6f2ddd
parentRemove a reference cycle in background process (#16314) (diff)
downloadsynapse-954921736b88de25c775c519a206449e46b3bf07.tar.xz
Refactor `get_user_by_id` (#16316)
-rw-r--r--changelog.d/16316.misc1
-rw-r--r--synapse/api/auth/internal.py2
-rw-r--r--synapse/api/auth/msc3861_delegated.py2
-rw-r--r--synapse/handlers/account.py2
-rw-r--r--synapse/handlers/admin.py49
-rw-r--r--synapse/handlers/message.py6
-rw-r--r--synapse/module_api/__init__.py4
-rw-r--r--synapse/rest/consent/consent_resource.py2
-rw-r--r--synapse/server_notices/consent_server_notices.py6
-rw-r--r--synapse/storage/databases/main/client_ips.py11
-rw-r--r--synapse/storage/databases/main/registration.py76
-rw-r--r--synapse/types/__init__.py10
-rw-r--r--tests/api/test_auth.py12
-rw-r--r--tests/storage/test_registration.py48
14 files changed, 108 insertions, 123 deletions
diff --git a/changelog.d/16316.misc b/changelog.d/16316.misc
new file mode 100644
index 0000000000..aa0644f278
--- /dev/null
+++ b/changelog.d/16316.misc
@@ -0,0 +1 @@
+Refactor `get_user_by_id`.
diff --git a/synapse/api/auth/internal.py b/synapse/api/auth/internal.py
index 6a5fd44ec0..a75f6f2cc4 100644
--- a/synapse/api/auth/internal.py
+++ b/synapse/api/auth/internal.py
@@ -268,7 +268,7 @@ class InternalAuth(BaseAuth):
             stored_user = await self.store.get_user_by_id(user_id)
             if not stored_user:
                 raise InvalidClientTokenError("Unknown user_id %s" % user_id)
-            if not stored_user["is_guest"]:
+            if not stored_user.is_guest:
                 raise InvalidClientTokenError(
                     "Guest access token used for regular user"
                 )
diff --git a/synapse/api/auth/msc3861_delegated.py b/synapse/api/auth/msc3861_delegated.py
index ef5d3f9b81..31bb035cc8 100644
--- a/synapse/api/auth/msc3861_delegated.py
+++ b/synapse/api/auth/msc3861_delegated.py
@@ -300,7 +300,7 @@ class MSC3861DelegatedAuth(BaseAuth):
             user_id = UserID(username, self._hostname)
 
             # First try to find a user from the username claim
-            user_info = await self.store.get_userinfo_by_id(user_id=user_id.to_string())
+            user_info = await self.store.get_user_by_id(user_id=user_id.to_string())
             if user_info is None:
                 # If the user does not exist, we should create it on the fly
                 # TODO: we could use SCIM to provision users ahead of time and listen
diff --git a/synapse/handlers/account.py b/synapse/handlers/account.py
index c05a14304c..fa043cca86 100644
--- a/synapse/handlers/account.py
+++ b/synapse/handlers/account.py
@@ -102,7 +102,7 @@ class AccountHandler:
         """
         status = {"exists": False}
 
-        userinfo = await self._main_store.get_userinfo_by_id(user_id.to_string())
+        userinfo = await self._main_store.get_user_by_id(user_id.to_string())
 
         if userinfo is not None:
             status = {
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index 2f0e5f3b0a..7092ff3449 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set
 
 from synapse.api.constants import Direction, Membership
 from synapse.events import EventBase
-from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID
+from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID, UserInfo
 from synapse.visibility import filter_events_for_client
 
 if TYPE_CHECKING:
@@ -57,38 +57,30 @@ class AdminHandler:
 
     async def get_user(self, user: UserID) -> Optional[JsonDict]:
         """Function to get user details"""
-        user_info_dict = await self._store.get_user_by_id(user.to_string())
-        if user_info_dict is None:
+        user_info: Optional[UserInfo] = await self._store.get_user_by_id(
+            user.to_string()
+        )
+        if user_info is None:
             return None
 
-        # Restrict returned information to a known set of fields. This prevents additional
-        # fields added to get_user_by_id from modifying Synapse's external API surface.
-        user_info_to_return = {
-            "name",
-            "admin",
-            "deactivated",
-            "locked",
-            "shadow_banned",
-            "creation_ts",
-            "appservice_id",
-            "consent_server_notice_sent",
-            "consent_version",
-            "consent_ts",
-            "user_type",
-            "is_guest",
-            "last_seen_ts",
+        user_info_dict = {
+            "name": user.to_string(),
+            "admin": user_info.is_admin,
+            "deactivated": user_info.is_deactivated,
+            "locked": user_info.locked,
+            "shadow_banned": user_info.is_shadow_banned,
+            "creation_ts": user_info.creation_ts,
+            "appservice_id": user_info.appservice_id,
+            "consent_server_notice_sent": user_info.consent_server_notice_sent,
+            "consent_version": user_info.consent_version,
+            "consent_ts": user_info.consent_ts,
+            "user_type": user_info.user_type,
+            "is_guest": user_info.is_guest,
         }
 
         if self._msc3866_enabled:
             # Only include the approved flag if support for MSC3866 is enabled.
-            user_info_to_return.add("approved")
-
-        # Restrict returned keys to a known set.
-        user_info_dict = {
-            key: value
-            for key, value in user_info_dict.items()
-            if key in user_info_to_return
-        }
+            user_info_dict["approved"] = user_info.approved
 
         # Add additional user metadata
         profile = await self._store.get_profileinfo(user)
@@ -105,6 +97,9 @@ class AdminHandler:
         user_info_dict["external_ids"] = external_ids
         user_info_dict["erased"] = await self._store.is_user_erased(user.to_string())
 
+        last_seen_ts = await self._store.get_last_seen_for_user_id(user.to_string())
+        user_info_dict["last_seen_ts"] = last_seen_ts
+
         return user_info_dict
 
     async def export_user_data(self, user_id: str, writer: "ExfiltrationWriter") -> Any:
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index d6be18cdef..c036578a3d 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -828,13 +828,13 @@ class EventCreationHandler:
 
         u = await self.store.get_user_by_id(user_id)
         assert u is not None
-        if u["user_type"] in (UserTypes.SUPPORT, UserTypes.BOT):
+        if u.user_type in (UserTypes.SUPPORT, UserTypes.BOT):
             # support and bot users are not required to consent
             return
-        if u["appservice_id"] is not None:
+        if u.appservice_id is not None:
             # users registered by an appservice are exempt
             return
-        if u["consent_version"] == self.config.consent.user_consent_version:
+        if u.consent_version == self.config.consent.user_consent_version:
             return
 
         consent_uri = self._consent_uri_builder.build_user_consent_uri(user.localpart)
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index d6efe10a28..7ec202be23 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -572,7 +572,7 @@ class ModuleApi:
         Returns:
             UserInfo object if a user was found, otherwise None
         """
-        return await self._store.get_userinfo_by_id(user_id)
+        return await self._store.get_user_by_id(user_id)
 
     async def get_user_by_req(
         self,
@@ -1878,7 +1878,7 @@ class AccountDataManager:
             raise TypeError(f"new_data must be a dict; got {type(new_data).__name__}")
 
         # Ensure the user exists, so we don't just write to users that aren't there.
-        if await self._store.get_userinfo_by_id(user_id) is None:
+        if await self._store.get_user_by_id(user_id) is None:
             raise ValueError(f"User {user_id} does not exist on this server.")
 
         await self._handler.add_account_data_for_user(user_id, data_type, new_data)
diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py
index 25f9ea285b..88d3ec1baf 100644
--- a/synapse/rest/consent/consent_resource.py
+++ b/synapse/rest/consent/consent_resource.py
@@ -129,7 +129,7 @@ class ConsentResource(DirectServeHtmlResource):
             if u is None:
                 raise NotFoundError("Unknown user")
 
-            has_consented = u["consent_version"] == version
+            has_consented = u.consent_version == version
             userhmac = userhmac_bytes.decode("ascii")
 
         try:
diff --git a/synapse/server_notices/consent_server_notices.py b/synapse/server_notices/consent_server_notices.py
index 94025ba41f..a879b6505e 100644
--- a/synapse/server_notices/consent_server_notices.py
+++ b/synapse/server_notices/consent_server_notices.py
@@ -79,15 +79,15 @@ class ConsentServerNotices:
             if u is None:
                 return
 
-            if u["is_guest"] and not self._send_to_guests:
+            if u.is_guest and not self._send_to_guests:
                 # don't send to guests
                 return
 
-            if u["consent_version"] == self._current_consent_version:
+            if u.consent_version == self._current_consent_version:
                 # user has already consented
                 return
 
-            if u["consent_server_notice_sent"] == self._current_consent_version:
+            if u.consent_server_notice_sent == self._current_consent_version:
                 # we've already sent a notice to the user
                 return
 
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index d8d333e11d..7da47c3dd7 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -764,3 +764,14 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke
                     }
 
         return list(results.values())
+
+    async def get_last_seen_for_user_id(self, user_id: str) -> Optional[int]:
+        """Get the last seen timestamp for a user, if we have it."""
+
+        return await self.db_pool.simple_select_one_onecol(
+            table="user_ips",
+            keyvalues={"user_id": user_id},
+            retcol="MAX(last_seen)",
+            allow_none=True,
+            desc="get_last_seen_for_user_id",
+        )
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index e34156dc55..cc964604e2 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -16,7 +16,7 @@
 import logging
 import random
 import re
-from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, cast
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
 
 import attr
 
@@ -192,8 +192,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             )
 
     @cached()
-    async def get_user_by_id(self, user_id: str) -> Optional[Mapping[str, Any]]:
-        """Deprecated: use get_userinfo_by_id instead"""
+    async def get_user_by_id(self, user_id: str) -> Optional[UserInfo]:
+        """Returns info about the user account, if it exists."""
 
         def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
             # We could technically use simple_select_one here, but it would not perform
@@ -202,16 +202,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             txn.execute(
                 """
                 SELECT
-                    name, password_hash, is_guest, admin, consent_version, consent_ts,
+                    name, is_guest, admin, consent_version, consent_ts,
                     consent_server_notice_sent, appservice_id, creation_ts, user_type,
                     deactivated, COALESCE(shadow_banned, FALSE) AS shadow_banned,
                     COALESCE(approved, TRUE) AS approved,
-                    COALESCE(locked, FALSE) AS locked, last_seen_ts
+                    COALESCE(locked, FALSE) AS locked
                 FROM users
-                LEFT JOIN (
-                    SELECT user_id, MAX(last_seen) AS last_seen_ts
-                    FROM user_ips GROUP BY user_id
-                ) ls ON users.name = ls.user_id
                 WHERE name = ?
                 """,
                 (user_id,),
@@ -228,51 +224,23 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             desc="get_user_by_id",
             func=get_user_by_id_txn,
         )
-
-        if row is not None:
-            # If we're using SQLite our boolean values will be integers. Because we
-            # present some of this data as is to e.g. server admins via REST APIs, we
-            # want to make sure we're returning the right type of data.
-            # Note: when adding a column name to this list, be wary of NULLable columns,
-            # since NULL values will be turned into False.
-            boolean_columns = [
-                "admin",
-                "deactivated",
-                "shadow_banned",
-                "approved",
-                "locked",
-            ]
-            for column in boolean_columns:
-                row[column] = bool(row[column])
-
-        return row
-
-    async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]:
-        """Get a UserInfo object for a user by user ID.
-
-        Note! Currently uses the cache of `get_user_by_id`. Once that deprecated method is removed,
-        this method should be cached.
-
-        Args:
-             user_id: The user to fetch user info for.
-        Returns:
-            `UserInfo` object if user found, otherwise `None`.
-        """
-        user_data = await self.get_user_by_id(user_id)
-        if not user_data:
+        if row is None:
             return None
+
         return UserInfo(
-            appservice_id=user_data["appservice_id"],
-            consent_server_notice_sent=user_data["consent_server_notice_sent"],
-            consent_version=user_data["consent_version"],
-            creation_ts=user_data["creation_ts"],
-            is_admin=bool(user_data["admin"]),
-            is_deactivated=bool(user_data["deactivated"]),
-            is_guest=bool(user_data["is_guest"]),
-            is_shadow_banned=bool(user_data["shadow_banned"]),
-            user_id=UserID.from_string(user_data["name"]),
-            user_type=user_data["user_type"],
-            last_seen_ts=user_data["last_seen_ts"],
+            appservice_id=row["appservice_id"],
+            consent_server_notice_sent=row["consent_server_notice_sent"],
+            consent_version=row["consent_version"],
+            consent_ts=row["consent_ts"],
+            creation_ts=row["creation_ts"],
+            is_admin=bool(row["admin"]),
+            is_deactivated=bool(row["deactivated"]),
+            is_guest=bool(row["is_guest"]),
+            is_shadow_banned=bool(row["shadow_banned"]),
+            user_id=UserID.from_string(row["name"]),
+            user_type=row["user_type"],
+            approved=bool(row["approved"]),
+            locked=bool(row["locked"]),
         )
 
     async def is_trial_user(self, user_id: str) -> bool:
@@ -290,10 +258,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
 
         now = self._clock.time_msec()
         days = self.config.server.mau_appservice_trial_days.get(
-            info["appservice_id"], self.config.server.mau_trial_days
+            info.appservice_id, self.config.server.mau_trial_days
         )
         trial_duration_ms = days * 24 * 60 * 60 * 1000
-        is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms
+        is_trial = (now - info.creation_ts * 1000) < trial_duration_ms
         return is_trial
 
     @cached()
diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py
index 488714f60c..76b0e3e694 100644
--- a/synapse/types/__init__.py
+++ b/synapse/types/__init__.py
@@ -933,33 +933,37 @@ def get_verify_key_from_cross_signing_key(
 
 @attr.s(auto_attribs=True, frozen=True, slots=True)
 class UserInfo:
-    """Holds information about a user. Result of get_userinfo_by_id.
+    """Holds information about a user. Result of get_user_by_id.
 
     Attributes:
         user_id:  ID of the user.
         appservice_id:  Application service ID that created this user.
         consent_server_notice_sent:  Version of policy documents the user has been sent.
         consent_version:  Version of policy documents the user has consented to.
+        consent_ts: Time the user consented
         creation_ts:  Creation timestamp of the user.
         is_admin:  True if the user is an admin.
         is_deactivated:  True if the user has been deactivated.
         is_guest:  True if the user is a guest user.
         is_shadow_banned:  True if the user has been shadow-banned.
         user_type:  User type (None for normal user, 'support' and 'bot' other options).
-        last_seen_ts:  Last activity timestamp of the user.
+        approved: If the user has been "approved" to register on the server.
+        locked: Whether the user's account has been locked
     """
 
     user_id: UserID
     appservice_id: Optional[int]
     consent_server_notice_sent: Optional[str]
     consent_version: Optional[str]
+    consent_ts: Optional[int]
     user_type: Optional[str]
     creation_ts: int
     is_admin: bool
     is_deactivated: bool
     is_guest: bool
     is_shadow_banned: bool
-    last_seen_ts: Optional[int]
+    approved: bool
+    locked: bool
 
 
 class UserProfile(TypedDict):
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index dcd01d5688..e00d7215df 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -188,8 +188,11 @@ class AuthTestCase(unittest.HomeserverTestCase):
         )
         app_service.is_interested_in_user = Mock(return_value=True)
         self.store.get_app_service_by_token = Mock(return_value=app_service)
-        # This just needs to return a truth-y value.
-        self.store.get_user_by_id = AsyncMock(return_value={"is_guest": False})
+
+        class FakeUserInfo:
+            is_guest = False
+
+        self.store.get_user_by_id = AsyncMock(return_value=FakeUserInfo())
         self.store.get_user_by_access_token = AsyncMock(return_value=None)
 
         request = Mock(args={})
@@ -341,7 +344,10 @@ class AuthTestCase(unittest.HomeserverTestCase):
         )
 
     def test_get_guest_user_from_macaroon(self) -> None:
-        self.store.get_user_by_id = AsyncMock(return_value={"is_guest": True})
+        class FakeUserInfo:
+            is_guest = True
+
+        self.store.get_user_by_id = AsyncMock(return_value=FakeUserInfo())
         self.store.get_user_by_access_token = AsyncMock(return_value=None)
 
         user_id = "@baldrick:matrix.org"
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 95c9792d54..0cca34d355 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -16,7 +16,7 @@ from twisted.test.proto_helpers import MemoryReactor
 from synapse.api.constants import UserTypes
 from synapse.api.errors import ThreepidValidationError
 from synapse.server import HomeServer
-from synapse.types import JsonDict, UserID
+from synapse.types import JsonDict, UserID, UserInfo
 from synapse.util import Clock
 
 from tests.unittest import HomeserverTestCase, override_config
@@ -35,24 +35,22 @@ class RegistrationStoreTestCase(HomeserverTestCase):
         self.get_success(self.store.register_user(self.user_id, self.pwhash))
 
         self.assertEqual(
-            {
+            UserInfo(
                 # TODO(paul): Surely this field should be 'user_id', not 'name'
-                "name": self.user_id,
-                "password_hash": self.pwhash,
-                "admin": 0,
-                "is_guest": 0,
-                "consent_version": None,
-                "consent_ts": None,
-                "consent_server_notice_sent": None,
-                "appservice_id": None,
-                "creation_ts": 0,
-                "user_type": None,
-                "deactivated": 0,
-                "locked": 0,
-                "shadow_banned": 0,
-                "approved": 1,
-                "last_seen_ts": None,
-            },
+                user_id=UserID.from_string(self.user_id),
+                is_admin=False,
+                is_guest=False,
+                consent_server_notice_sent=None,
+                consent_ts=None,
+                consent_version=None,
+                appservice_id=None,
+                creation_ts=0,
+                user_type=None,
+                is_deactivated=False,
+                locked=False,
+                is_shadow_banned=False,
+                approved=True,
+            ),
             (self.get_success(self.store.get_user_by_id(self.user_id))),
         )
 
@@ -65,9 +63,11 @@ class RegistrationStoreTestCase(HomeserverTestCase):
 
         user = self.get_success(self.store.get_user_by_id(self.user_id))
         assert user
-        self.assertEqual(user["consent_version"], "1")
-        self.assertGreater(user["consent_ts"], before_consent)
-        self.assertLess(user["consent_ts"], self.clock.time_msec())
+        self.assertEqual(user.consent_version, "1")
+        self.assertIsNotNone(user.consent_ts)
+        assert user.consent_ts is not None
+        self.assertGreater(user.consent_ts, before_consent)
+        self.assertLess(user.consent_ts, self.clock.time_msec())
 
     def test_add_tokens(self) -> None:
         self.get_success(self.store.register_user(self.user_id, self.pwhash))
@@ -215,7 +215,7 @@ class ApprovalRequiredRegistrationTestCase(HomeserverTestCase):
 
         user = self.get_success(self.store.get_user_by_id(self.user_id))
         assert user is not None
-        self.assertTrue(user["approved"])
+        self.assertTrue(user.approved)
 
         approved = self.get_success(self.store.is_user_approved(self.user_id))
         self.assertTrue(approved)
@@ -228,7 +228,7 @@ class ApprovalRequiredRegistrationTestCase(HomeserverTestCase):
 
         user = self.get_success(self.store.get_user_by_id(self.user_id))
         assert user is not None
-        self.assertFalse(user["approved"])
+        self.assertFalse(user.approved)
 
         approved = self.get_success(self.store.is_user_approved(self.user_id))
         self.assertFalse(approved)
@@ -248,7 +248,7 @@ class ApprovalRequiredRegistrationTestCase(HomeserverTestCase):
         user = self.get_success(self.store.get_user_by_id(self.user_id))
         self.assertIsNotNone(user)
         assert user is not None
-        self.assertEqual(user["approved"], 1)
+        self.assertEqual(user.approved, 1)
 
         approved = self.get_success(self.store.is_user_approved(self.user_id))
         self.assertTrue(approved)