summary refs log tree commit diff
path: root/tests/storage
diff options
context:
space:
mode:
Diffstat (limited to 'tests/storage')
-rw-r--r--tests/storage/test__base.py34
-rw-r--r--tests/storage/test_appservice.py71
-rw-r--r--tests/storage/test_background_update.py20
-rw-r--r--tests/storage/test_base.py87
-rw-r--r--tests/storage/test_cleanup_extrems.py21
-rw-r--r--tests/storage/test_client_ips.py31
-rw-r--r--tests/storage/test_devices.py52
-rw-r--r--tests/storage/test_directory.py38
-rw-r--r--tests/storage/test_end_to_end_keys.py62
-rw-r--r--tests/storage/test_event_federation.py16
-rw-r--r--tests/storage/test_event_metrics.py2
-rw-r--r--tests/storage/test_event_push_actions.py138
-rw-r--r--tests/storage/test_id_generators.py217
-rw-r--r--tests/storage/test_main.py14
-rw-r--r--tests/storage/test_monthly_active_users.py31
-rw-r--r--tests/storage/test_profile.py27
-rw-r--r--tests/storage/test_purge.py51
-rw-r--r--tests/storage/test_redaction.py12
-rw-r--r--tests/storage/test_registration.py95
-rw-r--r--tests/storage/test_room.py62
-rw-r--r--tests/storage/test_roommember.py76
-rw-r--r--tests/storage/test_state.py80
-rw-r--r--tests/storage/test_user_directory.py20
23 files changed, 856 insertions, 401 deletions
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 5a50e4fdd4..f5afed017c 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -99,7 +99,7 @@ class CacheTestCase(unittest.HomeserverTestCase):
 class CacheDecoratorTestCase(unittest.HomeserverTestCase):
     @defer.inlineCallbacks
     def test_passthrough(self):
-        class A(object):
+        class A:
             @cached()
             def func(self, key):
                 return key
@@ -113,7 +113,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
     def test_hit(self):
         callcount = [0]
 
-        class A(object):
+        class A:
             @cached()
             def func(self, key):
                 callcount[0] += 1
@@ -131,7 +131,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
     def test_invalidate(self):
         callcount = [0]
 
-        class A(object):
+        class A:
             @cached()
             def func(self, key):
                 callcount[0] += 1
@@ -149,7 +149,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
         self.assertEquals(callcount[0], 2)
 
     def test_invalidate_missing(self):
-        class A(object):
+        class A:
             @cached()
             def func(self, key):
                 return key
@@ -160,7 +160,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
     def test_max_entries(self):
         callcount = [0]
 
-        class A(object):
+        class A:
             @cached(max_entries=10)
             def func(self, key):
                 callcount[0] += 1
@@ -187,7 +187,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
 
         d = defer.succeed(123)
 
-        class A(object):
+        class A:
             @cached()
             def func(self, key):
                 callcount[0] += 1
@@ -205,7 +205,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
         callcount = [0]
         callcount2 = [0]
 
-        class A(object):
+        class A:
             @cached()
             def func(self, key):
                 callcount[0] += 1
@@ -238,7 +238,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
         callcount = [0]
         callcount2 = [0]
 
-        class A(object):
+        class A:
             @cached(max_entries=2)
             def func(self, key):
                 callcount[0] += 1
@@ -275,7 +275,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
         callcount = [0]
         callcount2 = [0]
 
-        class A(object):
+        class A:
             @cached()
             def func(self, key):
                 callcount[0] += 1
