summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/handlers/test_e2e_keys.py8
-rw-r--r--tests/handlers/test_stats.py80
-rw-r--r--tests/handlers/test_sync.py33
-rw-r--r--tests/handlers/test_typing.py24
-rw-r--r--tests/handlers/test_user_directory.py30
-rw-r--r--tests/replication/slave/storage/_base.py5
-rw-r--r--tests/rest/admin/test_admin.py2
-rw-r--r--tests/rest/client/v1/test_typing.py4
-rw-r--r--tests/storage/test__base.py16
-rw-r--r--tests/storage/test_appservice.py17
-rw-r--r--tests/storage/test_background_update.py28
-rw-r--r--tests/storage/test_base.py21
-rw-r--r--tests/storage/test_cleanup_extrems.py18
-rw-r--r--tests/storage/test_client_ips.py87
-rw-r--r--tests/storage/test_event_federation.py8
-rw-r--r--tests/storage/test_event_push_actions.py12
-rw-r--r--tests/storage/test_monthly_active_users.py6
-rw-r--r--tests/storage/test_profile.py3
-rw-r--r--tests/storage/test_redaction.py4
-rw-r--r--tests/storage/test_roommember.py20
-rw-r--r--tests/storage/test_user_directory.py4
-rw-r--r--tests/test_federation.py16
-rw-r--r--tests/unittest.py19
-rw-r--r--tests/util/test_logcontext.py24
-rw-r--r--tests/util/test_snapshot_cache.py63
25 files changed, 306 insertions, 246 deletions
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 854eb6c024..fdfa2cbbc4 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -183,6 +183,10 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         )
         self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]})
 
+    test_replace_master_key.skip = (
+        "Disabled waiting on #https://github.com/matrix-org/synapse/pull/6486"
+    )
+
     @defer.inlineCallbacks
     def test_reupload_signatures(self):
         """re-uploading a signature should not fail"""
@@ -503,3 +507,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
             ],
             other_master_key["signatures"][local_user]["ed25519:" + usersigning_pubkey],
         )
+
+    test_upload_signatures.skip = (
+        "Disabled waiting on #https://github.com/matrix-org/synapse/pull/6486"
+    )
diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py
index 380fd0d107..d9d312f0fb 100644
--- a/tests/handlers/test_stats.py
+++ b/tests/handlers/test_stats.py
@@ -42,16 +42,16 @@ class StatsRoomTests(unittest.HomeserverTestCase):
         Add the background updates we need to run.
         """
         # Ugh, have to reset this flag
-        self.store._all_done = False
+        self.store.db.updates._all_done = False
 
         self.get_success(
-            self.store.simple_insert(
+            self.store.db.simple_insert(
                 "background_updates",
                 {"update_name": "populate_stats_prepare", "progress_json": "{}"},
             )
         )
         self.get_success(
-            self.store.simple_insert(
+            self.store.db.simple_insert(
                 "background_updates",
                 {
                     "update_name": "populate_stats_process_rooms",
@@ -61,7 +61,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
             )
         )
         self.get_success(
-            self.store.simple_insert(
+            self.store.db.simple_insert(
                 "background_updates",
                 {
                     "update_name": "populate_stats_process_users",
@@ -71,7 +71,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
             )
         )
         self.get_success(
-            self.store.simple_insert(
+            self.store.db.simple_insert(
                 "background_updates",
                 {
                     "update_name": "populate_stats_cleanup",
@@ -82,7 +82,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
         )
 
     def get_all_room_state(self):
-        return self.store.simple_select_list(
+        return self.store.db.simple_select_list(
             "room_stats_state", None, retcols=("name", "topic", "canonical_alias")
         )
 
@@ -96,7 +96,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
         end_ts = self.store.quantise_stats_time(self.reactor.seconds() * 1000)
 
         return self.get_success(
-            self.store.simple_select_one(
+            self.store.db.simple_select_one(
                 table + "_historical",
                 {id_col: stat_id, end_ts: end_ts},
                 cols,
@@ -108,8 +108,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):
         # Do the initial population of the stats via the background update
         self._add_background_updates()
 
-        while not self.get_success(self.store.has_completed_background_updates()):
-            self.get_success(self.store.do_next_background_update(100), by=0.1)
+        while not self.get_success(
+            self.store.db.updates.has_completed_background_updates()
+        ):
+            self.get_success(
+                self.store.db.updates.do_next_background_update(100), by=0.1
+            )
 
     def test_initial_room(self):
         """
@@ -141,8 +145,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):
         # Do the initial population of the user directory via the background update
         self._add_background_updates()
 
-        while not self.get_success(self.store.has_completed_background_updates()):
-            self.get_success(self.store.do_next_background_update(100), by=0.1)
+        while not self.get_success(
+            self.store.db.updates.has_completed_background_updates()
+        ):
+            self.get_success(
+                self.store.db.updates.do_next_background_update(100), by=0.1
+            )
 
         r = self.get_success(self.get_all_room_state())
 
@@ -178,9 +186,9 @@ class StatsRoomTests(unittest.HomeserverTestCase):
         # the position that the deltas should begin at, once they take over.
         self.hs.config.stats_enabled = True
         self.handler.stats_enabled = True
