summary refs log tree commit diff
path: root/tests/utils.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tests/utils.py24
1 files changed, 17 insertions, 7 deletions
diff --git a/tests/utils.py b/tests/utils.py
index de33deb0b2..ab5e2341c9 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -66,13 +66,19 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
     if "clock" not in kargs:
         kargs["clock"] = MockClock()
 
+    db_engine = create_engine(config.database_config)
     if datastore is None:
-        db_pool = SQLiteMemoryDbPool()
+        # we need to configure the connection pool to run the on_new_connection
+        # function, so that we can test code that uses custom sqlite functions
+        # (like rank).
+        db_pool = SQLiteMemoryDbPool(
+            cp_openfun=db_engine.on_new_connection,
+        )
         yield db_pool.prepare()
         hs = HomeServer(
             name, db_pool=db_pool, config=config,
             version_string="Synapse/tests",
-            database_engine=create_engine(config.database_config),
+            database_engine=db_engine,
             get_db_conn=db_pool.get_db_conn,
             room_list_handler=object(),
             tls_server_context_factory=Mock(),
@@ -83,7 +89,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
         hs = HomeServer(
             name, db_pool=None, datastore=datastore, config=config,
             version_string="Synapse/tests",
-            database_engine=create_engine(config.database_config),
+            database_engine=db_engine,
             room_list_handler=object(),
             tls_server_context_factory=Mock(),
             **kargs
@@ -303,11 +309,15 @@ class MockClock(object):
 
 
 class SQLiteMemoryDbPool(ConnectionPool, object):
-    def __init__(self):
+    def __init__(self, **kwargs):
+        connkw = {
+            "cp_min": 1,
+            "cp_max": 1,
+        }
+        connkw.update(kwargs)
+
         super(SQLiteMemoryDbPool, self).__init__(
-            "sqlite3", ":memory:",
-            cp_min=1,
-            cp_max=1,
+            "sqlite3", ":memory:", **connkw
         )
 
         self.config = Mock()