diff --git a/tests/utils.py b/tests/utils.py
index 50de4199be..d1f59551e8 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -25,11 +25,17 @@ from synapse.api.errors import CodeMessageException, cs_error
from synapse.federation.transport import server
from synapse.http.server import HttpServer
from synapse.server import HomeServer
+from synapse.storage import PostgresEngine
from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database
from synapse.util.logcontext import LoggingContext
from synapse.util.ratelimitutils import FederationRateLimiter
+# set this to True to run the tests against postgres instead of sqlite.
+# It requires you to have a local postgres database called synapse_test, within
+# which ALL TABLES WILL BE DROPPED
+USE_POSTGRES_FOR_TESTS = False
+
@defer.inlineCallbacks
def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
@@ -64,14 +70,25 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
if "clock" not in kargs:
kargs["clock"] = MockClock()
- config.database_config = {
- "name": "sqlite3",
- "args": {
- "database": ":memory:",
- "cp_min": 1,
- "cp_max": 1,
- },
- }
+ if USE_POSTGRES_FOR_TESTS:
+ config.database_config = {
+ "name": "psycopg2",
+ "args": {
+ "database": "synapse_test",
+ "cp_min": 1,
+ "cp_max": 5,
+ },
+ }
+ else:
+ config.database_config = {
+ "name": "sqlite3",
+ "args": {
+ "database": ":memory:",
+ "cp_min": 1,
+ "cp_max": 1,
+ },
+ }
+
db_engine = create_engine(config.database_config)
# we need to configure the connection pool to run the on_new_connection
@@ -89,7 +106,15 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
tls_server_context_factory=Mock(),
**kargs
)
- yield prepare_database(hs.get_db_conn(), db_engine, config)
+ db_conn = hs.get_db_conn()
+ # make sure that the database is empty
+ if isinstance(db_engine, PostgresEngine):
+ cur = db_conn.cursor()
+ cur.execute("SELECT tablename FROM pg_tables where schemaname='public'")
+ rows = cur.fetchall()
+ for r in rows:
+ cur.execute("DROP TABLE %s CASCADE" % r[0])
+ yield prepare_database(db_conn, db_engine, config)
hs.setup()
else:
hs = HomeServer(
|