diff options
-rw-r--r-- | tests/utils.py | 61 |
1 files changed, 37 insertions, 24 deletions
diff --git a/tests/utils.py b/tests/utils.py index 3059c453d5..1d06398b48 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -15,6 +15,7 @@ import atexit import os +from typing import TYPE_CHECKING, cast from synapse.api.constants import EventTypes from synapse.api.room_versions import RoomVersions @@ -22,7 +23,7 @@ from synapse.config.homeserver import HomeServerConfig from synapse.config.server import DEFAULT_ROOM_VERSION from synapse.logging.context import current_context, set_current_context from synapse.storage.database import LoggingDatabaseConnection -from synapse.storage.engines import create_engine +from synapse.storage.engines import PostgresEngine, create_engine from synapse.storage.prepare_database import prepare_database # set this to True to run the tests against postgres instead of sqlite. @@ -48,21 +49,27 @@ SQLITE_PERSIST_DB = os.environ.get("SYNAPSE_TEST_PERSIST_SQLITE_DB") is not None # the dbname we will connect to in order to create the base database. POSTGRES_DBNAME_FOR_INITIAL_CREATE = "postgres" +if TYPE_CHECKING: + import psycopg2.extensions -def setupdb(): +def setupdb() -> None: # If we're using PostgreSQL, set up the db once if USE_POSTGRES_FOR_TESTS: # create a PostgresEngine - db_engine = create_engine({"name": "psycopg2", "args": {}}) - + db_engine = cast( + PostgresEngine, create_engine({"name": "psycopg2", "args": {}}) + ) # connect to postgres to create the base database. - db_conn = db_engine.module.connect( - user=POSTGRES_USER, - host=POSTGRES_HOST, - port=POSTGRES_PORT, - password=POSTGRES_PASSWORD, - dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE, + db_conn = cast( + "psycopg2.extensions.connection", + db_engine.module.connect( + user=POSTGRES_USER, + host=POSTGRES_HOST, + port=POSTGRES_PORT, + password=POSTGRES_PASSWORD, + dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE, + ), ) db_conn.autocommit = True cur = db_conn.cursor() @@ -75,24 +82,30 @@ def setupdb(): db_conn.close() # Set up in the db - db_conn = db_engine.module.connect( - database=POSTGRES_BASE_DB, - user=POSTGRES_USER, - host=POSTGRES_HOST, - port=POSTGRES_PORT, - password=POSTGRES_PASSWORD, - ) - db_conn = LoggingDatabaseConnection(db_conn, db_engine, "tests") - prepare_database(db_conn, db_engine, None) - db_conn.close() - - def _cleanup(): - db_conn = db_engine.module.connect( + db_conn = cast( + "psycopg2.extensions.connection", + db_engine.module.connect( + database=POSTGRES_BASE_DB, user=POSTGRES_USER, host=POSTGRES_HOST, port=POSTGRES_PORT, password=POSTGRES_PASSWORD, - dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE, + ), + ) + logging_conn = LoggingDatabaseConnection(db_conn, db_engine, "tests") + prepare_database(logging_conn, db_engine, None) + logging_conn.close() + + def _cleanup() -> None: + db_conn = cast( + "psycopg2.extensions.connection", + db_engine.module.connect( + user=POSTGRES_USER, + host=POSTGRES_HOST, + port=POSTGRES_PORT, + password=POSTGRES_PASSWORD, + dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE, + ), ) db_conn.autocommit = True cur = db_conn.cursor() |