diff --git a/tests/unittest.py b/tests/unittest.py
index 295573bc46..b30b7d1718 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -18,6 +18,7 @@
import gc
import hashlib
import hmac
+import inspect
import logging
import time
@@ -25,7 +26,7 @@ from mock import Mock
from canonicaljson import json
-from twisted.internet.defer import Deferred, succeed
+from twisted.internet.defer import Deferred, ensureDeferred, succeed
from twisted.python.threadpool import ThreadPool
from twisted.trial import unittest
@@ -401,10 +402,12 @@ class HomeserverTestCase(TestCase):
hs = setup_test_homeserver(self.addCleanup, *args, **kwargs)
stor = hs.get_datastore()
- # Run the database background updates.
- if hasattr(stor, "do_next_background_update"):
- while not self.get_success(stor.has_completed_background_updates()):
- self.get_success(stor.do_next_background_update(1))
+ # Run the database background updates, when running against "master".
+ if hs.__class__.__name__ == "TestHomeServer":
+ while not self.get_success(
+ stor.db.updates.has_completed_background_updates()
+ ):
+ self.get_success(stor.db.updates.do_next_background_update(1))
return hs
@@ -415,6 +418,8 @@ class HomeserverTestCase(TestCase):
self.reactor.pump([by] * 100)
def get_success(self, d, by=0.0):
+ if inspect.isawaitable(d):
+ d = ensureDeferred(d)
if not isinstance(d, Deferred):
return d
self.pump(by=by)
@@ -424,6 +429,8 @@ class HomeserverTestCase(TestCase):
"""
Run a Deferred and get a Failure from it. The failure must be of the type `exc`.
"""
+ if inspect.isawaitable(d):
+ d = ensureDeferred(d)
if not isinstance(d, Deferred):
return d
self.pump()
@@ -544,7 +551,7 @@ class HomeserverTestCase(TestCase):
Add the given event as an extremity to the room.
"""
self.get_success(
- self.hs.get_datastore().simple_insert(
+ self.hs.get_datastore().db.simple_insert(
table="event_forward_extremities",
values={"room_id": room_id, "event_id": event_id},
desc="test_add_extremity",
|