diff --git a/changelog.d/17017.misc b/changelog.d/17017.misc
new file mode 100644
index 0000000000..c8af23d67a
--- /dev/null
+++ b/changelog.d/17017.misc
@@ -0,0 +1 @@
+Patch the db conn pool sooner in tests.
diff --git a/tests/server.py b/tests/server.py
index f0cc4206b0..4aaa91e956 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -47,7 +47,7 @@ from typing import (
Union,
cast,
)
-from unittest.mock import Mock
+from unittest.mock import Mock, patch
import attr
from incremental import Version
@@ -55,6 +55,7 @@ from typing_extensions import ParamSpec
from zope.interface import implementer
import twisted
+from twisted.enterprise import adbapi
from twisted.internet import address, tcp, threads, udp
from twisted.internet._resolver import SimpleResolverComplexifier
from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed
@@ -94,8 +95,8 @@ from synapse.module_api.callbacks.third_party_event_rules_callbacks import (
)
from synapse.server import HomeServer
from synapse.storage import DataStore
-from synapse.storage.database import LoggingDatabaseConnection
-from synapse.storage.engines import create_engine
+from synapse.storage.database import LoggingDatabaseConnection, make_pool
+from synapse.storage.engines import BaseDatabaseEngine, create_engine
from synapse.storage.prepare_database import prepare_database
from synapse.types import ISynapseReactor, JsonDict
from synapse.util import Clock
@@ -670,6 +671,53 @@ def validate_connector(connector: tcp.Connector, expected_ip: str) -> None:
)
+def make_fake_db_pool(
+ reactor: ISynapseReactor,
+ db_config: DatabaseConnectionConfig,
+ engine: BaseDatabaseEngine,
+) -> adbapi.ConnectionPool:
+ """Wrapper for `make_pool` which builds a pool which runs db queries synchronously.
+
+ For more deterministic testing, we don't use a regular db connection pool: instead
+ we run all db queries synchronously on the test reactor's main thread. This function
+ is a drop-in replacement for the normal `make_pool` which builds such a connection
+ pool.
+ """
+ pool = make_pool(reactor, db_config, engine)
+
+ def runWithConnection(
+ func: Callable[..., R], *args: Any, **kwargs: Any
+ ) -> Awaitable[R]:
+ return threads.deferToThreadPool(
+ pool._reactor,
+ pool.threadpool,
+ pool._runWithConnection,
+ func,
+ *args,
+ **kwargs,
+ )
+
+ def runInteraction(
+ desc: str, func: Callable[..., R], *args: Any, **kwargs: Any
+ ) -> Awaitable[R]:
+ return threads.deferToThreadPool(
+ pool._reactor,
+ pool.threadpool,
+ pool._runInteraction,
+ desc,
+ func,
+ *args,
+ **kwargs,
+ )
+
+ pool.runWithConnection = runWithConnection # type: ignore[method-assign]
+ pool.runInteraction = runInteraction # type: ignore[assignment]
+ # Replace the thread pool with a threadless 'thread' pool
+ pool.threadpool = ThreadPool(reactor)
+ pool.running = True
+ return pool
+
+
class ThreadPool:
"""
Threadless thread pool.
@@ -706,52 +754,6 @@ class ThreadPool:
return d
-def _make_test_homeserver_synchronous(server: HomeServer) -> None:
- """
- Make the given test homeserver's database interactions synchronous.
- """
-
- clock = server.get_clock()
-
- for database in server.get_datastores().databases:
- pool = database._db_pool
-
- def runWithConnection(
- func: Callable[..., R], *args: Any, **kwargs: Any
- ) -> Awaitable[R]:
- return threads.deferToThreadPool(
- pool._reactor,
- pool.threadpool,
- pool._runWithConnection,
- func,
- *args,
- **kwargs,
- )
-
- def runInteraction(
- desc: str, func: Callable[..., R], *args: Any, **kwargs: Any
- ) -> Awaitable[R]:
- return threads.deferToThreadPool(
- pool._reactor,
- pool.threadpool,
- pool._runInteraction,
- desc,
- func,
- *args,
- **kwargs,
- )
-
- pool.runWithConnection = runWithConnection # type: ignore[method-assign]
- pool.runInteraction = runInteraction # type: ignore[assignment]
- # Replace the thread pool with a threadless 'thread' pool
- pool.threadpool = ThreadPool(clock._reactor)
- pool.running = True
-
- # We've just changed the Databases to run DB transactions on the same
- # thread, so we need to disable the dedicated thread behaviour.
- server.get_datastores().main.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = False
-
-
def get_clock() -> Tuple[ThreadedMemoryReactorClock, Clock]:
clock = ThreadedMemoryReactorClock()
hs_clock = Clock(clock)
@@ -1067,7 +1069,14 @@ def setup_test_homeserver(
# Mock TLS
hs.tls_server_context_factory = Mock()
- hs.setup()
+ # Patch `make_pool` before initialising the database, to make database transactions
+ # synchronous for testing.
+ with patch("synapse.storage.database.make_pool", side_effect=make_fake_db_pool):
+ hs.setup()
+
+ # Since we've changed the databases to run DB transactions on the same
+ # thread, we need to stop the event fetcher hogging that one thread.
+ hs.get_datastores().main.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = False
if USE_POSTGRES_FOR_TESTS:
database_pool = hs.get_datastores().databases[0]
@@ -1137,9 +1146,6 @@ def setup_test_homeserver(
hs.get_auth_handler().validate_hash = validate_hash # type: ignore[assignment]
- # Make the threadpool and database transactions synchronous for testing.
- _make_test_homeserver_synchronous(hs)
-
# Load any configured modules into the homeserver
module_api = hs.get_module_api()
for module, module_config in hs.config.modules.loaded_modules:
|