summary refs log tree commit diff
path: root/tests/server.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tests/server.py55
1 files changed, 28 insertions, 27 deletions
diff --git a/tests/server.py b/tests/server.py
index 2b7cf4242e..a554dfdd57 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -302,41 +302,42 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
     Set up a synchronous test server, driven by the reactor used by
     the homeserver.
     """
-    d = _sth(cleanup_func, *args, **kwargs).result
+    server = _sth(cleanup_func, *args, **kwargs)
 
-    if isinstance(d, Failure):
-        d.raiseException()
+    database = server.config.database.get_single_database()
 
     # Make the thread pool synchronous.
-    clock = d.get_clock()
-    pool = d.get_db_pool()
-
-    def runWithConnection(func, *args, **kwargs):
-        return threads.deferToThreadPool(
-            pool._reactor,
-            pool.threadpool,
-            pool._runWithConnection,
-            func,
-            *args,
-            **kwargs
-        )
-
-    def runInteraction(interaction, *args, **kwargs):
-        return threads.deferToThreadPool(
-            pool._reactor,
-            pool.threadpool,
-            pool._runInteraction,
-            interaction,
-            *args,
-            **kwargs
-        )
+    clock = server.get_clock()
+
+    for database in server.get_datastores().databases:
+        pool = database._db_pool
+
+        def runWithConnection(func, *args, **kwargs):
+            return threads.deferToThreadPool(
+                pool._reactor,
+                pool.threadpool,
+                pool._runWithConnection,
+                func,
+                *args,
+                **kwargs
+            )
+
+        def runInteraction(interaction, *args, **kwargs):
+            return threads.deferToThreadPool(
+                pool._reactor,
+                pool.threadpool,
+                pool._runInteraction,
+                interaction,
+                *args,
+                **kwargs
+            )
 
-    if pool:
         pool.runWithConnection = runWithConnection
         pool.runInteraction = runInteraction
         pool.threadpool = ThreadPool(clock._reactor)
         pool.running = True
-    return d
+
+    return server
 
 
 def get_clock():