summary refs log tree commit diff
path: root/tests/storage
diff options
context:
space:
mode:
Diffstat (limited to 'tests/storage')
-rw-r--r--tests/storage/test_devices.py80
-rw-r--r--tests/storage/test_directory.py44
-rw-r--r--tests/storage/test_end_to_end_keys.py59
-rw-r--r--tests/storage/test_event_push_actions.py133
-rw-r--r--tests/storage/test_profile.py35
-rw-r--r--tests/storage/test_redaction.py12
-rw-r--r--tests/storage/test_registration.py108
-rw-r--r--tests/storage/test_room.py61
-rw-r--r--tests/storage/test_state.py145
-rw-r--r--tests/storage/test_user_directory.py86
10 files changed, 264 insertions, 499 deletions
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index dabc1c5f09..ef4cf8d0f1 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2016 OpenMarket Ltd
+# Copyright 2016-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,32 +13,21 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from twisted.internet import defer
-
 import synapse.api.errors
 
-import tests.unittest
-import tests.utils
-
-
-class DeviceStoreTestCase(tests.unittest.TestCase):
-    def __init__(self, *args, **kwargs):
-        super().__init__(*args, **kwargs)
-        self.store = None  # type: synapse.storage.DataStore
+from tests.unittest import HomeserverTestCase
 
-    @defer.inlineCallbacks
-    def setUp(self):
-        hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
 
+class DeviceStoreTestCase(HomeserverTestCase):
+    def prepare(self, reactor, clock, hs):
         self.store = hs.get_datastore()
 
-    @defer.inlineCallbacks
     def test_store_new_device(self):
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.store_device("user_id", "device_id", "display_name")
         )
 
-        res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
+        res = self.get_success(self.store.get_device("user_id", "device_id"))
         self.assertDictContainsSubset(
             {
                 "user_id": "user_id",
@@ -48,19 +37,18 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
             res,
         )
 
-    @defer.inlineCallbacks
     def test_get_devices_by_user(self):
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.store_device("user_id", "device1", "display_name 1")
         )
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.store_device("user_id", "device2", "display_name 2")
         )
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.store_device("user_id2", "device3", "display_name 3")
         )
 
-        res = yield defer.ensureDeferred(self.store.get_devices_by_user("user_id"))
+        res = self.get_success(self.store.get_devices_by_user("user_id"))
         self.assertEqual(2, len(res.keys()))
         self.assertDictContainsSubset(
             {
@@ -79,43 +67,41 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
             res["device2"],
         )
 
-    @defer.inlineCallbacks
     def test_count_devices_by_users(self):
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.store_device("user_id", "device1", "display_name 1")
         )
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.store_device("user_id", "device2", "display_name 2")
         )
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.store_device("user_id2", "device3", "display_name 3")
         )
 
-        res = yield defer.ensureDeferred(self.store.count_devices_by_users())
+        res = self.get_success(self.store.count_devices_by_users())
         self.assertEqual(0, res)
 
-        res = yield defer.ensureDeferred(self.store.count_devices_by_users(["unknown"]))
+        res = self.get_success(self.store.count_devices_by_users(["unknown"]))
         self.assertEqual(0, res)
 
-        res = yield defer.ensureDeferred(self.store.count_devices_by_users(["user_id"]))
+        res = self.get_success(self.store.count_devices_by_users(["user_id"]))
         self.assertEqual(2, res)
 
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.store.count_devices_by_users(["user_id", "user_id2"])
         )
         self.assertEqual(3, res)
 
-    @defer.inlineCallbacks
     def test_get_device_updates_by_remote(self):
         device_ids = ["device_id1", "device_id2"]
 
         # Add two device updates with a single stream_id
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
         )
 
         # Get all device updates ever meant for this remote
-        now_stream_id, device_updates = yield defer.ensureDeferred(
+        now_stream_id, device_updates = self.get_success(
             self.store.get_device_updates_by_remote("somehost", -1, limit=100)
         )
 
@@ -131,37 +117,35 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
         }
         self.assertEqual(received_device_ids, set(expected_device_ids))
 
-    @defer.inlineCallbacks
     def test_update_device(self):
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.store_device("user_id", "device_id", "display_name 1")
         )
 
-        res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
+        res = self.get_success(self.store.get_device("user_id", "device_id"))
         self.assertEqual("display_name 1", res["display_name"])
 
         # do a no-op first
-        yield defer.ensureDeferred(self.store.update_device("user_id", "device_id"))
-        res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
+        self.get_success(self.store.update_device("user_id", "device_id"))
+        res = self.get_success(self.store.get_device("user_id", "device_id"))
         self.assertEqual("display_name 1", res["display_name"])
 
         # do the update
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.update_device(
                 "user_id", "device_id", new_display_name="display_name 2"
             )
         )
 
         # check it worked
-        res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
+        res = self.get_success(self.store.get_device("user_id", "device_id"))
         self.assertEqual("display_name 2", res["display_name"])
 
-    @defer.inlineCallbacks
     def test_update_unknown_device(self):
-        with self.assertRaises(synapse.api.errors.StoreError) as cm:
-            yield defer.ensureDeferred(
-                self.store.update_device(
-                    "user_id", "unknown_device_id", new_display_name="display_name 2"
-                )
-            )
-        self.assertEqual(404, cm.exception.code)
+        exc = self.get_failure(
+            self.store.update_device(
+                "user_id", "unknown_device_id", new_display_name="display_name 2"
+            ),
+            synapse.api.errors.StoreError,
+        )
+        self.assertEqual(404, exc.value.code)
diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py
index da93ca3980..0db233fd68 100644
--- a/tests/storage/test_directory.py
+++ b/tests/storage/test_directory.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,28 +13,20 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
-from twisted.internet import defer
-
 from synapse.types import RoomAlias, RoomID
 
-from tests import unittest
-from tests.utils import setup_test_homeserver
+from tests.unittest import HomeserverTestCase
 
 
-class DirectoryStoreTestCase(unittest.TestCase):
-    @defer.inlineCallbacks
-    def setUp(self):
-        hs = yield setup_test_homeserver(self.addCleanup)
-
+class DirectoryStoreTestCase(HomeserverTestCase):
+    def prepare(self, reactor, clock, hs):
         self.store = hs.get_datastore()
 
         self.room = RoomID.from_string("!abcde:test")
         self.alias = RoomAlias.from_string("#my-room:test")
 
