summary refs log tree commit diff
path: root/synapse/storage/databases/main/__init__.py
diff options
context:
space:
mode:
authorDirk Klimpel <5740567+dklimpel@users.noreply.github.com>2022-04-27 14:05:00 +0200
committerGitHub <noreply@github.com>2022-04-27 13:05:00 +0100
commitb76f1a4d5f918def1f643910939b80e9e035e07f (patch)
treeb0492ace0e54340b0b40d990a298c2c249274427 /synapse/storage/databases/main/__init__.py
parentBound ephemeral events by key (#12544) (diff)
downloadsynapse-b76f1a4d5f918def1f643910939b80e9e035e07f.tar.xz
Add some type hints to datastore (#12485)
Diffstat (limited to 'synapse/storage/databases/main/__init__.py')
-rw-r--r--synapse/storage/databases/main/__init__.py21
1 files changed, 15 insertions, 6 deletions
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 951031af50..5895b89202 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -15,12 +15,17 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, List, Optional, Tuple
+from typing import TYPE_CHECKING, List, Optional, Tuple, cast
 
 from synapse.config.homeserver import HomeServerConfig
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.databases.main.stats import UserSortOrder
-from synapse.storage.engines import PostgresEngine
+from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
+from synapse.storage.types import Cursor
 from synapse.storage.util.id_generators import (
     IdGenerator,
     MultiWriterIdGenerator,
@@ -266,7 +271,9 @@ class DataStore(
             A tuple of a list of mappings from user to information and a count of total users.
         """
 
-        def get_users_paginate_txn(txn):
+        def get_users_paginate_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[JsonDict], int]:
             filters = []
             args = [self.hs.config.server.server_name]
 
@@ -301,7 +308,7 @@ class DataStore(
                 """
             sql = "SELECT COUNT(*) as total_users " + sql_base
             txn.execute(sql, args)
-            count = txn.fetchone()[0]
+            count = cast(Tuple[int], txn.fetchone())[0]
 
             sql = f"""
                 SELECT name, user_type, is_guest, admin, deactivated, shadow_banned,
@@ -338,7 +345,9 @@ class DataStore(
         )
 
 
-def check_database_before_upgrade(cur, database_engine, config: HomeServerConfig):
+def check_database_before_upgrade(
+    cur: Cursor, database_engine: BaseDatabaseEngine, config: HomeServerConfig
+) -> None:
     """Called before upgrading an existing database to check that it is broadly sane
     compared with the configuration.
     """