diff --git a/tests/utils.py b/tests/utils.py
index dfbee5c23a..52405502e9 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -20,6 +20,7 @@ from synapse.storage.prepare_database import prepare_database
from synapse.storage.engines import create_engine
from synapse.server import HomeServer
from synapse.federation.transport import server
+from synapse.types import Requester
from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.logcontext import LoggingContext
@@ -50,6 +51,9 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
config.macaroon_secret_key = "not even a little secret"
config.server_name = "server.under.test"
config.trusted_third_party_id_servers = []
+ config.room_invite_state_types = []
+
+ config.database_config = {"name": "sqlite3"}
if "clock" not in kargs:
kargs["clock"] = MockClock()
@@ -60,7 +64,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
hs = HomeServer(
name, db_pool=db_pool, config=config,
version_string="Synapse/tests",
- database_engine=create_engine("sqlite3"),
+ database_engine=create_engine(config),
get_db_conn=db_pool.get_db_conn,
**kargs
)
@@ -69,7 +73,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
hs = HomeServer(
name, db_pool=None, datastore=datastore, config=config,
version_string="Synapse/tests",
- database_engine=create_engine("sqlite3"),
+ database_engine=create_engine(config),
**kargs
)
@@ -278,18 +282,24 @@ class SQLiteMemoryDbPool(ConnectionPool, object):
cp_max=1,
)
+ self.config = Mock()
+ self.config.database_config = {"name": "sqlite3"}
+
def prepare(self):
- engine = create_engine("sqlite3")
+ engine = self.create_engine()
return self.runWithConnection(
- lambda conn: prepare_database(conn, engine)
+ lambda conn: prepare_database(conn, engine, self.config)
)
def get_db_conn(self):
conn = self.connect()
- engine = create_engine("sqlite3")
- prepare_database(conn, engine)
+ engine = self.create_engine()
+ prepare_database(conn, engine, self.config)
return conn
+ def create_engine(self):
+ return create_engine(self.config)
+
class MemoryDataStore(object):
@@ -502,3 +512,7 @@ class DeferredMockCallable(object):
"call(%s)" % _format_call(c[0], c[1]) for c in calls
])
)
+
+
+def requester_for_user(user):
+ return Requester(user, None, False)
|