diff options
Diffstat (limited to 'tests/utils.py')
-rw-r--r-- | tests/utils.py | 175 |
1 files changed, 174 insertions, 1 deletions
diff --git a/tests/utils.py b/tests/utils.py index 6d013e8518..983859120f 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -14,7 +14,12 @@ # limitations under the License. import atexit +import hashlib import os +import time +import uuid +import warnings +from typing import Type from unittest.mock import Mock, patch from urllib import parse as urlparse @@ -23,11 +28,14 @@ from twisted.internet import defer from synapse.api.constants import EventTypes from synapse.api.errors import CodeMessageException, cs_error from synapse.api.room_versions import RoomVersions +from synapse.config.database import DatabaseConnectionConfig from synapse.config.homeserver import HomeServerConfig from synapse.config.server import DEFAULT_ROOM_VERSION from synapse.logging.context import current_context, set_current_context +from synapse.server import HomeServer +from synapse.storage import DataStore from synapse.storage.database import LoggingDatabaseConnection -from synapse.storage.engines import create_engine +from synapse.storage.engines import PostgresEngine, create_engine from synapse.storage.prepare_database import prepare_database # set this to True to run the tests against postgres instead of sqlite. @@ -174,6 +182,171 @@ def default_config(name, parse=False): return config_dict +class TestHomeServer(HomeServer): + DATASTORE_CLASS = DataStore + + +def setup_test_homeserver( + cleanup_func, + name="test", + config=None, + reactor=None, + homeserver_to_use: Type[HomeServer] = TestHomeServer, + **kwargs, +): + """ + Setup a homeserver suitable for running tests against. Keyword arguments + are passed to the Homeserver constructor. + + If no datastore is supplied, one is created and given to the homeserver. + + Args: + cleanup_func : The function used to register a cleanup routine for + after the test. + + Calling this method directly is deprecated: you should instead derive from + HomeserverTestCase. + """ + if reactor is None: + from twisted.internet import reactor + + if config is None: + config = default_config(name, parse=True) + + config.ldap_enabled = False + + if "clock" not in kwargs: + kwargs["clock"] = MockClock() + + if USE_POSTGRES_FOR_TESTS: + test_db = "synapse_test_%s" % uuid.uuid4().hex + + database_config = { + "name": "psycopg2", + "args": { + "database": test_db, + "host": POSTGRES_HOST, + "password": POSTGRES_PASSWORD, + "user": POSTGRES_USER, + "cp_min": 1, + "cp_max": 5, + }, + } + else: + database_config = { + "name": "sqlite3", + "args": {"database": ":memory:", "cp_min": 1, "cp_max": 1}, + } + + if "db_txn_limit" in kwargs: + database_config["txn_limit"] = kwargs["db_txn_limit"] + + database = DatabaseConnectionConfig("master", database_config) + config.database.databases = [database] + + db_engine = create_engine(database.config) + + # Create the database before we actually try and connect to it, based off + # the template database we generate in setupdb() + if isinstance(db_engine, PostgresEngine): + db_conn = db_engine.module.connect( + database=POSTGRES_BASE_DB, + user=POSTGRES_USER, + host=POSTGRES_HOST, + password=POSTGRES_PASSWORD, + ) + db_conn.autocommit = True + cur = db_conn.cursor() + cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) + cur.execute( + "CREATE DATABASE %s WITH TEMPLATE %s;" % (test_db, POSTGRES_BASE_DB) + ) + cur.close() + db_conn.close() + + hs = homeserver_to_use( + name, + config=config, + version_string="Synapse/tests", + reactor=reactor, + ) + + # Install @cache_in_self attributes + for key, val in kwargs.items(): + setattr(hs, "_" + key, val) + + # Mock TLS + hs.tls_server_context_factory = Mock() + hs.tls_client_options_factory = Mock() + + hs.setup() + if homeserver_to_use == TestHomeServer: + hs.setup_background_tasks() + + if isinstance(db_engine, PostgresEngine): + database = hs.get_datastores().databases[0] + + # We need to do cleanup on PostgreSQL + def cleanup(): + import psycopg2 + + # Close all the db pools + database._db_pool.close() + + dropped = False + + # Drop the test database + db_conn = db_engine.module.connect( + database=POSTGRES_BASE_DB, + user=POSTGRES_USER, + host=POSTGRES_HOST, + password=POSTGRES_PASSWORD, + ) + db_conn.autocommit = True + cur = db_conn.cursor() + + # Try a few times to drop the DB. Some things may hold on to the + # database for a few more seconds due to flakiness, preventing + # us from dropping it when the test is over. If we can't drop + # it, warn and move on. + for _ in range(5): + try: + cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) + db_conn.commit() + dropped = True + except psycopg2.OperationalError as e: + warnings.warn( + "Couldn't drop old db: " + str(e), category=UserWarning + ) + time.sleep(0.5) + + cur.close() + db_conn.close() + + if not dropped: + warnings.warn("Failed to drop old DB.", category=UserWarning) + + if not LEAVE_DB: + # Register the cleanup hook + cleanup_func(cleanup) + + # bcrypt is far too slow to be doing in unit tests + # Need to let the HS build an auth handler and then mess with it + # because AuthHandler's constructor requires the HS, so we can't make one + # beforehand and pass it in to the HS's constructor (chicken / egg) + async def hash(p): + return hashlib.md5(p.encode("utf8")).hexdigest() + + hs.get_auth_handler().hash = hash + + async def validate_hash(p, h): + return hashlib.md5(p.encode("utf8")).hexdigest() == h + + hs.get_auth_handler().validate_hash = validate_hash + + return hs + + def mock_getRawHeaders(headers=None): headers = headers if headers is not None else {} |