summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/server.py19
-rw-r--r--tests/utils.py5
2 files changed, 12 insertions, 12 deletions
diff --git a/tests/server.py b/tests/server.py
index 5a63ecee9f..74dd00cd3f 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -971,8 +971,12 @@ def setup_test_homeserver(
     if USE_POSTGRES_FOR_TESTS:
         test_db = "synapse_test_%s" % uuid.uuid4().hex
 
+        if USE_POSTGRES_FOR_TESTS == "psycopg":
+            db_type = "psycopg"
+        else:
+            db_type = "psycopg2"
         database_config = {
-            "name": "psycopg2",
+            "name": db_type,
             "args": {
                 "dbname": test_db,
                 "host": POSTGRES_HOST,
@@ -1030,8 +1034,6 @@ def setup_test_homeserver(
     # Create the database before we actually try and connect to it, based off
     # the template database we generate in setupdb()
     if isinstance(db_engine, PostgresEngine):
-        import psycopg2.extensions
-
         db_conn = db_engine.module.connect(
             dbname=POSTGRES_BASE_DB,
             user=POSTGRES_USER,
@@ -1039,8 +1041,7 @@ def setup_test_homeserver(
             port=POSTGRES_PORT,
             password=POSTGRES_PASSWORD,
         )
-        assert isinstance(db_conn, psycopg2.extensions.connection)
-        db_conn.autocommit = True
+        db_engine.attempt_to_set_autocommit(db_conn, True)
         cur = db_conn.cursor()
         cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
         cur.execute(
@@ -1070,9 +1071,6 @@ def setup_test_homeserver(
 
         # We need to do cleanup on PostgreSQL
         def cleanup() -> None:
-            import psycopg2
-            import psycopg2.extensions
-
             # Close all the db pools
             database_pool._db_pool.close()
 
@@ -1086,8 +1084,7 @@ def setup_test_homeserver(
                 port=POSTGRES_PORT,
                 password=POSTGRES_PASSWORD,
             )
-            assert isinstance(db_conn, psycopg2.extensions.connection)
-            db_conn.autocommit = True
+            db_engine.attempt_to_set_autocommit(db_conn, True)
             cur = db_conn.cursor()
 
             # Try a few times to drop the DB. Some things may hold on to the
@@ -1099,7 +1096,7 @@ def setup_test_homeserver(
                     cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
                     db_conn.commit()
                     dropped = True
-                except psycopg2.OperationalError as e:
+                except db_engine.OperationalError as e:
                     warnings.warn(
                         "Couldn't drop old db: " + str(e),
                         category=UserWarning,
diff --git a/tests/utils.py b/tests/utils.py
index a0c87ad628..c44e5cb4ee 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -59,7 +59,10 @@ def setupdb() -> None:
     # If we're using PostgreSQL, set up the db once
     if USE_POSTGRES_FOR_TESTS:
         # create a PostgresEngine
-        db_engine = create_engine({"name": "psycopg2", "args": {}})
+        if USE_POSTGRES_FOR_TESTS == "psycopg":
+            db_engine = create_engine({"name": "psycopg", "args": {}})
+        else:
+            db_engine = create_engine({"name": "psycopg2", "args": {}})
         # connect to postgres to create the base database.
         db_conn = db_engine.module.connect(
             user=POSTGRES_USER,