-    @defer.inlineCallbacks
     def test_room_to_alias(self):
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.create_room_alias_association(
                 room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
             )
@@ -42,16 +34,11 @@ class DirectoryStoreTestCase(unittest.TestCase):
 
         self.assertEquals(
             ["#my-room:test"],
-            (
-                yield defer.ensureDeferred(
-                    self.store.get_aliases_for_room(self.room.to_string())
-                )
-            ),
+            (self.get_success(self.store.get_aliases_for_room(self.room.to_string()))),
         )
 
-    @defer.inlineCallbacks
     def test_alias_to_room(self):
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.create_room_alias_association(
                 room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
             )
@@ -59,28 +46,19 @@ class DirectoryStoreTestCase(unittest.TestCase):
 
         self.assertObjectHasAttributes(
             {"room_id": self.room.to_string(), "servers": ["test"]},
-            (
-                yield defer.ensureDeferred(
-                    self.store.get_association_from_room_alias(self.alias)
-                )
-            ),
+            (self.get_success(self.store.get_association_from_room_alias(self.alias))),
         )
 
-    @defer.inlineCallbacks
     def test_delete_alias(self):
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.create_room_alias_association(
                 room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
             )
         )
 
-        room_id = yield defer.ensureDeferred(self.store.delete_room_alias(self.alias))
+        room_id = self.get_success(self.store.delete_room_alias(self.alias))
         self.assertEqual(self.room.to_string(), room_id)
 
         self.assertIsNone(
-            (
-                yield defer.ensureDeferred(
-                    self.store.get_association_from_room_alias(self.alias)
-                )
-            )
+            (self.get_success(self.store.get_association_from_room_alias(self.alias)))
         )
diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py
index 3fc4bb13b6..1e54b940fd 100644
--- a/tests/storage/test_end_to_end_keys.py
+++ b/tests/storage/test_end_to_end_keys.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2016 OpenMarket Ltd
+# Copyright 2016-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,30 +13,22 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from twisted.internet import defer
+from tests.unittest import HomeserverTestCase
 
-import tests.unittest
-import tests.utils
 
-
-class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
-    @defer.inlineCallbacks
-    def setUp(self):
-        hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
+class EndToEndKeyStoreTestCase(HomeserverTestCase):
+    def prepare(self, reactor, clock, hs):
         self.store = hs.get_datastore()
 
-    @defer.inlineCallbacks
     def test_key_without_device_name(self):
         now = 1470174257070
         json = {"key": "value"}
 
-        yield defer.ensureDeferred(self.store.store_device("user", "device", None))
+        self.get_success(self.store.store_device("user", "device", None))
 
-        yield defer.ensureDeferred(
-            self.store.set_e2e_device_keys("user", "device", now, json)
-        )
+        self.get_success(self.store.set_e2e_device_keys("user", "device", now, json))
 
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.store.get_e2e_device_keys_for_cs_api((("user", "device"),))
         )
         self.assertIn("user", res)
@@ -44,38 +36,32 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
         dev = res["user"]["device"]
         self.assertDictContainsSubset(json, dev)
 
-    @defer.inlineCallbacks
     def test_reupload_key(self):
         now = 1470174257070
         json = {"key": "value"}
 
-        yield defer.ensureDeferred(self.store.store_device("user", "device", None))
+        self.get_success(self.store.store_device("user", "device", None))
 
-        changed = yield defer.ensureDeferred(
+        changed = self.get_success(
             self.store.set_e2e_device_keys("user", "device", now, json)
         )
         self.assertTrue(changed)
 
         # If we try to upload the same key then we should be told nothing
         # changed
-        changed = yield defer.ensureDeferred(
+        changed = self.get_success(
             self.store.set_e2e_device_keys("user", "device", now, json)
         )
         self.assertFalse(changed)
 
-    @defer.inlineCallbacks
     def test_get_key_with_device_name(self):
         now = 1470174257070
         json = {"key": "value"}
 
-        yield defer.ensureDeferred(
-            self.store.set_e2e_device_keys("user", "device", now, json)
-        )
-        yield defer.ensureDeferred(
-            self.store.store_device("user", "device", "display_name")
-        )
+        self.get_success(self.store.set_e2e_device_keys("user", "device", now, json))
+        self.get_success(self.store.store_device("user", "device", "display_name"))
 
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.store.get_e2e_device_keys_for_cs_api((("user", "device"),))
         )
         self.assertIn("user", res)
@@ -85,29 +71,28 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
             {"key": "value", "unsigned": {"device_display_name": "display_name"}}, dev
         )
 
-    @defer.inlineCallbacks
     def test_multiple_devices(self):
         now = 1470174257070
 
-        yield defer.ensureDeferred(self.store.store_device("user1", "device1", None))
-        yield defer.ensureDeferred(self.store.store_device("user1", "device2", None))
-        yield defer.ensureDeferred(self.store.store_device("user2", "device1", None))
-        yield defer.ensureDeferred(self.store.store_device("user2", "device2", None))
+        self.get_success(self.store.store_device("user1", "device1", None))
+        self.get_success(self.store.store_device("user1", "device2", None))
+        self.get_success(self.store.store_device("user2", "device1", None))
+        self.get_success(self.store.store_device("user2", "device2", None))
 
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.set_e2e_device_keys("user1", "device1", now, {"key": "json11"})
         )
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.set_e2e_device_keys("user1", "device2", now, {"key": "json12"})
         )
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"})
         )
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"})
         )
 
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.store.get_e2e_device_keys_for_cs_api(
                 (("user1", "device1"), ("user2", "device2"))
             )
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index 485f1ee033..239f7c9faf 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2016 OpenMarket Ltd
+# Copyright 2016-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -15,10 +15,7 @@
 
 from mock import Mock
 
-from twisted.internet import defer
-
-import tests.unittest
-import tests.utils
+from tests.unittest import HomeserverTestCase
 
 USER_ID = "@user:example.com"
 
@@ -30,37 +27,31 @@ HIGHLIGHT = [
 ]
 
 
