summary refs log tree commit diff
path: root/tests/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/utils.py')
-rw-r--r--tests/utils.py81
1 files changed, 45 insertions, 36 deletions
diff --git a/tests/utils.py b/tests/utils.py
index de33deb0b2..d1f59551e8 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -19,18 +19,23 @@ import urllib
 import urlparse
 
 from mock import Mock, patch
-from twisted.enterprise.adbapi import ConnectionPool
 from twisted.internet import defer, reactor
 
 from synapse.api.errors import CodeMessageException, cs_error
 from synapse.federation.transport import server
 from synapse.http.server import HttpServer
 from synapse.server import HomeServer
+from synapse.storage import PostgresEngine
 from synapse.storage.engines import create_engine
 from synapse.storage.prepare_database import prepare_database
 from synapse.util.logcontext import LoggingContext
 from synapse.util.ratelimitutils import FederationRateLimiter
 
+# set this to True to run the tests against postgres instead of sqlite.
+# It requires you to have a local postgres database called synapse_test, within
+# which ALL TABLES WILL BE DROPPED
+USE_POSTGRES_FOR_TESTS = False
+
 
 @defer.inlineCallbacks
 def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
@@ -60,30 +65,62 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
         config.update_user_directory = False
 
     config.use_frozen_dicts = True
-    config.database_config = {"name": "sqlite3"}
     config.ldap_enabled = False
 
     if "clock" not in kargs:
         kargs["clock"] = MockClock()
 
+    if USE_POSTGRES_FOR_TESTS:
+        config.database_config = {
+            "name": "psycopg2",
+            "args": {
+                "database": "synapse_test",
+                "cp_min": 1,
+                "cp_max": 5,
+            },
+        }
+    else:
+        config.database_config = {
+            "name": "sqlite3",
+            "args": {
+                "database": ":memory:",
+                "cp_min": 1,
+                "cp_max": 1,
+            },
+        }
+
+    db_engine = create_engine(config.database_config)
+
+    # we need to configure the connection pool to run the on_new_connection
+    # function, so that we can test code that uses custom sqlite functions
+    # (like rank).
+    config.database_config["args"]["cp_openfun"] = db_engine.on_new_connection
+
     if datastore is None:
-        db_pool = SQLiteMemoryDbPool()
-        yield db_pool.prepare()
         hs = HomeServer(
-            name, db_pool=db_pool, config=config,
+            name, config=config,
+            db_config=config.database_config,
             version_string="Synapse/tests",
-            database_engine=create_engine(config.database_config),
-            get_db_conn=db_pool.get_db_conn,
+            database_engine=db_engine,
             room_list_handler=object(),
             tls_server_context_factory=Mock(),
             **kargs
         )
+        db_conn = hs.get_db_conn()
+        # make sure that the database is empty
+        if isinstance(db_engine, PostgresEngine):
+            cur = db_conn.cursor()
+            cur.execute("SELECT tablename FROM pg_tables where schemaname='public'")
+            rows = cur.fetchall()
+            for r in rows:
+                cur.execute("DROP TABLE %s CASCADE" % r[0])
+        yield prepare_database(db_conn, db_engine, config)
         hs.setup()
     else:
         hs = HomeServer(
             name, db_pool=None, datastore=datastore, config=config,
             version_string="Synapse/tests",
-            database_engine=create_engine(config.database_config),
+            database_engine=db_engine,
             room_list_handler=object(),
             tls_server_context_factory=Mock(),
             **kargs
@@ -302,34 +339,6 @@ class MockClock(object):
         return d
 
 
-class SQLiteMemoryDbPool(ConnectionPool, object):
-    def __init__(self):
-        super(SQLiteMemoryDbPool, self).__init__(
-            "sqlite3", ":memory:",
-            cp_min=1,
-            cp_max=1,
-        )
-
-        self.config = Mock()
-        self.config.password_providers = []
-        self.config.database_config = {"name": "sqlite3"}
-
-    def prepare(self):
-        engine = self.create_engine()
-        return self.runWithConnection(
-            lambda conn: prepare_database(conn, engine, self.config)
-        )
-
-    def get_db_conn(self):
-        conn = self.connect()
-        engine = self.create_engine()
-        prepare_database(conn, engine, self.config)
-        return conn
-
-    def create_engine(self):
-        return create_engine(self.config.database_config)
-
-
 def _format_call(args, kwargs):
     return ", ".join(
         ["%r" % (a) for a in args] +