summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2020-08-27 17:24:46 -0400
committerGitHub <noreply@github.com>2020-08-27 17:24:46 -0400
commite00816ad98a1165b67238f9711cb1b0e7135f25f (patch)
tree3b0135e8c57b2b9082191088237298f7b50615f7 /tests
parentConvert stats and related calls to async/await (#8192) (diff)
downloadsynapse-e00816ad98a1165b67238f9711cb1b0e7135f25f.tar.xz
Do not yield on awaitables in tests. (#8193)
Diffstat (limited to '')
-rw-r--r--tests/api/test_filtering.py36
-rw-r--r--tests/crypto/test_keyring.py4
-rw-r--r--tests/federation/test_complexity.py6
-rw-r--r--tests/handlers/test_typing.py4
-rw-r--r--tests/rest/client/v2_alpha/test_filter.py8
-rw-r--r--tests/storage/test_appservice.py24
-rw-r--r--tests/storage/test_background_update.py9
-rw-r--r--tests/storage/test_end_to_end_keys.py32
-rw-r--r--tests/storage/test_event_push_actions.py78
-rw-r--r--tests/storage/test_main.py8
-rw-r--r--tests/storage/test_registration.py44
-rw-r--r--tests/storage/test_user_directory.py16
-rw-r--r--tests/test_state.py86
-rw-r--r--tests/test_visibility.py5
14 files changed, 229 insertions, 131 deletions
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index 1fab1d6b69..d2d535d23c 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -369,8 +369,10 @@ class FilteringTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_filter_presence_match(self):
         user_filter_json = {"presence": {"types": ["m.*"]}}
-        filter_id = yield self.datastore.add_user_filter(
-            user_localpart=user_localpart, user_filter=user_filter_json
+        filter_id = yield defer.ensureDeferred(
+            self.datastore.add_user_filter(
+                user_localpart=user_localpart, user_filter=user_filter_json
+            )
         )
         event = MockEvent(sender="@foo:bar", type="m.profile")
         events = [event]
@@ -388,8 +390,10 @@ class FilteringTestCase(unittest.TestCase):
     def test_filter_presence_no_match(self):
         user_filter_json = {"presence": {"types": ["m.*"]}}
 
-        filter_id = yield self.datastore.add_user_filter(
-            user_localpart=user_localpart + "2", user_filter=user_filter_json
+        filter_id = yield defer.ensureDeferred(
+            self.datastore.add_user_filter(
+                user_localpart=user_localpart + "2", user_filter=user_filter_json
+            )
         )
         event = MockEvent(
             event_id="$asdasd:localhost",
@@ -410,8 +414,10 @@ class FilteringTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_filter_room_state_match(self):
         user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
-        filter_id = yield self.datastore.add_user_filter(
-            user_localpart=user_localpart, user_filter=user_filter_json
+        filter_id = yield defer.ensureDeferred(
+            self.datastore.add_user_filter(
+                user_localpart=user_localpart, user_filter=user_filter_json
+            )
         )
         event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar")
         events = [event]
@@ -428,8 +434,10 @@ class FilteringTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_filter_room_state_no_match(self):
         user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
-        filter_id = yield self.datastore.add_user_filter(
-            user_localpart=user_localpart, user_filter=user_filter_json
+        filter_id = yield defer.ensureDeferred(
+            self.datastore.add_user_filter(
+                user_localpart=user_localpart, user_filter=user_filter_json
+            )
         )
         event = MockEvent(
             sender="@foo:bar", type="org.matrix.custom.event", room_id="!foo:bar"
@@ -465,8 +473,10 @@ class FilteringTestCase(unittest.TestCase):
     def test_add_filter(self):
         user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
 
-        filter_id = yield self.filtering.add_user_filter(
-            user_localpart=user_localpart, user_filter=user_filter_json
+        filter_id = yield defer.ensureDeferred(
+            self.filtering.add_user_filter(
+                user_localpart=user_localpart, user_filter=user_filter_json
+            )
         )
 
         self.assertEquals(filter_id, 0)
@@ -485,8 +495,10 @@ class FilteringTestCase(unittest.TestCase):
     def test_get_filter(self):
         user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
 
-        filter_id = yield self.datastore.add_user_filter(
-            user_localpart=user_localpart, user_filter=user_filter_json
+        filter_id = yield defer.ensureDeferred(
+            self.datastore.add_user_filter(
+                user_localpart=user_localpart, user_filter=user_filter_json
+            )
         )
 
         filter = yield defer.ensureDeferred(
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 0d4b05304b..d264653e74 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -190,7 +190,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
 
         # should fail immediately on an unsigned object
         d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned")
-        self.failureResultOf(d, SynapseError)
+        self.get_failure(d, SynapseError)
 
         # should succeed on a signed object
         d = _verify_json_for_server(kr, "server9", json1, 500, "test signed")
@@ -221,7 +221,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
 
         # should fail immediately on an unsigned object
         d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned")
-        self.failureResultOf(d, SynapseError)
+        self.get_failure(d, SynapseError)
 
         # should fail on a signed object with a non-zero minimum_valid_until_ms,
         # as it tries to refetch the keys and fails.
diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py
index 9bd515080c..3d880c499d 100644
--- a/tests/federation/test_complexity.py
+++ b/tests/federation/test_complexity.py
@@ -15,8 +15,6 @@
 
 from mock import Mock
 
-from twisted.internet import defer
-
 from synapse.api.errors import Codes, SynapseError
 from synapse.rest import admin
 from synapse.rest.client.v1 import login, room
@@ -60,7 +58,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
 
         # Artificially raise the complexity
         store = self.hs.get_datastore()
-        store.get_current_state_event_counts = lambda x: defer.succeed(500 * 1.23)
+        store.get_current_state_event_counts = lambda x: make_awaitable(500 * 1.23)
 
         # Get the room complexity again -- make sure it's our artificial value
         request, channel = self.make_request(
@@ -160,7 +158,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
         )
 
         # Artificially raise the complexity
-        self.hs.get_datastore().get_current_state_event_counts = lambda x: defer.succeed(
+        self.hs.get_datastore().get_current_state_event_counts = lambda x: make_awaitable(
             600
         )
 
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 81c1839637..7bf15c4ba9 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -155,7 +155,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
         self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: make_awaitable(
             ([], 0)
         )
-        self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None
+        self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: make_awaitable(
+            None
+        )
         self.datastore.set_received_txn_response = lambda *args, **kwargs: make_awaitable(
             None
         )
diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py
index e0e9e94fbf..de00350580 100644
--- a/tests/rest/client/v2_alpha/test_filter.py
+++ b/tests/rest/client/v2_alpha/test_filter.py
@@ -13,6 +13,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from twisted.internet import defer
+
 from synapse.api.errors import Codes
 from synapse.rest.client.v2_alpha import filter
 
@@ -73,8 +75,10 @@ class FilterTestCase(unittest.HomeserverTestCase):
         self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN)
 
     def test_get_filter(self):
-        filter_id = self.filtering.add_user_filter(
-            user_localpart="apple", user_filter=self.EXAMPLE_FILTER
+        filter_id = defer.ensureDeferred(
+            self.filtering.add_user_filter(
+                user_localpart="apple", user_filter=self.EXAMPLE_FILTER
+            )
         )
         self.reactor.advance(1)
         filter_id = filter_id.result
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 17fbde284a..cb808d4de4 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -243,7 +243,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
     def test_create_appservice_txn_first(self):
         service = Mock(id=self.as_list[0]["id"])
         events = [Mock(event_id="e1"), Mock(event_id="e2")]
-        txn = yield self.store.create_appservice_txn(service, events)
+        txn = yield defer.ensureDeferred(
+            self.store.create_appservice_txn(service, events)
+        )
         self.assertEquals(txn.id, 1)
         self.assertEquals(txn.events, events)
         self.assertEquals(txn.service, service)
@@ -255,7 +257,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         yield self._set_last_txn(service.id, 9643)  # AS is falling behind
         yield self._insert_txn(service.id, 9644, events)
         yield self._insert_txn(service.id, 9645, events)
-        txn = yield self.store.create_appservice_txn(service, events)
+        txn = yield defer.ensureDeferred(
+            self.store.create_appservice_txn(service, events)
+        )
         self.assertEquals(txn.id, 9646)
         self.assertEquals(txn.events, events)
         self.assertEquals(txn.service, service)
@@ -265,7 +269,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         service = Mock(id=self.as_list[0]["id"])
         events = [Mock(event_id="e1"), Mock(event_id="e2")]
         yield self._set_last_txn(service.id, 9643)
-        txn = yield self.store.create_appservice_txn(service, events)
+        txn = yield defer.ensureDeferred(
+            self.store.create_appservice_txn(service, events)
+        )
         self.assertEquals(txn.id, 9644)
         self.assertEquals(txn.events, events)
         self.assertEquals(txn.service, service)
@@ -286,7 +292,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         yield self._insert_txn(self.as_list[2]["id"], 10, events)
         yield self._insert_txn(self.as_list[3]["id"], 9643, events)
 
-        txn = yield self.store.create_appservice_txn(service, events)
+        txn = yield defer.ensureDeferred(
+            self.store.create_appservice_txn(service, events)
+        )
         self.assertEquals(txn.id, 9644)
         self.assertEquals(txn.events, events)
         self.assertEquals(txn.service, service)
@@ -298,7 +306,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         txn_id = 1
 
         yield self._insert_txn(service.id, txn_id, events)
-        yield self.store.complete_appservice_txn(txn_id=txn_id, service=service)
+        yield defer.ensureDeferred(
+            self.store.complete_appservice_txn(txn_id=txn_id, service=service)
+        )
 
         res = yield self.db_pool.runQuery(
             self.engine.convert_param_style(
@@ -324,7 +334,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         txn_id = 5
         yield self._set_last_txn(service.id, 4)
         yield self._insert_txn(service.id, txn_id, events)
-        yield self.store.complete_appservice_txn(txn_id=txn_id, service=service)
+        yield defer.ensureDeferred(
+            self.store.complete_appservice_txn(txn_id=txn_id, service=service)
+        )
 
         res = yield self.db_pool.runQuery(
             self.engine.convert_param_style(
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index 1a1c59256c..02aae1c13d 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -1,7 +1,5 @@
 from mock import Mock
 
-from twisted.internet import defer
-
 from synapse.storage.background_updates import BackgroundUpdater
 
 from tests import unittest
@@ -38,11 +36,10 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
         )
 
         # first step: make a bit of progress
-        @defer.inlineCallbacks
-        def update(progress, count):
-            yield self.clock.sleep((count * duration_ms) / 1000)
+        async def update(progress, count):
+            await self.clock.sleep((count * duration_ms) / 1000)
             progress = {"my_key": progress["my_key"] + 1}
-            yield store.db_pool.runInteraction(
+            await store.db_pool.runInteraction(
                 "update_progress",
                 self.updates._background_update_progress_txn,
                 "test_update",
diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py
index d57cdffd8b..261bf5b08b 100644
--- a/tests/storage/test_end_to_end_keys.py
+++ b/tests/storage/test_end_to_end_keys.py
@@ -32,7 +32,9 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
 
         yield defer.ensureDeferred(self.store.store_device("user", "device", None))
 
-        yield self.store.set_e2e_device_keys("user", "device", now, json)
+        yield defer.ensureDeferred(
+            self.store.set_e2e_device_keys("user", "device", now, json)
+        )
 
         res = yield defer.ensureDeferred(
             self.store.get_e2e_device_keys((("user", "device"),))
@@ -49,12 +51,16 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
 
         yield defer.ensureDeferred(self.store.store_device("user", "device", None))
 
-        changed = yield self.store.set_e2e_device_keys("user", "device", now, json)
+        changed = yield defer.ensureDeferred(
+            self.store.set_e2e_device_keys("user", "device", now, json)
+        )
         self.assertTrue(changed)
 
         # If we try to upload the same key then we should be told nothing
         # changed
-        changed = yield self.store.set_e2e_device_keys("user", "device", now, json)
+        changed = yield defer.ensureDeferred(
+            self.store.set_e2e_device_keys("user", "device", now, json)
+        )
         self.assertFalse(changed)
 
     @defer.inlineCallbacks
@@ -62,7 +68,9 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
         now = 1470174257070
         json = {"key": "value"}
 
-        yield self.store.set_e2e_device_keys("user", "device", now, json)
+        yield defer.ensureDeferred(
+            self.store.set_e2e_device_keys("user", "device", now, json)
+        )
         yield defer.ensureDeferred(
             self.store.store_device("user", "device", "display_name")
         )
@@ -86,10 +94,18 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
         yield defer.ensureDeferred(self.store.store_device("user2", "device1", None))
         yield defer.ensureDeferred(self.store.store_device("user2", "device2", None))
 
-        yield self.store.set_e2e_device_keys("user1", "device1", now, {"key": "json11"})
-        yield self.store.set_e2e_device_keys("user1", "device2", now, {"key": "json12"})
-        yield self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"})
-        yield self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"})
+        yield defer.ensureDeferred(
+            self.store.set_e2e_device_keys("user1", "device1", now, {"key": "json11"})
+        )
+        yield defer.ensureDeferred(
+            self.store.set_e2e_device_keys("user1", "device2", now, {"key": "json12"})
+        )
+        yield defer.ensureDeferred(
+            self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"})
+        )
+        yield defer.ensureDeferred(
+            self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"})
+        )
 
         res = yield defer.ensureDeferred(
             self.store.get_e2e_device_keys((("user1", "device1"), ("user2", "device2")))
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index 0e7427e57a..cdfd2634aa 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -60,8 +60,10 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
 
         @defer.inlineCallbacks
         def _assert_counts(noitf_count, highlight_count):
-            counts = yield self.store.db_pool.runInteraction(
-                "", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0
+            counts = yield defer.ensureDeferred(
+                self.store.db_pool.runInteraction(
+                    "", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0
+                )
             )
             self.assertEquals(
                 counts,
@@ -81,25 +83,31 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
                     event.event_id, {user_id: action}
                 )
             )
-            yield self.store.db_pool.runInteraction(
-                "",
-                self.persist_events_store._set_push_actions_for_event_and_users_txn,
-                [(event, None)],
-                [(event, None)],
+            yield defer.ensureDeferred(
+                self.store.db_pool.runInteraction(
+                    "",
+                    self.persist_events_store._set_push_actions_for_event_and_users_txn,
+                    [(event, None)],
+                    [(event, None)],
+                )
             )
 
         def _rotate(stream):
-            return self.store.db_pool.runInteraction(
-                "", self.store._rotate_notifs_before_txn, stream
+            return defer.ensureDeferred(
+                self.store.db_pool.runInteraction(
+                    "", self.store._rotate_notifs_before_txn, stream
+                )
             )
 
         def _mark_read(stream, depth):
-            return self.store.db_pool.runInteraction(
-                "",
-                self.store._remove_old_push_actions_before_txn,
-                room_id,
-                user_id,
-                stream,
+            return defer.ensureDeferred(
+                self.store.db_pool.runInteraction(
+                    "",
+                    self.store._remove_old_push_actions_before_txn,
+                    room_id,
+                    user_id,
+                    stream,
+                )
             )
 
         yield _assert_counts(0, 0)
@@ -163,16 +171,24 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
             )
 
         # start with the base case where there are no events in the table
-        r = yield self.store.find_first_stream_ordering_after_ts(11)
+        r = yield defer.ensureDeferred(
+            self.store.find_first_stream_ordering_after_ts(11)
+        )
         self.assertEqual(r, 0)
 
         # now with one event
         yield add_event(2, 10)
-        r = yield self.store.find_first_stream_ordering_after_ts(9)
+        r = yield defer.ensureDeferred(
+            self.store.find_first_stream_ordering_after_ts(9)
+        )
         self.assertEqual(r, 2)
-        r = yield self.store.find_first_stream_ordering_after_ts(10)
+        r = yield defer.ensureDeferred(
+            self.store.find_first_stream_ordering_after_ts(10)
+        )
         self.assertEqual(r, 2)
-        r = yield self.store.find_first_stream_ordering_after_ts(11)
+        r = yield defer.ensureDeferred(
+            self.store.find_first_stream_ordering_after_ts(11)
+        )
         self.assertEqual(r, 3)
 
         # add a bunch of dummy events to the events table
@@ -185,25 +201,37 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
         ):
             yield add_event(stream_ordering, ts)
 
-        r = yield self.store.find_first_stream_ordering_after_ts(110)
+        r = yield defer.ensureDeferred(
+            self.store.find_first_stream_ordering_after_ts(110)
+        )
         self.assertEqual(r, 3, "First event after 110ms should be 3, was %i" % r)
 
         # 4 and 5 are both after 120: we want 4 rather than 5
-        r = yield self.store.find_first_stream_ordering_after_ts(120)
+        r = yield defer.ensureDeferred(
+            self.store.find_first_stream_ordering_after_ts(120)
+        )
         self.assertEqual(r, 4, "First event after 120ms should be 4, was %i" % r)
 
-        r = yield self.store.find_first_stream_ordering_after_ts(129)
+        r = yield defer.ensureDeferred(
+            self.store.find_first_stream_ordering_after_ts(129)
+        )
         self.assertEqual(r, 10, "First event after 129ms should be 10, was %i" % r)
 
         # check we can get the last event
-        r = yield self.store.find_first_stream_ordering_after_ts(140)
+        r = yield defer.ensureDeferred(
+            self.store.find_first_stream_ordering_after_ts(140)
+        )
         self.assertEqual(r, 20, "First event after 14ms should be 20, was %i" % r)
 
         # off the end
-        r = yield self.store.find_first_stream_ordering_after_ts(160)
+        r = yield defer.ensureDeferred(
+            self.store.find_first_stream_ordering_after_ts(160)
+        )
         self.assertEqual(r, 21)
 
         # check we can find an event at ordering zero
         yield add_event(0, 5)
-        r = yield self.store.find_first_stream_ordering_after_ts(1)
+        r = yield defer.ensureDeferred(
+            self.store.find_first_stream_ordering_after_ts(1)
+        )
         self.assertEqual(r, 0)
diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py
index 954338a592..7e7f1286d9 100644
--- a/tests/storage/test_main.py
+++ b/tests/storage/test_main.py
@@ -34,14 +34,16 @@ class DataStoreTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_get_users_paginate(self):
-        yield self.store.register_user(self.user.to_string(), "pass")
+        yield defer.ensureDeferred(
+            self.store.register_user(self.user.to_string(), "pass")
+        )
         yield defer.ensureDeferred(self.store.create_profile(self.user.localpart))
         yield defer.ensureDeferred(
             self.store.set_profile_displayname(self.user.localpart, self.displayname)
         )
 
-        users, total = yield self.store.get_users_paginate(
-            0, 10, name="bc", guests=False
+        users, total = yield defer.ensureDeferred(
+            self.store.get_users_paginate(0, 10, name="bc", guests=False)
         )
 
         self.assertEquals(1, total)
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 70c55cd650..6b582771fe 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -37,7 +37,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_register(self):
-        yield self.store.register_user(self.user_id, self.pwhash)
+        yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash))
 
         self.assertEquals(
             {
@@ -58,14 +58,16 @@ class RegistrationStoreTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_add_tokens(self):
-        yield self.store.register_user(self.user_id, self.pwhash)
+        yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash))
         yield defer.ensureDeferred(
             self.store.add_access_token_to_user(
                 self.user_id, self.tokens[1], self.device_id, valid_until_ms=None
             )
         )
 
-        result = yield self.store.get_user_by_access_token(self.tokens[1])
+        result = yield defer.ensureDeferred(
+            self.store.get_user_by_access_token(self.tokens[1])
+        )
 
         self.assertDictContainsSubset(
             {"name": self.user_id, "device_id": self.device_id}, result
@@ -76,7 +78,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_user_delete_access_tokens(self):
         # add some tokens
-        yield self.store.register_user(self.user_id, self.pwhash)
+        yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash))
         yield defer.ensureDeferred(
             self.store.add_access_token_to_user(
                 self.user_id, self.tokens[0], device_id=None, valid_until_ms=None
@@ -89,22 +91,28 @@ class RegistrationStoreTestCase(unittest.TestCase):
         )
 
         # now delete some
-        yield self.store.user_delete_access_tokens(
-            self.user_id, device_id=self.device_id
+        yield defer.ensureDeferred(
+            self.store.user_delete_access_tokens(self.user_id, device_id=self.device_id)
         )
 
         # check they were deleted
-        user = yield self.store.get_user_by_access_token(self.tokens[1])
+        user = yield defer.ensureDeferred(
+            self.store.get_user_by_access_token(self.tokens[1])
+        )
         self.assertIsNone(user, "access token was not deleted by device_id")
 
         # check the one not associated with the device was not deleted
-        user = yield self.store.get_user_by_access_token(self.tokens[0])
+        user = yield defer.ensureDeferred(
+            self.store.get_user_by_access_token(self.tokens[0])
+        )
         self.assertEqual(self.user_id, user["name"])
 
         # now delete the rest
-        yield self.store.user_delete_access_tokens(self.user_id)
+        yield defer.ensureDeferred(self.store.user_delete_access_tokens(self.user_id))
 
-        user = yield self.store.get_user_by_access_token(self.tokens[0])
+        user = yield defer.ensureDeferred(
+            self.store.get_user_by_access_token(self.tokens[0])
+        )
         self.assertIsNone(user, "access token was not deleted without device_id")
 
     @defer.inlineCallbacks
@@ -112,16 +120,20 @@ class RegistrationStoreTestCase(unittest.TestCase):
         TEST_USER = "@test:test"
         SUPPORT_USER = "@support:test"
 
-        res = yield self.store.is_support_user(None)
+        res = yield defer.ensureDeferred(self.store.is_support_user(None))
         self.assertFalse(res)
-        yield self.store.register_user(user_id=TEST_USER, password_hash=None)
-        res = yield self.store.is_support_user(TEST_USER)
+        yield defer.ensureDeferred(
+            self.store.register_user(user_id=TEST_USER, password_hash=None)
+        )
+        res = yield defer.ensureDeferred(self.store.is_support_user(TEST_USER))
         self.assertFalse(res)
 
-        yield self.store.register_user(
-            user_id=SUPPORT_USER, password_hash=None, user_type=UserTypes.SUPPORT
+        yield defer.ensureDeferred(
+            self.store.register_user(
+                user_id=SUPPORT_USER, password_hash=None, user_type=UserTypes.SUPPORT
+            )
         )
-        res = yield self.store.is_support_user(SUPPORT_USER)
+        res = yield defer.ensureDeferred(self.store.is_support_user(SUPPORT_USER))
         self.assertTrue(res)
 
     @defer.inlineCallbacks
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index ecfafe68a9..738e912468 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -31,10 +31,18 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
 
         # alice and bob are both in !room_id. bobby is not but shares
         # a homeserver with alice.
-        yield self.store.update_profile_in_user_dir(ALICE, "alice", None)
-        yield self.store.update_profile_in_user_dir(BOB, "bob", None)
-        yield self.store.update_profile_in_user_dir(BOBBY, "bobby", None)
-        yield self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB))
+        yield defer.ensureDeferred(
+            self.store.update_profile_in_user_dir(ALICE, "alice", None)
+        )
+        yield defer.ensureDeferred(
+            self.store.update_profile_in_user_dir(BOB, "bob", None)
+        )
+        yield defer.ensureDeferred(
+            self.store.update_profile_in_user_dir(BOBBY, "bobby", None)
+        )
+        yield defer.ensureDeferred(
+            self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB))
+        )
 
     @defer.inlineCallbacks
     def test_search_user_dir(self):
diff --git a/tests/test_state.py b/tests/test_state.py
index b5c3667d2a..56ba0fecf5 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -80,16 +80,16 @@ class StateGroupStore(object):
 
         self._next_group = 1
 
-    def get_state_groups_ids(self, room_id, event_ids):
+    async def get_state_groups_ids(self, room_id, event_ids):
         groups = {}
         for event_id in event_ids:
             group = self._event_to_state_group.get(event_id)
             if group:
                 groups[group] = self._group_to_state[group]
 
-        return defer.succeed(groups)
+        return groups
 
-    def store_state_group(
+    async def store_state_group(
         self, event_id, room_id, prev_group, delta_ids, current_state_ids
     ):
         state_group = self._next_group
@@ -97,19 +97,17 @@ class StateGroupStore(object):
 
         self._group_to_state[state_group] = dict(current_state_ids)
 
-        return defer.succeed(state_group)
+        return state_group
 
-    def get_events(self, event_ids, **kwargs):
-        return defer.succeed(
-            {
-                e_id: self._event_id_to_event[e_id]
-                for e_id in event_ids
-                if e_id in self._event_id_to_event
-            }
-        )
+    async def get_events(self, event_ids, **kwargs):
+        return {
+            e_id: self._event_id_to_event[e_id]
+            for e_id in event_ids
+            if e_id in self._event_id_to_event
+        }
 
-    def get_state_group_delta(self, name):
-        return defer.succeed((None, None))
+    async def get_state_group_delta(self, name):
+        return (None, None)
 
     def register_events(self, events):
         for e in events:
@@ -121,8 +119,8 @@ class StateGroupStore(object):
     def register_event_id_state_group(self, event_id, state_group):
         self._event_to_state_group[event_id] = state_group
 
-    def get_room_version_id(self, room_id):
-        return defer.succeed(RoomVersions.V1.identifier)
+    async def get_room_version_id(self, room_id):
+        return RoomVersions.V1.identifier
 
 
 class DictObj(dict):
@@ -476,12 +474,14 @@ class StateTestCase(unittest.TestCase):
             create_event(type="test2", state_key=""),
         ]
 
-        group_name = yield self.store.store_state_group(
-            prev_event_id,
-            event.room_id,
-            None,
-            None,
-            {(e.type, e.state_key): e.event_id for e in old_state},
+        group_name = yield defer.ensureDeferred(
+            self.store.store_state_group(
+                prev_event_id,
+                event.room_id,
+                None,
+                None,
+                {(e.type, e.state_key): e.event_id for e in old_state},
+            )
         )
         self.store.register_event_id_state_group(prev_event_id, group_name)
 
@@ -508,12 +508,14 @@ class StateTestCase(unittest.TestCase):
             create_event(type="test2", state_key=""),
         ]
 
-        group_name = yield self.store.store_state_group(
-            prev_event_id,
-            event.room_id,
-            None,
-            None,
-            {(e.type, e.state_key): e.event_id for e in old_state},
+        group_name = yield defer.ensureDeferred(
+            self.store.store_state_group(
+                prev_event_id,
+                event.room_id,
+                None,
+                None,
+                {(e.type, e.state_key): e.event_id for e in old_state},
+            )
         )
         self.store.register_event_id_state_group(prev_event_id, group_name)
 
@@ -691,21 +693,25 @@ class StateTestCase(unittest.TestCase):
     def _get_context(
         self, event, prev_event_id_1, old_state_1, prev_event_id_2, old_state_2
     ):
-        sg1 = yield self.store.store_state_group(
-            prev_event_id_1,
-            event.room_id,
-            None,
-            None,
-            {(e.type, e.state_key): e.event_id for e in old_state_1},
+        sg1 = yield defer.ensureDeferred(
+            self.store.store_state_group(
+                prev_event_id_1,
+                event.room_id,
+                None,
+                None,
+                {(e.type, e.state_key): e.event_id for e in old_state_1},
+            )
         )
         self.store.register_event_id_state_group(prev_event_id_1, sg1)
 
-        sg2 = yield self.store.store_state_group(
-            prev_event_id_2,
-            event.room_id,
-            None,
-            None,
-            {(e.type, e.state_key): e.event_id for e in old_state_2},
+        sg2 = yield defer.ensureDeferred(
+            self.store.store_state_group(
+                prev_event_id_2,
+                event.room_id,
+                None,
+                None,
+                {(e.type, e.state_key): e.event_id for e in old_state_2},
+            )
         )
         self.store.register_event_id_state_group(prev_event_id_2, sg2)
 
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index 531a9b9118..4a4483ba12 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -37,7 +37,6 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
         self.hs = yield setup_test_homeserver(self.addCleanup)
         self.event_creation_handler = self.hs.get_event_creation_handler()
         self.event_builder_factory = self.hs.get_event_builder_factory()
-        self.store = self.hs.get_datastore()
         self.storage = self.hs.get_storage()
 
         yield defer.ensureDeferred(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM"))
@@ -99,7 +98,9 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
         events_to_filter.append(evt)
 
         # the erasey user gets erased
-        yield self.hs.get_datastore().mark_user_erased("@erased:local_hs")
+        yield defer.ensureDeferred(
+            self.hs.get_datastore().mark_user_erased("@erased:local_hs")
+        )
 
         # ... and the filtering happens.
         filtered = yield defer.ensureDeferred(