diff --git a/tests/utils.py b/tests/utils.py
index 6d013e8518..983859120f 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -14,7 +14,12 @@
# limitations under the License.
import atexit
+import hashlib
import os
+import time
+import uuid
+import warnings
+from typing import Type
from unittest.mock import Mock, patch
from urllib import parse as urlparse
@@ -23,11 +28,14 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes
from synapse.api.errors import CodeMessageException, cs_error
from synapse.api.room_versions import RoomVersions
+from synapse.config.database import DatabaseConnectionConfig
from synapse.config.homeserver import HomeServerConfig
from synapse.config.server import DEFAULT_ROOM_VERSION
from synapse.logging.context import current_context, set_current_context
+from synapse.server import HomeServer
+from synapse.storage import DataStore
from synapse.storage.database import LoggingDatabaseConnection
-from synapse.storage.engines import create_engine
+from synapse.storage.engines import PostgresEngine, create_engine
from synapse.storage.prepare_database import prepare_database
# set this to True to run the tests against postgres instead of sqlite.
@@ -174,6 +182,171 @@ def default_config(name, parse=False):
return config_dict
+class TestHomeServer(HomeServer):
+ DATASTORE_CLASS = DataStore
+
+
+def setup_test_homeserver(
+ cleanup_func,
+ name="test",
+ config=None,
+ reactor=None,
+ homeserver_to_use: Type[HomeServer] = TestHomeServer,
+ **kwargs,
+):
+ """
+ Setup a homeserver suitable for running tests against. Keyword arguments
+ are passed to the Homeserver constructor.
+
+ If no datastore is supplied, one is created and given to the homeserver.
+
+ Args:
+ cleanup_func : The function used to register a cleanup routine for
+ after the test.
+
+ Calling this method directly is deprecated: you should instead derive from
+ HomeserverTestCase.
+ """
+ if reactor is None:
+ from twisted.internet import reactor
+
+ if config is None:
+ config = default_config(name, parse=True)
+
+ config.ldap_enabled = False
+
+ if "clock" not in kwargs:
+ kwargs["clock"] = MockClock()
+
+ if USE_POSTGRES_FOR_TESTS:
+ test_db = "synapse_test_%s" % uuid.uuid4().hex
+
+ database_config = {
+ "name": "psycopg2",
+ "args": {
+ "database": test_db,
+ "host": POSTGRES_HOST,
+ "password": POSTGRES_PASSWORD,
+ "user": POSTGRES_USER,
+ "cp_min": 1,
+ "cp_max": 5,
+ },
+ }
+ else:
+ database_config = {
+ "name": "sqlite3",
+ "args": {"database": ":memory:", "cp_min": 1, "cp_max": 1},
+ }
+
+ if "db_txn_limit" in kwargs:
+ database_config["txn_limit"] = kwargs["db_txn_limit"]
+
+ database = DatabaseConnectionConfig("master", database_config)
+ config.database.databases = [database]
+
+ db_engine = create_engine(database.config)
+
+ # Create the database before we actually try and connect to it, based off
+ # the template database we generate in setupdb()
+ if isinstance(db_engine, PostgresEngine):
+ db_conn = db_engine.module.connect(
+ database=POSTGRES_BASE_DB,
+ user=POSTGRES_USER,
+ host=POSTGRES_HOST,
+ password=POSTGRES_PASSWORD,
+ )
+ db_conn.autocommit = True
+ cur = db_conn.cursor()
+ cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
+ cur.execute(
+ "CREATE DATABASE %s WITH TEMPLATE %s;" % (test_db, POSTGRES_BASE_DB)
+ )
+ cur.close()
+ db_conn.close()
+
+ hs = homeserver_to_use(
+ name,
+ config=config,
+ version_string="Synapse/tests",
+ reactor=reactor,
+ )
+
+ # Install @cache_in_self attributes
+ for key, val in kwargs.items():
+ setattr(hs, "_" + key, val)
+
+ # Mock TLS
+ hs.tls_server_context_factory = Mock()
+ hs.tls_client_options_factory = Mock()
+
+ hs.setup()
+ if homeserver_to_use == TestHomeServer:
+ hs.setup_background_tasks()
+
+ if isinstance(db_engine, PostgresEngine):
+ database = hs.get_datastores().databases[0]
+
+ # We need to do cleanup on PostgreSQL
+ def cleanup():
+ import psycopg2
+
+ # Close all the db pools
+ database._db_pool.close()
+
+ dropped = False
+
+ # Drop the test database
+ db_conn = db_engine.module.connect(
+ database=POSTGRES_BASE_DB,
+ user=POSTGRES_USER,
+ host=POSTGRES_HOST,
+ password=POSTGRES_PASSWORD,
+ )
+ db_conn.autocommit = True
+ cur = db_conn.cursor()
+
+ # Try a few times to drop the DB. Some things may hold on to the
+ # database for a few more seconds due to flakiness, preventing
+ # us from dropping it when the test is over. If we can't drop
+ # it, warn and move on.
+ for _ in range(5):
+ try:
+ cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
+ db_conn.commit()
+ dropped = True
+ except psycopg2.OperationalError as e:
+ warnings.warn(
+ "Couldn't drop old db: " + str(e), category=UserWarning
+ )
+ time.sleep(0.5)
+
+ cur.close()
+ db_conn.close()
+
+ if not dropped:
+ warnings.warn("Failed to drop old DB.", category=UserWarning)
+
+ if not LEAVE_DB:
+ # Register the cleanup hook
+ cleanup_func(cleanup)
+
+ # bcrypt is far too slow to be doing in unit tests
+ # Need to let the HS build an auth handler and then mess with it
+ # because AuthHandler's constructor requires the HS, so we can't make one
+ # beforehand and pass it in to the HS's constructor (chicken / egg)
+ async def hash(p):
+ return hashlib.md5(p.encode("utf8")).hexdigest()
+
+ hs.get_auth_handler().hash = hash
+
+ async def validate_hash(p, h):
+ return hashlib.md5(p.encode("utf8")).hexdigest() == h
+
+ hs.get_auth_handler().validate_hash = validate_hash
+
+ return hs
+
+
def mock_getRawHeaders(headers=None):
headers = headers if headers is not None else {}
|