summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
authorErik Johnston <erikj@matrix.org>2023-09-14 12:46:30 +0100
committerGitHub <noreply@github.com>2023-09-14 12:46:30 +0100
commit954921736b88de25c775c519a206449e46b3bf07 (patch)
treed4e428b26c39659d9da44996d9d3db3bcf6f2ddd /synapse/storage
parentRemove a reference cycle in background process (#16314) (diff)
downloadsynapse-954921736b88de25c775c519a206449e46b3bf07.tar.xz
Refactor `get_user_by_id` (#16316)
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/databases/main/client_ips.py11
-rw-r--r--synapse/storage/databases/main/registration.py76
2 files changed, 33 insertions, 54 deletions
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index d8d333e11d..7da47c3dd7 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -764,3 +764,14 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke
                     }
 
         return list(results.values())
+
+    async def get_last_seen_for_user_id(self, user_id: str) -> Optional[int]:
+        """Get the last seen timestamp for a user, if we have it."""
+
+        return await self.db_pool.simple_select_one_onecol(
+            table="user_ips",
+            keyvalues={"user_id": user_id},
+            retcol="MAX(last_seen)",
+            allow_none=True,
+            desc="get_last_seen_for_user_id",
+        )
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index e34156dc55..cc964604e2 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -16,7 +16,7 @@
 import logging
 import random
 import re
-from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, cast
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
 
 import attr
 
@@ -192,8 +192,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             )
 
     @cached()
-    async def get_user_by_id(self, user_id: str) -> Optional[Mapping[str, Any]]:
-        """Deprecated: use get_userinfo_by_id instead"""
+    async def get_user_by_id(self, user_id: str) -> Optional[UserInfo]:
+        """Returns info about the user account, if it exists."""
 
         def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
             # We could technically use simple_select_one here, but it would not perform
@@ -202,16 +202,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             txn.execute(
                 """
                 SELECT
-                    name, password_hash, is_guest, admin, consent_version, consent_ts,
+                    name, is_guest, admin, consent_version, consent_ts,
                     consent_server_notice_sent, appservice_id, creation_ts, user_type,
                     deactivated, COALESCE(shadow_banned, FALSE) AS shadow_banned,
                     COALESCE(approved, TRUE) AS approved,
-                    COALESCE(locked, FALSE) AS locked, last_seen_ts
+                    COALESCE(locked, FALSE) AS locked
                 FROM users
-                LEFT JOIN (
-                    SELECT user_id, MAX(last_seen) AS last_seen_ts
-                    FROM user_ips GROUP BY user_id
-                ) ls ON users.name = ls.user_id
                 WHERE name = ?
                 """,
                 (user_id,),
@@ -228,51 +224,23 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             desc="get_user_by_id",
             func=get_user_by_id_txn,
         )
-
-        if row is not None:
-            # If we're using SQLite our boolean values will be integers. Because we
-            # present some of this data as is to e.g. server admins via REST APIs, we
-            # want to make sure we're returning the right type of data.
-            # Note: when adding a column name to this list, be wary of NULLable columns,
-            # since NULL values will be turned into False.
-            boolean_columns = [
-                "admin",
-                "deactivated",
-                "shadow_banned",
-                "approved",
-                "locked",
-            ]
-            for column in boolean_columns:
-                row[column] = bool(row[column])
-
-        return row
-
-    async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]:
-        """Get a UserInfo object for a user by user ID.
-
-        Note! Currently uses the cache of `get_user_by_id`. Once that deprecated method is removed,
-        this method should be cached.
-
-        Args:
-             user_id: The user to fetch user info for.
-        Returns:
-            `UserInfo` object if user found, otherwise `None`.
-        """
-        user_data = await self.get_user_by_id(user_id)
-        if not user_data:
+        if row is None:
             return None
+
         return UserInfo(
-            appservice_id=user_data["appservice_id"],
-            consent_server_notice_sent=user_data["consent_server_notice_sent"],
-            consent_version=user_data["consent_version"],
-            creation_ts=user_data["creation_ts"],
-            is_admin=bool(user_data["admin"]),
-            is_deactivated=bool(user_data["deactivated"]),
-            is_guest=bool(user_data["is_guest"]),
-            is_shadow_banned=bool(user_data["shadow_banned"]),
-            user_id=UserID.from_string(user_data["name"]),
-            user_type=user_data["user_type"],
-            last_seen_ts=user_data["last_seen_ts"],
+            appservice_id=row["appservice_id"],
+            consent_server_notice_sent=row["consent_server_notice_sent"],
+            consent_version=row["consent_version"],
+            consent_ts=row["consent_ts"],
+            creation_ts=row["creation_ts"],
+            is_admin=bool(row["admin"]),
+            is_deactivated=bool(row["deactivated"]),
+            is_guest=bool(row["is_guest"]),
+            is_shadow_banned=bool(row["shadow_banned"]),
+            user_id=UserID.from_string(row["name"]),
+            user_type=row["user_type"],
+            approved=bool(row["approved"]),
+            locked=bool(row["locked"]),
         )
 
     async def is_trial_user(self, user_id: str) -> bool:
@@ -290,10 +258,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
 
         now = self._clock.time_msec()
         days = self.config.server.mau_appservice_trial_days.get(
-            info["appservice_id"], self.config.server.mau_trial_days
+            info.appservice_id, self.config.server.mau_trial_days
         )
         trial_duration_ms = days * 24 * 60 * 60 * 1000
-        is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms
+        is_trial = (now - info.creation_ts * 1000) < trial_duration_ms
         return is_trial
 
     @cached()