diff options
Diffstat (limited to 'tests/server.py')
-rw-r--r-- | tests/server.py | 199 |
1 files changed, 9 insertions, 190 deletions
diff --git a/tests/server.py b/tests/server.py index b29df37595..40cf5b12c3 100644 --- a/tests/server.py +++ b/tests/server.py @@ -11,12 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import hashlib + import json import logging -import time -import uuid -import warnings from collections import deque from io import SEEK_END, BytesIO from typing import ( @@ -30,7 +27,6 @@ from typing import ( Type, Union, ) -from unittest.mock import Mock import attr from typing_extensions import Deque @@ -57,24 +53,11 @@ from twisted.web.http_headers import Headers from twisted.web.resource import IResource from twisted.web.server import Request, Site -from synapse.config.database import DatabaseConnectionConfig from synapse.http.site import SynapseRequest -from synapse.server import HomeServer -from synapse.storage import DataStore -from synapse.storage.engines import PostgresEngine, create_engine from synapse.types import JsonDict from synapse.util import Clock -from tests.utils import ( - LEAVE_DB, - POSTGRES_BASE_DB, - POSTGRES_HOST, - POSTGRES_PASSWORD, - POSTGRES_USER, - USE_POSTGRES_FOR_TESTS, - MockClock, - default_config, -) +from tests.utils import setup_test_homeserver as _sth logger = logging.getLogger(__name__) @@ -467,11 +450,14 @@ class ThreadPool: return d -def make_test_homeserver_synchronous(server: HomeServer) -> None: +def setup_test_homeserver(cleanup_func, *args, **kwargs): """ - Make the given test homeserver's database interactions synchronous. + Set up a synchronous test server, driven by the reactor used by + the homeserver. """ + server = _sth(cleanup_func, *args, **kwargs) + # Make the thread pool synchronous. clock = server.get_clock() for database in server.get_datastores().databases: @@ -499,7 +485,6 @@ def make_test_homeserver_synchronous(server: HomeServer) -> None: pool.runWithConnection = runWithConnection pool.runInteraction = runInteraction - # Replace the thread pool with a threadless 'thread' pool pool.threadpool = ThreadPool(clock._reactor) pool.running = True @@ -507,6 +492,8 @@ def make_test_homeserver_synchronous(server: HomeServer) -> None: # thread, so we need to disable the dedicated thread behaviour. server.get_datastores().main.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = False + return server + def get_clock() -> Tuple[ThreadedMemoryReactorClock, Clock]: clock = ThreadedMemoryReactorClock() @@ -686,171 +673,3 @@ def connect_client( client.makeConnection(FakeTransport(server, reactor)) return client, server - - -class TestHomeServer(HomeServer): - DATASTORE_CLASS = DataStore - - -def setup_test_homeserver( - cleanup_func, - name="test", - config=None, - reactor=None, - homeserver_to_use: Type[HomeServer] = TestHomeServer, - **kwargs, -): - """ - Setup a homeserver suitable for running tests against. Keyword arguments - are passed to the Homeserver constructor. - - If no datastore is supplied, one is created and given to the homeserver. - - Args: - cleanup_func : The function used to register a cleanup routine for - after the test. - - Calling this method directly is deprecated: you should instead derive from - HomeserverTestCase. - """ - if reactor is None: - from twisted.internet import reactor - - if config is None: - config = default_config(name, parse=True) - - config.ldap_enabled = False - - if "clock" not in kwargs: - kwargs["clock"] = MockClock() - - if USE_POSTGRES_FOR_TESTS: - test_db = "synapse_test_%s" % uuid.uuid4().hex - - database_config = { - "name": "psycopg2", - "args": { - "database": test_db, - "host": POSTGRES_HOST, - "password": POSTGRES_PASSWORD, - "user": POSTGRES_USER, - "cp_min": 1, - "cp_max": 5, - }, - } - else: - database_config = { - "name": "sqlite3", - "args": {"database": ":memory:", "cp_min": 1, "cp_max": 1}, - } - - if "db_txn_limit" in kwargs: - database_config["txn_limit"] = kwargs["db_txn_limit"] - - database = DatabaseConnectionConfig("master", database_config) - config.database.databases = [database] - - db_engine = create_engine(database.config) - - # Create the database before we actually try and connect to it, based off - # the template database we generate in setupdb() - if isinstance(db_engine, PostgresEngine): - db_conn = db_engine.module.connect( - database=POSTGRES_BASE_DB, - user=POSTGRES_USER, - host=POSTGRES_HOST, - password=POSTGRES_PASSWORD, - ) - db_conn.autocommit = True - cur = db_conn.cursor() - cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) - cur.execute( - "CREATE DATABASE %s WITH TEMPLATE %s;" % (test_db, POSTGRES_BASE_DB) - ) - cur.close() - db_conn.close() - - hs = homeserver_to_use( - name, - config=config, - version_string="Synapse/tests", - reactor=reactor, - ) - - # Install @cache_in_self attributes - for key, val in kwargs.items(): - setattr(hs, "_" + key, val) - - # Mock TLS - hs.tls_server_context_factory = Mock() - hs.tls_client_options_factory = Mock() - - hs.setup() - if homeserver_to_use == TestHomeServer: - hs.setup_background_tasks() - - if isinstance(db_engine, PostgresEngine): - database = hs.get_datastores().databases[0] - - # We need to do cleanup on PostgreSQL - def cleanup(): - import psycopg2 - - # Close all the db pools - database._db_pool.close() - - dropped = False - - # Drop the test database - db_conn = db_engine.module.connect( - database=POSTGRES_BASE_DB, - user=POSTGRES_USER, - host=POSTGRES_HOST, - password=POSTGRES_PASSWORD, - ) - db_conn.autocommit = True - cur = db_conn.cursor() - - # Try a few times to drop the DB. Some things may hold on to the - # database for a few more seconds due to flakiness, preventing - # us from dropping it when the test is over. If we can't drop - # it, warn and move on. - for _ in range(5): - try: - cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) - db_conn.commit() - dropped = True - except psycopg2.OperationalError as e: - warnings.warn( - "Couldn't drop old db: " + str(e), category=UserWarning - ) - time.sleep(0.5) - - cur.close() - db_conn.close() - - if not dropped: - warnings.warn("Failed to drop old DB.", category=UserWarning) - - if not LEAVE_DB: - # Register the cleanup hook - cleanup_func(cleanup) - - # bcrypt is far too slow to be doing in unit tests - # Need to let the HS build an auth handler and then mess with it - # because AuthHandler's constructor requires the HS, so we can't make one - # beforehand and pass it in to the HS's constructor (chicken / egg) - async def hash(p): - return hashlib.md5(p.encode("utf8")).hexdigest() - - hs.get_auth_handler().hash = hash - - async def validate_hash(p, h): - return hashlib.md5(p.encode("utf8")).hexdigest() == h - - hs.get_auth_handler().validate_hash = validate_hash - - # Make the threadpool and database transactions synchronous for testing. - make_test_homeserver_synchronous(hs) - - return hs |