-class EventPushActionsStoreTestCase(tests.unittest.TestCase):
-    @defer.inlineCallbacks
-    def setUp(self):
-        hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
+class EventPushActionsStoreTestCase(HomeserverTestCase):
+    def prepare(self, reactor, clock, hs):
         self.store = hs.get_datastore()
         self.persist_events_store = hs.get_datastores().persist_events
 
-    @defer.inlineCallbacks
     def test_get_unread_push_actions_for_user_in_range_for_http(self):
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.get_unread_push_actions_for_user_in_range_for_http(
                 USER_ID, 0, 1000, 20
             )
         )
 
-    @defer.inlineCallbacks
     def test_get_unread_push_actions_for_user_in_range_for_email(self):
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.get_unread_push_actions_for_user_in_range_for_email(
                 USER_ID, 0, 1000, 20
             )
         )
 
-    @defer.inlineCallbacks
     def test_count_aggregation(self):
         room_id = "!foo:example.com"
         user_id = "@user1235:example.com"
 
-        @defer.inlineCallbacks
         def _assert_counts(noitf_count, highlight_count):
-            counts = yield defer.ensureDeferred(
+            counts = self.get_success(
                 self.store.db_pool.runInteraction(
                     "", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0
                 )
@@ -74,7 +65,6 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
                 },
             )
 
-        @defer.inlineCallbacks
         def _inject_actions(stream, action):
             event = Mock()
             event.room_id = room_id
@@ -82,14 +72,14 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
             event.internal_metadata.stream_ordering = stream
             event.depth = stream
 
