diff --git a/tests/utils.py b/tests/utils.py
index f8c7ad2604..513f358f4f 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -30,19 +30,16 @@ from twisted.internet import defer, reactor
from synapse.api.constants import EventTypes
from synapse.api.errors import CodeMessageException, cs_error
from synapse.api.room_versions import RoomVersions
+from synapse.config.database import DatabaseConnectionConfig
from synapse.config.homeserver import HomeServerConfig
from synapse.config.server import DEFAULT_ROOM_VERSION
from synapse.federation.transport import server as federation_server
from synapse.http.server import HttpServer
+from synapse.logging.context import LoggingContext
from synapse.server import HomeServer
from synapse.storage import DataStore
from synapse.storage.engines import PostgresEngine, create_engine
-from synapse.storage.prepare_database import (
- _get_or_create_schema_state,
- _setup_new_database,
- prepare_database,
-)
-from synapse.util.logcontext import LoggingContext
+from synapse.storage.prepare_database import prepare_database
from synapse.util.ratelimitutils import FederationRateLimiter
# set this to True to run the tests against postgres instead of sqlite.
@@ -88,11 +85,7 @@ def setupdb():
host=POSTGRES_HOST,
password=POSTGRES_PASSWORD,
)
- cur = db_conn.cursor()
- _get_or_create_schema_state(cur, db_engine)
- _setup_new_database(cur, db_engine)
- db_conn.commit()
- cur.close()
+ prepare_database(db_conn, db_engine, None)
db_conn.close()
def _cleanup():
@@ -117,6 +110,7 @@ def default_config(name, parse=False):
"""
config_dict = {
"server_name": name,
+ "send_federation": False,
"media_store_path": "media",
"uploads_path": "uploads",
# the test signing key is just an arbitrary ed25519 key to keep the config
@@ -126,7 +120,6 @@ def default_config(name, parse=False):
"enable_registration": True,
"enable_registration_captcha": False,
"macaroon_secret_key": "not even a little secret",
- "expire_access_token": False,
"trusted_third_party_id_servers": [],
"room_invite_state_types": [],
"password_providers": [],
@@ -146,18 +139,11 @@ def default_config(name, parse=False):
"limit_usage_by_mau": False,
"hs_disabled": False,
"hs_disabled_message": "",
- "hs_disabled_limit_type": "",
"max_mau_value": 50,
"mau_trial_days": 0,
"mau_stats_only": False,
"mau_limits_reserved_threepids": [],
"admin_contact": None,
- "rc_federation": {
- "reject_limit": 10,
- "sleep_limit": 10,
- "sleep_delay": 10,
- "concurrent": 10,
- },
"rc_message": {"per_second": 10000, "burst_count": 10000},
"rc_registration": {"per_second": 10000, "burst_count": 10000},
"rc_login": {
@@ -182,7 +168,7 @@ def default_config(name, parse=False):
if parse:
config = HomeServerConfig()
- config.parse_config_dict(config_dict)
+ config.parse_config_dict(config_dict, "", "")
return config
return config_dict
@@ -192,7 +178,6 @@ class TestHomeServer(HomeServer):
DATASTORE_CLASS = DataStore
-@defer.inlineCallbacks
def setup_test_homeserver(
cleanup_func,
name="test",
@@ -229,7 +214,7 @@ def setup_test_homeserver(
if USE_POSTGRES_FOR_TESTS:
test_db = "synapse_test_%s" % uuid.uuid4().hex
- config.database_config = {
+ database_config = {
"name": "psycopg2",
"args": {
"database": test_db,
@@ -241,12 +226,15 @@ def setup_test_homeserver(
},
}
else:
- config.database_config = {
+ database_config = {
"name": "sqlite3",
"args": {"database": ":memory:", "cp_min": 1, "cp_max": 1},
}
- db_engine = create_engine(config.database_config)
+ database = DatabaseConnectionConfig("master", database_config)
+ config.database.databases = [database]
+
+ db_engine = create_engine(database.config)
# Create the database before we actually try and connect to it, based off
# the template database we generate in setupdb()
@@ -266,39 +254,30 @@ def setup_test_homeserver(
cur.close()
db_conn.close()
- # we need to configure the connection pool to run the on_new_connection
- # function, so that we can test code that uses custom sqlite functions
- # (like rank).
- config.database_config["args"]["cp_openfun"] = db_engine.on_new_connection
-
if datastore is None:
hs = homeserverToUse(
name,
config=config,
- db_config=config.database_config,
version_string="Synapse/tests",
- database_engine=db_engine,
tls_server_context_factory=Mock(),
tls_client_options_factory=Mock(),
reactor=reactor,
**kargs
)
- # Prepare the DB on SQLite -- PostgreSQL is a copy of an already up to
- # date db
- if not isinstance(db_engine, PostgresEngine):
- db_conn = hs.get_db_conn()
- yield prepare_database(db_conn, db_engine, config)
- db_conn.commit()
- db_conn.close()
+ hs.setup()
+ if homeserverToUse.__name__ == "TestHomeServer":
+ hs.setup_master()
+
+ if isinstance(db_engine, PostgresEngine):
+ database = hs.get_datastores().databases[0]
- else:
# We need to do cleanup on PostgreSQL
def cleanup():
import psycopg2
# Close all the db pools
- hs.get_db_pool().close()
+ database._db_pool.close()
dropped = False
@@ -337,17 +316,12 @@ def setup_test_homeserver(
# Register the cleanup hook
cleanup_func(cleanup)
- hs.setup()
- if homeserverToUse.__name__ == "TestHomeServer":
- hs.setup_master()
else:
hs = homeserverToUse(
name,
- db_pool=None,
datastore=datastore,
config=config,
version_string="Synapse/tests",
- database_engine=db_engine,
tls_server_context_factory=Mock(),
tls_client_options_factory=Mock(),
reactor=reactor,
@@ -358,16 +332,16 @@ def setup_test_homeserver(
# Need to let the HS build an auth handler and then mess with it
# because AuthHandler's constructor requires the HS, so we can't make one
# beforehand and pass it in to the HS's constructor (chicken / egg)
- hs.get_auth_handler().hash = lambda p: hashlib.md5(p.encode('utf8')).hexdigest()
+ hs.get_auth_handler().hash = lambda p: hashlib.md5(p.encode("utf8")).hexdigest()
hs.get_auth_handler().validate_hash = (
- lambda p, h: hashlib.md5(p.encode('utf8')).hexdigest() == h
+ lambda p, h: hashlib.md5(p.encode("utf8")).hexdigest() == h
)
fed = kargs.get("resource_for_federation", None)
if fed:
register_federation_servlets(hs, fed)
- defer.returnValue(hs)
+ return hs
def register_federation_servlets(hs, resource):
@@ -407,7 +381,7 @@ class MockHttpResource(HttpServer):
def trigger_get(self, path):
return self.trigger(b"GET", path, None)
- @patch('twisted.web.http.Request')
+ @patch("twisted.web.http.Request")
@defer.inlineCallbacks
def trigger(
self, http_method, path, content, mock_request, federation_auth_origin=None
@@ -431,12 +405,12 @@ class MockHttpResource(HttpServer):
# annoyingly we return a twisted http request which has chained calls
# to get at the http content, hence mock it here.
mock_content = Mock()
- config = {'read.return_value': content}
+ config = {"read.return_value": content}
mock_content.configure_mock(**config)
mock_request.content = mock_content
- mock_request.method = http_method.encode('ascii')
- mock_request.uri = path.encode('ascii')
+ mock_request.method = http_method.encode("ascii")
+ mock_request.uri = path.encode("ascii")
mock_request.getClientIP.return_value = "-"
@@ -452,14 +426,14 @@ class MockHttpResource(HttpServer):
# add in query params to the right place
try:
- mock_request.args = urlparse.parse_qs(path.split('?')[1])
- mock_request.path = path.split('?')[0]
+ mock_request.args = urlparse.parse_qs(path.split("?")[1])
+ mock_request.path = path.split("?")[0]
path = mock_request.path
except Exception:
pass
if isinstance(path, bytes):
- path = path.decode('utf8')
+ path = path.decode("utf8")
for (method, pattern, func) in self.callbacks:
if http_method != method:
@@ -470,14 +444,16 @@ class MockHttpResource(HttpServer):
try:
args = [urlparse.unquote(u) for u in matcher.groups()]
- (code, response) = yield func(mock_request, *args)
- defer.returnValue((code, response))
+ (code, response) = yield defer.ensureDeferred(
+ func(mock_request, *args)
+ )
+ return code, response
except CodeMessageException as e:
- defer.returnValue((e.code, cs_error(e.msg, code=e.errcode)))
+ return (e.code, cs_error(e.msg, code=e.errcode))
raise KeyError("No event can handle %s" % path)
- def register_paths(self, method, path_patterns, callback):
+ def register_paths(self, method, path_patterns, callback, servlet_name):
for path_pattern in path_patterns:
self.callbacks.append((method, path_pattern, callback))
@@ -662,10 +638,18 @@ def create_room(hs, room_id, creator_id):
creator_id (str)
"""
+ persistence_store = hs.get_storage().persistence
store = hs.get_datastore()
event_builder_factory = hs.get_event_builder_factory()
event_creation_handler = hs.get_event_creation_handler()
+ yield store.store_room(
+ room_id=room_id,
+ room_creator_user_id=creator_id,
+ is_public=False,
+ room_version=RoomVersions.V1,
+ )
+
builder = event_builder_factory.for_room_version(
RoomVersions.V1,
{
@@ -679,4 +663,4 @@ def create_room(hs, room_id, creator_id):
event, context = yield event_creation_handler.create_new_client_event(builder)
- yield store.persist_event(event, context)
+ yield persistence_store.persist_event(event, context)
|