diff --git a/tests/utils.py b/tests/utils.py
index 46ef2959f2..59c020a051 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -30,19 +30,16 @@ from twisted.internet import defer, reactor
from synapse.api.constants import EventTypes
from synapse.api.errors import CodeMessageException, cs_error
from synapse.api.room_versions import RoomVersions
+from synapse.config.database import DatabaseConnectionConfig
from synapse.config.homeserver import HomeServerConfig
from synapse.config.server import DEFAULT_ROOM_VERSION
from synapse.federation.transport import server as federation_server
from synapse.http.server import HttpServer
-from synapse.logging.context import LoggingContext
+from synapse.logging.context import current_context, set_current_context
from synapse.server import HomeServer
from synapse.storage import DataStore
from synapse.storage.engines import PostgresEngine, create_engine
-from synapse.storage.prepare_database import (
- _get_or_create_schema_state,
- _setup_new_database,
- prepare_database,
-)
+from synapse.storage.prepare_database import prepare_database
from synapse.util.ratelimitutils import FederationRateLimiter
# set this to True to run the tests against postgres instead of sqlite.
@@ -77,7 +74,10 @@ def setupdb():
db_conn.autocommit = True
cur = db_conn.cursor()
cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,))
- cur.execute("CREATE DATABASE %s;" % (POSTGRES_BASE_DB,))
+ cur.execute(
+ "CREATE DATABASE %s ENCODING 'UTF8' LC_COLLATE='C' LC_CTYPE='C' "
+ "template=template0;" % (POSTGRES_BASE_DB,)
+ )
cur.close()
db_conn.close()
@@ -88,11 +88,7 @@ def setupdb():
host=POSTGRES_HOST,
password=POSTGRES_PASSWORD,
)
- cur = db_conn.cursor()
- _get_or_create_schema_state(cur, db_engine)
- _setup_new_database(cur, db_engine)
- db_conn.commit()
- cur.close()
+ prepare_database(db_conn, db_engine, None)
db_conn.close()
def _cleanup():
@@ -117,6 +113,7 @@ def default_config(name, parse=False):
"""
config_dict = {
"server_name": name,
+ "send_federation": False,
"media_store_path": "media",
"uploads_path": "uploads",
# the test signing key is just an arbitrary ed25519 key to keep the config
@@ -145,7 +142,6 @@ def default_config(name, parse=False):
"limit_usage_by_mau": False,
"hs_disabled": False,
"hs_disabled_message": "",
- "hs_disabled_limit_type": "",
"max_mau_value": 50,
"mau_trial_days": 0,
"mau_stats_only": False,
@@ -171,6 +167,7 @@ def default_config(name, parse=False):
# disable user directory updates, because they get done in the
# background, which upsets the test runner.
"update_user_directory": False,
+ "caches": {"global_factor": 1},
}
if parse:
@@ -185,7 +182,6 @@ class TestHomeServer(HomeServer):
DATASTORE_CLASS = DataStore
-@defer.inlineCallbacks
def setup_test_homeserver(
cleanup_func,
name="test",
@@ -222,7 +218,7 @@ def setup_test_homeserver(
if USE_POSTGRES_FOR_TESTS:
test_db = "synapse_test_%s" % uuid.uuid4().hex
- config.database_config = {
+ database_config = {
"name": "psycopg2",
"args": {
"database": test_db,
@@ -234,12 +230,15 @@ def setup_test_homeserver(
},
}
else:
- config.database_config = {
+ database_config = {
"name": "sqlite3",
"args": {"database": ":memory:", "cp_min": 1, "cp_max": 1},
}
- db_engine = create_engine(config.database_config)
+ database = DatabaseConnectionConfig("master", database_config)
+ config.database.databases = [database]
+
+ db_engine = create_engine(database.config)
# Create the database before we actually try and connect to it, based off
# the template database we generate in setupdb()
@@ -259,39 +258,30 @@ def setup_test_homeserver(
cur.close()
db_conn.close()
- # we need to configure the connection pool to run the on_new_connection
- # function, so that we can test code that uses custom sqlite functions
- # (like rank).
- config.database_config["args"]["cp_openfun"] = db_engine.on_new_connection
-
if datastore is None:
hs = homeserverToUse(
name,
config=config,
- db_config=config.database_config,
version_string="Synapse/tests",
- database_engine=db_engine,
tls_server_context_factory=Mock(),
tls_client_options_factory=Mock(),
reactor=reactor,
**kargs
)
- # Prepare the DB on SQLite -- PostgreSQL is a copy of an already up to
- # date db
- if not isinstance(db_engine, PostgresEngine):
- db_conn = hs.get_db_conn()
- yield prepare_database(db_conn, db_engine, config)
- db_conn.commit()
- db_conn.close()
+ hs.setup()
+ if homeserverToUse.__name__ == "TestHomeServer":
+ hs.setup_master()
+
+ if isinstance(db_engine, PostgresEngine):
+ database = hs.get_datastores().databases[0]
- else:
# We need to do cleanup on PostgreSQL
def cleanup():
import psycopg2
# Close all the db pools
- hs.get_db_pool().close()
+ database._db_pool.close()
dropped = False
@@ -330,17 +320,12 @@ def setup_test_homeserver(
# Register the cleanup hook
cleanup_func(cleanup)
- hs.setup()
- if homeserverToUse.__name__ == "TestHomeServer":
- hs.setup_master()
else:
hs = homeserverToUse(
name,
- db_pool=None,
datastore=datastore,
config=config,
version_string="Synapse/tests",
- database_engine=db_engine,
tls_server_context_factory=Mock(),
tls_client_options_factory=Mock(),
reactor=reactor,
@@ -351,10 +336,15 @@ def setup_test_homeserver(
# Need to let the HS build an auth handler and then mess with it
# because AuthHandler's constructor requires the HS, so we can't make one
# beforehand and pass it in to the HS's constructor (chicken / egg)
- hs.get_auth_handler().hash = lambda p: hashlib.md5(p.encode("utf8")).hexdigest()
- hs.get_auth_handler().validate_hash = (
- lambda p, h: hashlib.md5(p.encode("utf8")).hexdigest() == h
- )
+ async def hash(p):
+ return hashlib.md5(p.encode("utf8")).hexdigest()
+
+ hs.get_auth_handler().hash = hash
+
+ async def validate_hash(p, h):
+ return hashlib.md5(p.encode("utf8")).hexdigest() == h
+
+ hs.get_auth_handler().validate_hash = validate_hash
fed = kargs.get("resource_for_federation", None)
if fed:
@@ -463,7 +453,9 @@ class MockHttpResource(HttpServer):
try:
args = [urlparse.unquote(u) for u in matcher.groups()]
- (code, response) = yield func(mock_request, *args)
+ (code, response) = yield defer.ensureDeferred(
+ func(mock_request, *args)
+ )
return code, response
except CodeMessageException as e:
return (e.code, cs_error(e.msg, code=e.errcode))
@@ -510,10 +502,10 @@ class MockClock(object):
return self.time() * 1000
def call_later(self, delay, callback, *args, **kwargs):
- current_context = LoggingContext.current_context()
+ ctx = current_context()
def wrapped_callback():
- LoggingContext.thread_local.current_context = current_context
+ set_current_context(ctx)
callback(*args, **kwargs)
t = [self.now + delay, wrapped_callback, False]
@@ -521,8 +513,8 @@ class MockClock(object):
return t
- def looping_call(self, function, interval):
- self.loopers.append([function, interval / 1000.0, self.now])
+ def looping_call(self, function, interval, *args, **kwargs):
+ self.loopers.append([function, interval / 1000.0, self.now, args, kwargs])
def cancel_call_later(self, timer, ignore_errs=False):
if timer[2]:
@@ -552,9 +544,9 @@ class MockClock(object):
self.timers.append(t)
for looped in self.loopers:
- func, interval, last = looped
+ func, interval, last, args, kwargs = looped
if last + interval < self.now:
- func()
+ func(*args, **kwargs)
looped[2] = self.now
def advance_time_msec(self, ms):
@@ -655,10 +647,18 @@ def create_room(hs, room_id, creator_id):
creator_id (str)
"""
+ persistence_store = hs.get_storage().persistence
store = hs.get_datastore()
event_builder_factory = hs.get_event_builder_factory()
event_creation_handler = hs.get_event_creation_handler()
+ yield store.store_room(
+ room_id=room_id,
+ room_creator_user_id=creator_id,
+ is_public=False,
+ room_version=RoomVersions.V1,
+ )
+
builder = event_builder_factory.for_room_version(
RoomVersions.V1,
{
@@ -672,4 +672,4 @@ def create_room(hs, room_id, creator_id):
event, context = yield event_creation_handler.create_new_client_event(builder)
- yield store.persist_event(event, context)
+ yield persistence_store.persist_event(event, context)
|