summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--tests/utils.py61
1 files changed, 37 insertions, 24 deletions
diff --git a/tests/utils.py b/tests/utils.py
index 3059c453d5..1d06398b48 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -15,6 +15,7 @@
 
 import atexit
 import os
+from typing import TYPE_CHECKING, cast
 
 from synapse.api.constants import EventTypes
 from synapse.api.room_versions import RoomVersions
@@ -22,7 +23,7 @@ from synapse.config.homeserver import HomeServerConfig
 from synapse.config.server import DEFAULT_ROOM_VERSION
 from synapse.logging.context import current_context, set_current_context
 from synapse.storage.database import LoggingDatabaseConnection
-from synapse.storage.engines import create_engine
+from synapse.storage.engines import PostgresEngine, create_engine
 from synapse.storage.prepare_database import prepare_database
 
 # set this to True to run the tests against postgres instead of sqlite.
@@ -48,21 +49,27 @@ SQLITE_PERSIST_DB = os.environ.get("SYNAPSE_TEST_PERSIST_SQLITE_DB") is not None
 
 # the dbname we will connect to in order to create the base database.
 POSTGRES_DBNAME_FOR_INITIAL_CREATE = "postgres"
+if TYPE_CHECKING:
+    import psycopg2.extensions
 
 
-def setupdb():
+def setupdb() -> None:
     # If we're using PostgreSQL, set up the db once
     if USE_POSTGRES_FOR_TESTS:
         # create a PostgresEngine
-        db_engine = create_engine({"name": "psycopg2", "args": {}})
-
+        db_engine = cast(
+            PostgresEngine, create_engine({"name": "psycopg2", "args": {}})
+        )
         # connect to postgres to create the base database.
-        db_conn = db_engine.module.connect(
-            user=POSTGRES_USER,
-            host=POSTGRES_HOST,
-            port=POSTGRES_PORT,
-            password=POSTGRES_PASSWORD,
-            dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE,
+        db_conn = cast(
+            "psycopg2.extensions.connection",
+            db_engine.module.connect(
+                user=POSTGRES_USER,
+                host=POSTGRES_HOST,
+                port=POSTGRES_PORT,
+                password=POSTGRES_PASSWORD,
+                dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE,
+            ),
         )
         db_conn.autocommit = True
         cur = db_conn.cursor()
@@ -75,24 +82,30 @@ def setupdb():
         db_conn.close()
 
         # Set up in the db
-        db_conn = db_engine.module.connect(
-            database=POSTGRES_BASE_DB,
-            user=POSTGRES_USER,
-            host=POSTGRES_HOST,
-            port=POSTGRES_PORT,
-            password=POSTGRES_PASSWORD,
-        )
-        db_conn = LoggingDatabaseConnection(db_conn, db_engine, "tests")
-        prepare_database(db_conn, db_engine, None)
-        db_conn.close()
-
-        def _cleanup():
-            db_conn = db_engine.module.connect(
+        db_conn = cast(
+            "psycopg2.extensions.connection",
+            db_engine.module.connect(
+                database=POSTGRES_BASE_DB,
                 user=POSTGRES_USER,
                 host=POSTGRES_HOST,
                 port=POSTGRES_PORT,
                 password=POSTGRES_PASSWORD,
-                dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE,
+            ),
+        )
+        logging_conn = LoggingDatabaseConnection(db_conn, db_engine, "tests")
+        prepare_database(logging_conn, db_engine, None)
+        logging_conn.close()
+
+        def _cleanup() -> None:
+            db_conn = cast(
+                "psycopg2.extensions.connection",
+                db_engine.module.connect(
+                    user=POSTGRES_USER,
+                    host=POSTGRES_HOST,
+                    port=POSTGRES_PORT,
+                    password=POSTGRES_PASSWORD,
+                    dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE,
+                ),
             )
             db_conn.autocommit = True
             cur = db_conn.cursor()