@@ -323,7 +323,7 @@ class UpsertManyTests(unittest.HomeserverTestCase):
 
         self.table_name = "table_" + hs.get_secrets().token_hex(6)
         self.get_success(
-            self.storage.db.runInteraction(
+            self.storage.db_pool.runInteraction(
                 "create",
                 lambda x, *a: x.execute(*a),
                 "CREATE TABLE %s (id INTEGER, username TEXT, value TEXT)"
@@ -331,7 +331,7 @@ class UpsertManyTests(unittest.HomeserverTestCase):
             )
         )
         self.get_success(
-            self.storage.db.runInteraction(
+            self.storage.db_pool.runInteraction(
                 "index",
                 lambda x, *a: x.execute(*a),
                 "CREATE UNIQUE INDEX %sindex ON %s(id, username)"
@@ -354,9 +354,9 @@ class UpsertManyTests(unittest.HomeserverTestCase):
         value_values = [["hello"], ["there"]]
 
         self.get_success(
-            self.storage.db.runInteraction(
+            self.storage.db_pool.runInteraction(
                 "test",
-                self.storage.db.simple_upsert_many_txn,
+                self.storage.db_pool.simple_upsert_many_txn,
                 self.table_name,
                 key_names,
                 key_values,
@@ -367,7 +367,7 @@ class UpsertManyTests(unittest.HomeserverTestCase):
 
         # Check results are what we expect
         res = self.get_success(
-            self.storage.db.simple_select_list(
+            self.storage.db_pool.simple_select_list(
                 self.table_name, None, ["id, username, value"]
             )
         )
@@ -381,9 +381,9 @@ class UpsertManyTests(unittest.HomeserverTestCase):
         value_values = [["bleb"]]
 
         self.get_success(
-            self.storage.db.runInteraction(
+            self.storage.db_pool.runInteraction(
                 "test",
-                self.storage.db.simple_upsert_many_txn,
+                self.storage.db_pool.simple_upsert_many_txn,
                 self.table_name,
                 key_names,
                 key_values,
@@ -394,7 +394,7 @@ class UpsertManyTests(unittest.HomeserverTestCase):
 
         # Check results are what we expect
         res = self.get_success(
-            self.storage.db.simple_select_list(
+            self.storage.db_pool.simple_select_list(
                 self.table_name, None, ["id, username, value"]
             )
         )
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index ef296e7dab..cb808d4de4 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -24,13 +24,14 @@ from twisted.internet import defer
 
 from synapse.appservice import ApplicationService, ApplicationServiceState
 from synapse.config._base import ConfigError
-from synapse.storage.data_stores.main.appservice import (
+from synapse.storage.database import DatabasePool, make_conn
+from synapse.storage.databases.main.appservice import (
     ApplicationServiceStore,
     ApplicationServiceTransactionStore,
 )
-from synapse.storage.database import Database, make_conn
 
 from tests import unittest
+from tests.test_utils import make_awaitable
 from tests.utils import setup_test_homeserver
 
 
@@ -178,14 +179,14 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_get_appservice_state_none(self):
         service = Mock(id="999")
-        state = yield self.store.get_appservice_state(service)
+        state = yield defer.ensureDeferred(self.store.get_appservice_state(service))
         self.assertEquals(None, state)
 
     @defer.inlineCallbacks
     def test_get_appservice_state_up(self):
         yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.UP)
         service = Mock(id=self.as_list[0]["id"])
-        state = yield self.store.get_appservice_state(service)
+        state = yield defer.ensureDeferred(self.store.get_appservice_state(service))
         self.assertEquals(ApplicationServiceState.UP, state)
 
     @defer.inlineCallbacks
@@ -194,20 +195,22 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.DOWN)
         yield self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN)
         service = Mock(id=self.as_list[1]["id"])
-        state = yield self.store.get_appservice_state(service)
+        state = yield defer.ensureDeferred(self.store.get_appservice_state(service))
         self.assertEquals(ApplicationServiceState.DOWN, state)
 
     @defer.inlineCallbacks
     def test_get_appservices_by_state_none(self):
-        services = yield self.store.get_appservices_by_state(
-            ApplicationServiceState.DOWN
+        services = yield defer.ensureDeferred(
+            self.store.get_appservices_by_state(ApplicationServiceState.DOWN)
         )
         self.assertEquals(0, len(services))
 
     @defer.inlineCallbacks
     def test_set_appservices_state_down(self):
         service = Mock(id=self.as_list[1]["id"])
-        yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
+        yield defer.ensureDeferred(
+            self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
+        )
         rows = yield self.db_pool.runQuery(
             self.engine.convert_param_style(
                 "SELECT as_id FROM application_services_state WHERE state=?"
@@ -219,9 +222,15 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_set_appservices_state_multiple_up(self):
         service = Mock(id=self.as_list[1]["id"])
-        yield self.store.set_appservice_state(service, ApplicationServiceState.UP)
-        yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
-        yield self.store.set_appservice_state(service, ApplicationServiceState.UP)
+        yield defer.ensureDeferred(
+            self.store.set_appservice_state(service, ApplicationServiceState.UP)
+        )
+        yield defer.ensureDeferred(
+            self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
+        )
+        yield defer.ensureDeferred(
+            self.store.set_appservice_state(service, ApplicationServiceState.UP)
+        )
         rows = yield self.db_pool.runQuery(
             self.engine.convert_param_style(
                 "SELECT as_id FROM application_services_state WHERE state=?"
@@ -234,7 +243,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
     def test_create_appservice_txn_first(self):
         service = Mock(id=self.as_list[0]["id"])
         events = [Mock(event_id="e1"), Mock(event_id="e2")]
-        txn = yield self.store.create_appservice_txn(service, events)
+        txn = yield defer.ensureDeferred(
+            self.store.create_appservice_txn(service, events)
+        )
         self.assertEquals(txn.id, 1)
         self.assertEquals(txn.events, events)
         self.assertEquals(txn.service, service)
@@ -246,7 +257,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         yield self._set_last_txn(service.id, 9643)  # AS is falling behind
         yield self._insert_txn(service.id, 9644, events)
         yield self._insert_txn(service.id, 9645, events)
-        txn = yield self.store.create_appservice_txn(service, events)
+        txn = yield defer.ensureDeferred(
+            self.store.create_appservice_txn(service, events)
+        )
         self.assertEquals(txn.id, 9646)
         self.assertEquals(txn.events, events)
         self.assertEquals(txn.service, service)
@@ -256,7 +269,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         service = Mock(id=self.as_list[0]["id"])
         events = [Mock(event_id="e1"), Mock(event_id="e2")]
         yield self._set_last_txn(service.id, 9643)
-        txn = yield self.store.create_appservice_txn(service, events)
+        txn = yield defer.ensureDeferred(
+            self.store.create_appservice_txn(service, events)
+        )
         self.assertEquals(txn.id, 9644)
         self.assertEquals(txn.events, events)
         self.assertEquals(txn.service, service)
@@ -277,7 +292,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         yield self._insert_txn(self.as_list[2]["id"], 10, events)
         yield self._insert_txn(self.as_list[3]["id"], 9643, events)
 
-        txn = yield self.store.create_appservice_txn(service, events)
+        txn = yield defer.ensureDeferred(
+            self.store.create_appservice_txn(service, events)
+        )
         self.assertEquals(txn.id, 9644)
         self.assertEquals(txn.events, events)
         self.assertEquals(txn.service, service)
@@ -289,7 +306,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         txn_id = 1
 
         yield self._insert_txn(service.id, txn_id, events)
-        yield self.store.complete_appservice_txn(txn_id=txn_id, service=service)
+        yield defer.ensureDeferred(
+            self.store.complete_appservice_txn(txn_id=txn_id, service=service)
+        )
 
         res = yield self.db_pool.runQuery(
             self.engine.convert_param_style(
@@ -315,7 +334,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         txn_id = 5
         yield self._set_last_txn(service.id, 4)
         yield self._insert_txn(service.id, txn_id, events)
-        yield self.store.complete_appservice_txn(txn_id=txn_id, service=service)
+        yield defer.ensureDeferred(
+            self.store.complete_appservice_txn(txn_id=txn_id, service=service)
+        )
 
         res = yield self.db_pool.runQuery(
             self.engine.convert_param_style(
@@ -339,7 +360,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
     def test_get_oldest_unsent_txn_none(self):
         service = Mock(id=self.as_list[0]["id"])
 
-        txn = yield self.store.get_oldest_unsent_txn(service)
+        txn = yield defer.ensureDeferred(self.store.get_oldest_unsent_txn(service))
         self.assertEquals(None, txn)
 
     @defer.inlineCallbacks
@@ -349,14 +370,14 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         other_events = [Mock(event_id="e5"), Mock(event_id="e6")]
 
         # we aren't testing store._base stuff here, so mock this out
-        self.store.get_events_as_list = Mock(return_value=events)
+        self.store.get_events_as_list = Mock(return_value=make_awaitable(events))
 
         yield self._insert_txn(self.as_list[1]["id"], 9, other_events)
         yield self._insert_txn(service.id, 10, events)
         yield self._insert_txn(service.id, 11, other_events)
         yield self._insert_txn(service.id, 12, other_events)
 
-        txn = yield self.store.get_oldest_unsent_txn(service)
+        txn = yield defer.ensureDeferred(self.store.get_oldest_unsent_txn(service))
         self.assertEquals(service, txn.service)
         self.assertEquals(10, txn.id)
         self.assertEquals(events, txn.events)
@@ -366,8 +387,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.DOWN)
         yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.UP)
 
-        services = yield self.store.get_appservices_by_state(
-            ApplicationServiceState.DOWN
+        services = yield defer.ensureDeferred(
+            self.store.get_appservices_by_state(ApplicationServiceState.DOWN)
         )
         self.assertEquals(1, len(services))
         self.assertEquals(self.as_list[0]["id"], services[0].id)
@@ -379,8 +400,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         yield self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN)
         yield self._set_state(self.as_list[3]["id"], ApplicationServiceState.UP)
 
-        services = yield self.store.get_appservices_by_state(
-            ApplicationServiceState.DOWN
+        services = yield defer.ensureDeferred(
+            self.store.get_appservices_by_state(ApplicationServiceState.DOWN)
         )
         self.assertEquals(2, len(services))
         self.assertEquals(
@@ -391,7 +412,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
 
 # required for ApplicationServiceTransactionStoreTestCase tests
 class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore):
-    def __init__(self, database: Database, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs):
         super(TestTransactionStore, self).__init__(database, db_conn, hs)
 
 
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index 940b166129..02aae1c13d 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -1,7 +1,5 @@
 from mock import Mock
 
-from twisted.internet import defer
-
 from synapse.storage.background_updates import BackgroundUpdater
 
 from tests import unittest
@@ -9,7 +7,9 @@ from tests import unittest
 
 class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
     def prepare(self, reactor, clock, homeserver):
-        self.updates = self.hs.get_datastore().db.updates  # type: BackgroundUpdater
+        self.updates = (
+            self.hs.get_datastore().db_pool.updates
+        )  # type: BackgroundUpdater
         # the base test class should have run the real bg updates for us
         self.assertTrue(
             self.get_success(self.updates.has_completed_background_updates())
@@ -29,18 +29,17 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
 
         store = self.hs.get_datastore()
         self.get_success(
-            store.db.simple_insert(
+            store.db_pool.simple_insert(
                 "background_updates",
                 values={"update_name": "test_update", "progress_json": '{"my_key": 1}'},
             )
         )
 
         # first step: make a bit of progress
-        @defer.inlineCallbacks
-        def update(progress, count):
-            yield self.clock.sleep((count * duration_ms) / 1000)
+        async def update(progress, count):
+            await self.clock.sleep((count * duration_ms) / 1000)
             progress = {"my_key": progress["my_key"] + 1}
-            yield store.db.runInteraction(
+            await store.db_pool.runInteraction(
                 "update_progress",
                 self.updates._background_update_progress_txn,
                 "test_update",
@@ -65,13 +64,12 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
 
         # second step: complete the update
         # we should now get run with a much bigger number of items to update
-        @defer.inlineCallbacks
-        def update(progress, count):
+        async def update(progress, count):
             self.assertEqual(progress, {"my_key": 2})
             self.assertAlmostEqual(
                 count, target_background_update_duration_ms / duration_ms, places=0,
             )
-            yield self.updates._end_background_update("test_update")
+            await self.updates._end_background_update("test_update")
             return count
 
         self.update_handler.side_effect = update
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index 278961c331..40ba652248 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -21,11 +21,11 @@ from mock import Mock
 from twisted.internet import defer
 
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
 from synapse.storage.engines import create_engine
 
 from tests import unittest
-from tests.utils import TestHomeServer
+from tests.utils import TestHomeServer, default_config
 
 
 class SQLBaseStoreTestCase(unittest.TestCase):
@@ -49,10 +49,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
 
         self.db_pool.runWithConnection = runWithConnection
 
-        config = Mock()
-        config._disable_native_upserts = True
-        config.caches = Mock()
-        config.caches.event_cache_size = 1
+        config = default_config(name="test", parse=True)
         hs = TestHomeServer("test", config=config)
 
         sqlite_config = {"name": "sqlite3"}
@@ -60,7 +57,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
         fake_engine = Mock(wraps=engine)
         fake_engine.can_native_upsert = False
 
-        db = Database(Mock(), Mock(config=sqlite_config), fake_engine)
+        db = DatabasePool(Mock(), Mock(config=sqlite_config), fake_engine)
         db._db_pool = self.db_pool
 
         self.datastore = SQLBaseStore(db, None, hs)
@@ -69,8 +66,10 @@ class SQLBaseStoreTestCase(unittest.TestCase):
     def test_insert_1col(self):
         self.mock_txn.rowcount = 1
 
-        yield self.datastore.db.simple_insert(
-            table="tablename", values={"columname": "Value"}
+        yield defer.ensureDeferred(
+            self.datastore.db_pool.simple_insert(
+                table="tablename", values={"columname": "Value"}
+            )
         )
 
         self.mock_txn.execute.assert_called_with(
@@ -81,10 +80,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
     def test_insert_3cols(self):
         self.mock_txn.rowcount = 1
 
-        yield self.datastore.db.simple_insert(
-            table="tablename",
-            # Use OrderedDict() so we can assert on the SQL generated
-            values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]),
+        yield defer.ensureDeferred(
+            self.datastore.db_pool.simple_insert(
+                table="tablename",
+                # Use OrderedDict() so we can assert on the SQL generated
+                values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]),
+            )
         )
 
         self.mock_txn.execute.assert_called_with(
@@ -96,8 +97,10 @@ class SQLBaseStoreTestCase(unittest.TestCase):
         self.mock_txn.rowcount = 1
         self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)]))
 
-        value = yield self.datastore.db.simple_select_one_onecol(
-            table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol"
+        value = yield defer.ensureDeferred(
+            self.datastore.db_pool.simple_select_one_onecol(
+                table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol"
+            )
         )
 
         self.assertEquals("Value", value)
@@ -110,10 +113,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
         self.mock_txn.rowcount = 1
         self.mock_txn.fetchone.return_value = (1, 2, 3)
 
-        ret = yield self.datastore.db.simple_select_one(
-            table="tablename",
-            keyvalues={"keycol": "TheKey"},
-            retcols=["colA", "colB", "colC"],
+        ret = yield defer.ensureDeferred(
+            self.datastore.db_pool.simple_select_one(
+                table="tablename",
+                keyvalues={"keycol": "TheKey"},
+                retcols=["colA", "colB", "colC"],
+            )
         )
 
         self.assertEquals({"colA": 1, "colB": 2, "colC": 3}, ret)
@@ -126,11 +131,13 @@ class SQLBaseStoreTestCase(unittest.TestCase):
         self.mock_txn.rowcount = 0
         self.mock_txn.fetchone.return_value = None
 
-        ret = yield self.datastore.db.simple_select_one(
-            table="tablename",
-            keyvalues={"keycol": "Not here"},
-            retcols=["colA"],
-            allow_none=True,
+        ret = yield defer.ensureDeferred(
+            self.datastore.db_pool.simple_select_one(
+                table="tablename",
+                keyvalues={"keycol": "Not here"},
+                retcols=["colA"],
+                allow_none=True,
+            )
         )
 
         self.assertFalse(ret)
@@ -141,8 +148,10 @@ class SQLBaseStoreTestCase(unittest.TestCase):
         self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)]))
         self.mock_txn.description = (("colA", None, None, None, None, None, None),)
 
-        ret = yield self.datastore.db.simple_select_list(
-            table="tablename", keyvalues={"keycol": "A set"}, retcols=["colA"]
+        ret = yield defer.ensureDeferred(
+            self.datastore.db_pool.simple_select_list(
+                table="tablename", keyvalues={"keycol": "A set"}, retcols=["colA"]
+            )
         )
 
         self.assertEquals([{"colA": 1}, {"colA": 2}, {"colA": 3}], ret)
@@ -154,10 +163,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
     def test_update_one_1col(self):
         self.mock_txn.rowcount = 1
 
-        yield self.datastore.db.simple_update_one(
-            table="tablename",
-            keyvalues={"keycol": "TheKey"},
-            updatevalues={"columnname": "New Value"},
+        yield defer.ensureDeferred(
+            self.datastore.db_pool.simple_update_one(
+                table="tablename",
+                keyvalues={"keycol": "TheKey"},
+                updatevalues={"columnname": "New Value"},
+            )
         )
 
         self.mock_txn.execute.assert_called_with(
@@ -169,10 +180,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
     def test_update_one_4cols(self):
         self.mock_txn.rowcount = 1
 
-        yield self.datastore.db.simple_update_one(
-            table="tablename",
-            keyvalues=OrderedDict([("colA", 1), ("colB", 2)]),
-            updatevalues=OrderedDict([("colC", 3), ("colD", 4)]),
+        yield defer.ensureDeferred(
+            self.datastore.db_pool.simple_update_one(
+                table="tablename",
+                keyvalues=OrderedDict([("colA", 1), ("colB", 2)]),
+                updatevalues=OrderedDict([("colC", 3), ("colD", 4)]),
+            )
         )
 
         self.mock_txn.execute.assert_called_with(
@@ -184,8 +197,10 @@ class SQLBaseStoreTestCase(unittest.TestCase):
     def test_delete_one(self):
         self.mock_txn.rowcount = 1
 
-        yield self.datastore.db.simple_delete_one(
-            table="tablename", keyvalues={"keycol": "Go away"}
+        yield defer.ensureDeferred(
+            self.datastore.db_pool.simple_delete_one(
+                table="tablename", keyvalues={"keycol": "Go away"}
+            )
         )
 
         self.mock_txn.execute.assert_called_with(
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index 43425c969a..080761d1d2 100644
--- a/tests/storage/test_cleanup_extrems.py
+++ b/tests/storage/test_cleanup_extrems.py
@@ -38,7 +38,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
 
         # Create a test user and room
         self.user = UserID("alice", "test")
-        self.requester = Requester(self.user, None, False, None, None)
+        self.requester = Requester(self.user, None, False, False, None, None)
         info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
         self.room_id = info["room_id"]
 
@@ -47,12 +47,12 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
         """
         # Make sure we don't clash with in progress updates.
         self.assertTrue(
-            self.store.db.updates._all_done, "Background updates are still ongoing"
+            self.store.db_pool.updates._all_done, "Background updates are still ongoing"
         )
 
         schema_path = os.path.join(
             prepare_database.dir_path,
-            "data_stores",
+            "databases",
             "main",
             "schema",
             "delta",
@@ -64,19 +64,19 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
             prepare_database.executescript(txn, schema_path)
 
         self.get_success(
-            self.store.db.runInteraction(
+            self.store.db_pool.runInteraction(
                 "test_delete_forward_extremities", run_delta_file
             )
         )
 
         # Ugh, have to reset this flag
-        self.store.db.updates._all_done = False
+        self.store.db_pool.updates._all_done = False
 
         while not self.get_success(
-            self.store.db.updates.has_completed_background_updates()
+            self.store.db_pool.updates.has_completed_background_updates()
         ):
             self.get_success(
-                self.store.db.updates.do_next_background_update(100), by=0.1
+                self.store.db_pool.updates.do_next_background_update(100), by=0.1
             )
 
     def test_soft_failed_extremities_handled_correctly(self):
@@ -260,7 +260,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
         # Create a test user and room
         self.user = UserID.from_string(self.register_user("user1", "password"))
         self.token1 = self.login("user1", "password")
-        self.requester = Requester(self.user, None, False, None, None)
+        self.requester = Requester(self.user, None, False, False, None, None)
         info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
         self.room_id = info["room_id"]
         self.event_creator = homeserver.get_event_creation_handler()
@@ -271,7 +271,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
 
         # Pump the reactor repeatedly so that the background updates have a
         # chance to run.
-        self.pump(10 * 60)
+        self.pump(20)
 
         latest_event_ids = self.get_success(
             self.store.get_latest_event_ids_in_room(self.room_id)
@@ -353,6 +353,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
         self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion[
             "3"
         ] = 300000
+
         self.event_creator_handler._expire_rooms_to_exclude_from_dummy_event_insertion()
         # All entries within time frame
         self.assertEqual(
@@ -362,7 +363,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
             3,
         )
         # Oldest room to expire
-        self.pump(1)
+        self.pump(1.01)
         self.event_creator_handler._expire_rooms_to_exclude_from_dummy_event_insertion()
         self.assertEqual(
             len(
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 3b483bc7f0..370c247e16 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -16,13 +16,12 @@
 
 from mock import Mock
 
-from twisted.internet import defer
-
 import synapse.rest.admin
 from synapse.http.site import XForwardedForRequest
 from synapse.rest.client.v1 import login
 
 from tests import unittest
+from tests.test_utils import make_awaitable
 from tests.unittest import override_config
 
 
@@ -86,7 +85,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         self.pump(0)
 
         result = self.get_success(
-            self.store.db.simple_select_list(
+            self.store.db_pool.simple_select_list(
                 table="user_ips",
                 keyvalues={"user_id": user_id},
                 retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
@@ -117,7 +116,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         self.pump(0)
 
         result = self.get_success(
-            self.store.db.simple_select_list(
+            self.store.db_pool.simple_select_list(
                 table="user_ips",
                 keyvalues={"user_id": user_id},
                 retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
@@ -155,7 +154,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         user_id = "@user:server"
 
         self.store.get_monthly_active_count = Mock(
-            return_value=defer.succeed(lots_of_users)
+            side_effect=lambda: make_awaitable(lots_of_users)
         )
         self.get_success(
             self.store.insert_client_ip(
@@ -204,10 +203,10 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
     def test_devices_last_seen_bg_update(self):
         # First make sure we have completed all updates.
         while not self.get_success(
-            self.store.db.updates.has_completed_background_updates()
+            self.store.db_pool.updates.has_completed_background_updates()
         ):
             self.get_success(
-                self.store.db.updates.do_next_background_update(100), by=0.1
+                self.store.db_pool.updates.do_next_background_update(100), by=0.1
             )
 
         user_id = "@user:id"
@@ -225,7 +224,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
 
         # But clear the associated entry in devices table
         self.get_success(
-            self.store.db.simple_update(
+            self.store.db_pool.simple_update(
                 table="devices",
                 keyvalues={"user_id": user_id, "device_id": device_id},
                 updatevalues={"last_seen": None, "ip": None, "user_agent": None},
@@ -252,7 +251,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
 
         # Register the background update to run again.
         self.get_success(
-            self.store.db.simple_insert(
+            self.store.db_pool.simple_insert(
                 table="background_updates",
                 values={
                     "update_name": "devices_last_seen",
@@ -263,14 +262,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         )
 
         # ... and tell the DataStore that it hasn't finished all updates yet
-        self.store.db.updates._all_done = False
+        self.store.db_pool.updates._all_done = False
 
         # Now let's actually drive the updates to completion
         while not self.get_success(
-            self.store.db.updates.has_completed_background_updates()
+            self.store.db_pool.updates.has_completed_background_updates()
         ):
             self.get_success(
-                self.store.db.updates.do_next_background_update(100), by=0.1
+                self.store.db_pool.updates.do_next_background_update(100), by=0.1
             )
 
         # We should now get the correct result again
@@ -293,10 +292,10 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
     def test_old_user_ips_pruned(self):
         # First make sure we have completed all updates.
         while not self.get_success(
-            self.store.db.updates.has_completed_background_updates()
+            self.store.db_pool.updates.has_completed_background_updates()
         ):
             self.get_success(
-                self.store.db.updates.do_next_background_update(100), by=0.1
+                self.store.db_pool.updates.do_next_background_update(100), by=0.1
             )
 
         user_id = "@user:id"
@@ -315,7 +314,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
 
         # We should see that in the DB
         result = self.get_success(
-            self.store.db.simple_select_list(
+            self.store.db_pool.simple_select_list(
                 table="user_ips",
                 keyvalues={"user_id": user_id},
                 retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
@@ -341,7 +340,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
 
         # We should get no results.
         result = self.get_success(
-            self.store.db.simple_select_list(
+            self.store.db_pool.simple_select_list(
                 table="user_ips",
                 keyvalues={"user_id": user_id},
                 retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index c2539b353a..34ae8c9da7 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -34,9 +34,11 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_store_new_device(self):
-        yield self.store.store_device("user_id", "device_id", "display_name")
+        yield defer.ensureDeferred(
+            self.store.store_device("user_id", "device_id", "display_name")
+        )
 
-        res = yield self.store.get_device("user_id", "device_id")
+        res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
         self.assertDictContainsSubset(
             {
                 "user_id": "user_id",
@@ -48,11 +50,17 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_get_devices_by_user(self):
-        yield self.store.store_device("user_id", "device1", "display_name 1")
-        yield self.store.store_device("user_id", "device2", "display_name 2")
-        yield self.store.store_device("user_id2", "device3", "display_name 3")
+        yield defer.ensureDeferred(
+            self.store.store_device("user_id", "device1", "display_name 1")
+        )
+        yield defer.ensureDeferred(
+            self.store.store_device("user_id", "device2", "display_name 2")
+        )
+        yield defer.ensureDeferred(
+            self.store.store_device("user_id2", "device3", "display_name 3")
+        )
 
-        res = yield self.store.get_devices_by_user("user_id")
+        res = yield defer.ensureDeferred(self.store.get_devices_by_user("user_id"))
         self.assertEqual(2, len(res.keys()))
         self.assertDictContainsSubset(
             {
@@ -76,13 +84,13 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
         device_ids = ["device_id1", "device_id2"]
 
         # Add two device updates with a single stream_id
-        yield self.store.add_device_change_to_streams(
-            "user_id", device_ids, ["somehost"]
+        yield defer.ensureDeferred(
+            self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
         )
 
         # Get all device updates ever meant for this remote
-        now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
-            "somehost", -1, limit=100
+        now_stream_id, device_updates = yield defer.ensureDeferred(
+            self.store.get_device_updates_by_remote("somehost", -1, limit=100)
         )
 
         # Check original device_ids are contained within these updates
@@ -99,29 +107,35 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_update_device(self):
-        yield self.store.store_device("user_id", "device_id", "display_name 1")
+        yield defer.ensureDeferred(
+            self.store.store_device("user_id", "device_id", "display_name 1")
+        )
 
-        res = yield self.store.get_device("user_id", "device_id")
+        res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
         self.assertEqual("display_name 1", res["display_name"])
 
         # do a no-op first
-        yield self.store.update_device("user_id", "device_id")
-        res = yield self.store.get_device("user_id", "device_id")
+        yield defer.ensureDeferred(self.store.update_device("user_id", "device_id"))
+        res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
         self.assertEqual("display_name 1", res["display_name"])
 
         # do the update
-        yield self.store.update_device(
-            "user_id", "device_id", new_display_name="display_name 2"
+        yield defer.ensureDeferred(
+            self.store.update_device(
+                "user_id", "device_id", new_display_name="display_name 2"
+            )
         )
 
         # check it worked
-        res = yield self.store.get_device("user_id", "device_id")
+        res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
         self.assertEqual("display_name 2", res["display_name"])
 
     @defer.inlineCallbacks
     def test_update_unknown_device(self):
         with self.assertRaises(synapse.api.errors.StoreError) as cm:
-            yield self.store.update_device(
-                "user_id", "unknown_device_id", new_display_name="display_name 2"
+            yield defer.ensureDeferred(
+                self.store.update_device(
+                    "user_id", "unknown_device_id", new_display_name="display_name 2"
+                )
             )
         self.assertEqual(404, cm.exception.code)
diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py
index 4e128e1047..da93ca3980 100644
--- a/tests/storage/test_directory.py
+++ b/tests/storage/test_directory.py
@@ -34,35 +34,53 @@ class DirectoryStoreTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_room_to_alias(self):
-        yield self.store.create_room_alias_association(
-            room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
+        yield defer.ensureDeferred(
+            self.store.create_room_alias_association(
+                room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
+            )
         )
 
         self.assertEquals(
             ["#my-room:test"],
-            (yield self.store.get_aliases_for_room(self.room.to_string())),
+            (
+                yield defer.ensureDeferred(
+                    self.store.get_aliases_for_room(self.room.to_string())
+                )
+            ),
         )
 
     @defer.inlineCallbacks
     def test_alias_to_room(self):
-        yield self.store.create_room_alias_association(
-            room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
+        yield defer.ensureDeferred(
+            self.store.create_room_alias_association(
+                room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
+            )
         )
 
         self.assertObjectHasAttributes(
             {"room_id": self.room.to_string(), "servers": ["test"]},
-            (yield self.store.get_association_from_room_alias(self.alias)),
+            (
+                yield defer.ensureDeferred(
+                    self.store.get_association_from_room_alias(self.alias)
+                )
+            ),
         )
 
     @defer.inlineCallbacks
     def test_delete_alias(self):
-        yield self.store.create_room_alias_association(
-            room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
+        yield defer.ensureDeferred(
+            self.store.create_room_alias_association(
+                room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
+            )
         )
 
-        room_id = yield self.store.delete_room_alias(self.alias)
+        room_id = yield defer.ensureDeferred(self.store.delete_room_alias(self.alias))
         self.assertEqual(self.room.to_string(), room_id)
 
         self.assertIsNone(
-            (yield self.store.get_association_from_room_alias(self.alias))
+            (
+                yield defer.ensureDeferred(
+                    self.store.get_association_from_room_alias(self.alias)
+                )
+            )
         )
diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py
index 398d546280..3fc4bb13b6 100644
--- a/tests/storage/test_end_to_end_keys.py
+++ b/tests/storage/test_end_to_end_keys.py
@@ -30,11 +30,15 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
         now = 1470174257070
         json = {"key": "value"}
 
-        yield self.store.store_device("user", "device", None)
+        yield defer.ensureDeferred(self.store.store_device("user", "device", None))
 
-        yield self.store.set_e2e_device_keys("user", "device", now, json)
+        yield defer.ensureDeferred(
+            self.store.set_e2e_device_keys("user", "device", now, json)
+        )
 
-        res = yield self.store.get_e2e_device_keys((("user", "device"),))
+        res = yield defer.ensureDeferred(
+            self.store.get_e2e_device_keys_for_cs_api((("user", "device"),))
+        )
         self.assertIn("user", res)
         self.assertIn("device", res["user"])
         dev = res["user"]["device"]
@@ -45,14 +49,18 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
         now = 1470174257070
         json = {"key": "value"}
 
-        yield self.store.store_device("user", "device", None)
+        yield defer.ensureDeferred(self.store.store_device("user", "device", None))
 
-        changed = yield self.store.set_e2e_device_keys("user", "device", now, json)
+        changed = yield defer.ensureDeferred(
+            self.store.set_e2e_device_keys("user", "device", now, json)
+        )
         self.assertTrue(changed)
 
         # If we try to upload the same key then we should be told nothing
         # changed
-        changed = yield self.store.set_e2e_device_keys("user", "device", now, json)
+        changed = yield defer.ensureDeferred(
+            self.store.set_e2e_device_keys("user", "device", now, json)
+        )
         self.assertFalse(changed)
 
     @defer.inlineCallbacks
@@ -60,10 +68,16 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
         now = 1470174257070
         json = {"key": "value"}
 
-        yield self.store.set_e2e_device_keys("user", "device", now, json)
-        yield self.store.store_device("user", "device", "display_name")
+        yield defer.ensureDeferred(
+            self.store.set_e2e_device_keys("user", "device", now, json)
+        )
+        yield defer.ensureDeferred(
+            self.store.store_device("user", "device", "display_name")
+        )
 
-        res = yield self.store.get_e2e_device_keys((("user", "device"),))
+        res = yield defer.ensureDeferred(
+            self.store.get_e2e_device_keys_for_cs_api((("user", "device"),))
+        )
         self.assertIn("user", res)
         self.assertIn("device", res["user"])
         dev = res["user"]["device"]
@@ -75,18 +89,28 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
     def test_multiple_devices(self):
         now = 1470174257070
 
-        yield self.store.store_device("user1", "device1", None)
-        yield self.store.store_device("user1", "device2", None)
-        yield self.store.store_device("user2", "device1", None)
-        yield self.store.store_device("user2", "device2", None)
+        yield defer.ensureDeferred(self.store.store_device("user1", "device1", None))
+        yield defer.ensureDeferred(self.store.store_device("user1", "device2", None))
+        yield defer.ensureDeferred(self.store.store_device("user2", "device1", None))
+        yield defer.ensureDeferred(self.store.store_device("user2", "device2", None))
 
-        yield self.store.set_e2e_device_keys("user1", "device1", now, {"key": "json11"})
-        yield self.store.set_e2e_device_keys("user1", "device2", now, {"key": "json12"})
-        yield self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"})
-        yield self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"})
+        yield defer.ensureDeferred(
+            self.store.set_e2e_device_keys("user1", "device1", now, {"key": "json11"})
+        )
+        yield defer.ensureDeferred(
+            self.store.set_e2e_device_keys("user1", "device2", now, {"key": "json12"})
+        )
+        yield defer.ensureDeferred(
+            self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"})
+        )
+        yield defer.ensureDeferred(
+            self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"})
+        )
 
-        res = yield self.store.get_e2e_device_keys(
-            (("user1", "device1"), ("user2", "device2"))
+        res = yield defer.ensureDeferred(
+            self.store.get_e2e_device_keys_for_cs_api(
+                (("user1", "device1"), ("user2", "device2"))
+            )
         )
         self.assertIn("user1", res)
         self.assertIn("device1", res["user1"])
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 3aeec0dc0f..d4c3b867e3 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -56,7 +56,9 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
             )
 
         for i in range(0, 20):
-            self.get_success(self.store.db.runInteraction("insert", insert_event, i))
+            self.get_success(
+                self.store.db_pool.runInteraction("insert", insert_event, i)
+            )
 
         # this should get the last ten
         r = self.get_success(self.store.get_prev_events_for_room(room_id))
@@ -81,13 +83,13 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
 
         for i in range(0, 20):
             self.get_success(
-                self.store.db.runInteraction("insert", insert_event, i, room1)
+                self.store.db_pool.runInteraction("insert", insert_event, i, room1)
             )
             self.get_success(
-                self.store.db.runInteraction("insert", insert_event, i, room2)
+                self.store.db_pool.runInteraction("insert", insert_event, i, room2)
             )
             self.get_success(
-                self.store.db.runInteraction("insert", insert_event, i, room3)
+                self.store.db_pool.runInteraction("insert", insert_event, i, room3)
             )
 
         # Test simple case
@@ -164,7 +166,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
 
             depth = depth_map[event_id]
 
-            self.store.db.simple_insert_txn(
+            self.store.db_pool.simple_insert_txn(
                 txn,
                 table="events",
                 values={
@@ -179,7 +181,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
                 },
             )
 
-            self.store.db.simple_insert_many_txn(
+            self.store.db_pool.simple_insert_many_txn(
                 txn,
                 table="event_auth",
                 values=[
@@ -192,7 +194,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
         for event_id in auth_graph:
             next_stream_ordering += 1
             self.get_success(
-                self.store.db.runInteraction(
+                self.store.db_pool.runInteraction(
                     "insert", insert_event, event_id, next_stream_ordering
                 )
             )
diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py
index a7b85004e5..949846fe33 100644
--- a/tests/storage/test_event_metrics.py
+++ b/tests/storage/test_event_metrics.py
@@ -27,7 +27,7 @@ class ExtremStatisticsTestCase(HomeserverTestCase):
         room_creator = self.hs.get_room_creation_handler()
 
         user = UserID("alice", "test")
-        requester = Requester(user, None, False, None, None)
+        requester = Requester(user, None, False, False, None, None)
 
         # Real events, forward extremities
         events = [(3, 2), (6, 2), (4, 6)]
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index b45bc9c115..c0595963dd 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -39,14 +39,18 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_get_unread_push_actions_for_user_in_range_for_http(self):
-        yield self.store.get_unread_push_actions_for_user_in_range_for_http(
-            USER_ID, 0, 1000, 20
+        yield defer.ensureDeferred(
+            self.store.get_unread_push_actions_for_user_in_range_for_http(
+                USER_ID, 0, 1000, 20
+            )
         )
 
     @defer.inlineCallbacks
     def test_get_unread_push_actions_for_user_in_range_for_email(self):
-        yield self.store.get_unread_push_actions_for_user_in_range_for_email(
-            USER_ID, 0, 1000, 20
+        yield defer.ensureDeferred(
+            self.store.get_unread_push_actions_for_user_in_range_for_email(
+                USER_ID, 0, 1000, 20
+            )
         )
 
     @defer.inlineCallbacks
@@ -56,12 +60,18 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
 
         @defer.inlineCallbacks
         def _assert_counts(noitf_count, highlight_count):
-            counts = yield self.store.db.runInteraction(
-                "", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0
+            counts = yield defer.ensureDeferred(
+                self.store.db_pool.runInteraction(
+                    "", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0
+                )
             )
             self.assertEquals(
                 counts,
-                {"notify_count": noitf_count, "highlight_count": highlight_count},
+                {
+                    "notify_count": noitf_count,
+                    "unread_count": 0,  # Unread counts are tested in the sync tests.
+                    "highlight_count": highlight_count,
+                },
             )
 
         @defer.inlineCallbacks
@@ -72,28 +82,36 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
             event.internal_metadata.stream_ordering = stream
             event.depth = stream
 
-            yield self.store.add_push_actions_to_staging(
-                event.event_id, {user_id: action}
+            yield defer.ensureDeferred(
+                self.store.add_push_actions_to_staging(
+                    event.event_id, {user_id: action}, False,
+                )
             )
-            yield self.store.db.runInteraction(
-                "",
-                self.persist_events_store._set_push_actions_for_event_and_users_txn,
-                [(event, None)],
-                [(event, None)],
+            yield defer.ensureDeferred(
+                self.store.db_pool.runInteraction(
+                    "",
+                    self.persist_events_store._set_push_actions_for_event_and_users_txn,
+                    [(event, None)],
+                    [(event, None)],
+                )
             )
 
         def _rotate(stream):
-            return self.store.db.runInteraction(
-                "", self.store._rotate_notifs_before_txn, stream
+            return defer.ensureDeferred(
+                self.store.db_pool.runInteraction(
+                    "", self.store._rotate_notifs_before_txn, stream
+                )
             )
 
         def _mark_read(stream, depth):
-            return self.store.db.runInteraction(
-                "",
-                self.store._remove_old_push_actions_before_txn,
-                room_id,
-                user_id,
-                stream,
+            return defer.ensureDeferred(
+                self.store.db_pool.runInteraction(
+                    "",
+                    self.store._remove_old_push_actions_before_txn,
+                    room_id,
+                    user_id,
+                    stream,
+                )
             )
 
         yield _assert_counts(0, 0)
@@ -117,8 +135,10 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
         yield _inject_actions(6, PlAIN_NOTIF)
         yield _rotate(7)
 
-        yield self.store.db.simple_delete(
-            table="event_push_actions", keyvalues={"1": 1}, desc=""
+        yield defer.ensureDeferred(
+            self.store.db_pool.simple_delete(
+                table="event_push_actions", keyvalues={"1": 1}, desc=""
+            )
         )
 
         yield _assert_counts(1, 0)
@@ -136,33 +156,43 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
     @defer.inlineCallbacks
     def test_find_first_stream_ordering_after_ts(self):
         def add_event(so, ts):
-            return self.store.db.simple_insert(
-                "events",
-                {
-                    "stream_ordering": so,
-                    "received_ts": ts,
-                    "event_id": "event%i" % so,
-                    "type": "",
-                    "room_id": "",
-                    "content": "",
-                    "processed": True,
-                    "outlier": False,
-                    "topological_ordering": 0,
-                    "depth": 0,
-                },
+            return defer.ensureDeferred(
+                self.store.db_pool.simple_insert(
+                    "events",
+                    {
+                        "stream_ordering": so,
+                        "received_ts": ts,
+                        "event_id": "event%i" % so,
+                        "type": "",
+                        "room_id": "",
+                        "content": "",
+                        "processed": True,
+                        "outlier": False,
+                        "topological_ordering": 0,
+                        "depth": 0,
+                    },
+                )
             )
 
         # start with the base case where there are no events in the table
-        r = yield self.store.find_first_stream_ordering_after_ts(11)
+        r = yield defer.ensureDeferred(
+            self.store.find_first_stream_ordering_after_ts(11)
+        )
         self.assertEqual(r, 0)
 
         # now with one event
         yield add_event(2, 10)
-        r = yield self.store.find_first_stream_ordering_after_ts(9)
+        r = yield defer.ensureDeferred(
+            self.store.find_first_stream_ordering_after_ts(9)
+        )
         self.assertEqual(r, 2)
-        r = yield self.store.find_first_stream_ordering_after_ts(10)
+        r = yield defer.ensureDeferred(
+            self.store.find_first_stream_ordering_after_ts(10)
+        )
         self.assertEqual(r, 2)
-        r = yield self.store.find_first_stream_ordering_after_ts(11)
+        r = yield defer.ensureDeferred(
+            self.store.find_first_stream_ordering_after_ts(11)
+        )
         self.assertEqual(r, 3)
 
         # add a bunch of dummy events to the events table
@@ -175,25 +205,37 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
         ):
             yield add_event(stream_ordering, ts)
 
-        r = yield self.store.find_first_stream_ordering_after_ts(110)
+        r = yield defer.ensureDeferred(
+            self.store.find_first_stream_ordering_after_ts(110)
+        )
         self.assertEqual(r, 3, "First event after 110ms should be 3, was %i" % r)
 
         # 4 and 5 are both after 120: we want 4 rather than 5
-        r = yield self.store.find_first_stream_ordering_after_ts(120)
+        r = yield defer.ensureDeferred(
+            self.store.find_first_stream_ordering_after_ts(120)
+        )
         self.assertEqual(r, 4, "First event after 120ms should be 4, was %i" % r)
 
-        r = yield self.store.find_first_stream_ordering_after_ts(129)
+        r = yield defer.ensureDeferred(
+            self.store.find_first_stream_ordering_after_ts(129)
+        )
         self.assertEqual(r, 10, "First event after 129ms should be 10, was %i" % r)
 
         # check we can get the last event
-        r = yield self.store.find_first_stream_ordering_after_ts(140)
+        r = yield defer.ensureDeferred(
+            self.store.find_first_stream_ordering_after_ts(140)
+        )
         self.assertEqual(r, 20, "First event after 14ms should be 20, was %i" % r)
 
         # off the end
-        r = yield self.store.find_first_stream_ordering_after_ts(160)
+        r = yield defer.ensureDeferred(
+            self.store.find_first_stream_ordering_after_ts(160)
+        )
         self.assertEqual(r, 21)
 
         # check we can find an event at ordering zero
         yield add_event(0, 5)
-        r = yield self.store.find_first_stream_ordering_after_ts(1)
+        r = yield defer.ensureDeferred(
+            self.store.find_first_stream_ordering_after_ts(1)
+        )
         self.assertEqual(r, 0)
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 55e9ecf264..f0a8e32f1e 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
 from synapse.storage.util.id_generators import MultiWriterIdGenerator
 
 from tests.unittest import HomeserverTestCase
@@ -27,9 +27,9 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
 
     def prepare(self, reactor, clock, hs):
         self.store = hs.get_datastore()
-        self.db = self.store.db  # type: Database
+        self.db_pool = self.store.db_pool  # type: DatabasePool
 
-        self.get_success(self.db.runInteraction("_setup_db", self._setup_db))
+        self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
 
     def _setup_db(self, txn):
         txn.execute("CREATE SEQUENCE foobar_seq")
@@ -47,7 +47,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         def _create(conn):
             return MultiWriterIdGenerator(
                 conn,
-                self.db,
+                self.db_pool,
                 instance_name=instance_name,
                 table="foobar",
                 instance_column="instance_name",
@@ -55,9 +55,13 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
                 sequence_name="foobar_seq",
             )
 
-        return self.get_success(self.db.runWithConnection(_create))
+        return self.get_success(self.db_pool.runWithConnection(_create))
 
     def _insert_rows(self, instance_name: str, number: int):
+        """Insert N rows as the given instance, inserting with stream IDs pulled
+        from the postgres sequence.
+        """
+
         def _insert(txn):
             for _ in range(number):
                 txn.execute(
@@ -65,7 +69,20 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
                     (instance_name,),
                 )
 
-        self.get_success(self.db.runInteraction("test_single_instance", _insert))
+        self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
+
+    def _insert_row_with_id(self, instance_name: str, stream_id: int):
+        """Insert one row as the given instance with given stream_id, updating
+        the postgres sequence position to match.
+        """
+
+        def _insert(txn):
+            txn.execute(
+                "INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,),
+            )
+            txn.execute("SELECT setval('foobar_seq', ?)", (stream_id,))
+
+        self.get_success(self.db_pool.runInteraction("_insert_row_with_id", _insert))
 
     def test_empty(self):
         """Test an ID generator against an empty database gives sensible
@@ -88,7 +105,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         id_gen = self._create_id_generator()
 
         self.assertEqual(id_gen.get_positions(), {"master": 7})
-        self.assertEqual(id_gen.get_current_token("master"), 7)
+        self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
 
         # Try allocating a new ID gen and check that we only see position
         # advanced after we leave the context manager.
@@ -98,12 +115,12 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
                 self.assertEqual(stream_id, 8)
 
                 self.assertEqual(id_gen.get_positions(), {"master": 7})
-                self.assertEqual(id_gen.get_current_token("master"), 7)
+                self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
 
         self.get_success(_get_next_async())
 
         self.assertEqual(id_gen.get_positions(), {"master": 8})
-        self.assertEqual(id_gen.get_current_token("master"), 8)
+        self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
 
     def test_multi_instance(self):
         """Test that reads and writes from multiple processes are handled
@@ -116,8 +133,8 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         second_id_gen = self._create_id_generator("second")
 
         self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
-        self.assertEqual(first_id_gen.get_current_token("first"), 3)
-        self.assertEqual(first_id_gen.get_current_token("second"), 7)
+        self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3)
+        self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
 
         # Try allocating a new ID gen and check that we only see position
         # advanced after we leave the context manager.
@@ -166,7 +183,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         id_gen = self._create_id_generator()
 
         self.assertEqual(id_gen.get_positions(), {"master": 7})
-        self.assertEqual(id_gen.get_current_token("master"), 7)
+        self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
 
         # Try allocating a new ID gen and check that we only see position
         # advanced after we leave the context manager.
@@ -176,9 +193,179 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
             self.assertEqual(stream_id, 8)
 
             self.assertEqual(id_gen.get_positions(), {"master": 7})
-            self.assertEqual(id_gen.get_current_token("master"), 7)
+            self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
 
-        self.get_success(self.db.runInteraction("test", _get_next_txn))
+        self.get_success(self.db_pool.runInteraction("test", _get_next_txn))
 
         self.assertEqual(id_gen.get_positions(), {"master": 8})
-        self.assertEqual(id_gen.get_current_token("master"), 8)
+        self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
+
+    def test_get_persisted_upto_position(self):
+        """Test that `get_persisted_upto_position` correctly tracks updates to
+        positions.
+        """
+
+        # The following tests are a bit cheeky in that we notify about new
+        # positions via `advance` without *actually* advancing the postgres
+        # sequence.
+
+        self._insert_row_with_id("first", 3)
+        self._insert_row_with_id("second", 5)
+
+        id_gen = self._create_id_generator("first")
+
+        self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
+
+        # Min is 3 and there is a gap between 5, so we expect it to be 3.
+        self.assertEqual(id_gen.get_persisted_upto_position(), 3)
+
+        # We advance "first" straight to 6. Min is now 5 but there is no gap so
+        # we expect it to be 6
+        id_gen.advance("first", 6)
+        self.assertEqual(id_gen.get_persisted_upto_position(), 6)
+
+        # No gap, so we expect 7.
+        id_gen.advance("second", 7)
+        self.assertEqual(id_gen.get_persisted_upto_position(), 7)
+
+        # We haven't seen 8 yet, so we expect 7 still.
+        id_gen.advance("second", 9)
+        self.assertEqual(id_gen.get_persisted_upto_position(), 7)
+
+        # Now that we've seen 7, 8 and 9 we can got straight to 9.
+        id_gen.advance("first", 8)
+        self.assertEqual(id_gen.get_persisted_upto_position(), 9)
+
+        # Jump forward with gaps. The minimum is 11, even though we haven't seen
+        # 10 we know that everything before 11 must be persisted.
+        id_gen.advance("first", 11)
+        id_gen.advance("second", 15)
+        self.assertEqual(id_gen.get_persisted_upto_position(), 11)
+
+    def test_get_persisted_upto_position_get_next(self):
+        """Test that `get_persisted_upto_position` correctly tracks updates to
+        positions when `get_next` is called.
+        """
+
+        self._insert_row_with_id("first", 3)
+        self._insert_row_with_id("second", 5)
+
+        id_gen = self._create_id_generator("first")
+
+        self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
+
+        self.assertEqual(id_gen.get_persisted_upto_position(), 3)
+        with self.get_success(id_gen.get_next()) as stream_id:
+            self.assertEqual(stream_id, 6)
+            self.assertEqual(id_gen.get_persisted_upto_position(), 3)
+
+        self.assertEqual(id_gen.get_persisted_upto_position(), 6)
+
+        # We assume that so long as `get_next` does correctly advance the
+        # `persisted_upto_position` in this case, then it will be correct in the
+        # other cases that are tested above (since they'll hit the same code).
+
+
+class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
+    """Tests MultiWriterIdGenerator that produce *negative* stream IDs.
+    """
+
+    if not USE_POSTGRES_FOR_TESTS:
+        skip = "Requires Postgres"
+
+    def prepare(self, reactor, clock, hs):
+        self.store = hs.get_datastore()
+        self.db_pool = self.store.db_pool  # type: DatabasePool
+
+        self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
+
+    def _setup_db(self, txn):
+        txn.execute("CREATE SEQUENCE foobar_seq")
+        txn.execute(
+            """
+            CREATE TABLE foobar (
+                stream_id BIGINT NOT NULL,
+                instance_name TEXT NOT NULL,
+                data TEXT
+            );
+            """
+        )
+
+    def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator:
+        def _create(conn):
+            return MultiWriterIdGenerator(
+                conn,
+                self.db_pool,
+                instance_name=instance_name,
+                table="foobar",
+                instance_column="instance_name",
+                id_column="stream_id",
+                sequence_name="foobar_seq",
+                positive=False,
+            )
+
+        return self.get_success(self.db_pool.runWithConnection(_create))
+
+    def _insert_row(self, instance_name: str, stream_id: int):
+        """Insert one row as the given instance with given stream_id.
+        """
+
+        def _insert(txn):
+            txn.execute(
+                "INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,),
+            )
+
+        self.get_success(self.db_pool.runInteraction("_insert_row", _insert))
+
+    def test_single_instance(self):
+        """Test that reads and writes from a single process are handled
+        correctly.
+        """
+        id_gen = self._create_id_generator()
+
+        with self.get_success(id_gen.get_next()) as stream_id:
+            self._insert_row("master", stream_id)
+
+        self.assertEqual(id_gen.get_positions(), {"master": -1})
+        self.assertEqual(id_gen.get_current_token_for_writer("master"), -1)
+        self.assertEqual(id_gen.get_persisted_upto_position(), -1)
+
+        with self.get_success(id_gen.get_next_mult(3)) as stream_ids:
+            for stream_id in stream_ids:
+                self._insert_row("master", stream_id)
+
+        self.assertEqual(id_gen.get_positions(), {"master": -4})
+        self.assertEqual(id_gen.get_current_token_for_writer("master"), -4)
+        self.assertEqual(id_gen.get_persisted_upto_position(), -4)
+
+        # Test loading from DB by creating a second ID gen
+        second_id_gen = self._create_id_generator()
+
+        self.assertEqual(second_id_gen.get_positions(), {"master": -4})
+        self.assertEqual(second_id_gen.get_current_token_for_writer("master"), -4)
+        self.assertEqual(second_id_gen.get_persisted_upto_position(), -4)
+
+    def test_multiple_instance(self):
+        """Tests that having multiple instances that get advanced over
+        federation works corretly.
+        """
+        id_gen_1 = self._create_id_generator("first")
+        id_gen_2 = self._create_id_generator("second")
+
+        with self.get_success(id_gen_1.get_next()) as stream_id:
+            self._insert_row("first", stream_id)
+            id_gen_2.advance("first", stream_id)
+
+        self.assertEqual(id_gen_1.get_positions(), {"first": -1})
+        self.assertEqual(id_gen_2.get_positions(), {"first": -1})
+        self.assertEqual(id_gen_1.get_persisted_upto_position(), -1)
+        self.assertEqual(id_gen_2.get_persisted_upto_position(), -1)
+
+        with self.get_success(id_gen_2.get_next()) as stream_id:
+            self._insert_row("second", stream_id)
+            id_gen_1.advance("second", stream_id)
+
+        self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -2})
+        self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2})
+        self.assertEqual(id_gen_1.get_persisted_upto_position(), -2)
+        self.assertEqual(id_gen_2.get_persisted_upto_position(), -2)
diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py
index ab0df5ea93..7e7f1286d9 100644
--- a/tests/storage/test_main.py
+++ b/tests/storage/test_main.py
@@ -34,12 +34,16 @@ class DataStoreTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_get_users_paginate(self):
-        yield self.store.register_user(self.user.to_string(), "pass")
-        yield self.store.create_profile(self.user.localpart)
-        yield self.store.set_profile_displayname(self.user.localpart, self.displayname)
+        yield defer.ensureDeferred(
+            self.store.register_user(self.user.to_string(), "pass")
+        )
+        yield defer.ensureDeferred(self.store.create_profile(self.user.localpart))
+        yield defer.ensureDeferred(
+            self.store.set_profile_displayname(self.user.localpart, self.displayname)
+        )
 
-        users, total = yield self.store.get_users_paginate(
-            0, 10, name="bc", guests=False
+        users, total = yield defer.ensureDeferred(
+            self.store.get_users_paginate(0, 10, name="bc", guests=False)
         )
 
         self.assertEquals(1, total)
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index 9c04e92577..9870c74883 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -19,6 +19,7 @@ from twisted.internet import defer
 from synapse.api.constants import UserTypes
 
 from tests import unittest
+from tests.test_utils import make_awaitable
 from tests.unittest import default_config, override_config
 
 FORTY_DAYS = 40 * 24 * 60 * 60
@@ -78,7 +79,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
         # XXX why are we doing this here? this function is only run at startup
         # so it is odd to re-run it here.
         self.get_success(
-            self.store.db.runInteraction(
+            self.store.db_pool.runInteraction(
                 "initialise", self.store._initialise_reserved_users, threepids
             )
         )
@@ -204,7 +205,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
                     self.store.user_add_threepid(user, "email", email, now, now)
                 )
 
-        d = self.store.db.runInteraction(
+        d = self.store.db_pool.runInteraction(
             "initialise", self.store._initialise_reserved_users, threepids
         )
         self.get_success(d)
@@ -230,7 +231,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
         )
         self.get_success(d)
 
-        self.store.upsert_monthly_active_user = Mock()
+        self.store.upsert_monthly_active_user = Mock(
+            side_effect=lambda user_id: make_awaitable(None)
+        )
 
         d = self.store.populate_monthly_active_users(user_id)
         self.get_success(d)
@@ -238,7 +241,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
         self.store.upsert_monthly_active_user.assert_not_called()
 
     def test_populate_monthly_users_should_update(self):
-        self.store.upsert_monthly_active_user = Mock()
+        self.store.upsert_monthly_active_user = Mock(
+            side_effect=lambda user_id: make_awaitable(None)
+        )
 
         self.store.is_trial_user = Mock(return_value=defer.succeed(False))
 
@@ -251,7 +256,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
         self.store.upsert_monthly_active_user.assert_called_once()
 
     def test_populate_monthly_users_should_not_update(self):
-        self.store.upsert_monthly_active_user = Mock()
+        self.store.upsert_monthly_active_user = Mock(
+            side_effect=lambda user_id: make_awaitable(None)
+        )
 
         self.store.is_trial_user = Mock(return_value=defer.succeed(False))
         self.store.user_last_seen_monthly_active = Mock(
@@ -280,7 +287,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
         ]
 
         self.hs.config.mau_limits_reserved_threepids = threepids
-        d = self.store.db.runInteraction(
+        d = self.store.db_pool.runInteraction(
             "initialise", self.store._initialise_reserved_users, threepids
         )
         self.get_success(d)
@@ -293,8 +300,12 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
         self.get_success(self.store.register_user(user_id=user2, password_hash=None))
 
         now = int(self.hs.get_clock().time_msec())
-        self.store.user_add_threepid(user1, "email", user1_email, now, now)
-        self.store.user_add_threepid(user2, "email", user2_email, now, now)
+        self.get_success(
+            self.store.user_add_threepid(user1, "email", user1_email, now, now)
+        )
+        self.get_success(
+            self.store.user_add_threepid(user2, "email", user2_email, now, now)
+        )
 
         users = self.get_success(self.store.get_registered_reserved_users())
         self.assertEqual(len(users), len(threepids))
@@ -333,7 +344,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
 
     @override_config({"limit_usage_by_mau": False, "mau_stats_only": False})
     def test_no_users_when_not_tracking(self):
-        self.store.upsert_monthly_active_user = Mock()
+        self.store.upsert_monthly_active_user = Mock(
+            side_effect=lambda user_id: make_awaitable(None)
+        )
 
         self.get_success(self.store.populate_monthly_active_users("@user:sever"))
 
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index 9b6f7211ae..3fd0a38cf5 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -33,23 +33,36 @@ class ProfileStoreTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_displayname(self):
-        yield self.store.create_profile(self.u_frank.localpart)
+        yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
 
-        yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
+        yield defer.ensureDeferred(
+            self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
+        )
 
         self.assertEquals(
-            "Frank", (yield self.store.get_profile_displayname(self.u_frank.localpart))
+            "Frank",
+            (
+                yield defer.ensureDeferred(
+                    self.store.get_profile_displayname(self.u_frank.localpart)
+                )
+            ),
         )
 
     @defer.inlineCallbacks
     def test_avatar_url(self):
-        yield self.store.create_profile(self.u_frank.localpart)
+        yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
 
-        yield self.store.set_profile_avatar_url(
-            self.u_frank.localpart, "http://my.site/here"
+        yield defer.ensureDeferred(
+            self.store.set_profile_avatar_url(
+                self.u_frank.localpart, "http://my.site/here"
+            )
         )
 
         self.assertEquals(
             "http://my.site/here",
-            (yield self.store.get_profile_avatar_url(self.u_frank.localpart)),
+            (
+                yield defer.ensureDeferred(
+                    self.store.get_profile_avatar_url(self.u_frank.localpart)
+                )
+            ),
         )
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index b9fafaa1a6..918387733b 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -13,6 +13,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from twisted.internet import defer
+
+from synapse.api.errors import NotFoundError
 from synapse.rest.client.v1 import room
 
 from tests.unittest import HomeserverTestCase
@@ -44,28 +47,19 @@ class PurgeTests(HomeserverTestCase):
         storage = self.hs.get_storage()
 
         # Get the topological token
-        event = store.get_topological_token_for_event(last["event_id"])
-        self.pump()
-        event = self.successResultOf(event)
+        event = self.get_success(
+            store.get_topological_token_for_event(last["event_id"])
+        )
 
         # Purge everything before this topological token
-        purge = storage.purge_events.purge_history(self.room_id, event, True)
-        self.pump()
-        self.assertEqual(self.successResultOf(purge), None)
-
-        # Try and get the events
-        get_first = store.get_event(first["event_id"])
-        get_second = store.get_event(second["event_id"])
-        get_third = store.get_event(third["event_id"])
-        get_last = store.get_event(last["event_id"])
-        self.pump()
+        self.get_success(storage.purge_events.purge_history(self.room_id, event, True))
 
         # 1-3 should fail and last will succeed, meaning that 1-3 are deleted
         # and last is not.
-        self.failureResultOf(get_first)
-        self.failureResultOf(get_second)
-        self.failureResultOf(get_third)
-        self.successResultOf(get_last)
+        self.get_failure(store.get_event(first["event_id"]), NotFoundError)
+        self.get_failure(store.get_event(second["event_id"]), NotFoundError)
+        self.get_failure(store.get_event(third["event_id"]), NotFoundError)
+        self.get_success(store.get_event(last["event_id"]))
 
     def test_purge_wont_delete_extrems(self):
         """
@@ -80,28 +74,21 @@ class PurgeTests(HomeserverTestCase):
         storage = self.hs.get_datastore()
 
         # Set the topological token higher than it should be
-        event = storage.get_topological_token_for_event(last["event_id"])
-        self.pump()
-        event = self.successResultOf(event)
+        event = self.get_success(
+            storage.get_topological_token_for_event(last["event_id"])
+        )
         event = "t{}-{}".format(
             *list(map(lambda x: x + 1, map(int, event[1:].split("-"))))
         )
 
         # Purge everything before this topological token
-        purge = storage.purge_history(self.room_id, event, True)
+        purge = defer.ensureDeferred(storage.purge_history(self.room_id, event, True))
         self.pump()
         f = self.failureResultOf(purge)
         self.assertIn("greater than forward", f.value.args[0])
 
         # Try and get the events
-        get_first = storage.get_event(first["event_id"])
-        get_second = storage.get_event(second["event_id"])
-        get_third = storage.get_event(third["event_id"])
-        get_last = storage.get_event(last["event_id"])
-        self.pump()
-
-        # Nothing is deleted.
-        self.successResultOf(get_first)
-        self.successResultOf(get_second)
-        self.successResultOf(get_third)
-        self.successResultOf(get_last)
+        self.get_success(storage.get_event(first["event_id"]))
+        self.get_success(storage.get_event(second["event_id"]))
+        self.get_success(storage.get_event(third["event_id"]))
+        self.get_success(storage.get_event(last["event_id"]))
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index db3667dc43..1ea35d60c1 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -237,7 +237,9 @@ class RedactionTestCase(unittest.HomeserverTestCase):
 
             @defer.inlineCallbacks
             def build(self, prev_event_ids):
-                built_event = yield self._base_builder.build(prev_event_ids)
+                built_event = yield defer.ensureDeferred(
+                    self._base_builder.build(prev_event_ids)
+                )
 
                 built_event._event_id = self._event_id
                 built_event._dict["event_id"] = self._event_id
@@ -249,6 +251,10 @@ class RedactionTestCase(unittest.HomeserverTestCase):
             def room_id(self):
                 return self._base_builder.room_id
 
+            @property
+            def type(self):
+                return self._base_builder.type
+
         event_1, context_1 = self.get_success(
             self.event_creation_handler.create_new_client_event(
                 EventIdManglingBuilder(
@@ -341,7 +347,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
         )
 
         event_json = self.get_success(
-            self.store.db.simple_select_one_onecol(
+            self.store.db_pool.simple_select_one_onecol(
                 table="event_json",
                 keyvalues={"event_id": msg_event.event_id},
                 retcol="json",
@@ -359,7 +365,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
         self.reactor.advance(60 * 60 * 2)
 
         event_json = self.get_success(
-            self.store.db.simple_select_one_onecol(
+            self.store.db_pool.simple_select_one_onecol(
                 table="event_json",
                 keyvalues={"event_id": msg_event.event_id},
                 retcol="json",
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 71a40a0a49..6b582771fe 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -17,6 +17,7 @@
 from twisted.internet import defer
 
 from synapse.api.constants import UserTypes
+from synapse.api.errors import ThreepidValidationError
 
 from tests import unittest
 from tests.utils import setup_test_homeserver
@@ -36,7 +37,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_register(self):
-        yield self.store.register_user(self.user_id, self.pwhash)
+        yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash))
 
         self.assertEquals(
             {
@@ -52,17 +53,21 @@ class RegistrationStoreTestCase(unittest.TestCase):
                 "user_type": None,
                 "deactivated": 0,
             },
-            (yield self.store.get_user_by_id(self.user_id)),
+            (yield defer.ensureDeferred(self.store.get_user_by_id(self.user_id))),
         )
 
     @defer.inlineCallbacks
     def test_add_tokens(self):
-        yield self.store.register_user(self.user_id, self.pwhash)
-        yield self.store.add_access_token_to_user(
-            self.user_id, self.tokens[1], self.device_id, valid_until_ms=None
+        yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash))
+        yield defer.ensureDeferred(
+            self.store.add_access_token_to_user(
+                self.user_id, self.tokens[1], self.device_id, valid_until_ms=None
+            )
         )
 
-        result = yield self.store.get_user_by_access_token(self.tokens[1])
+        result = yield defer.ensureDeferred(
+            self.store.get_user_by_access_token(self.tokens[1])
+        )
 
         self.assertDictContainsSubset(
             {"name": self.user_id, "device_id": self.device_id}, result
@@ -73,31 +78,41 @@ class RegistrationStoreTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_user_delete_access_tokens(self):
         # add some tokens
-        yield self.store.register_user(self.user_id, self.pwhash)
-        yield self.store.add_access_token_to_user(
-            self.user_id, self.tokens[0], device_id=None, valid_until_ms=None
+        yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash))
+        yield defer.ensureDeferred(
+            self.store.add_access_token_to_user(
+                self.user_id, self.tokens[0], device_id=None, valid_until_ms=None
+            )
         )
-        yield self.store.add_access_token_to_user(
-            self.user_id, self.tokens[1], self.device_id, valid_until_ms=None
+        yield defer.ensureDeferred(
+            self.store.add_access_token_to_user(
+                self.user_id, self.tokens[1], self.device_id, valid_until_ms=None
+            )
         )
 
         # now delete some
-        yield self.store.user_delete_access_tokens(
-            self.user_id, device_id=self.device_id
+        yield defer.ensureDeferred(
+            self.store.user_delete_access_tokens(self.user_id, device_id=self.device_id)
         )
 
         # check they were deleted
-        user = yield self.store.get_user_by_access_token(self.tokens[1])
+        user = yield defer.ensureDeferred(
+            self.store.get_user_by_access_token(self.tokens[1])
+        )
         self.assertIsNone(user, "access token was not deleted by device_id")
 
         # check the one not associated with the device was not deleted
-        user = yield self.store.get_user_by_access_token(self.tokens[0])
+        user = yield defer.ensureDeferred(
+            self.store.get_user_by_access_token(self.tokens[0])
+        )
         self.assertEqual(self.user_id, user["name"])
 
         # now delete the rest
-        yield self.store.user_delete_access_tokens(self.user_id)
+        yield defer.ensureDeferred(self.store.user_delete_access_tokens(self.user_id))
 
-        user = yield self.store.get_user_by_access_token(self.tokens[0])
+        user = yield defer.ensureDeferred(
+            self.store.get_user_by_access_token(self.tokens[0])
+        )
         self.assertIsNone(user, "access token was not deleted without device_id")
 
     @defer.inlineCallbacks
@@ -105,14 +120,48 @@ class RegistrationStoreTestCase(unittest.TestCase):
         TEST_USER = "@test:test"
         SUPPORT_USER = "@support:test"
 
-        res = yield self.store.is_support_user(None)
+        res = yield defer.ensureDeferred(self.store.is_support_user(None))
         self.assertFalse(res)
-        yield self.store.register_user(user_id=TEST_USER, password_hash=None)
-        res = yield self.store.is_support_user(TEST_USER)
+        yield defer.ensureDeferred(
+            self.store.register_user(user_id=TEST_USER, password_hash=None)
+        )
+        res = yield defer.ensureDeferred(self.store.is_support_user(TEST_USER))
         self.assertFalse(res)
 
-        yield self.store.register_user(
-            user_id=SUPPORT_USER, password_hash=None, user_type=UserTypes.SUPPORT
+        yield defer.ensureDeferred(
+            self.store.register_user(
+                user_id=SUPPORT_USER, password_hash=None, user_type=UserTypes.SUPPORT
+            )
         )
-        res = yield self.store.is_support_user(SUPPORT_USER)
+        res = yield defer.ensureDeferred(self.store.is_support_user(SUPPORT_USER))
         self.assertTrue(res)
+
+    @defer.inlineCallbacks
+    def test_3pid_inhibit_invalid_validation_session_error(self):
+        """Tests that enabling the configuration option to inhibit 3PID errors on
+        /requestToken also inhibits validation errors caused by an unknown session ID.
+        """
+
+        # Check that, with the config setting set to false (the default value), a
+        # validation error is caused by the unknown session ID.
+        try:
+            yield defer.ensureDeferred(
+                self.store.validate_threepid_session(
+                    "fake_sid", "fake_client_secret", "fake_token", 0,
+                )
+            )
+        except ThreepidValidationError as e:
+            self.assertEquals(e.msg, "Unknown session_id", e)
+
+        # Set the config setting to true.
+        self.store._ignore_unknown_session_error = True
+
+        # Check that now the validation error is caused by the token not matching.
+        try:
+            yield defer.ensureDeferred(
+                self.store.validate_threepid_session(
+                    "fake_sid", "fake_client_secret", "fake_token", 0,
+                )
+            )
+        except ThreepidValidationError as e:
+            self.assertEquals(e.msg, "Validation token not found or has expired", e)
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index 3b78d48896..bc8400f240 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -37,11 +37,13 @@ class RoomStoreTestCase(unittest.TestCase):
         self.alias = RoomAlias.from_string("#a-room-name:test")
         self.u_creator = UserID.from_string("@creator:test")
 
-        yield self.store.store_room(
-            self.room.to_string(),
-            room_creator_user_id=self.u_creator.to_string(),
-            is_public=True,
-            room_version=RoomVersions.V1,
+        yield defer.ensureDeferred(
+            self.store.store_room(
+                self.room.to_string(),
+                room_creator_user_id=self.u_creator.to_string(),
+                is_public=True,
+                room_version=RoomVersions.V1,
+            )
         )
 
     @defer.inlineCallbacks
@@ -52,7 +54,13 @@ class RoomStoreTestCase(unittest.TestCase):
                 "creator": self.u_creator.to_string(),
                 "is_public": True,
             },
-            (yield self.store.get_room(self.room.to_string())),
+            (yield defer.ensureDeferred(self.store.get_room(self.room.to_string()))),
+        )
+
+    @defer.inlineCallbacks
+    def test_get_room_unknown_room(self):
+        self.assertIsNone(
+            (yield defer.ensureDeferred(self.store.get_room("!uknown:test")))
         )
 
     @defer.inlineCallbacks
@@ -63,7 +71,21 @@ class RoomStoreTestCase(unittest.TestCase):
                 "creator": self.u_creator.to_string(),
                 "public": True,
             },
-            (yield self.store.get_room_with_stats(self.room.to_string())),
+            (
+                yield defer.ensureDeferred(
+                    self.store.get_room_with_stats(self.room.to_string())
+                )
+            ),
+        )
+
+    @defer.inlineCallbacks
+    def test_get_room_with_stats_unknown_room(self):
+        self.assertIsNone(
+            (
+                yield defer.ensureDeferred(
+                    self.store.get_room_with_stats("!uknown:test")
+                )
+            ),
         )
 
 
@@ -80,17 +102,21 @@ class RoomEventsStoreTestCase(unittest.TestCase):
 
         self.room = RoomID.from_string("!abcde:test")
 
-        yield self.store.store_room(
-            self.room.to_string(),
-            room_creator_user_id="@creator:text",
-            is_public=True,
-            room_version=RoomVersions.V1,
+        yield defer.ensureDeferred(
+            self.store.store_room(
+                self.room.to_string(),
+                room_creator_user_id="@creator:text",
+                is_public=True,
+                room_version=RoomVersions.V1,
+            )
         )
 
     @defer.inlineCallbacks
     def inject_room_event(self, **kwargs):
-        yield self.storage.persistence.persist_event(
-            self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
+        yield defer.ensureDeferred(
+            self.storage.persistence.persist_event(
+                self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
+            )
         )
 
     @defer.inlineCallbacks
@@ -101,7 +127,9 @@ class RoomEventsStoreTestCase(unittest.TestCase):
             etype=EventTypes.Name, name=name, content={"name": name}, depth=1
         )
 
-        state = yield self.store.get_current_state(room_id=self.room.to_string())
+        state = yield defer.ensureDeferred(
+            self.store.get_current_state(room_id=self.room.to_string())
+        )
 
         self.assertEquals(1, len(state))
         self.assertObjectHasAttributes(
@@ -117,7 +145,9 @@ class RoomEventsStoreTestCase(unittest.TestCase):
             etype=EventTypes.Topic, topic=topic, content={"topic": topic}, depth=1
         )
 
-        state = yield self.store.get_current_state(room_id=self.room.to_string())
+        state = yield defer.ensureDeferred(
+            self.store.get_current_state(room_id=self.room.to_string())
+        )
 
         self.assertEquals(1, len(state))
         self.assertObjectHasAttributes(
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 5dd46005e6..12ccc1f53e 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -87,7 +87,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
         self.inject_room_member(self.room, self.u_bob, Membership.JOIN)
         self.inject_room_member(self.room, self.u_charlie.to_string(), Membership.JOIN)
 
-        self.pump(20)
+        self.pump()
 
         self.assertTrue("_known_servers_count" not in self.store.__dict__.keys())
 
@@ -101,7 +101,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
         # Initialises to 1 -- itself
         self.assertEqual(self.store._known_servers_count, 1)
 
-        self.pump(20)
+        self.pump()
 
         # No rooms have been joined, so technically the SQL returns 0, but it
         # will still say it knows about itself.
@@ -111,25 +111,29 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
         self.inject_room_member(self.room, self.u_bob, Membership.JOIN)
         self.inject_room_member(self.room, self.u_charlie.to_string(), Membership.JOIN)
 
-        self.pump(20)
+        self.pump(1)
 
         # It now knows about Charlie's server.
         self.assertEqual(self.store._known_servers_count, 2)
 
     def test_get_joined_users_from_context(self):
         room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
-        bob_event = event_injection.inject_member_event(
-            self.hs, room, self.u_bob, Membership.JOIN
+        bob_event = self.get_success(
+            event_injection.inject_member_event(
+                self.hs, room, self.u_bob, Membership.JOIN
+            )
         )
 
         # first, create a regular event
-        event, context = event_injection.create_event(
-            self.hs,
-            room_id=room,
-            sender=self.u_alice,
-            prev_event_ids=[bob_event.event_id],
-            type="m.test.1",
-            content={},
+        event, context = self.get_success(
+            event_injection.create_event(
+                self.hs,
+                room_id=room,
+                sender=self.u_alice,
+                prev_event_ids=[bob_event.event_id],
+                type="m.test.1",
+                content={},
+            )
         )
 
         users = self.get_success(
@@ -140,22 +144,26 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
         # Regression test for #7376: create a state event whose key matches bob's
         # user_id, but which is *not* a membership event, and persist that; then check
         # that `get_joined_users_from_context` returns the correct users for the next event.
-        non_member_event = event_injection.inject_event(
-            self.hs,
-            room_id=room,
-            sender=self.u_bob,
-            prev_event_ids=[bob_event.event_id],
-            type="m.test.2",
-            state_key=self.u_bob,
-            content={},
+        non_member_event = self.get_success(
+            event_injection.inject_event(
+                self.hs,
+                room_id=room,
+                sender=self.u_bob,
+                prev_event_ids=[bob_event.event_id],
+                type="m.test.2",
+                state_key=self.u_bob,
+                content={},
+            )
         )
-        event, context = event_injection.create_event(
-            self.hs,
-            room_id=room,
-            sender=self.u_alice,
-            prev_event_ids=[non_member_event.event_id],
-            type="m.test.3",
-            content={},
+        event, context = self.get_success(
+            event_injection.create_event(
+                self.hs,
+                room_id=room,
+                sender=self.u_alice,
+                prev_event_ids=[non_member_event.event_id],
+                type="m.test.3",
+                content={},
+            )
         )
         users = self.get_success(
             self.store.get_joined_users_from_context(event, context)
@@ -171,20 +179,20 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
     def test_can_rerun_update(self):
         # First make sure we have completed all updates.
         while not self.get_success(
-            self.store.db.updates.has_completed_background_updates()
+            self.store.db_pool.updates.has_completed_background_updates()
         ):
             self.get_success(
-                self.store.db.updates.do_next_background_update(100), by=0.1
+                self.store.db_pool.updates.do_next_background_update(100), by=0.1
             )
 
         # Now let's create a room, which will insert a membership
         user = UserID("alice", "test")
-        requester = Requester(user, None, False, None, None)
+        requester = Requester(user, None, False, False, None, None)
         self.get_success(self.room_creator.create_room(requester, {}))
 
         # Register the background update to run again.
         self.get_success(
-            self.store.db.simple_insert(
+            self.store.db_pool.simple_insert(
                 table="background_updates",
                 values={
                     "update_name": "current_state_events_membership",
@@ -195,12 +203,12 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
         )
 
         # ... and tell the DataStore that it hasn't finished all updates yet
-        self.store.db.updates._all_done = False
+        self.store.db_pool.updates._all_done = False
 
         # Now let's actually drive the updates to completion
         while not self.get_success(
-            self.store.db.updates.has_completed_background_updates()
+            self.store.db_pool.updates.has_completed_background_updates()
         ):
             self.get_success(
-                self.store.db.updates.do_next_background_update(100), by=0.1
+                self.store.db_pool.updates.do_next_background_update(100), by=0.1
             )
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 0b88308ff4..8bd12fa847 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -44,11 +44,13 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         self.room = RoomID.from_string("!abc123:test")
 
-        yield self.store.store_room(
-            self.room.to_string(),
-            room_creator_user_id="@creator:text",
-            is_public=True,
-            room_version=RoomVersions.V1,
+        yield defer.ensureDeferred(
+            self.store.store_room(
+                self.room.to_string(),
+                room_creator_user_id="@creator:text",
+                is_public=True,
+                room_version=RoomVersions.V1,
+            )
         )
 
     @defer.inlineCallbacks
@@ -64,11 +66,13 @@ class StateStoreTestCase(tests.unittest.TestCase):
             },
         )
 
-        event, context = yield self.event_creation_handler.create_new_client_event(
-            builder
+        event, context = yield defer.ensureDeferred(
+            self.event_creation_handler.create_new_client_event(builder)
         )
 
-        yield self.storage.persistence.persist_event(event, context)
+        yield defer.ensureDeferred(
+            self.storage.persistence.persist_event(event, context)
+        )
 
         return event
 
@@ -87,8 +91,8 @@ class StateStoreTestCase(tests.unittest.TestCase):
             self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
         )
 
-        state_group_map = yield self.storage.state.get_state_groups_ids(
-            self.room, [e2.event_id]
+        state_group_map = yield defer.ensureDeferred(
+            self.storage.state.get_state_groups_ids(self.room, [e2.event_id])
         )
         self.assertEqual(len(state_group_map), 1)
         state_map = list(state_group_map.values())[0]
@@ -106,8 +110,8 @@ class StateStoreTestCase(tests.unittest.TestCase):
             self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
         )
 
-        state_group_map = yield self.storage.state.get_state_groups(
-            self.room, [e2.event_id]
+        state_group_map = yield defer.ensureDeferred(
+            self.storage.state.get_state_groups(self.room, [e2.event_id])
         )
         self.assertEqual(len(state_group_map), 1)
         state_list = list(state_group_map.values())[0]
@@ -148,7 +152,9 @@ class StateStoreTestCase(tests.unittest.TestCase):
         )
 
         # check we get the full state as of the final event
-        state = yield self.storage.state.get_state_for_event(e5.event_id)
+        state = yield defer.ensureDeferred(
+            self.storage.state.get_state_for_event(e5.event_id)
+        )
 
         self.assertIsNotNone(e4)
 
@@ -164,22 +170,28 @@ class StateStoreTestCase(tests.unittest.TestCase):
         )
 
         # check we can filter to the m.room.name event (with a '' state key)
-        state = yield self.storage.state.get_state_for_event(
-            e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
+        state = yield defer.ensureDeferred(
+            self.storage.state.get_state_for_event(
+                e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
+            )
         )
 
         self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
 
         # check we can filter to the m.room.name event (with a wildcard None state key)
-        state = yield self.storage.state.get_state_for_event(
-            e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
+        state = yield defer.ensureDeferred(
+            self.storage.state.get_state_for_event(
+                e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
+            )
         )
 
         self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
 
         # check we can grab the m.room.member events (with a wildcard None state key)
-        state = yield self.storage.state.get_state_for_event(
-            e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
+        state = yield defer.ensureDeferred(
+            self.storage.state.get_state_for_event(
+                e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
+            )
         )
 
         self.assertStateMapEqual(
@@ -188,12 +200,14 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         # check we can grab a specific room member without filtering out the
         # other event types
-        state = yield self.storage.state.get_state_for_event(
-            e5.event_id,
-            state_filter=StateFilter(
-                types={EventTypes.Member: {self.u_alice.to_string()}},
-                include_others=True,
-            ),
+        state = yield defer.ensureDeferred(
+            self.storage.state.get_state_for_event(
+                e5.event_id,
+                state_filter=StateFilter(
+                    types={EventTypes.Member: {self.u_alice.to_string()}},
+                    include_others=True,
+                ),
+            )
         )
 
         self.assertStateMapEqual(
@@ -206,11 +220,13 @@ class StateStoreTestCase(tests.unittest.TestCase):
         )
 
         # check that we can grab everything except members
-        state = yield self.storage.state.get_state_for_event(
-            e5.event_id,
-            state_filter=StateFilter(
-                types={EventTypes.Member: set()}, include_others=True
-            ),
+        state = yield defer.ensureDeferred(
+            self.storage.state.get_state_for_event(
+                e5.event_id,
+                state_filter=StateFilter(
+                    types={EventTypes.Member: set()}, include_others=True
+                ),
+            )
         )
 
         self.assertStateMapEqual(
@@ -222,8 +238,8 @@ class StateStoreTestCase(tests.unittest.TestCase):
         #######################################################
 
         room_id = self.room.to_string()
-        group_ids = yield self.storage.state.get_state_groups_ids(
-            room_id, [e5.event_id]
+        group_ids = yield defer.ensureDeferred(
+            self.storage.state.get_state_groups_ids(room_id, [e5.event_id])
         )
         group = list(group_ids.keys())[0]
 
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index 6a545d2eb0..738e912468 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -31,16 +31,24 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
 
         # alice and bob are both in !room_id. bobby is not but shares
         # a homeserver with alice.
-        yield self.store.update_profile_in_user_dir(ALICE, "alice", None)
-        yield self.store.update_profile_in_user_dir(BOB, "bob", None)
-        yield self.store.update_profile_in_user_dir(BOBBY, "bobby", None)
-        yield self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB))
+        yield defer.ensureDeferred(
+            self.store.update_profile_in_user_dir(ALICE, "alice", None)
+        )
+        yield defer.ensureDeferred(
+            self.store.update_profile_in_user_dir(BOB, "bob", None)
+        )
+        yield defer.ensureDeferred(
+            self.store.update_profile_in_user_dir(BOBBY, "bobby", None)
+        )
+        yield defer.ensureDeferred(
+            self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB))
+        )
 
     @defer.inlineCallbacks
     def test_search_user_dir(self):
         # normally when alice searches the directory she should just find
         # bob because bobby doesn't share a room with her.
-        r = yield self.store.search_user_dir(ALICE, "bob", 10)
+        r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "bob", 10))
         self.assertFalse(r["limited"])
         self.assertEqual(1, len(r["results"]))
         self.assertDictEqual(
@@ -51,7 +59,7 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
     def test_search_user_dir_all_users(self):
         self.hs.config.user_directory_search_all_users = True
         try:
-            r = yield self.store.search_user_dir(ALICE, "bob", 10)
+            r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "bob", 10))
             self.assertFalse(r["limited"])
             self.assertEqual(2, len(r["results"]))
             self.assertDictEqual(