-        self.store._all_done = False
+        self.store.db.updates._all_done = False
         self.get_success(
-            self.store.simple_update_one(
+            self.store.db.simple_update_one(
                 table="stats_incremental_position",
                 keyvalues={},
                 updatevalues={"stream_id": 0},
@@ -188,14 +196,18 @@ class StatsRoomTests(unittest.HomeserverTestCase):
         )
 
         self.get_success(
-            self.store.simple_insert(
+            self.store.db.simple_insert(
                 "background_updates",
                 {"update_name": "populate_stats_prepare", "progress_json": "{}"},
             )
         )
 
-        while not self.get_success(self.store.has_completed_background_updates()):
-            self.get_success(self.store.do_next_background_update(100), by=0.1)
+        while not self.get_success(
+            self.store.db.updates.has_completed_background_updates()
+        ):
+            self.get_success(
+                self.store.db.updates.do_next_background_update(100), by=0.1
+            )
 
         # Now, before the table is actually ingested, add some more events.
         self.helper.invite(room=room_1, src=u1, targ=u2, tok=u1_token)
@@ -205,13 +217,13 @@ class StatsRoomTests(unittest.HomeserverTestCase):
 
         # Now do the initial ingestion.
         self.get_success(
-            self.store.simple_insert(
+            self.store.db.simple_insert(
                 "background_updates",
                 {"update_name": "populate_stats_process_rooms", "progress_json": "{}"},
             )
         )
         self.get_success(
-            self.store.simple_insert(
+            self.store.db.simple_insert(
                 "background_updates",
                 {
                     "update_name": "populate_stats_cleanup",
@@ -221,9 +233,13 @@ class StatsRoomTests(unittest.HomeserverTestCase):
             )
         )
 
-        self.store._all_done = False
-        while not self.get_success(self.store.has_completed_background_updates()):
-            self.get_success(self.store.do_next_background_update(100), by=0.1)
+        self.store.db.updates._all_done = False
+        while not self.get_success(
+            self.store.db.updates.has_completed_background_updates()
+        ):
+            self.get_success(
+                self.store.db.updates.do_next_background_update(100), by=0.1
+            )
 
         self.reactor.advance(86401)
 
@@ -653,15 +669,15 @@ class StatsRoomTests(unittest.HomeserverTestCase):
 
         # preparation stage of the initial background update
         # Ugh, have to reset this flag
-        self.store._all_done = False
+        self.store.db.updates._all_done = False
 
         self.get_success(
-            self.store.simple_delete(
+            self.store.db.simple_delete(
                 "room_stats_current", {"1": 1}, "test_delete_stats"
             )
         )
         self.get_success(
-            self.store.simple_delete(
+            self.store.db.simple_delete(
                 "user_stats_current", {"1": 1}, "test_delete_stats"
             )
         )
@@ -673,9 +689,9 @@ class StatsRoomTests(unittest.HomeserverTestCase):
 
         # now do the background updates
 
-        self.store._all_done = False
+        self.store.db.updates._all_done = False
         self.get_success(
-            self.store.simple_insert(
+            self.store.db.simple_insert(
                 "background_updates",
                 {
                     "update_name": "populate_stats_process_rooms",
@@ -685,7 +701,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
             )
         )
         self.get_success(
-            self.store.simple_insert(
+            self.store.db.simple_insert(
                 "background_updates",
                 {
                     "update_name": "populate_stats_process_users",
@@ -695,7 +711,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
             )
         )
         self.get_success(
-            self.store.simple_insert(
+            self.store.db.simple_insert(
                 "background_updates",
                 {
                     "update_name": "populate_stats_cleanup",
@@ -705,8 +721,12 @@ class StatsRoomTests(unittest.HomeserverTestCase):
             )
         )
 
-        while not self.get_success(self.store.has_completed_background_updates()):
-            self.get_success(self.store.do_next_background_update(100), by=0.1)
+        while not self.get_success(
+            self.store.db.updates.has_completed_background_updates()
+        ):
+            self.get_success(
+                self.store.db.updates.do_next_background_update(100), by=0.1
+            )
 
         r1stats_complete = self._get_current_stats("room", r1)
         u1stats_complete = self._get_current_stats("user", u1)
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index 31f54bbd7d..758ee071a5 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -12,54 +12,53 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from twisted.internet import defer
 
 from synapse.api.errors import Codes, ResourceLimitError
 from synapse.api.filtering import DEFAULT_FILTER_COLLECTION
-from synapse.handlers.sync import SyncConfig, SyncHandler
+from synapse.handlers.sync import SyncConfig
 from synapse.types import UserID
 
 import tests.unittest
 import tests.utils
-from tests.utils import setup_test_homeserver
 
 
-class SyncTestCase(tests.unittest.TestCase):
+class SyncTestCase(tests.unittest.HomeserverTestCase):
     """ Tests Sync Handler. """
 
-    @defer.inlineCallbacks
-    def setUp(self):
-        self.hs = yield setup_test_homeserver(self.addCleanup)
-        self.sync_handler = SyncHandler(self.hs)
+    def prepare(self, reactor, clock, hs):
+        self.hs = hs
+        self.sync_handler = self.hs.get_sync_handler()
         self.store = self.hs.get_datastore()
 
-    @defer.inlineCallbacks
     def test_wait_for_sync_for_user_auth_blocking(self):
 
         user_id1 = "@user1:server"
         user_id2 = "@user2:server"
         sync_config = self._generate_sync_config(user_id1)
 
+        self.reactor.advance(100)  # So we get not 0 time
         self.hs.config.limit_usage_by_mau = True
         self.hs.config.max_mau_value = 1
 
         # Check that the happy case does not throw errors
-        yield self.store.upsert_monthly_active_user(user_id1)
-        yield self.sync_handler.wait_for_sync_for_user(sync_config)
+        self.get_success(self.store.upsert_monthly_active_user(user_id1))
+        self.get_success(self.sync_handler.wait_for_sync_for_user(sync_config))
 
         # Test that global lock works
         self.hs.config.hs_disabled = True
-        with self.assertRaises(ResourceLimitError) as e:
-            yield self.sync_handler.wait_for_sync_for_user(sync_config)
-        self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+        e = self.get_failure(
+            self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError
+        )
+        self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
 
         self.hs.config.hs_disabled = False
 
         sync_config = self._generate_sync_config(user_id2)
 
-        with self.assertRaises(ResourceLimitError) as e:
-            yield self.sync_handler.wait_for_sync_for_user(sync_config)
-        self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+        e = self.get_failure(
+            self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError
+        )
+        self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
 
     def _generate_sync_config(self, user_id):
         return SyncConfig(
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index f6d8660285..92b8726093 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -163,7 +163,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
         self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
 
         self.assertEquals(self.event_source.get_current_key(), 1)
-        events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+        events = self.get_success(
+            self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+        )
         self.assertEquals(
             events[0],
             [
@@ -227,7 +229,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
         self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
 
         self.assertEquals(self.event_source.get_current_key(), 1)
-        events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+        events = self.get_success(
+            self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+        )
         self.assertEquals(
             events[0],
             [
@@ -279,7 +283,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
         )
 
         self.assertEquals(self.event_source.get_current_key(), 1)
-        events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+        events = self.get_success(
+            self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+        )
         self.assertEquals(
             events[0],
             [{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}],
@@ -300,7 +306,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
         self.on_new_event.reset_mock()
 
         self.assertEquals(self.event_source.get_current_key(), 1)
-        events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+        events = self.get_success(
+            self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+        )
         self.assertEquals(
             events[0],
             [
@@ -317,7 +325,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
         self.on_new_event.assert_has_calls([call("typing_key", 2, rooms=[ROOM_ID])])
 
         self.assertEquals(self.event_source.get_current_key(), 2)
-        events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=1)
+        events = self.get_success(
+            self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=1)
+        )
         self.assertEquals(
             events[0],
             [{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}],
@@ -335,7 +345,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
         self.on_new_event.reset_mock()
 
         self.assertEquals(self.event_source.get_current_key(), 3)
-        events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+        events = self.get_success(
+            self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+        )
         self.assertEquals(
             events[0],
             [
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index d5b1c5b4ac..26071059d2 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -158,7 +158,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
 
     def get_users_in_public_rooms(self):
         r = self.get_success(
-            self.store.simple_select_list(
+            self.store.db.simple_select_list(
                 "users_in_public_rooms", None, ("user_id", "room_id")
             )
         )
@@ -169,7 +169,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
 
     def get_users_who_share_private_rooms(self):
         return self.get_success(
-            self.store.simple_select_list(
+            self.store.db.simple_select_list(
                 "users_who_share_private_rooms",
                 None,
                 ["user_id", "other_user_id", "room_id"],
@@ -181,10 +181,10 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         Add the background updates we need to run.
         """
         # Ugh, have to reset this flag
-        self.store._all_done = False
+        self.store.db.updates._all_done = False
 
         self.get_success(
-            self.store.simple_insert(
+            self.store.db.simple_insert(
                 "background_updates",
                 {
                     "update_name": "populate_user_directory_createtables",
@@ -193,7 +193,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
             )
         )
         self.get_success(
-            self.store.simple_insert(
+            self.store.db.simple_insert(
                 "background_updates",
                 {
                     "update_name": "populate_user_directory_process_rooms",
@@ -203,7 +203,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
             )
         )
         self.get_success(
-            self.store.simple_insert(
+            self.store.db.simple_insert(
                 "background_updates",
                 {
                     "update_name": "populate_user_directory_process_users",
@@ -213,7 +213,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
             )
         )
         self.get_success(
-            self.store.simple_insert(
+            self.store.db.simple_insert(
                 "background_updates",
                 {
                     "update_name": "populate_user_directory_cleanup",
@@ -255,8 +255,12 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         # Do the initial population of the user directory via the background update
         self._add_background_updates()
 
-        while not self.get_success(self.store.has_completed_background_updates()):
-            self.get_success(self.store.do_next_background_update(100), by=0.1)
+        while not self.get_success(
+            self.store.db.updates.has_completed_background_updates()
+        ):
+            self.get_success(
+                self.store.db.updates.do_next_background_update(100), by=0.1
+            )
 
         shares_private = self.get_users_who_share_private_rooms()
         public_users = self.get_users_in_public_rooms()
@@ -290,8 +294,12 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         # Do the initial population of the user directory via the background update
         self._add_background_updates()
 
-        while not self.get_success(self.store.has_completed_background_updates()):
-            self.get_success(self.store.do_next_background_update(100), by=0.1)
+        while not self.get_success(
+            self.store.db.updates.has_completed_background_updates()
+        ):
+            self.get_success(
+                self.store.db.updates.do_next_background_update(100), by=0.1
+            )
 
         shares_private = self.get_users_who_share_private_rooms()
         public_users = self.get_users_in_public_rooms()
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index e7472e3a93..3dae83c543 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -20,6 +20,7 @@ from synapse.replication.tcp.client import (
     ReplicationClientHandler,
 )
 from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
+from synapse.storage.database import Database
 
 from tests import unittest
 from tests.server import FakeTransport
@@ -42,7 +43,9 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
 
         self.master_store = self.hs.get_datastore()
         self.storage = hs.get_storage()
-        self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs)
+        self.slaved_store = self.STORE_TYPE(
+            Database(hs), self.hs.get_db_conn(), self.hs
+        )
         self.event_id = 0
 
         server_factory = ReplicationStreamProtocolFactory(self.hs)
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 124ce0768a..0ed2594381 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -632,7 +632,7 @@ class PurgeRoomTestCase(unittest.HomeserverTestCase):
             "state_groups_state",
         ):
             count = self.get_success(
-                self.store.simple_select_one_onecol(
+                self.store.db.simple_select_one_onecol(
                     table=table,
                     keyvalues={"room_id": room_id},
                     retcol="COUNT(*)",
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index 30fb77bac8..4bc3aaf02d 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -109,7 +109,9 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
         self.assertEquals(200, channel.code)
 
         self.assertEquals(self.event_source.get_current_key(), 1)
-        events = self.event_source.get_new_events(from_key=0, room_ids=[self.room_id])
+        events = self.get_success(
+            self.event_source.get_new_events(from_key=0, room_ids=[self.room_id])
+        )
         self.assertEquals(
             events[0],
             [
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 7b7434a468..d491ea2924 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -323,7 +323,7 @@ class UpsertManyTests(unittest.HomeserverTestCase):
 
         self.table_name = "table_" + hs.get_secrets().token_hex(6)
         self.get_success(
-            self.storage.runInteraction(
+            self.storage.db.runInteraction(
                 "create",
                 lambda x, *a: x.execute(*a),
                 "CREATE TABLE %s (id INTEGER, username TEXT, value TEXT)"
@@ -331,7 +331,7 @@ class UpsertManyTests(unittest.HomeserverTestCase):
             )
         )
         self.get_success(
-            self.storage.runInteraction(
+            self.storage.db.runInteraction(
                 "index",
                 lambda x, *a: x.execute(*a),
                 "CREATE UNIQUE INDEX %sindex ON %s(id, username)"
@@ -354,9 +354,9 @@ class UpsertManyTests(unittest.HomeserverTestCase):
         value_values = [["hello"], ["there"]]
 
         self.get_success(
-            self.storage.runInteraction(
+            self.storage.db.runInteraction(
                 "test",
-                self.storage.simple_upsert_many_txn,
+                self.storage.db.simple_upsert_many_txn,
                 self.table_name,
                 key_names,
                 key_values,
@@ -367,7 +367,7 @@ class UpsertManyTests(unittest.HomeserverTestCase):
 
         # Check results are what we expect
         res = self.get_success(
-            self.storage.simple_select_list(
+            self.storage.db.simple_select_list(
                 self.table_name, None, ["id, username, value"]
             )
         )
@@ -381,9 +381,9 @@ class UpsertManyTests(unittest.HomeserverTestCase):
         value_values = [["bleb"]]
 
         self.get_success(
-            self.storage.runInteraction(
+            self.storage.db.runInteraction(
                 "test",
-                self.storage.simple_upsert_many_txn,
+                self.storage.db.simple_upsert_many_txn,
                 self.table_name,
                 key_names,
                 key_values,
@@ -394,7 +394,7 @@ class UpsertManyTests(unittest.HomeserverTestCase):
 
         # Check results are what we expect
         res = self.get_success(
-            self.storage.simple_select_list(
+            self.storage.db.simple_select_list(
                 self.table_name, None, ["id, username, value"]
             )
         )
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index dfeea24599..2e521e9ab7 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -28,6 +28,7 @@ from synapse.storage.data_stores.main.appservice import (
     ApplicationServiceStore,
     ApplicationServiceTransactionStore,
 )
+from synapse.storage.database import Database
 
 from tests import unittest
 from tests.utils import setup_test_homeserver
@@ -54,7 +55,8 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
         self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob")
         self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob")
         # must be done after inserts
-        self.store = ApplicationServiceStore(hs.get_db_conn(), hs)
+        database = Database(hs)
+        self.store = ApplicationServiceStore(database, hs.get_db_conn(), hs)
 
     def tearDown(self):
         # TODO: suboptimal that we need to create files for tests!
@@ -123,7 +125,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
 
         self.as_yaml_files = []
 
-        self.store = TestTransactionStore(hs.get_db_conn(), hs)
+        database = Database(hs)
+        self.store = TestTransactionStore(database, hs.get_db_conn(), hs)
 
     def _add_service(self, url, as_token, id):
         as_yaml = dict(
@@ -382,8 +385,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
 
 # required for ApplicationServiceTransactionStoreTestCase tests
 class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore):
-    def __init__(self, db_conn, hs):
-        super(TestTransactionStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(TestTransactionStore, self).__init__(database, db_conn, hs)
 
 
 class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
@@ -416,7 +419,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
         hs.config.event_cache_size = 1
         hs.config.password_providers = []
 
-        ApplicationServiceStore(hs.get_db_conn(), hs)
+        ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs)
 
     @defer.inlineCallbacks
     def test_duplicate_ids(self):
@@ -432,7 +435,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
         hs.config.password_providers = []
 
         with self.assertRaises(ConfigError) as cm:
-            ApplicationServiceStore(hs.get_db_conn(), hs)
+            ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs)
 
         e = cm.exception
         self.assertIn(f1, str(e))
@@ -453,7 +456,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
         hs.config.password_providers = []
 
         with self.assertRaises(ConfigError) as cm:
-            ApplicationServiceStore(hs.get_db_conn(), hs)
+            ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs)
 
         e = cm.exception
         self.assertIn(f1, str(e))
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index 9fabe3fbc0..aec76f4ab1 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -15,7 +15,7 @@ class BackgroundUpdateTestCase(unittest.TestCase):
 
         self.update_handler = Mock()
 
-        yield self.store.register_background_update_handler(
+        yield self.store.db.updates.register_background_update_handler(
             "test_update", self.update_handler
         )
 
@@ -23,7 +23,7 @@ class BackgroundUpdateTestCase(unittest.TestCase):
         # (perhaps we should run them as part of the test HS setup, since we
         # run all of the other schema setup stuff there?)
         while True:
-            res = yield self.store.do_next_background_update(1000)
+            res = yield self.store.db.updates.do_next_background_update(1000)
             if res is None:
                 break
 
@@ -37,9 +37,9 @@ class BackgroundUpdateTestCase(unittest.TestCase):
         def update(progress, count):
             self.clock.advance_time_msec(count * duration_ms)
             progress = {"my_key": progress["my_key"] + 1}
-            yield self.store.runInteraction(
+            yield self.store.db.runInteraction(
                 "update_progress",
-                self.store._background_update_progress_txn,
+                self.store.db.updates._background_update_progress_txn,
                 "test_update",
                 progress,
             )
@@ -47,29 +47,37 @@ class BackgroundUpdateTestCase(unittest.TestCase):
 
         self.update_handler.side_effect = update
 
-        yield self.store.start_background_update("test_update", {"my_key": 1})
+        yield self.store.db.updates.start_background_update(
+            "test_update", {"my_key": 1}
+        )
 
         self.update_handler.reset_mock()
-        result = yield self.store.do_next_background_update(duration_ms * desired_count)
+        result = yield self.store.db.updates.do_next_background_update(
+            duration_ms * desired_count
+        )
         self.assertIsNotNone(result)
         self.update_handler.assert_called_once_with(
-            {"my_key": 1}, self.store.DEFAULT_BACKGROUND_BATCH_SIZE
+            {"my_key": 1}, self.store.db.updates.DEFAULT_BACKGROUND_BATCH_SIZE
         )
 
         # second step: complete the update
         @defer.inlineCallbacks
         def update(progress, count):
-            yield self.store._end_background_update("test_update")
+            yield self.store.db.updates._end_background_update("test_update")
             return count
 
         self.update_handler.side_effect = update
         self.update_handler.reset_mock()
-        result = yield self.store.do_next_background_update(duration_ms * desired_count)
+        result = yield self.store.db.updates.do_next_background_update(
+            duration_ms * desired_count
+        )
         self.assertIsNotNone(result)
         self.update_handler.assert_called_once_with({"my_key": 2}, desired_count)
 
         # third step: we don't expect to be called any more
         self.update_handler.reset_mock()
-        result = yield self.store.do_next_background_update(duration_ms * desired_count)
+        result = yield self.store.db.updates.do_next_background_update(
+            duration_ms * desired_count
+        )
         self.assertIsNone(result)
         self.assertFalse(self.update_handler.called)
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index de5e4a5fce..537cfe9f64 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -21,6 +21,7 @@ from mock import Mock
 from twisted.internet import defer
 
 from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
 from synapse.storage.engines import create_engine
 
 from tests import unittest
@@ -59,13 +60,13 @@ class SQLBaseStoreTestCase(unittest.TestCase):
             "test", db_pool=self.db_pool, config=config, database_engine=fake_engine
         )
 
-        self.datastore = SQLBaseStore(None, hs)
+        self.datastore = SQLBaseStore(Database(hs), None, hs)
 
     @defer.inlineCallbacks
     def test_insert_1col(self):
         self.mock_txn.rowcount = 1
 
-        yield self.datastore.simple_insert(
+        yield self.datastore.db.simple_insert(
             table="tablename", values={"columname": "Value"}
         )
 
@@ -77,7 +78,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
     def test_insert_3cols(self):
         self.mock_txn.rowcount = 1
 
-        yield self.datastore.simple_insert(
+        yield self.datastore.db.simple_insert(
             table="tablename",
             # Use OrderedDict() so we can assert on the SQL generated
             values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]),
@@ -92,7 +93,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
         self.mock_txn.rowcount = 1
         self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)]))
 
-        value = yield self.datastore.simple_select_one_onecol(
+        value = yield self.datastore.db.simple_select_one_onecol(
             table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol"
         )
 
@@ -106,7 +107,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
         self.mock_txn.rowcount = 1
         self.mock_txn.fetchone.return_value = (1, 2, 3)
 
-        ret = yield self.datastore.simple_select_one(
+        ret = yield self.datastore.db.simple_select_one(
             table="tablename",
             keyvalues={"keycol": "TheKey"},
             retcols=["colA", "colB", "colC"],
@@ -122,7 +123,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
         self.mock_txn.rowcount = 0
         self.mock_txn.fetchone.return_value = None
 
-        ret = yield self.datastore.simple_select_one(
+        ret = yield self.datastore.db.simple_select_one(
             table="tablename",
             keyvalues={"keycol": "Not here"},
             retcols=["colA"],
@@ -137,7 +138,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
         self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)]))
         self.mock_txn.description = (("colA", None, None, None, None, None, None),)
 
-        ret = yield self.datastore.simple_select_list(
+        ret = yield self.datastore.db.simple_select_list(
             table="tablename", keyvalues={"keycol": "A set"}, retcols=["colA"]
         )
 
@@ -150,7 +151,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
     def test_update_one_1col(self):
         self.mock_txn.rowcount = 1
 
-        yield self.datastore.simple_update_one(
+        yield self.datastore.db.simple_update_one(
             table="tablename",
             keyvalues={"keycol": "TheKey"},
             updatevalues={"columnname": "New Value"},
@@ -165,7 +166,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
     def test_update_one_4cols(self):
         self.mock_txn.rowcount = 1
 
-        yield self.datastore.simple_update_one(
+        yield self.datastore.db.simple_update_one(
             table="tablename",
             keyvalues=OrderedDict([("colA", 1), ("colB", 2)]),
             updatevalues=OrderedDict([("colC", 3), ("colD", 4)]),
@@ -180,7 +181,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
     def test_delete_one(self):
         self.mock_txn.rowcount = 1
 
-        yield self.datastore.simple_delete_one(
+        yield self.datastore.db.simple_delete_one(
             table="tablename", keyvalues={"keycol": "Go away"}
         )
 
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index 69dcaa63d5..029ac26454 100644
--- a/tests/storage/test_cleanup_extrems.py
+++ b/tests/storage/test_cleanup_extrems.py
@@ -46,7 +46,9 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
         """Re run the background update to clean up the extremities.
         """
         # Make sure we don't clash with in progress updates.
-        self.assertTrue(self.store._all_done, "Background updates are still ongoing")
+        self.assertTrue(
+            self.store.db.updates._all_done, "Background updates are still ongoing"
+        )
 
         schema_path = os.path.join(
             prepare_database.dir_path,
@@ -62,14 +64,20 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
             prepare_database.executescript(txn, schema_path)
 
         self.get_success(
-            self.store.runInteraction("test_delete_forward_extremities", run_delta_file)
+            self.store.db.runInteraction(
+                "test_delete_forward_extremities", run_delta_file
+            )
         )
 
         # Ugh, have to reset this flag
-        self.store._all_done = False
+        self.store.db.updates._all_done = False
 
-        while not self.get_success(self.store.has_completed_background_updates()):
-            self.get_success(self.store.do_next_background_update(100), by=0.1)
+        while not self.get_success(
+            self.store.db.updates.has_completed_background_updates()
+        ):
+            self.get_success(
+                self.store.db.updates.do_next_background_update(100), by=0.1
+            )
 
     def test_soft_failed_extremities_handled_correctly(self):
         """Test that extremities are correctly calculated in the presence of
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 25bdd2c163..bf674dd184 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -37,9 +37,13 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         self.reactor.advance(12345678)
 
         user_id = "@user:id"
+        device_id = "MY_DEVICE"
+
+        # Insert a user IP
+        self.get_success(self.store.store_device(user_id, device_id, "display name",))
         self.get_success(
             self.store.insert_client_ip(
-                user_id, "access_token", "ip", "user_agent", "device_id"
+                user_id, "access_token", "ip", "user_agent", device_id
             )
         )
 
@@ -47,14 +51,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         self.reactor.advance(10)
 
         result = self.get_success(
-            self.store.get_last_client_ip_by_device(user_id, "device_id")
+            self.store.get_last_client_ip_by_device(user_id, device_id)
         )
 
-        r = result[(user_id, "device_id")]
+        r = result[(user_id, device_id)]
         self.assertDictContainsSubset(
             {
                 "user_id": user_id,
-                "device_id": "device_id",
+                "device_id": device_id,
                 "ip": "ip",
                 "user_agent": "user_agent",
                 "last_seen": 12345678000,
@@ -81,7 +85,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         self.pump(0)
 
         result = self.get_success(
-            self.store.simple_select_list(
+            self.store.db.simple_select_list(
                 table="user_ips",
                 keyvalues={"user_id": user_id},
                 retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
@@ -112,7 +116,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         self.pump(0)
 
         result = self.get_success(
-            self.store.simple_select_list(
+            self.store.db.simple_select_list(
                 table="user_ips",
                 keyvalues={"user_id": user_id},
                 retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
@@ -202,25 +206,31 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
 
     def test_devices_last_seen_bg_update(self):
         # First make sure we have completed all updates.
-        while not self.get_success(self.store.has_completed_background_updates()):
-            self.get_success(self.store.do_next_background_update(100), by=0.1)
+        while not self.get_success(
+            self.store.db.updates.has_completed_background_updates()
+        ):
+            self.get_success(
+                self.store.db.updates.do_next_background_update(100), by=0.1
+            )
 
-        # Insert a user IP
         user_id = "@user:id"
+        device_id = "MY_DEVICE"
+
+        # Insert a user IP
+        self.get_success(self.store.store_device(user_id, device_id, "display name",))
         self.get_success(
             self.store.insert_client_ip(
-                user_id, "access_token", "ip", "user_agent", "device_id"
+                user_id, "access_token", "ip", "user_agent", device_id
             )
         )
-
         # Force persisting to disk
         self.reactor.advance(200)
 
         # But clear the associated entry in devices table
         self.get_success(
-            self.store.simple_update(
+            self.store.db.simple_update(
                 table="devices",
-                keyvalues={"user_id": user_id, "device_id": "device_id"},
+                keyvalues={"user_id": user_id, "device_id": device_id},
                 updatevalues={"last_seen": None, "ip": None, "user_agent": None},
                 desc="test_devices_last_seen_bg_update",
             )
@@ -228,14 +238,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
 
         # We should now get nulls when querying
         result = self.get_success(
-            self.store.get_last_client_ip_by_device(user_id, "device_id")
+            self.store.get_last_client_ip_by_device(user_id, device_id)
         )
 
-        r = result[(user_id, "device_id")]
+        r = result[(user_id, device_id)]
         self.assertDictContainsSubset(
             {
                 "user_id": user_id,
-                "device_id": "device_id",
+                "device_id": device_id,
                 "ip": None,
                 "user_agent": None,
                 "last_seen": None,
@@ -245,7 +255,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
 
         # Register the background update to run again.
         self.get_success(
-            self.store.simple_insert(
+            self.store.db.simple_insert(
                 table="background_updates",
                 values={
                     "update_name": "devices_last_seen",
@@ -256,22 +266,26 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         )
 
         # ... and tell the DataStore that it hasn't finished all updates yet
-        self.store._all_done = False
+        self.store.db.updates._all_done = False
 
         # Now let's actually drive the updates to completion
-        while not self.get_success(self.store.has_completed_background_updates()):
-            self.get_success(self.store.do_next_background_update(100), by=0.1)
+        while not self.get_success(
+            self.store.db.updates.has_completed_background_updates()
+        ):
+            self.get_success(
+                self.store.db.updates.do_next_background_update(100), by=0.1
+            )
 
         # We should now get the correct result again
         result = self.get_success(
-            self.store.get_last_client_ip_by_device(user_id, "device_id")
+            self.store.get_last_client_ip_by_device(user_id, device_id)
         )
 
-        r = result[(user_id, "device_id")]
+        r = result[(user_id, device_id)]
         self.assertDictContainsSubset(
             {
                 "user_id": user_id,
-                "device_id": "device_id",
+                "device_id": device_id,
                 "ip": "ip",
                 "user_agent": "user_agent",
                 "last_seen": 0,
@@ -281,14 +295,21 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
 
     def test_old_user_ips_pruned(self):
         # First make sure we have completed all updates.
-        while not self.get_success(self.store.has_completed_background_updates()):
-            self.get_success(self.store.do_next_background_update(100), by=0.1)
+        while not self.get_success(
+            self.store.db.updates.has_completed_background_updates()
+        ):
+            self.get_success(
+                self.store.db.updates.do_next_background_update(100), by=0.1
+            )
 
-        # Insert a user IP
         user_id = "@user:id"
+        device_id = "MY_DEVICE"
+
+        # Insert a user IP
+        self.get_success(self.store.store_device(user_id, device_id, "display name",))
         self.get_success(
             self.store.insert_client_ip(
-                user_id, "access_token", "ip", "user_agent", "device_id"
+                user_id, "access_token", "ip", "user_agent", device_id
             )
         )
 
@@ -297,7 +318,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
 
         # We should see that in the DB
         result = self.get_success(
-            self.store.simple_select_list(
+            self.store.db.simple_select_list(
                 table="user_ips",
                 keyvalues={"user_id": user_id},
                 retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
@@ -312,7 +333,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
                     "access_token": "access_token",
                     "ip": "ip",
                     "user_agent": "user_agent",
-                    "device_id": "device_id",
+                    "device_id": device_id,
                     "last_seen": 0,
                 }
             ],
@@ -323,7 +344,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
 
         # We should get no results.
         result = self.get_success(
-            self.store.simple_select_list(
+            self.store.db.simple_select_list(
                 table="user_ips",
                 keyvalues={"user_id": user_id},
                 retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
@@ -335,14 +356,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
 
         # But we should still get the correct values for the device
         result = self.get_success(
-            self.store.get_last_client_ip_by_device(user_id, "device_id")
+            self.store.get_last_client_ip_by_device(user_id, device_id)
         )
 
-        r = result[(user_id, "device_id")]
+        r = result[(user_id, device_id)]
         self.assertDictContainsSubset(
             {
                 "user_id": user_id,
-                "device_id": "device_id",
+                "device_id": device_id,
                 "ip": "ip",
                 "user_agent": "user_agent",
                 "last_seen": 0,
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 2fe50377f8..eadfb90a22 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -61,7 +61,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase):
             )
 
         for i in range(0, 11):
-            yield self.store.runInteraction("insert", insert_event, i)
+            yield self.store.db.runInteraction("insert", insert_event, i)
 
         # this should get the last five and five others
         r = yield self.store.get_prev_events_for_room(room_id)
@@ -93,9 +93,9 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase):
             )
 
         for i in range(0, 20):
-            yield self.store.runInteraction("insert", insert_event, i, room1)
-            yield self.store.runInteraction("insert", insert_event, i, room2)
-            yield self.store.runInteraction("insert", insert_event, i, room3)
+            yield self.store.db.runInteraction("insert", insert_event, i, room1)
+            yield self.store.db.runInteraction("insert", insert_event, i, room2)
+            yield self.store.db.runInteraction("insert", insert_event, i, room3)
 
         # Test simple case
         r = yield self.store.get_rooms_with_many_extremities(5, 5, [])
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index 2337a1ae46..d4bcf1821e 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -55,7 +55,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
 
         @defer.inlineCallbacks
         def _assert_counts(noitf_count, highlight_count):
-            counts = yield self.store.runInteraction(
+            counts = yield self.store.db.runInteraction(
                 "", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0
             )
             self.assertEquals(
@@ -74,7 +74,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
             yield self.store.add_push_actions_to_staging(
                 event.event_id, {user_id: action}
             )
-            yield self.store.runInteraction(
+            yield self.store.db.runInteraction(
                 "",
                 self.store._set_push_actions_for_event_and_users_txn,
                 [(event, None)],
@@ -82,12 +82,12 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
             )
 
         def _rotate(stream):
-            return self.store.runInteraction(
+            return self.store.db.runInteraction(
                 "", self.store._rotate_notifs_before_txn, stream
             )
 
         def _mark_read(stream, depth):
-            return self.store.runInteraction(
+            return self.store.db.runInteraction(
                 "",
                 self.store._remove_old_push_actions_before_txn,
                 room_id,
@@ -116,7 +116,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
         yield _inject_actions(6, PlAIN_NOTIF)
         yield _rotate(7)
 
-        yield self.store.simple_delete(
+        yield self.store.db.simple_delete(
             table="event_push_actions", keyvalues={"1": 1}, desc=""
         )
 
@@ -135,7 +135,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
     @defer.inlineCallbacks
     def test_find_first_stream_ordering_after_ts(self):
         def add_event(so, ts):
-            return self.store.simple_insert(
+            return self.store.db.simple_insert(
                 "events",
                 {
                     "stream_ordering": so,
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index 90a63dc477..3c78faab45 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -65,7 +65,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
         self.store.user_add_threepid(user1, "email", user1_email, now, now)
         self.store.user_add_threepid(user2, "email", user2_email, now, now)
 
-        self.store.runInteraction(
+        self.store.db.runInteraction(
             "initialise", self.store._initialise_reserved_users, threepids
         )
         self.pump()
@@ -183,7 +183,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
                 )
 
         self.hs.config.mau_limits_reserved_threepids = threepids
-        self.store.runInteraction(
+        self.store.db.runInteraction(
             "initialise", self.store._initialise_reserved_users, threepids
         )
         count = self.store.get_monthly_active_count()
@@ -244,7 +244,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
             {"medium": "email", "address": user2_email},
         ]
         self.hs.config.mau_limits_reserved_threepids = threepids
-        self.store.runInteraction(
+        self.store.db.runInteraction(
             "initialise", self.store._initialise_reserved_users, threepids
         )
 
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index 24c7fe16c3..9b6f7211ae 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -16,7 +16,6 @@
 
 from twisted.internet import defer
 
-from synapse.storage.data_stores.main.profile import ProfileStore
 from synapse.types import UserID
 
 from tests import unittest
@@ -28,7 +27,7 @@ class ProfileStoreTestCase(unittest.TestCase):
     def setUp(self):
         hs = yield setup_test_homeserver(self.addCleanup)
 
-        self.store = ProfileStore(hs.get_db_conn(), hs)
+        self.store = hs.get_datastore()
 
         self.u_frank = UserID.from_string("@frank:test")
 
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index 4930b6777e..dc45173355 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -338,7 +338,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
         )
 
         event_json = self.get_success(
-            self.store.simple_select_one_onecol(
+            self.store.db.simple_select_one_onecol(
                 table="event_json",
                 keyvalues={"event_id": msg_event.event_id},
                 retcol="json",
@@ -356,7 +356,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
         self.reactor.advance(60 * 60 * 2)
 
         event_json = self.get_success(
-            self.store.simple_select_one_onecol(
+            self.store.db.simple_select_one_onecol(
                 table="event_json",
                 keyvalues={"event_id": msg_event.event_id},
                 retcol="json",
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index d389cf578f..7840f63fe3 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -122,8 +122,12 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
 
     def test_can_rerun_update(self):
         # First make sure we have completed all updates.
-        while not self.get_success(self.store.has_completed_background_updates()):
-            self.get_success(self.store.do_next_background_update(100), by=0.1)
+        while not self.get_success(
+            self.store.db.updates.has_completed_background_updates()
+        ):
+            self.get_success(
+                self.store.db.updates.do_next_background_update(100), by=0.1
+            )
 
         # Now let's create a room, which will insert a membership
         user = UserID("alice", "test")
@@ -132,7 +136,7 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
 
         # Register the background update to run again.
         self.get_success(
-            self.store.simple_insert(
+            self.store.db.simple_insert(
                 table="background_updates",
                 values={
                     "update_name": "current_state_events_membership",
@@ -143,8 +147,12 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
         )
 
         # ... and tell the DataStore that it hasn't finished all updates yet
-        self.store._all_done = False
+        self.store.db.updates._all_done = False
 
         # Now let's actually drive the updates to completion
-        while not self.get_success(self.store.has_completed_background_updates()):
-            self.get_success(self.store.do_next_background_update(100), by=0.1)
+        while not self.get_success(
+            self.store.db.updates.has_completed_background_updates()
+        ):
+            self.get_success(
+                self.store.db.updates.do_next_background_update(100), by=0.1
+            )
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index 7eea57c0e2..6a545d2eb0 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -15,8 +15,6 @@
 
 from twisted.internet import defer
 
-from synapse.storage.data_stores.main.user_directory import UserDirectoryStore
-
 from tests import unittest
 from tests.utils import setup_test_homeserver
 
@@ -29,7 +27,7 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def setUp(self):
         self.hs = yield setup_test_homeserver(self.addCleanup)
-        self.store = UserDirectoryStore(self.hs.get_db_conn(), self.hs)
+        self.store = self.hs.get_datastore()
 
         # alice and bob are both in !room_id. bobby is not but shares
         # a homeserver with alice.
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 7d82b58466..ad165d7295 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -33,6 +33,8 @@ class MessageAcceptTests(unittest.TestCase):
         self.reactor.advance(0.1)
         self.room_id = self.successResultOf(room)["room_id"]
 
+        self.store = self.homeserver.get_datastore()
+
         # Figure out what the most recent event is
         most_recent = self.successResultOf(
             maybeDeferred(
@@ -77,10 +79,7 @@ class MessageAcceptTests(unittest.TestCase):
         # Make sure we actually joined the room
         self.assertEqual(
             self.successResultOf(
-                maybeDeferred(
-                    self.homeserver.get_datastore().get_latest_event_ids_in_room,
-                    self.room_id,
-                )
+                maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id)
             )[0],
             "$join:test.serv",
         )
@@ -100,10 +99,7 @@ class MessageAcceptTests(unittest.TestCase):
 
         # Figure out what the most recent event is
         most_recent = self.successResultOf(
-            maybeDeferred(
-                self.homeserver.get_datastore().get_latest_event_ids_in_room,
-                self.room_id,
-            )
+            maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id)
         )[0]
 
         # Now lie about an event
@@ -141,7 +137,5 @@ class MessageAcceptTests(unittest.TestCase):
         )
 
         # Make sure the invalid event isn't there
-        extrem = maybeDeferred(
-            self.homeserver.get_datastore().get_latest_event_ids_in_room, self.room_id
-        )
+        extrem = maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id)
         self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv")
diff --git a/tests/unittest.py b/tests/unittest.py
index 295573bc46..b30b7d1718 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -18,6 +18,7 @@
 import gc
 import hashlib
 import hmac
+import inspect
 import logging
 import time
 
@@ -25,7 +26,7 @@ from mock import Mock
 
 from canonicaljson import json
 
-from twisted.internet.defer import Deferred, succeed
+from twisted.internet.defer import Deferred, ensureDeferred, succeed
 from twisted.python.threadpool import ThreadPool
 from twisted.trial import unittest
 
@@ -401,10 +402,12 @@ class HomeserverTestCase(TestCase):
         hs = setup_test_homeserver(self.addCleanup, *args, **kwargs)
         stor = hs.get_datastore()
 
-        # Run the database background updates.
-        if hasattr(stor, "do_next_background_update"):
-            while not self.get_success(stor.has_completed_background_updates()):
-                self.get_success(stor.do_next_background_update(1))
+        # Run the database background updates, when running against "master".
+        if hs.__class__.__name__ == "TestHomeServer":
+            while not self.get_success(
+                stor.db.updates.has_completed_background_updates()
+            ):
+                self.get_success(stor.db.updates.do_next_background_update(1))
 
         return hs
 
@@ -415,6 +418,8 @@ class HomeserverTestCase(TestCase):
         self.reactor.pump([by] * 100)
 
     def get_success(self, d, by=0.0):
+        if inspect.isawaitable(d):
+            d = ensureDeferred(d)
         if not isinstance(d, Deferred):
             return d
         self.pump(by=by)
@@ -424,6 +429,8 @@ class HomeserverTestCase(TestCase):
         """
         Run a Deferred and get a Failure from it. The failure must be of the type `exc`.
         """
+        if inspect.isawaitable(d):
+            d = ensureDeferred(d)
         if not isinstance(d, Deferred):
             return d
         self.pump()
@@ -544,7 +551,7 @@ class HomeserverTestCase(TestCase):
         Add the given event as an extremity to the room.
         """
         self.get_success(
-            self.hs.get_datastore().simple_insert(
+            self.hs.get_datastore().db.simple_insert(
                 table="event_forward_extremities",
                 values={"room_id": room_id, "event_id": event_id},
                 desc="test_add_extremity",
diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py
index 8b8455c8b7..281b32c4b8 100644
--- a/tests/util/test_logcontext.py
+++ b/tests/util/test_logcontext.py
@@ -179,6 +179,30 @@ class LoggingContextTestCase(unittest.TestCase):
             nested_context = nested_logging_context(suffix="bar")
             self.assertEqual(nested_context.request, "foo-bar")
 
+    @defer.inlineCallbacks
+    def test_make_deferred_yieldable_with_await(self):
+        # an async function which retuns an incomplete coroutine, but doesn't
+        # follow the synapse rules.
+
+        async def blocking_function():
+            d = defer.Deferred()
+            reactor.callLater(0, d.callback, None)
+            await d
+
+        sentinel_context = LoggingContext.current_context()
+
+        with LoggingContext() as context_one:
+            context_one.request = "one"
+
+            d1 = make_deferred_yieldable(blocking_function())
+            # make sure that the context was reset by make_deferred_yieldable
+            self.assertIs(LoggingContext.current_context(), sentinel_context)
+
+            yield d1
+
+            # now it should be restored
+            self._check_test_key("one")
+
 
 # a function which returns a deferred which has been "called", but
 # which had a function which returned another incomplete deferred on
diff --git a/tests/util/test_snapshot_cache.py b/tests/util/test_snapshot_cache.py
deleted file mode 100644
index 1a44f72425..0000000000
--- a/tests/util/test_snapshot_cache.py
+++ /dev/null
@@ -1,63 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2015, 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-from twisted.internet.defer import Deferred
-
-from synapse.util.caches.snapshot_cache import SnapshotCache
-
-from .. import unittest
-
-
-class SnapshotCacheTestCase(unittest.TestCase):
-    def setUp(self):
-        self.cache = SnapshotCache()
-        self.cache.DURATION_MS = 1
-
-    def test_get_set(self):
-        # Check that getting a missing key returns None
-        self.assertEquals(self.cache.get(0, "key"), None)
-
-        # Check that setting a key with a deferred returns
-        # a deferred that resolves when the initial deferred does
-        d = Deferred()
-        set_result = self.cache.set(0, "key", d)
-        self.assertIsNotNone(set_result)
-        self.assertFalse(set_result.called)
-
-        # Check that getting the key before the deferred has resolved
-        # returns a deferred that resolves when the initial deferred does.
-        get_result_at_10 = self.cache.get(10, "key")
-        self.assertIsNotNone(get_result_at_10)
-        self.assertFalse(get_result_at_10.called)
-
-        # Check that the returned deferreds resolve when the initial deferred
-        # does.
-        d.callback("v")
-        self.assertTrue(set_result.called)
-        self.assertTrue(get_result_at_10.called)
-
-        # Check that getting the key after the deferred has resolved
-        # before the cache expires returns a resolved deferred.
-        get_result_at_11 = self.cache.get(11, "key")
-        self.assertIsNotNone(get_result_at_11)
-        if isinstance(get_result_at_11, Deferred):
-            # The cache may return the actual result rather than a deferred
-            self.assertTrue(get_result_at_11.called)
-
-        # Check that getting the key after the deferred has resolved
-        # after the cache expires returns None
-        get_result_at_12 = self.cache.get(12, "key")
-        self.assertIsNone(get_result_at_12)