summary refs log tree commit diff
path: root/tests/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/utils.py')
-rw-r--r--tests/utils.py30
1 files changed, 22 insertions, 8 deletions
diff --git a/tests/utils.py b/tests/utils.py
index bf7a31ff9e..291b549053 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -20,6 +20,7 @@ from synapse.storage.prepare_database import prepare_database
 from synapse.storage.engines import create_engine
 from synapse.server import HomeServer
 from synapse.federation.transport import server
+from synapse.types import Requester
 from synapse.util.ratelimitutils import FederationRateLimiter
 
 from synapse.util.logcontext import LoggingContext
@@ -51,6 +52,8 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
         config.server_name = "server.under.test"
         config.trusted_third_party_id_servers = []
 
+    config.database_config = {"name": "sqlite3"}
+
     if "clock" not in kargs:
         kargs["clock"] = MockClock()
 
@@ -60,7 +63,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
         hs = HomeServer(
             name, db_pool=db_pool, config=config,
             version_string="Synapse/tests",
-            database_engine=create_engine("sqlite3"),
+            database_engine=create_engine(config),
             get_db_conn=db_pool.get_db_conn,
             **kargs
         )
@@ -69,7 +72,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
         hs = HomeServer(
             name, db_pool=None, datastore=datastore, config=config,
             version_string="Synapse/tests",
-            database_engine=create_engine("sqlite3"),
+            database_engine=create_engine(config),
             **kargs
         )
 
@@ -239,9 +242,10 @@ class MockClock(object):
     def looping_call(self, function, interval):
         pass
 
-    def cancel_call_later(self, timer):
+    def cancel_call_later(self, timer, ignore_errs=False):
         if timer[2]:
-            raise Exception("Cannot cancel an expired timer")
+            if not ignore_errs:
+                raise Exception("Cannot cancel an expired timer")
 
         timer[2] = True
         self.timers = [t for t in self.timers if t != timer]
@@ -277,18 +281,24 @@ class SQLiteMemoryDbPool(ConnectionPool, object):
             cp_max=1,
         )
 
+        self.config = Mock()
+        self.config.database_config = {"name": "sqlite3"}
+
     def prepare(self):
-        engine = create_engine("sqlite3")
+        engine = self.create_engine()
         return self.runWithConnection(
-            lambda conn: prepare_database(conn, engine)
+            lambda conn: prepare_database(conn, engine, self.config)
         )
 
     def get_db_conn(self):
         conn = self.connect()
-        engine = create_engine("sqlite3")
-        prepare_database(conn, engine)
+        engine = self.create_engine()
+        prepare_database(conn, engine, self.config)
         return conn
 
+    def create_engine(self):
+        return create_engine(self.config)
+
 
 class MemoryDataStore(object):
 
@@ -501,3 +511,7 @@ class DeferredMockCallable(object):
                     "call(%s)" % _format_call(c[0], c[1]) for c in calls
                 ])
             )
+
+
+def requester_for_user(user):
+    return Requester(user, None, False)