-            yield defer.ensureDeferred(
+            self.get_success(
                 self.store.add_push_actions_to_staging(
                     event.event_id,
                     {user_id: action},
                     False,
                 )
             )
-            yield defer.ensureDeferred(
+            self.get_success(
                 self.store.db_pool.runInteraction(
                     "",
                     self.persist_events_store._set_push_actions_for_event_and_users_txn,
@@ -99,14 +89,14 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
             )
 
         def _rotate(stream):
-            return defer.ensureDeferred(
+            self.get_success(
                 self.store.db_pool.runInteraction(
                     "", self.store._rotate_notifs_before_txn, stream
                 )
             )
 
         def _mark_read(stream, depth):
-            return defer.ensureDeferred(
+            self.get_success(
                 self.store.db_pool.runInteraction(
                     "",
                     self.store._remove_old_push_actions_before_txn,
@@ -116,49 +106,48 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
                 )
             )
 
-        yield _assert_counts(0, 0)
-        yield _inject_actions(1, PlAIN_NOTIF)
-        yield _assert_counts(1, 0)
-        yield _rotate(2)
-        yield _assert_counts(1, 0)
+        _assert_counts(0, 0)
+        _inject_actions(1, PlAIN_NOTIF)
+        _assert_counts(1, 0)
+        _rotate(2)
+        _assert_counts(1, 0)
 
-        yield _inject_actions(3, PlAIN_NOTIF)
-        yield _assert_counts(2, 0)
-        yield _rotate(4)
-        yield _assert_counts(2, 0)
+        _inject_actions(3, PlAIN_NOTIF)
+        _assert_counts(2, 0)
+        _rotate(4)
+        _assert_counts(2, 0)
 
-        yield _inject_actions(5, PlAIN_NOTIF)
-        yield _mark_read(3, 3)
-        yield _assert_counts(1, 0)
+        _inject_actions(5, PlAIN_NOTIF)
+        _mark_read(3, 3)
+        _assert_counts(1, 0)
 
-        yield _mark_read(5, 5)
-        yield _assert_counts(0, 0)
+        _mark_read(5, 5)
+        _assert_counts(0, 0)
 
-        yield _inject_actions(6, PlAIN_NOTIF)
-        yield _rotate(7)
+        _inject_actions(6, PlAIN_NOTIF)
+        _rotate(7)
 
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.db_pool.simple_delete(
                 table="event_push_actions", keyvalues={"1": 1}, desc=""
             )
         )
 
-        yield _assert_counts(1, 0)
+        _assert_counts(1, 0)
 
-        yield _mark_read(7, 7)
-        yield _assert_counts(0, 0)
+        _mark_read(7, 7)
+        _assert_counts(0, 0)
 
-        yield _inject_actions(8, HIGHLIGHT)
-        yield _assert_counts(1, 1)
-        yield _rotate(9)
-        yield _assert_counts(1, 1)
-        yield _rotate(10)
-        yield _assert_counts(1, 1)
+        _inject_actions(8, HIGHLIGHT)
+        _assert_counts(1, 1)
+        _rotate(9)
+        _assert_counts(1, 1)
+        _rotate(10)
+        _assert_counts(1, 1)
 
-    @defer.inlineCallbacks
     def test_find_first_stream_ordering_after_ts(self):
         def add_event(so, ts):
-            return defer.ensureDeferred(
+            self.get_success(
                 self.store.db_pool.simple_insert(
                     "events",
                     {
@@ -177,24 +166,16 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
             )
 
         # start with the base case where there are no events in the table
-        r = yield defer.ensureDeferred(
-            self.store.find_first_stream_ordering_after_ts(11)
-        )
+        r = self.get_success(self.store.find_first_stream_ordering_after_ts(11))
         self.assertEqual(r, 0)
 
         # now with one event
-        yield add_event(2, 10)
-        r = yield defer.ensureDeferred(
-            self.store.find_first_stream_ordering_after_ts(9)
-        )
+        add_event(2, 10)
+        r = self.get_success(self.store.find_first_stream_ordering_after_ts(9))
         self.assertEqual(r, 2)
-        r = yield defer.ensureDeferred(
-            self.store.find_first_stream_ordering_after_ts(10)
-        )
+        r = self.get_success(self.store.find_first_stream_ordering_after_ts(10))
         self.assertEqual(r, 2)
-        r = yield defer.ensureDeferred(
-            self.store.find_first_stream_ordering_after_ts(11)
-        )
+        r = self.get_success(self.store.find_first_stream_ordering_after_ts(11))
         self.assertEqual(r, 3)
 
         # add a bunch of dummy events to the events table
@@ -205,39 +186,27 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
             (10, 130),
             (20, 140),
         ):
-            yield add_event(stream_ordering, ts)
+            add_event(stream_ordering, ts)
 
-        r = yield defer.ensureDeferred(
-            self.store.find_first_stream_ordering_after_ts(110)
-        )
+        r = self.get_success(self.store.find_first_stream_ordering_after_ts(110))
         self.assertEqual(r, 3, "First event after 110ms should be 3, was %i" % r)
 
         # 4 and 5 are both after 120: we want 4 rather than 5
-        r = yield defer.ensureDeferred(
-            self.store.find_first_stream_ordering_after_ts(120)
-        )
+        r = self.get_success(self.store.find_first_stream_ordering_after_ts(120))
         self.assertEqual(r, 4, "First event after 120ms should be 4, was %i" % r)
 
-        r = yield defer.ensureDeferred(
-            self.store.find_first_stream_ordering_after_ts(129)
-        )
+        r = self.get_success(self.store.find_first_stream_ordering_after_ts(129))
         self.assertEqual(r, 10, "First event after 129ms should be 10, was %i" % r)
 
         # check we can get the last event
-        r = yield defer.ensureDeferred(
-            self.store.find_first_stream_ordering_after_ts(140)
-        )
+        r = self.get_success(self.store.find_first_stream_ordering_after_ts(140))
         self.assertEqual(r, 20, "First event after 14ms should be 20, was %i" % r)
 
         # off the end
-        r = yield defer.ensureDeferred(
-            self.store.find_first_stream_ordering_after_ts(160)
-        )
+        r = self.get_success(self.store.find_first_stream_ordering_after_ts(160))
         self.assertEqual(r, 21)
 
         # check we can find an event at ordering zero
-        yield add_event(0, 5)
-        r = yield defer.ensureDeferred(
-            self.store.find_first_stream_ordering_after_ts(1)
-        )
+        add_event(0, 5)
+        r = self.get_success(self.store.find_first_stream_ordering_after_ts(1))
         self.assertEqual(r, 0)
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index ea63bd56b4..d18ceb41a9 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,59 +13,50 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
-from twisted.internet import defer
-
 from synapse.types import UserID
 
 from tests import unittest
-from tests.utils import setup_test_homeserver
-
 
-class ProfileStoreTestCase(unittest.TestCase):
-    @defer.inlineCallbacks
-    def setUp(self):
-        hs = yield setup_test_homeserver(self.addCleanup)
 
+class ProfileStoreTestCase(unittest.HomeserverTestCase):
+    def prepare(self, reactor, clock, hs):
         self.store = hs.get_datastore()
 
         self.u_frank = UserID.from_string("@frank:test")
 
-    @defer.inlineCallbacks
     def test_displayname(self):
-        yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
+        self.get_success(self.store.create_profile(self.u_frank.localpart))
 
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
         )
 
         self.assertEquals(
             "Frank",
             (
-                yield defer.ensureDeferred(
+                self.get_success(
                     self.store.get_profile_displayname(self.u_frank.localpart)
                 )
             ),
         )
 
         # test set to None
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.set_profile_displayname(self.u_frank.localpart, None)
         )
 
         self.assertIsNone(
             (
-                yield defer.ensureDeferred(
+                self.get_success(
                     self.store.get_profile_displayname(self.u_frank.localpart)
                 )
             )
         )
 
-    @defer.inlineCallbacks
     def test_avatar_url(self):
-        yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
+        self.get_success(self.store.create_profile(self.u_frank.localpart))
 
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.set_profile_avatar_url(
                 self.u_frank.localpart, "http://my.site/here"
             )
@@ -74,20 +65,20 @@ class ProfileStoreTestCase(unittest.TestCase):
         self.assertEquals(
             "http://my.site/here",
             (
-                yield defer.ensureDeferred(
+                self.get_success(
                     self.store.get_profile_avatar_url(self.u_frank.localpart)
                 )
             ),
         )
 
         # test set to None
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.set_profile_avatar_url(self.u_frank.localpart, None)
         )
 
         self.assertIsNone(
             (
-                yield defer.ensureDeferred(
+                self.get_success(
                     self.store.get_profile_avatar_url(self.u_frank.localpart)
                 )
             )
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index b2a0e60856..2622207639 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -1,6 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -16,8 +15,6 @@
 
 from canonicaljson import json
 
-from twisted.internet import defer
-
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.room_versions import RoomVersions
 from synapse.types import RoomID, UserID
@@ -230,10 +227,9 @@ class RedactionTestCase(unittest.HomeserverTestCase):
                 self._base_builder = base_builder
                 self._event_id = event_id
 
-            @defer.inlineCallbacks
-            def build(self, prev_event_ids, auth_event_ids):
-                built_event = yield defer.ensureDeferred(
-                    self._base_builder.build(prev_event_ids, auth_event_ids)
+            async def build(self, prev_event_ids, auth_event_ids):
+                built_event = await self._base_builder.build(
+                    prev_event_ids, auth_event_ids
                 )
 
                 built_event._event_id = self._event_id
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 4eb41c46e8..c82cf15bc2 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,21 +13,14 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
-from twisted.internet import defer
-
 from synapse.api.constants import UserTypes
 from synapse.api.errors import ThreepidValidationError
 
-from tests import unittest
-from tests.utils import setup_test_homeserver
-
+from tests.unittest import HomeserverTestCase
 
-class RegistrationStoreTestCase(unittest.TestCase):
-    @defer.inlineCallbacks
-    def setUp(self):
-        hs = yield setup_test_homeserver(self.addCleanup)
 
+class RegistrationStoreTestCase(HomeserverTestCase):
+    def prepare(self, reactor, clock, hs):
         self.store = hs.get_datastore()
 
         self.user_id = "@my-user:test"
@@ -35,9 +28,8 @@ class RegistrationStoreTestCase(unittest.TestCase):
         self.pwhash = "{xx1}123456789"
         self.device_id = "akgjhdjklgshg"
 
-    @defer.inlineCallbacks
     def test_register(self):
-        yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash))
+        self.get_success(self.store.register_user(self.user_id, self.pwhash))
 
         self.assertEquals(
             {
@@ -49,93 +41,81 @@ class RegistrationStoreTestCase(unittest.TestCase):
                 "consent_version": None,
                 "consent_server_notice_sent": None,
                 "appservice_id": None,
-                "creation_ts": 1000,
+                "creation_ts": 0,
                 "user_type": None,
                 "deactivated": 0,
                 "shadow_banned": 0,
             },
-            (yield defer.ensureDeferred(self.store.get_user_by_id(self.user_id))),
+            (self.get_success(self.store.get_user_by_id(self.user_id))),
         )
 
-    @defer.inlineCallbacks
     def test_add_tokens(self):
-        yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash))
-        yield defer.ensureDeferred(
+        self.get_success(self.store.register_user(self.user_id, self.pwhash))
+        self.get_success(
             self.store.add_access_token_to_user(
                 self.user_id, self.tokens[1], self.device_id, valid_until_ms=None
             )
         )
 
-        result = yield defer.ensureDeferred(
-            self.store.get_user_by_access_token(self.tokens[1])
-        )
+        result = self.get_success(self.store.get_user_by_access_token(self.tokens[1]))
 
         self.assertEqual(result.user_id, self.user_id)
         self.assertEqual(result.device_id, self.device_id)
         self.assertIsNotNone(result.token_id)
 
-    @defer.inlineCallbacks
     def test_user_delete_access_tokens(self):
         # add some tokens
-        yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash))
-        yield defer.ensureDeferred(
+        self.get_success(self.store.register_user(self.user_id, self.pwhash))
+        self.get_success(
             self.store.add_access_token_to_user(
                 self.user_id, self.tokens[0], device_id=None, valid_until_ms=None
             )
         )
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.add_access_token_to_user(
                 self.user_id, self.tokens[1], self.device_id, valid_until_ms=None
             )
         )
 
         # now delete some
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.user_delete_access_tokens(self.user_id, device_id=self.device_id)
         )
 
         # check they were deleted
-        user = yield defer.ensureDeferred(
-            self.store.get_user_by_access_token(self.tokens[1])
-        )
+        user = self.get_success(self.store.get_user_by_access_token(self.tokens[1]))
         self.assertIsNone(user, "access token was not deleted by device_id")
 
         # check the one not associated with the device was not deleted
-        user = yield defer.ensureDeferred(
-            self.store.get_user_by_access_token(self.tokens[0])
-        )
+        user = self.get_success(self.store.get_user_by_access_token(self.tokens[0]))
         self.assertEqual(self.user_id, user.user_id)
 
         # now delete the rest
-        yield defer.ensureDeferred(self.store.user_delete_access_tokens(self.user_id))
+        self.get_success(self.store.user_delete_access_tokens(self.user_id))
 
-        user = yield defer.ensureDeferred(
-            self.store.get_user_by_access_token(self.tokens[0])
-        )
+        user = self.get_success(self.store.get_user_by_access_token(self.tokens[0]))
         self.assertIsNone(user, "access token was not deleted without device_id")
 
-    @defer.inlineCallbacks
     def test_is_support_user(self):
         TEST_USER = "@test:test"
         SUPPORT_USER = "@support:test"
 
-        res = yield defer.ensureDeferred(self.store.is_support_user(None))
+        res = self.get_success(self.store.is_support_user(None))
         self.assertFalse(res)
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.register_user(user_id=TEST_USER, password_hash=None)
         )
-        res = yield defer.ensureDeferred(self.store.is_support_user(TEST_USER))
+        res = self.get_success(self.store.is_support_user(TEST_USER))
         self.assertFalse(res)
 
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.register_user(
                 user_id=SUPPORT_USER, password_hash=None, user_type=UserTypes.SUPPORT
             )
         )
-        res = yield defer.ensureDeferred(self.store.is_support_user(SUPPORT_USER))
+        res = self.get_success(self.store.is_support_user(SUPPORT_USER))
         self.assertTrue(res)
 
-    @defer.inlineCallbacks
     def test_3pid_inhibit_invalid_validation_session_error(self):
         """Tests that enabling the configuration option to inhibit 3PID errors on
         /requestToken also inhibits validation errors caused by an unknown session ID.
@@ -143,30 +123,28 @@ class RegistrationStoreTestCase(unittest.TestCase):
 
         # Check that, with the config setting set to false (the default value), a
         # validation error is caused by the unknown session ID.
-        try:
-            yield defer.ensureDeferred(
-                self.store.validate_threepid_session(
-                    "fake_sid",
-                    "fake_client_secret",
-                    "fake_token",
-                    0,
-                )
-            )
-        except ThreepidValidationError as e:
-            self.assertEquals(e.msg, "Unknown session_id", e)
+        e = self.get_failure(
+            self.store.validate_threepid_session(
+                "fake_sid",
+                "fake_client_secret",
+                "fake_token",
+                0,
+            ),
+            ThreepidValidationError,
+        )
+        self.assertEquals(e.value.msg, "Unknown session_id", e)
 
         # Set the config setting to true.
         self.store._ignore_unknown_session_error = True
 
         # Check that now the validation error is caused by the token not matching.
-        try:
-            yield defer.ensureDeferred(
-                self.store.validate_threepid_session(
-                    "fake_sid",
-                    "fake_client_secret",
-                    "fake_token",
-                    0,
-                )
-            )
-        except ThreepidValidationError as e:
-            self.assertEquals(e.msg, "Validation token not found or has expired", e)
+        e = self.get_failure(
+            self.store.validate_threepid_session(
+                "fake_sid",
+                "fake_client_secret",
+                "fake_token",
+                0,
+            ),
+            ThreepidValidationError,
+        )
+        self.assertEquals(e.value.msg, "Validation token not found or has expired", e)
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index bc8400f240..0089d33c93 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,22 +13,15 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
-from twisted.internet import defer
-
 from synapse.api.constants import EventTypes
 from synapse.api.room_versions import RoomVersions
 from synapse.types import RoomAlias, RoomID, UserID
 
-from tests import unittest
-from tests.utils import setup_test_homeserver
-
+from tests.unittest import HomeserverTestCase
 
-class RoomStoreTestCase(unittest.TestCase):
-    @defer.inlineCallbacks
-    def setUp(self):
-        hs = yield setup_test_homeserver(self.addCleanup)
 
+class RoomStoreTestCase(HomeserverTestCase):
+    def prepare(self, reactor, clock, hs):
         # We can't test RoomStore on its own without the DirectoryStore, for
         # management of the 'room_aliases' table
         self.store = hs.get_datastore()
@@ -37,7 +30,7 @@ class RoomStoreTestCase(unittest.TestCase):
         self.alias = RoomAlias.from_string("#a-room-name:test")
         self.u_creator = UserID.from_string("@creator:test")
 
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.store_room(
                 self.room.to_string(),
                 room_creator_user_id=self.u_creator.to_string(),
@@ -46,7 +39,6 @@ class RoomStoreTestCase(unittest.TestCase):
             )
         )
 
-    @defer.inlineCallbacks
     def test_get_room(self):
         self.assertDictContainsSubset(
             {
@@ -54,16 +46,12 @@ class RoomStoreTestCase(unittest.TestCase):
                 "creator": self.u_creator.to_string(),
                 "is_public": True,
             },
-            (yield defer.ensureDeferred(self.store.get_room(self.room.to_string()))),
+            (self.get_success(self.store.get_room(self.room.to_string()))),
         )
 
-    @defer.inlineCallbacks
     def test_get_room_unknown_room(self):
-        self.assertIsNone(
-            (yield defer.ensureDeferred(self.store.get_room("!uknown:test")))
-        )
+        self.assertIsNone((self.get_success(self.store.get_room("!uknown:test"))))
 
-    @defer.inlineCallbacks
     def test_get_room_with_stats(self):
         self.assertDictContainsSubset(
             {
@@ -71,29 +59,17 @@ class RoomStoreTestCase(unittest.TestCase):
                 "creator": self.u_creator.to_string(),
                 "public": True,
             },
-            (
-                yield defer.ensureDeferred(
-                    self.store.get_room_with_stats(self.room.to_string())
-                )
-            ),
+            (self.get_success(self.store.get_room_with_stats(self.room.to_string()))),
         )
 
-    @defer.inlineCallbacks
     def test_get_room_with_stats_unknown_room(self):
         self.assertIsNone(
-            (
-                yield defer.ensureDeferred(
-                    self.store.get_room_with_stats("!uknown:test")
-                )
-            ),
+            (self.get_success(self.store.get_room_with_stats("!uknown:test"))),
         )
 
 
-class RoomEventsStoreTestCase(unittest.TestCase):
-    @defer.inlineCallbacks
-    def setUp(self):
-        hs = setup_test_homeserver(self.addCleanup)
-
+class RoomEventsStoreTestCase(HomeserverTestCase):
+    def prepare(self, reactor, clock, hs):
         # Room events need the full datastore, for persist_event() and
         # get_room_state()
         self.store = hs.get_datastore()
@@ -102,7 +78,7 @@ class RoomEventsStoreTestCase(unittest.TestCase):
 
         self.room = RoomID.from_string("!abcde:test")
 
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.store_room(
                 self.room.to_string(),
                 room_creator_user_id="@creator:text",
@@ -111,23 +87,21 @@ class RoomEventsStoreTestCase(unittest.TestCase):
             )
         )
 
-    @defer.inlineCallbacks
     def inject_room_event(self, **kwargs):
-        yield defer.ensureDeferred(
+        self.get_success(
             self.storage.persistence.persist_event(
                 self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
             )
         )
 
-    @defer.inlineCallbacks
     def STALE_test_room_name(self):
         name = "A-Room-Name"
 
-        yield self.inject_room_event(
+        self.inject_room_event(
             etype=EventTypes.Name, name=name, content={"name": name}, depth=1
         )
 
-        state = yield defer.ensureDeferred(
+        state = self.get_success(
             self.store.get_current_state(room_id=self.room.to_string())
         )
 
@@ -137,15 +111,14 @@ class RoomEventsStoreTestCase(unittest.TestCase):
             state[0],
         )
 
-    @defer.inlineCallbacks
     def STALE_test_room_topic(self):
         topic = "A place for things"
 
-        yield self.inject_room_event(
+        self.inject_room_event(
             etype=EventTypes.Topic, topic=topic, content={"topic": topic}, depth=1
         )
 
-        state = yield defer.ensureDeferred(
+        state = self.get_success(
             self.store.get_current_state(room_id=self.room.to_string())
         )
 
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 2471f1267d..f06b452fa9 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2018 New Vector Ltd
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -15,24 +15,18 @@
 
 import logging
 
-from twisted.internet import defer
-
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.room_versions import RoomVersions
 from synapse.storage.state import StateFilter
 from synapse.types import RoomID, UserID
 
-import tests.unittest
-import tests.utils
+from tests.unittest import HomeserverTestCase
 
 logger = logging.getLogger(__name__)
 
 
-class StateStoreTestCase(tests.unittest.TestCase):
-    @defer.inlineCallbacks
-    def setUp(self):
-        hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
-
+class StateStoreTestCase(HomeserverTestCase):
+    def prepare(self, reactor, clock, hs):
         self.store = hs.get_datastore()
         self.storage = hs.get_storage()
         self.state_datastore = self.storage.state.stores.state
@@ -44,7 +38,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         self.room = RoomID.from_string("!abc123:test")
 
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.store_room(
                 self.room.to_string(),
                 room_creator_user_id="@creator:text",
@@ -53,7 +47,6 @@ class StateStoreTestCase(tests.unittest.TestCase):
             )
         )
 
-    @defer.inlineCallbacks
     def inject_state_event(self, room, sender, typ, state_key, content):
         builder = self.event_builder_factory.for_room_version(
             RoomVersions.V1,
@@ -66,13 +59,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
             },
         )
 
-        event, context = yield defer.ensureDeferred(
+        event, context = self.get_success(
             self.event_creation_handler.create_new_client_event(builder)
         )
 
-        yield defer.ensureDeferred(
-            self.storage.persistence.persist_event(event, context)
-        )
+        self.get_success(self.storage.persistence.persist_event(event, context))
 
         return event
 
@@ -82,16 +73,13 @@ class StateStoreTestCase(tests.unittest.TestCase):
             self.assertEqual(s1[t].event_id, s2[t].event_id)
         self.assertEqual(len(s1), len(s2))
 
-    @defer.inlineCallbacks
     def test_get_state_groups_ids(self):
-        e1 = yield self.inject_state_event(
-            self.room, self.u_alice, EventTypes.Create, "", {}
-        )
-        e2 = yield self.inject_state_event(
+        e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
+        e2 = self.inject_state_event(
             self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
         )
 
-        state_group_map = yield defer.ensureDeferred(
+        state_group_map = self.get_success(
             self.storage.state.get_state_groups_ids(self.room, [e2.event_id])
         )
         self.assertEqual(len(state_group_map), 1)
@@ -101,16 +89,13 @@ class StateStoreTestCase(tests.unittest.TestCase):
             {(EventTypes.Create, ""): e1.event_id, (EventTypes.Name, ""): e2.event_id},
         )
 
-    @defer.inlineCallbacks
     def test_get_state_groups(self):
-        e1 = yield self.inject_state_event(
-            self.room, self.u_alice, EventTypes.Create, "", {}
-        )
-        e2 = yield self.inject_state_event(
+        e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
+        e2 = self.inject_state_event(
             self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
         )
 
-        state_group_map = yield defer.ensureDeferred(
+        state_group_map = self.get_success(
             self.storage.state.get_state_groups(self.room, [e2.event_id])
         )
         self.assertEqual(len(state_group_map), 1)
@@ -118,32 +103,29 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         self.assertEqual({ev.event_id for ev in state_list}, {e1.event_id, e2.event_id})
 
-    @defer.inlineCallbacks
     def test_get_state_for_event(self):
 
         # this defaults to a linear DAG as each new injection defaults to whatever
         # forward extremities are currently in the DB for this room.
-        e1 = yield self.inject_state_event(
-            self.room, self.u_alice, EventTypes.Create, "", {}
-        )
-        e2 = yield self.inject_state_event(
+        e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
+        e2 = self.inject_state_event(
             self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
         )
-        e3 = yield self.inject_state_event(
+        e3 = self.inject_state_event(
             self.room,
             self.u_alice,
             EventTypes.Member,
             self.u_alice.to_string(),
             {"membership": Membership.JOIN},
         )
-        e4 = yield self.inject_state_event(
+        e4 = self.inject_state_event(
             self.room,
             self.u_bob,
             EventTypes.Member,
             self.u_bob.to_string(),
             {"membership": Membership.JOIN},
         )
-        e5 = yield self.inject_state_event(
+        e5 = self.inject_state_event(
             self.room,
             self.u_bob,
             EventTypes.Member,
@@ -152,9 +134,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
         )
 
         # check we get the full state as of the final event
-        state = yield defer.ensureDeferred(
-            self.storage.state.get_state_for_event(e5.event_id)
-        )
+        state = self.get_success(self.storage.state.get_state_for_event(e5.event_id))
 
         self.assertIsNotNone(e4)
 
@@ -170,7 +150,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
         )
 
         # check we can filter to the m.room.name event (with a '' state key)
-        state = yield defer.ensureDeferred(
+        state = self.get_success(
             self.storage.state.get_state_for_event(
                 e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
             )
@@ -179,7 +159,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
         self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
 
         # check we can filter to the m.room.name event (with a wildcard None state key)
-        state = yield defer.ensureDeferred(
+        state = self.get_success(
             self.storage.state.get_state_for_event(
                 e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
             )
@@ -188,7 +168,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
         self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
 
         # check we can grab the m.room.member events (with a wildcard None state key)
-        state = yield defer.ensureDeferred(
+        state = self.get_success(
             self.storage.state.get_state_for_event(
                 e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
             )
@@ -200,7 +180,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         # check we can grab a specific room member without filtering out the
         # other event types
-        state = yield defer.ensureDeferred(
+        state = self.get_success(
             self.storage.state.get_state_for_event(
                 e5.event_id,
                 state_filter=StateFilter(
@@ -220,7 +200,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
         )
 
         # check that we can grab everything except members
-        state = yield defer.ensureDeferred(
+        state = self.get_success(
             self.storage.state.get_state_for_event(
                 e5.event_id,
                 state_filter=StateFilter(
@@ -238,17 +218,14 @@ class StateStoreTestCase(tests.unittest.TestCase):
         #######################################################
 
         room_id = self.room.to_string()
-        group_ids = yield defer.ensureDeferred(
+        group_ids = self.get_success(
             self.storage.state.get_state_groups_ids(room_id, [e5.event_id])
         )
         group = list(group_ids.keys())[0]
 
         # test _get_state_for_group_using_cache correctly filters out members
         # with types=[]
-        (
-            state_dict,
-            is_all,
-        ) = yield self.state_datastore._get_state_for_group_using_cache(
+        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
@@ -265,10 +242,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
             state_dict,
         )
 
-        (
-            state_dict,
-            is_all,
-        ) = yield self.state_datastore._get_state_for_group_using_cache(
+        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
@@ -281,10 +255,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         # test _get_state_for_group_using_cache correctly filters in members
         # with wildcard types
-        (
-            state_dict,
-            is_all,
-        ) = yield self.state_datastore._get_state_for_group_using_cache(
+        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
@@ -301,10 +272,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
             state_dict,
         )
 
-        (
-            state_dict,
-            is_all,
-        ) = yield self.state_datastore._get_state_for_group_using_cache(
+        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
@@ -324,10 +292,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         # test _get_state_for_group_using_cache correctly filters in members
         # with specific types
-        (
-            state_dict,
-            is_all,
-        ) = yield self.state_datastore._get_state_for_group_using_cache(
+        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
@@ -344,10 +309,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
             state_dict,
         )
 
-        (
-            state_dict,
-            is_all,
-        ) = yield self.state_datastore._get_state_for_group_using_cache(
+        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
@@ -360,10 +322,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         # test _get_state_for_group_using_cache correctly filters in members
         # with specific types
-        (
-            state_dict,
-            is_all,
-        ) = yield self.state_datastore._get_state_for_group_using_cache(
+        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
@@ -413,10 +372,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
         # test _get_state_for_group_using_cache correctly filters out members
         # with types=[]
         room_id = self.room.to_string()
-        (
-            state_dict,
-            is_all,
-        ) = yield self.state_datastore._get_state_for_group_using_cache(
+        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
@@ -428,10 +384,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
         self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
 
         room_id = self.room.to_string()
-        (
-            state_dict,
-            is_all,
-        ) = yield self.state_datastore._get_state_for_group_using_cache(
+        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
@@ -444,10 +397,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         # test _get_state_for_group_using_cache correctly filters in members
         # wildcard types
-        (
-            state_dict,
-            is_all,
-        ) = yield self.state_datastore._get_state_for_group_using_cache(
+        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
@@ -458,10 +408,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
         self.assertEqual(is_all, False)
         self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
 
-        (
-            state_dict,
-            is_all,
-        ) = yield self.state_datastore._get_state_for_group_using_cache(
+        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
@@ -480,10 +427,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         # test _get_state_for_group_using_cache correctly filters in members
         # with specific types
-        (
-            state_dict,
-            is_all,
-        ) = yield self.state_datastore._get_state_for_group_using_cache(
+        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
@@ -494,10 +438,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
         self.assertEqual(is_all, False)
         self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
 
-        (
-            state_dict,
-            is_all,
-        ) = yield self.state_datastore._get_state_for_group_using_cache(
+        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
@@ -510,10 +451,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         # test _get_state_for_group_using_cache correctly filters in members
         # with specific types
-        (
-            state_dict,
-            is_all,
-        ) = yield self.state_datastore._get_state_for_group_using_cache(
+        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
@@ -524,10 +462,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
         self.assertEqual(is_all, False)
         self.assertDictEqual({}, state_dict)
 
-        (
-            state_dict,
-            is_all,
-        ) = yield self.state_datastore._get_state_for_group_using_cache(
+        (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
             self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index a6f63f4aaf..019c5b7b14 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2018 New Vector Ltd
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,10 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from twisted.internet import defer
-
-from tests import unittest
-from tests.utils import setup_test_homeserver
+from tests.unittest import HomeserverTestCase, override_config
 
 ALICE = "@alice:a"
 BOB = "@bob:b"
@@ -25,73 +22,52 @@ BOBBY = "@bobby:a"
 BELA = "@somenickname:a"
 
 
-class UserDirectoryStoreTestCase(unittest.TestCase):
-    @defer.inlineCallbacks
-    def setUp(self):
-        self.hs = yield setup_test_homeserver(self.addCleanup)
-        self.store = self.hs.get_datastore()
+class UserDirectoryStoreTestCase(HomeserverTestCase):
+    def prepare(self, reactor, clock, hs):
+        self.store = hs.get_datastore()
 
         # alice and bob are both in !room_id. bobby is not but shares
         # a homeserver with alice.
-        yield defer.ensureDeferred(
-            self.store.update_profile_in_user_dir(ALICE, "alice", None)
-        )
-        yield defer.ensureDeferred(
-            self.store.update_profile_in_user_dir(BOB, "bob", None)
-        )
-        yield defer.ensureDeferred(
-            self.store.update_profile_in_user_dir(BOBBY, "bobby", None)
-        )
-        yield defer.ensureDeferred(
-            self.store.update_profile_in_user_dir(BELA, "Bela", None)
-        )
-        yield defer.ensureDeferred(
-            self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB))
-        )
+        self.get_success(self.store.update_profile_in_user_dir(ALICE, "alice", None))
+        self.get_success(self.store.update_profile_in_user_dir(BOB, "bob", None))
+        self.get_success(self.store.update_profile_in_user_dir(BOBBY, "bobby", None))
+        self.get_success(self.store.update_profile_in_user_dir(BELA, "Bela", None))
+        self.get_success(self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB)))
 
-    @defer.inlineCallbacks
     def test_search_user_dir(self):
         # normally when alice searches the directory she should just find
         # bob because bobby doesn't share a room with her.
-        r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "bob", 10))
+        r = self.get_success(self.store.search_user_dir(ALICE, "bob", 10))
         self.assertFalse(r["limited"])
         self.assertEqual(1, len(r["results"]))
         self.assertDictEqual(
             r["results"][0], {"user_id": BOB, "display_name": "bob", "avatar_url": None}
         )
 
-    @defer.inlineCallbacks
+    @override_config({"user_directory": {"search_all_users": True}})
     def test_search_user_dir_all_users(self):
-        self.hs.config.user_directory_search_all_users = True
-        try:
-            r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "bob", 10))
-            self.assertFalse(r["limited"])
-            self.assertEqual(2, len(r["results"]))
-            self.assertDictEqual(
-                r["results"][0],
-                {"user_id": BOB, "display_name": "bob", "avatar_url": None},
-            )
-            self.assertDictEqual(
-                r["results"][1],
-                {"user_id": BOBBY, "display_name": "bobby", "avatar_url": None},
-            )
-        finally:
-            self.hs.config.user_directory_search_all_users = False
+        r = self.get_success(self.store.search_user_dir(ALICE, "bob", 10))
+        self.assertFalse(r["limited"])
+        self.assertEqual(2, len(r["results"]))
+        self.assertDictEqual(
+            r["results"][0],
+            {"user_id": BOB, "display_name": "bob", "avatar_url": None},
+        )
+        self.assertDictEqual(
+            r["results"][1],
+            {"user_id": BOBBY, "display_name": "bobby", "avatar_url": None},
+        )
 
-    @defer.inlineCallbacks
+    @override_config({"user_directory": {"search_all_users": True}})
     def test_search_user_dir_stop_words(self):
         """Tests that a user can look up another user by searching for the start if its
         display name even if that name happens to be a common English word that would
         usually be ignored in full text searches.
         """
-        self.hs.config.user_directory_search_all_users = True
-        try:
-            r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "be", 10))
-            self.assertFalse(r["limited"])
-            self.assertEqual(1, len(r["results"]))
-            self.assertDictEqual(
-                r["results"][0],
-                {"user_id": BELA, "display_name": "Bela", "avatar_url": None},
-            )
-        finally:
-            self.hs.config.user_directory_search_all_users = False
+        r = self.get_success(self.store.search_user_dir(ALICE, "be", 10))
+        self.assertFalse(r["limited"])
+        self.assertEqual(1, len(r["results"]))
+        self.assertDictEqual(
+            r["results"][0],
+            {"user_id": BELA, "display_name": "Bela", "avatar_url": None},
+        )