summary refs log tree commit diff
path: root/synapse/storage/databases/main/registration.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/registration.py')
-rw-r--r--synapse/storage/databases/main/registration.py133
1 files changed, 77 insertions, 56 deletions
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index cc964604e2..64a2c31a5d 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -195,7 +195,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
     async def get_user_by_id(self, user_id: str) -> Optional[UserInfo]:
         """Returns info about the user account, if it exists."""
 
-        def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
+        def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[UserInfo]:
             # We could technically use simple_select_one here, but it would not perform
             # the COALESCEs (unless hacked into the column names), which could yield
             # confusing results.
@@ -213,35 +213,46 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
                 (user_id,),
             )
 
-            rows = self.db_pool.cursor_to_dict(txn)
-
-            if len(rows) == 0:
+            row = txn.fetchone()
+            if not row:
                 return None
 
-            return rows[0]
+            (
+                name,
+                is_guest,
+                admin,
+                consent_version,
+                consent_ts,
+                consent_server_notice_sent,
+                appservice_id,
+                creation_ts,
+                user_type,
+                deactivated,
+                shadow_banned,
+                approved,
+                locked,
+            ) = row
+
+            return UserInfo(
+                appservice_id=appservice_id,
+                consent_server_notice_sent=consent_server_notice_sent,
+                consent_version=consent_version,
+                consent_ts=consent_ts,
+                creation_ts=creation_ts,
+                is_admin=bool(admin),
+                is_deactivated=bool(deactivated),
+                is_guest=bool(is_guest),
+                is_shadow_banned=bool(shadow_banned),
+                user_id=UserID.from_string(name),
+                user_type=user_type,
+                approved=bool(approved),
+                locked=bool(locked),
+            )
 
-        row = await self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             desc="get_user_by_id",
             func=get_user_by_id_txn,
         )
-        if row is None:
-            return None
-
-        return UserInfo(
-            appservice_id=row["appservice_id"],
-            consent_server_notice_sent=row["consent_server_notice_sent"],
-            consent_version=row["consent_version"],
-            consent_ts=row["consent_ts"],
-            creation_ts=row["creation_ts"],
-            is_admin=bool(row["admin"]),
-            is_deactivated=bool(row["deactivated"]),
-            is_guest=bool(row["is_guest"]),
-            is_shadow_banned=bool(row["shadow_banned"]),
-            user_id=UserID.from_string(row["name"]),
-            user_type=row["user_type"],
-            approved=bool(row["approved"]),
-            locked=bool(row["locked"]),
-        )
 
     async def is_trial_user(self, user_id: str) -> bool:
         """Checks if user is in the "trial" period, i.e. within the first
@@ -579,16 +590,31 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
         """
 
         txn.execute(sql, (token,))
-        rows = self.db_pool.cursor_to_dict(txn)
-
-        if rows:
-            row = rows[0]
-
-            # This field is nullable, ensure it comes out as a boolean
-            if row["token_used"] is None:
-                row["token_used"] = False
+        row = txn.fetchone()
 
-            return TokenLookupResult(**row)
+        if row:
+            (
+                user_id,
+                is_guest,
+                shadow_banned,
+                token_id,
+                device_id,
+                valid_until_ms,
+                token_owner,
+                token_used,
+            ) = row
+
+            return TokenLookupResult(
+                user_id=user_id,
+                is_guest=is_guest,
+                shadow_banned=shadow_banned,
+                token_id=token_id,
+                device_id=device_id,
+                valid_until_ms=valid_until_ms,
+                token_owner=token_owner,
+                # This field is nullable, ensure it comes out as a boolean
+                token_used=bool(token_used),
+            )
 
         return None
 
@@ -833,11 +859,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
         """Counts all users registered on the homeserver."""
 
         def _count_users(txn: LoggingTransaction) -> int:
-            txn.execute("SELECT COUNT(*) AS users FROM users")
-            rows = self.db_pool.cursor_to_dict(txn)
-            if rows:
-                return rows[0]["users"]
-            return 0
+            txn.execute("SELECT COUNT(*) FROM users")
+            row = txn.fetchone()
+            assert row is not None
+            return row[0]
 
         return await self.db_pool.runInteraction("count_users", _count_users)
 
@@ -891,11 +916,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
         """Counts all users without a special user_type registered on the homeserver."""
 
         def _count_users(txn: LoggingTransaction) -> int:
-            txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null")
-            rows = self.db_pool.cursor_to_dict(txn)
-            if rows:
-                return rows[0]["users"]
-            return 0
+            txn.execute("SELECT COUNT(*) FROM users where user_type is null")
+            row = txn.fetchone()
+            assert row is not None
+            return row[0]
 
         return await self.db_pool.runInteraction("count_real_users", _count_users)
 
@@ -1252,12 +1276,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             )
             txn.execute(sql, [])
 
-            res = self.db_pool.cursor_to_dict(txn)
-            if res:
-                for user in res:
-                    self.set_expiration_date_for_user_txn(
-                        txn, user["name"], use_delta=True
-                    )
+            for (name,) in txn.fetchall():
+                self.set_expiration_date_for_user_txn(txn, name, use_delta=True)
 
         await self.db_pool.runInteraction(
             "get_users_with_no_expiration_date",
@@ -1963,11 +1983,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
                 (user_id,),
             )
 
-            rows = self.db_pool.cursor_to_dict(txn)
+            row = txn.fetchone()
+            assert row is not None
 
             # We cast to bool because the value returned by the database engine might
             # be an integer if we're using SQLite.
-            return bool(rows[0]["approved"])
+            return bool(row[0])
 
         return await self.db_pool.runInteraction(
             desc="is_user_pending_approval",
@@ -2045,22 +2066,22 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
                 (last_user, batch_size),
             )
 
-            rows = self.db_pool.cursor_to_dict(txn)
+            rows = txn.fetchall()
 
             if not rows:
                 return True, 0
 
             rows_processed_nb = 0
 
-            for user in rows:
-                if not user["count_tokens"] and not user["count_threepids"]:
-                    self.set_user_deactivated_status_txn(txn, user["name"], True)
+            for name, count_tokens, count_threepids in rows:
+                if not count_tokens and not count_threepids:
+                    self.set_user_deactivated_status_txn(txn, name, True)
                     rows_processed_nb += 1
 
             logger.info("Marked %d rows as deactivated", rows_processed_nb)
 
             self.db_pool.updates._background_update_progress_txn(
-                txn, "users_set_deactivated_flag", {"user_id": rows[-1]["name"]}
+                txn, "users_set_deactivated_flag", {"user_id": rows[-1][0]}
             )
 
             if batch_size > len(rows):