diff --git a/tests/utils.py b/tests/utils.py
index de33deb0b2..ab5e2341c9 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -66,13 +66,19 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
if "clock" not in kargs:
kargs["clock"] = MockClock()
+ db_engine = create_engine(config.database_config)
if datastore is None:
- db_pool = SQLiteMemoryDbPool()
+ # we need to configure the connection pool to run the on_new_connection
+ # function, so that we can test code that uses custom sqlite functions
+ # (like rank).
+ db_pool = SQLiteMemoryDbPool(
+ cp_openfun=db_engine.on_new_connection,
+ )
yield db_pool.prepare()
hs = HomeServer(
name, db_pool=db_pool, config=config,
version_string="Synapse/tests",
- database_engine=create_engine(config.database_config),
+ database_engine=db_engine,
get_db_conn=db_pool.get_db_conn,
room_list_handler=object(),
tls_server_context_factory=Mock(),
@@ -83,7 +89,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
hs = HomeServer(
name, db_pool=None, datastore=datastore, config=config,
version_string="Synapse/tests",
- database_engine=create_engine(config.database_config),
+ database_engine=db_engine,
room_list_handler=object(),
tls_server_context_factory=Mock(),
**kargs
@@ -303,11 +309,15 @@ class MockClock(object):
class SQLiteMemoryDbPool(ConnectionPool, object):
- def __init__(self):
+ def __init__(self, **kwargs):
+ connkw = {
+ "cp_min": 1,
+ "cp_max": 1,
+ }
+ connkw.update(kwargs)
+
super(SQLiteMemoryDbPool, self).__init__(
- "sqlite3", ":memory:",
- cp_min=1,
- cp_max=1,
+ "sqlite3", ":memory:", **connkw
)
self.config = Mock()
|