summary refs log tree commit diff
path: root/synapse/storage/databases/main/registration.py
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2023-10-05 11:07:38 -0400
committerGitHub <noreply@github.com>2023-10-05 11:07:38 -0400
commitfa907025f4b263d27c2b338fb0fe86d257d74fa8 (patch)
tree325bfac7ba840c71a0b9a7b3e7a72e3549f03c54 /synapse/storage/databases/main/registration.py
parentAdd __slots__ to replication commands. (#16429) (diff)
downloadsynapse-fa907025f4b263d27c2b338fb0fe86d257d74fa8.tar.xz
Remove manys calls to cursor_to_dict (#16431)
This avoids calling cursor_to_dict and then immediately
unpacking the values in the dict for other users. By not
creating the intermediate dictionary we can avoid allocating
the dictionary and strings for the keys, which should generally
be more performant.

Additionally this improves type hints by avoid Dict[str, Any]
dictionaries coming out of the database layer.
Diffstat (limited to 'synapse/storage/databases/main/registration.py')
-rw-r--r--synapse/storage/databases/main/registration.py133
1 files changed, 77 insertions, 56 deletions
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index cc964604e2..64a2c31a5d 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -195,7 +195,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
     async def get_user_by_id(self, user_id: str) -> Optional[UserInfo]:
         """Returns info about the user account, if it exists."""
 
-        def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
+        def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[UserInfo]:
             # We could technically use simple_select_one here, but it would not perform
             # the COALESCEs (unless hacked into the column names), which could yield
             # confusing results.
@@ -213,35 +213,46 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
                 (user_id,),
             )
 
-            rows = self.db_pool.cursor_to_dict(txn)
-
-            if len(rows) == 0:
+            row = txn.fetchone()
+            if not row:
                 return None
 
-            return rows[0]
+            (
+                name,
+                is_guest,
+                admin,
+                consent_version,
+                consent_ts,
+                consent_server_notice_sent,
+                appservice_id,
+                creation_ts,
+                user_type,
+                deactivated,
+                shadow_banned,
+                approved,
+                locked,
+            ) = row
+
+            return UserInfo(
+                appservice_id=appservice_id,
+                consent_server_notice_sent=consent_server_notice_sent,
+                consent_version=consent_version,
+                consent_ts=consent_ts,
+                creation_ts=creation_ts,
+                is_admin=bool(admin),
+                is_deactivated=bool(deactivated),
+                is_guest=bool(is_guest),
+                is_shadow_banned=bool(shadow_banned),
+                user_id=UserID.from_string(name),
+                user_type=user_type,
+                approved=bool(approved),
+                locked=bool(locked),
+            )
 
-        row = await self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             desc="get_user_by_id",
             func=get_user_by_id_txn,
         )
-        if row is None:
-            return None
-
-        return UserInfo(
-            appservice_id=row["appservice_id"],
-            consent_server_notice_sent=row["consent_server_notice_sent"],
-            consent_version=row["consent_version"],
-            consent_ts=row["consent_ts"],
-            creation_ts=row["creation_ts"],
-            is_admin=bool(row["admin"]),
-            is_deactivated=bool(row["deactivated"]),
-            is_guest=bool(row["is_guest"]),
-            is_shadow_banned=bool(row["shadow_banned"]),
-            user_id=UserID.from_string(row["name"]),
-            user_type=row["user_type"],
-            approved=bool(row["approved"]),
-            locked=bool(row["locked"]),
-        )
 
     async def is_trial_user(self, user_id: str) -> bool:
         """Checks if user is in the "trial" period, i.e. within the first
@@ -579,16 +590,31 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
         """
 
         txn.execute(sql, (token,))
-        rows = self.db_pool.cursor_to_dict(txn)
-
-        if rows:
-            row = rows[0]
-
-            # This field is nullable, ensure it comes out as a boolean
-            if row["token_used"] is None:
-                row["token_used"] = False
+        row = txn.fetchone()
 
-            return TokenLookupResult(**row)
+        if row:
+            (
+                user_id,
+                is_guest,
+                shadow_banned,
+                token_id,
+                device_id,
+                valid_until_ms,
+                token_owner,
+                token_used,
+            ) = row
+
+            return TokenLookupResult(
+                user_id=user_id,
+                is_guest=is_guest,
+                shadow_banned=shadow_banned,
+                token_id=token_id,
+                device_id=device_id,
+                valid_until_ms=valid_until_ms,
+                token_owner=token_owner,
+                # This field is nullable, ensure it comes out as a boolean
+                token_used=bool(token_used),
+            )
 
         return None
 
@@ -833,11 +859,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
         """Counts all users registered on the homeserver."""
 
         def _count_users(txn: LoggingTransaction) -> int:
-            txn.execute("SELECT COUNT(*) AS users FROM users")
-            rows = self.db_pool.cursor_to_dict(txn)
-            if rows:
-                return rows[0]["users"]
-            return 0
+            txn.execute("SELECT COUNT(*) FROM users")
+            row = txn.fetchone()
+            assert row is not None
+            return row[0]
 
         return await self.db_pool.runInteraction("count_users", _count_users)
 
@@ -891,11 +916,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
         """Counts all users without a special user_type registered on the homeserver."""
 
         def _count_users(txn: LoggingTransaction) -> int:
-            txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null")
-            rows = self.db_pool.cursor_to_dict(txn)
-            if rows:
-                return rows[0]["users"]
-            return 0
+            txn.execute("SELECT COUNT(*) FROM users where user_type is null")
+            row = txn.fetchone()
+            assert row is not None
+            return row[0]
 
         return await self.db_pool.runInteraction("count_real_users", _count_users)
 
@@ -1252,12 +1276,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             )
             txn.execute(sql, [])
 
-            res = self.db_pool.cursor_to_dict(txn)
-            if res:
-                for user in res:
-                    self.set_expiration_date_for_user_txn(
-                        txn, user["name"], use_delta=True
-                    )
+            for (name,) in txn.fetchall():
+                self.set_expiration_date_for_user_txn(txn, name, use_delta=True)
 
         await self.db_pool.runInteraction(
             "get_users_with_no_expiration_date",
@@ -1963,11 +1983,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
                 (user_id,),
             )
 
-            rows = self.db_pool.cursor_to_dict(txn)
+            row = txn.fetchone()
+            assert row is not None
 
             # We cast to bool because the value returned by the database engine might
             # be an integer if we're using SQLite.
-            return bool(rows[0]["approved"])
+            return bool(row[0])
 
         return await self.db_pool.runInteraction(
             desc="is_user_pending_approval",
@@ -2045,22 +2066,22 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
                 (last_user, batch_size),
             )
 
-            rows = self.db_pool.cursor_to_dict(txn)
+            rows = txn.fetchall()
 
             if not rows:
                 return True, 0
 
             rows_processed_nb = 0
 
-            for user in rows:
-                if not user["count_tokens"] and not user["count_threepids"]:
-                    self.set_user_deactivated_status_txn(txn, user["name"], True)
+            for name, count_tokens, count_threepids in rows:
+                if not count_tokens and not count_threepids:
+                    self.set_user_deactivated_status_txn(txn, name, True)
                     rows_processed_nb += 1
 
             logger.info("Marked %d rows as deactivated", rows_processed_nb)
 
             self.db_pool.updates._background_update_progress_txn(
-                txn, "users_set_deactivated_flag", {"user_id": rows[-1]["name"]}
+                txn, "users_set_deactivated_flag", {"user_id": rows[-1][0]}
             )
 
             if batch_size > len(rows):