summary refs log tree commit diff
path: root/tests/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/utils.py')
-rw-r--r--tests/utils.py43
1 files changed, 34 insertions, 9 deletions
diff --git a/tests/utils.py b/tests/utils.py
index 50de4199be..d1f59551e8 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -25,11 +25,17 @@ from synapse.api.errors import CodeMessageException, cs_error
 from synapse.federation.transport import server
 from synapse.http.server import HttpServer
 from synapse.server import HomeServer
+from synapse.storage import PostgresEngine
 from synapse.storage.engines import create_engine
 from synapse.storage.prepare_database import prepare_database
 from synapse.util.logcontext import LoggingContext
 from synapse.util.ratelimitutils import FederationRateLimiter
 
+# set this to True to run the tests against postgres instead of sqlite.
+# It requires you to have a local postgres database called synapse_test, within
+# which ALL TABLES WILL BE DROPPED
+USE_POSTGRES_FOR_TESTS = False
+
 
 @defer.inlineCallbacks
 def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
@@ -64,14 +70,25 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
     if "clock" not in kargs:
         kargs["clock"] = MockClock()
 
-    config.database_config = {
-        "name": "sqlite3",
-        "args": {
-            "database": ":memory:",
-            "cp_min": 1,
-            "cp_max": 1,
-        },
-    }
+    if USE_POSTGRES_FOR_TESTS:
+        config.database_config = {
+            "name": "psycopg2",
+            "args": {
+                "database": "synapse_test",
+                "cp_min": 1,
+                "cp_max": 5,
+            },
+        }
+    else:
+        config.database_config = {
+            "name": "sqlite3",
+            "args": {
+                "database": ":memory:",
+                "cp_min": 1,
+                "cp_max": 1,
+            },
+        }
+
     db_engine = create_engine(config.database_config)
 
     # we need to configure the connection pool to run the on_new_connection
@@ -89,7 +106,15 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
             tls_server_context_factory=Mock(),
             **kargs
         )
-        yield prepare_database(hs.get_db_conn(), db_engine, config)
+        db_conn = hs.get_db_conn()
+        # make sure that the database is empty
+        if isinstance(db_engine, PostgresEngine):
+            cur = db_conn.cursor()
+            cur.execute("SELECT tablename FROM pg_tables where schemaname='public'")
+            rows = cur.fetchall()
+            for r in rows:
+                cur.execute("DROP TABLE %s CASCADE" % r[0])
+        yield prepare_database(db_conn, db_engine, config)
         hs.setup()
     else:
         hs = HomeServer(