summary refs log tree commit diff
path: root/tests/server.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/server.py')
-rw-r--r--tests/server.py112
1 files changed, 59 insertions, 53 deletions
diff --git a/tests/server.py b/tests/server.py
index f0cc4206b0..4aaa91e956 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -47,7 +47,7 @@ from typing import (
     Union,
     cast,
 )
-from unittest.mock import Mock
+from unittest.mock import Mock, patch
 
 import attr
 from incremental import Version
@@ -55,6 +55,7 @@ from typing_extensions import ParamSpec
 from zope.interface import implementer
 
 import twisted
+from twisted.enterprise import adbapi
 from twisted.internet import address, tcp, threads, udp
 from twisted.internet._resolver import SimpleResolverComplexifier
 from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed
@@ -94,8 +95,8 @@ from synapse.module_api.callbacks.third_party_event_rules_callbacks import (
 )
 from synapse.server import HomeServer
 from synapse.storage import DataStore
-from synapse.storage.database import LoggingDatabaseConnection
-from synapse.storage.engines import create_engine
+from synapse.storage.database import LoggingDatabaseConnection, make_pool
+from synapse.storage.engines import BaseDatabaseEngine, create_engine
 from synapse.storage.prepare_database import prepare_database
 from synapse.types import ISynapseReactor, JsonDict
 from synapse.util import Clock
@@ -670,6 +671,53 @@ def validate_connector(connector: tcp.Connector, expected_ip: str) -> None:
         )
 
 
+def make_fake_db_pool(
+    reactor: ISynapseReactor,
+    db_config: DatabaseConnectionConfig,
+    engine: BaseDatabaseEngine,
+) -> adbapi.ConnectionPool:
+    """Wrapper for `make_pool` which builds a pool which runs db queries synchronously.
+
+    For more deterministic testing, we don't use a regular db connection pool: instead
+    we run all db queries synchronously on the test reactor's main thread. This function
+    is a drop-in replacement for the normal `make_pool` which builds such a connection
+    pool.
+    """
+    pool = make_pool(reactor, db_config, engine)
+
+    def runWithConnection(
+        func: Callable[..., R], *args: Any, **kwargs: Any
+    ) -> Awaitable[R]:
+        return threads.deferToThreadPool(
+            pool._reactor,
+            pool.threadpool,
+            pool._runWithConnection,
+            func,
+            *args,
+            **kwargs,
+        )
+
+    def runInteraction(
+        desc: str, func: Callable[..., R], *args: Any, **kwargs: Any
+    ) -> Awaitable[R]:
+        return threads.deferToThreadPool(
+            pool._reactor,
+            pool.threadpool,
+            pool._runInteraction,
+            desc,
+            func,
+            *args,
+            **kwargs,
+        )
+
+    pool.runWithConnection = runWithConnection  # type: ignore[method-assign]
+    pool.runInteraction = runInteraction  # type: ignore[assignment]
+    # Replace the thread pool with a threadless 'thread' pool
+    pool.threadpool = ThreadPool(reactor)
+    pool.running = True
+    return pool
+
+
 class ThreadPool:
     """
     Threadless thread pool.
@@ -706,52 +754,6 @@ class ThreadPool:
         return d
 
 
-def _make_test_homeserver_synchronous(server: HomeServer) -> None:
-    """
-    Make the given test homeserver's database interactions synchronous.
-    """
-
-    clock = server.get_clock()
-
-    for database in server.get_datastores().databases:
-        pool = database._db_pool
-
-        def runWithConnection(
-            func: Callable[..., R], *args: Any, **kwargs: Any
-        ) -> Awaitable[R]:
-            return threads.deferToThreadPool(
-                pool._reactor,
-                pool.threadpool,
-                pool._runWithConnection,
-                func,
-                *args,
-                **kwargs,
-            )
-
-        def runInteraction(
-            desc: str, func: Callable[..., R], *args: Any, **kwargs: Any
-        ) -> Awaitable[R]:
-            return threads.deferToThreadPool(
-                pool._reactor,
-                pool.threadpool,
-                pool._runInteraction,
-                desc,
-                func,
-                *args,
-                **kwargs,
-            )
-
-        pool.runWithConnection = runWithConnection  # type: ignore[method-assign]
-        pool.runInteraction = runInteraction  # type: ignore[assignment]
-        # Replace the thread pool with a threadless 'thread' pool
-        pool.threadpool = ThreadPool(clock._reactor)
-        pool.running = True
-
-    # We've just changed the Databases to run DB transactions on the same
-    # thread, so we need to disable the dedicated thread behaviour.
-    server.get_datastores().main.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = False
-
-
 def get_clock() -> Tuple[ThreadedMemoryReactorClock, Clock]:
     clock = ThreadedMemoryReactorClock()
     hs_clock = Clock(clock)
@@ -1067,7 +1069,14 @@ def setup_test_homeserver(
     # Mock TLS
     hs.tls_server_context_factory = Mock()
 
-    hs.setup()
+    # Patch `make_pool` before initialising the database, to make database transactions
+    # synchronous for testing.
+    with patch("synapse.storage.database.make_pool", side_effect=make_fake_db_pool):
+        hs.setup()
+
+    # Since we've changed the databases to run DB transactions on the same
+    # thread, we need to stop the event fetcher hogging that one thread.
+    hs.get_datastores().main.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = False
 
     if USE_POSTGRES_FOR_TESTS:
         database_pool = hs.get_datastores().databases[0]
@@ -1137,9 +1146,6 @@ def setup_test_homeserver(
 
     hs.get_auth_handler().validate_hash = validate_hash  # type: ignore[assignment]
 
-    # Make the threadpool and database transactions synchronous for testing.
-    _make_test_homeserver_synchronous(hs)
-
     # Load any configured modules into the homeserver
     module_api = hs.get_module_api()
     for module, module_config in hs.config.modules.loaded_modules: