summary refs log tree commit diff
path: root/tests/utils.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tests/utils.py60
1 files changed, 17 insertions, 43 deletions
diff --git a/tests/utils.py b/tests/utils.py
index ab5e2341c9..50de4199be 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -19,7 +19,6 @@ import urllib
 import urlparse
 
 from mock import Mock, patch
-from twisted.enterprise.adbapi import ConnectionPool
 from twisted.internet import defer, reactor
 
 from synapse.api.errors import CodeMessageException, cs_error
@@ -60,30 +59,37 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
         config.update_user_directory = False
 
     config.use_frozen_dicts = True
-    config.database_config = {"name": "sqlite3"}
     config.ldap_enabled = False
 
     if "clock" not in kargs:
         kargs["clock"] = MockClock()
 
+    config.database_config = {
+        "name": "sqlite3",
+        "args": {
+            "database": ":memory:",
+            "cp_min": 1,
+            "cp_max": 1,
+        },
+    }
     db_engine = create_engine(config.database_config)
+
+    # we need to configure the connection pool to run the on_new_connection
+    # function, so that we can test code that uses custom sqlite functions
+    # (like rank).
+    config.database_config["args"]["cp_openfun"] = db_engine.on_new_connection
+
     if datastore is None:
-        # we need to configure the connection pool to run the on_new_connection
-        # function, so that we can test code that uses custom sqlite functions
-        # (like rank).
-        db_pool = SQLiteMemoryDbPool(
-            cp_openfun=db_engine.on_new_connection,
-        )
-        yield db_pool.prepare()
         hs = HomeServer(
-            name, db_pool=db_pool, config=config,
+            name, config=config,
+            db_config=config.database_config,
             version_string="Synapse/tests",
             database_engine=db_engine,
-            get_db_conn=db_pool.get_db_conn,
             room_list_handler=object(),
             tls_server_context_factory=Mock(),
             **kargs
         )
+        yield prepare_database(hs.get_db_conn(), db_engine, config)
         hs.setup()
     else:
         hs = HomeServer(
@@ -308,38 +314,6 @@ class MockClock(object):
         return d
 
 
-class SQLiteMemoryDbPool(ConnectionPool, object):
-    def __init__(self, **kwargs):
-        connkw = {
-            "cp_min": 1,
-            "cp_max": 1,
-        }
-        connkw.update(kwargs)
-
-        super(SQLiteMemoryDbPool, self).__init__(
-            "sqlite3", ":memory:", **connkw
-        )
-
-        self.config = Mock()
-        self.config.password_providers = []
-        self.config.database_config = {"name": "sqlite3"}
-
-    def prepare(self):
-        engine = self.create_engine()
-        return self.runWithConnection(
-            lambda conn: prepare_database(conn, engine, self.config)
-        )
-
-    def get_db_conn(self):
-        conn = self.connect()
-        engine = self.create_engine()
-        prepare_database(conn, engine, self.config)
-        return conn
-
-    def create_engine(self):
-        return create_engine(self.config.database_config)
-
-
 def _format_call(args, kwargs):
     return ", ".join(
         ["%r" % (a) for a in args] +