diff options
Diffstat (limited to '')
-rw-r--r-- | tests/server.py | 112 |
1 files changed, 59 insertions, 53 deletions
diff --git a/tests/server.py b/tests/server.py index f0cc4206b0..4aaa91e956 100644 --- a/tests/server.py +++ b/tests/server.py @@ -47,7 +47,7 @@ from typing import ( Union, cast, ) -from unittest.mock import Mock +from unittest.mock import Mock, patch import attr from incremental import Version @@ -55,6 +55,7 @@ from typing_extensions import ParamSpec from zope.interface import implementer import twisted +from twisted.enterprise import adbapi from twisted.internet import address, tcp, threads, udp from twisted.internet._resolver import SimpleResolverComplexifier from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed @@ -94,8 +95,8 @@ from synapse.module_api.callbacks.third_party_event_rules_callbacks import ( ) from synapse.server import HomeServer from synapse.storage import DataStore -from synapse.storage.database import LoggingDatabaseConnection -from synapse.storage.engines import create_engine +from synapse.storage.database import LoggingDatabaseConnection, make_pool +from synapse.storage.engines import BaseDatabaseEngine, create_engine from synapse.storage.prepare_database import prepare_database from synapse.types import ISynapseReactor, JsonDict from synapse.util import Clock @@ -670,6 +671,53 @@ def validate_connector(connector: tcp.Connector, expected_ip: str) -> None: ) +def make_fake_db_pool( + reactor: ISynapseReactor, + db_config: DatabaseConnectionConfig, + engine: BaseDatabaseEngine, +) -> adbapi.ConnectionPool: + """Wrapper for `make_pool` which builds a pool which runs db queries synchronously. + + For more deterministic testing, we don't use a regular db connection pool: instead + we run all db queries synchronously on the test reactor's main thread. This function + is a drop-in replacement for the normal `make_pool` which builds such a connection + pool. + """ + pool = make_pool(reactor, db_config, engine) + + def runWithConnection( + func: Callable[..., R], *args: Any, **kwargs: Any + ) -> Awaitable[R]: + return threads.deferToThreadPool( + pool._reactor, + pool.threadpool, + pool._runWithConnection, + func, + *args, + **kwargs, + ) + + def runInteraction( + desc: str, func: Callable[..., R], *args: Any, **kwargs: Any + ) -> Awaitable[R]: + return threads.deferToThreadPool( + pool._reactor, + pool.threadpool, + pool._runInteraction, + desc, + func, + *args, + **kwargs, + ) + + pool.runWithConnection = runWithConnection # type: ignore[method-assign] + pool.runInteraction = runInteraction # type: ignore[assignment] + # Replace the thread pool with a threadless 'thread' pool + pool.threadpool = ThreadPool(reactor) + pool.running = True + return pool + + class ThreadPool: """ Threadless thread pool. @@ -706,52 +754,6 @@ class ThreadPool: return d -def _make_test_homeserver_synchronous(server: HomeServer) -> None: - """ - Make the given test homeserver's database interactions synchronous. - """ - - clock = server.get_clock() - - for database in server.get_datastores().databases: - pool = database._db_pool - - def runWithConnection( - func: Callable[..., R], *args: Any, **kwargs: Any - ) -> Awaitable[R]: - return threads.deferToThreadPool( - pool._reactor, - pool.threadpool, - pool._runWithConnection, - func, - *args, - **kwargs, - ) - - def runInteraction( - desc: str, func: Callable[..., R], *args: Any, **kwargs: Any - ) -> Awaitable[R]: - return threads.deferToThreadPool( - pool._reactor, - pool.threadpool, - pool._runInteraction, - desc, - func, - *args, - **kwargs, - ) - - pool.runWithConnection = runWithConnection # type: ignore[method-assign] - pool.runInteraction = runInteraction # type: ignore[assignment] - # Replace the thread pool with a threadless 'thread' pool - pool.threadpool = ThreadPool(clock._reactor) - pool.running = True - - # We've just changed the Databases to run DB transactions on the same - # thread, so we need to disable the dedicated thread behaviour. - server.get_datastores().main.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = False - - def get_clock() -> Tuple[ThreadedMemoryReactorClock, Clock]: clock = ThreadedMemoryReactorClock() hs_clock = Clock(clock) @@ -1067,7 +1069,14 @@ def setup_test_homeserver( # Mock TLS hs.tls_server_context_factory = Mock() - hs.setup() + # Patch `make_pool` before initialising the database, to make database transactions + # synchronous for testing. + with patch("synapse.storage.database.make_pool", side_effect=make_fake_db_pool): + hs.setup() + + # Since we've changed the databases to run DB transactions on the same + # thread, we need to stop the event fetcher hogging that one thread. + hs.get_datastores().main.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = False if USE_POSTGRES_FOR_TESTS: database_pool = hs.get_datastores().databases[0] @@ -1137,9 +1146,6 @@ def setup_test_homeserver( hs.get_auth_handler().validate_hash = validate_hash # type: ignore[assignment] - # Make the threadpool and database transactions synchronous for testing. - _make_test_homeserver_synchronous(hs) - # Load any configured modules into the homeserver module_api = hs.get_module_api() for module, module_config in hs.config.modules.loaded_modules: |