diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/server.py | 19 | ||||
-rw-r--r-- | tests/utils.py | 5 |
2 files changed, 12 insertions, 12 deletions
diff --git a/tests/server.py b/tests/server.py index 5a63ecee9f..74dd00cd3f 100644 --- a/tests/server.py +++ b/tests/server.py @@ -971,8 +971,12 @@ def setup_test_homeserver( if USE_POSTGRES_FOR_TESTS: test_db = "synapse_test_%s" % uuid.uuid4().hex + if USE_POSTGRES_FOR_TESTS == "psycopg": + db_type = "psycopg" + else: + db_type = "psycopg2" database_config = { - "name": "psycopg2", + "name": db_type, "args": { "dbname": test_db, "host": POSTGRES_HOST, @@ -1030,8 +1034,6 @@ def setup_test_homeserver( # Create the database before we actually try and connect to it, based off # the template database we generate in setupdb() if isinstance(db_engine, PostgresEngine): - import psycopg2.extensions - db_conn = db_engine.module.connect( dbname=POSTGRES_BASE_DB, user=POSTGRES_USER, @@ -1039,8 +1041,7 @@ def setup_test_homeserver( port=POSTGRES_PORT, password=POSTGRES_PASSWORD, ) - assert isinstance(db_conn, psycopg2.extensions.connection) - db_conn.autocommit = True + db_engine.attempt_to_set_autocommit(db_conn, True) cur = db_conn.cursor() cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) cur.execute( @@ -1070,9 +1071,6 @@ def setup_test_homeserver( # We need to do cleanup on PostgreSQL def cleanup() -> None: - import psycopg2 - import psycopg2.extensions - # Close all the db pools database_pool._db_pool.close() @@ -1086,8 +1084,7 @@ def setup_test_homeserver( port=POSTGRES_PORT, password=POSTGRES_PASSWORD, ) - assert isinstance(db_conn, psycopg2.extensions.connection) - db_conn.autocommit = True + db_engine.attempt_to_set_autocommit(db_conn, True) cur = db_conn.cursor() # Try a few times to drop the DB. Some things may hold on to the @@ -1099,7 +1096,7 @@ def setup_test_homeserver( cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) db_conn.commit() dropped = True - except psycopg2.OperationalError as e: + except db_engine.OperationalError as e: warnings.warn( "Couldn't drop old db: " + str(e), category=UserWarning, diff --git a/tests/utils.py b/tests/utils.py index a0c87ad628..c44e5cb4ee 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -59,7 +59,10 @@ def setupdb() -> None: # If we're using PostgreSQL, set up the db once if USE_POSTGRES_FOR_TESTS: # create a PostgresEngine - db_engine = create_engine({"name": "psycopg2", "args": {}}) + if USE_POSTGRES_FOR_TESTS == "psycopg": + db_engine = create_engine({"name": "psycopg", "args": {}}) + else: + db_engine = create_engine({"name": "psycopg2", "args": {}}) # connect to postgres to create the base database. db_conn = db_engine.module.connect( user=POSTGRES_USER, |