summary refs log tree commit diff
path: root/tests/storage
diff options
context:
space:
mode:
Diffstat (limited to 'tests/storage')
-rw-r--r--tests/storage/test__base.py30
-rw-r--r--tests/storage/test_appservice.py86
-rw-r--r--tests/storage/test_background_update.py84
-rw-r--r--tests/storage/test_base.py36
-rw-r--r--tests/storage/test_cleanup_extrems.py181
-rw-r--r--tests/storage/test_client_ips.py197
-rw-r--r--tests/storage/test_database.py52
-rw-r--r--tests/storage/test_devices.py53
-rw-r--r--tests/storage/test_e2e_room_keys.py75
-rw-r--r--tests/storage/test_end_to_end_keys.py12
-rw-r--r--tests/storage/test_event_federation.py198
-rw-r--r--tests/storage/test_event_metrics.py38
-rw-r--r--tests/storage/test_event_push_actions.py15
-rw-r--r--tests/storage/test_id_generators.py184
-rw-r--r--tests/storage/test_keys.py15
-rw-r--r--tests/storage/test_main.py46
-rw-r--r--tests/storage/test_monthly_active_users.py377
-rw-r--r--tests/storage/test_profile.py3
-rw-r--r--tests/storage/test_purge.py15
-rw-r--r--tests/storage/test_redaction.py93
-rw-r--r--tests/storage/test_registration.py3
-rw-r--r--tests/storage/test_room.py21
-rw-r--r--tests/storage/test_roommember.py97
-rw-r--r--tests/storage/test_state.py160
-rw-r--r--tests/storage/test_transactions.py11
-rw-r--r--tests/storage/test_user_directory.py4
26 files changed, 1610 insertions, 476 deletions
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index dd49a14524..5a50e4fdd4 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -25,8 +25,8 @@ from synapse.util.caches.descriptors import Cache, cached
 from tests import unittest
 
 
-class CacheTestCase(unittest.TestCase):
-    def setUp(self):
+class CacheTestCase(unittest.HomeserverTestCase):
+    def prepare(self, reactor, clock, homeserver):
         self.cache = Cache("test")
 
     def test_empty(self):
@@ -96,7 +96,7 @@ class CacheTestCase(unittest.TestCase):
         cache.get(3)
 
 
-class CacheDecoratorTestCase(unittest.TestCase):
+class CacheDecoratorTestCase(unittest.HomeserverTestCase):
     @defer.inlineCallbacks
     def test_passthrough(self):
         class A(object):
@@ -197,7 +197,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
 
         a.func.prefill(("foo",), ObservableDeferred(d))
 
-        self.assertEquals(a.func("foo"), d.result)
+        self.assertEquals(a.func("foo").result, d.result)
         self.assertEquals(callcount[0], 0)
 
     @defer.inlineCallbacks
@@ -239,7 +239,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
         callcount2 = [0]
 
         class A(object):
-            @cached(max_entries=4)  # HACK: This makes it 2 due to cache factor
+            @cached(max_entries=2)
             def func(self, key):
                 callcount[0] += 1
                 return key
@@ -323,7 +323,7 @@ class UpsertManyTests(unittest.HomeserverTestCase):
 
         self.table_name = "table_" + hs.get_secrets().token_hex(6)
         self.get_success(
-            self.storage.runInteraction(
+            self.storage.db.runInteraction(
                 "create",
                 lambda x, *a: x.execute(*a),
                 "CREATE TABLE %s (id INTEGER, username TEXT, value TEXT)"
@@ -331,7 +331,7 @@ class UpsertManyTests(unittest.HomeserverTestCase):
             )
         )
         self.get_success(
-            self.storage.runInteraction(
+            self.storage.db.runInteraction(
                 "index",
                 lambda x, *a: x.execute(*a),
                 "CREATE UNIQUE INDEX %sindex ON %s(id, username)"
@@ -354,9 +354,9 @@ class UpsertManyTests(unittest.HomeserverTestCase):
         value_values = [["hello"], ["there"]]
 
         self.get_success(
-            self.storage.runInteraction(
+            self.storage.db.runInteraction(
                 "test",
-                self.storage._simple_upsert_many_txn,
+                self.storage.db.simple_upsert_many_txn,
                 self.table_name,
                 key_names,
                 key_values,
@@ -367,13 +367,13 @@ class UpsertManyTests(unittest.HomeserverTestCase):
 
         # Check results are what we expect
         res = self.get_success(
-            self.storage._simple_select_list(
+            self.storage.db.simple_select_list(
                 self.table_name, None, ["id, username, value"]
             )
         )
         self.assertEqual(
             set(self._dump_to_tuple(res)),
-            set([(1, "user1", "hello"), (2, "user2", "there")]),
+            {(1, "user1", "hello"), (2, "user2", "there")},
         )
 
         # Update only user2
@@ -381,9 +381,9 @@ class UpsertManyTests(unittest.HomeserverTestCase):
         value_values = [["bleb"]]
 
         self.get_success(
-            self.storage.runInteraction(
+            self.storage.db.runInteraction(
                 "test",
-                self.storage._simple_upsert_many_txn,
+                self.storage.db.simple_upsert_many_txn,
                 self.table_name,
                 key_names,
                 key_values,
@@ -394,11 +394,11 @@ class UpsertManyTests(unittest.HomeserverTestCase):
 
         # Check results are what we expect
         res = self.get_success(
-            self.storage._simple_select_list(
+            self.storage.db.simple_select_list(
                 self.table_name, None, ["id, username, value"]
             )
         )
         self.assertEqual(
             set(self._dump_to_tuple(res)),
-            set([(1, "user1", "hello"), (2, "user2", "bleb")]),
+            {(1, "user1", "hello"), (2, "user2", "bleb")},
         )
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 622b16a071..ef296e7dab 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -24,10 +24,11 @@ from twisted.internet import defer
 
 from synapse.appservice import ApplicationService, ApplicationServiceState
 from synapse.config._base import ConfigError
-from synapse.storage.appservice import (
+from synapse.storage.data_stores.main.appservice import (
     ApplicationServiceStore,
     ApplicationServiceTransactionStore,
 )
+from synapse.storage.database import Database, make_conn
 
 from tests import unittest
 from tests.utils import setup_test_homeserver
@@ -42,7 +43,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
         )
 
         hs.config.app_service_config_files = self.as_yaml_files
-        hs.config.event_cache_size = 1
+        hs.config.caches.event_cache_size = 1
         hs.config.password_providers = []
 
         self.as_token = "token1"
@@ -54,7 +55,10 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
         self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob")
         self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob")
         # must be done after inserts
-        self.store = ApplicationServiceStore(hs.get_db_conn(), hs)
+        database = hs.get_datastores().databases[0]
+        self.store = ApplicationServiceStore(
+            database, make_conn(database._database_config, database.engine), hs
+        )
 
     def tearDown(self):
         # TODO: suboptimal that we need to create files for tests!
@@ -65,14 +69,14 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
                 pass
 
     def _add_appservice(self, as_token, id, url, hs_token, sender):
-        as_yaml = dict(
-            url=url,
-            as_token=as_token,
-            hs_token=hs_token,
-            id=id,
-            sender_localpart=sender,
-            namespaces={},
-        )
+        as_yaml = {
+            "url": url,
+            "as_token": as_token,
+            "hs_token": hs_token,
+            "id": id,
+            "sender_localpart": sender,
+            "namespaces": {},
+        }
         # use the token as the filename
         with open(as_token, "w") as outfile:
             outfile.write(yaml.dump(as_yaml))
@@ -106,12 +110,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         )
 
         hs.config.app_service_config_files = self.as_yaml_files
-        hs.config.event_cache_size = 1
+        hs.config.caches.event_cache_size = 1
         hs.config.password_providers = []
 
-        self.db_pool = hs.get_db_pool()
-        self.engine = hs.database_engine
-
         self.as_list = [
             {"token": "token1", "url": "https://matrix-as.org", "id": "id_1"},
             {"token": "alpha_tok", "url": "https://alpha.com", "id": "id_alpha"},
@@ -123,17 +124,25 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
 
         self.as_yaml_files = []
 
-        self.store = TestTransactionStore(hs.get_db_conn(), hs)
+        # We assume there is only one database in these tests
+        database = hs.get_datastores().databases[0]
+        self.db_pool = database._db_pool
+        self.engine = database.engine
 
-    def _add_service(self, url, as_token, id):
-        as_yaml = dict(
-            url=url,
-            as_token=as_token,
-            hs_token="something",
-            id=id,
-            sender_localpart="a_sender",
-            namespaces={},
+        db_config = hs.config.get_single_database()
+        self.store = TestTransactionStore(
+            database, make_conn(db_config, self.engine), hs
         )
+
+    def _add_service(self, url, as_token, id):
+        as_yaml = {
+            "url": url,
+            "as_token": as_token,
+            "hs_token": "something",
+            "id": id,
+            "sender_localpart": "a_sender",
+            "namespaces": {},
+        }
         # use the token as the filename
         with open(as_token, "w") as outfile:
             outfile.write(yaml.dump(as_yaml))
@@ -375,15 +384,15 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         )
         self.assertEquals(2, len(services))
         self.assertEquals(
-            set([self.as_list[2]["id"], self.as_list[0]["id"]]),
-            set([services[0].id, services[1].id]),
+            {self.as_list[2]["id"], self.as_list[0]["id"]},
+            {services[0].id, services[1].id},
         )
 
 
 # required for ApplicationServiceTransactionStoreTestCase tests
 class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore):
-    def __init__(self, db_conn, hs):
-        super(TestTransactionStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(TestTransactionStore, self).__init__(database, db_conn, hs)
 
 
 class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
@@ -413,10 +422,13 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
         )
 
         hs.config.app_service_config_files = [f1, f2]
-        hs.config.event_cache_size = 1
+        hs.config.caches.event_cache_size = 1
         hs.config.password_providers = []
 
-        ApplicationServiceStore(hs.get_db_conn(), hs)
+        database = hs.get_datastores().databases[0]
+        ApplicationServiceStore(
+            database, make_conn(database._database_config, database.engine), hs
+        )
 
     @defer.inlineCallbacks
     def test_duplicate_ids(self):
@@ -428,11 +440,14 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
         )
 
         hs.config.app_service_config_files = [f1, f2]
-        hs.config.event_cache_size = 1
+        hs.config.caches.event_cache_size = 1
         hs.config.password_providers = []
 
         with self.assertRaises(ConfigError) as cm:
-            ApplicationServiceStore(hs.get_db_conn(), hs)
+            database = hs.get_datastores().databases[0]
+            ApplicationServiceStore(
+                database, make_conn(database._database_config, database.engine), hs
+            )
 
         e = cm.exception
         self.assertIn(f1, str(e))
@@ -449,11 +464,14 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
         )
 
         hs.config.app_service_config_files = [f1, f2]
-        hs.config.event_cache_size = 1
+        hs.config.caches.event_cache_size = 1
         hs.config.password_providers = []
 
         with self.assertRaises(ConfigError) as cm:
-            ApplicationServiceStore(hs.get_db_conn(), hs)
+            database = hs.get_datastores().databases[0]
+            ApplicationServiceStore(
+                database, make_conn(database._database_config, database.engine), hs
+            )
 
         e = cm.exception
         self.assertIn(f1, str(e))
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index 9fabe3fbc0..940b166129 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -2,74 +2,90 @@ from mock import Mock
 
 from twisted.internet import defer
 
+from synapse.storage.background_updates import BackgroundUpdater
+
 from tests import unittest
-from tests.utils import setup_test_homeserver
 
 
-class BackgroundUpdateTestCase(unittest.TestCase):
-    @defer.inlineCallbacks
-    def setUp(self):
-        hs = yield setup_test_homeserver(self.addCleanup)
-        self.store = hs.get_datastore()
-        self.clock = hs.get_clock()
+class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
+    def prepare(self, reactor, clock, homeserver):
+        self.updates = self.hs.get_datastore().db.updates  # type: BackgroundUpdater
+        # the base test class should have run the real bg updates for us
+        self.assertTrue(
+            self.get_success(self.updates.has_completed_background_updates())
+        )
 
         self.update_handler = Mock()
-
-        yield self.store.register_background_update_handler(
+        self.updates.register_background_update_handler(
             "test_update", self.update_handler
         )
 
-        # run the real background updates, to get them out the way
-        # (perhaps we should run them as part of the test HS setup, since we
-        # run all of the other schema setup stuff there?)
-        while True:
-            res = yield self.store.do_next_background_update(1000)
-            if res is None:
-                break
-
-    @defer.inlineCallbacks
     def test_do_background_update(self):
-        desired_count = 1000
+        # the time we claim each update takes
         duration_ms = 42
 
+        # the target runtime for each bg update
+        target_background_update_duration_ms = 50000
+
+        store = self.hs.get_datastore()
+        self.get_success(
+            store.db.simple_insert(
+                "background_updates",
+                values={"update_name": "test_update", "progress_json": '{"my_key": 1}'},
+            )
+        )
+
         # first step: make a bit of progress
         @defer.inlineCallbacks
         def update(progress, count):
-            self.clock.advance_time_msec(count * duration_ms)
+            yield self.clock.sleep((count * duration_ms) / 1000)
             progress = {"my_key": progress["my_key"] + 1}
-            yield self.store.runInteraction(
+            yield store.db.runInteraction(
                 "update_progress",
-                self.store._background_update_progress_txn,
+                self.updates._background_update_progress_txn,
                 "test_update",
                 progress,
             )
             return count
 
         self.update_handler.side_effect = update
-
-        yield self.store.start_background_update("test_update", {"my_key": 1})
-
         self.update_handler.reset_mock()
-        result = yield self.store.do_next_background_update(duration_ms * desired_count)
-        self.assertIsNotNone(result)
+        res = self.get_success(
+            self.updates.do_next_background_update(
+                target_background_update_duration_ms
+            ),
+            by=0.1,
+        )
+        self.assertFalse(res)
+
+        # on the first call, we should get run with the default background update size
         self.update_handler.assert_called_once_with(
-            {"my_key": 1}, self.store.DEFAULT_BACKGROUND_BATCH_SIZE
+            {"my_key": 1}, self.updates.DEFAULT_BACKGROUND_BATCH_SIZE
         )
 
         # second step: complete the update
+        # we should now get run with a much bigger number of items to update
         @defer.inlineCallbacks
         def update(progress, count):
-            yield self.store._end_background_update("test_update")
+            self.assertEqual(progress, {"my_key": 2})
+            self.assertAlmostEqual(
+                count, target_background_update_duration_ms / duration_ms, places=0,
+            )
+            yield self.updates._end_background_update("test_update")
             return count
 
         self.update_handler.side_effect = update
         self.update_handler.reset_mock()
-        result = yield self.store.do_next_background_update(duration_ms * desired_count)
-        self.assertIsNotNone(result)
-        self.update_handler.assert_called_once_with({"my_key": 2}, desired_count)
+        result = self.get_success(
+            self.updates.do_next_background_update(target_background_update_duration_ms)
+        )
+        self.assertFalse(result)
+        self.update_handler.assert_called_once()
 
         # third step: we don't expect to be called any more
         self.update_handler.reset_mock()
-        result = yield self.store.do_next_background_update(duration_ms * desired_count)
-        self.assertIsNone(result)
+        result = self.get_success(
+            self.updates.do_next_background_update(target_background_update_duration_ms)
+        )
+        self.assertTrue(result)
         self.assertFalse(self.update_handler.called)
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index c778de1f0c..278961c331 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -21,6 +21,7 @@ from mock import Mock
 from twisted.internet import defer
 
 from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
 from synapse.storage.engines import create_engine
 
 from tests import unittest
@@ -50,22 +51,25 @@ class SQLBaseStoreTestCase(unittest.TestCase):
 
         config = Mock()
         config._disable_native_upserts = True
-        config.event_cache_size = 1
-        config.database_config = {"name": "sqlite3"}
-        engine = create_engine(config.database_config)
+        config.caches = Mock()
+        config.caches.event_cache_size = 1
+        hs = TestHomeServer("test", config=config)
+
+        sqlite_config = {"name": "sqlite3"}
+        engine = create_engine(sqlite_config)
         fake_engine = Mock(wraps=engine)
         fake_engine.can_native_upsert = False
-        hs = TestHomeServer(
-            "test", db_pool=self.db_pool, config=config, database_engine=fake_engine
-        )
 
-        self.datastore = SQLBaseStore(None, hs)
+        db = Database(Mock(), Mock(config=sqlite_config), fake_engine)
+        db._db_pool = self.db_pool
+
+        self.datastore = SQLBaseStore(db, None, hs)
 
     @defer.inlineCallbacks
     def test_insert_1col(self):
         self.mock_txn.rowcount = 1
 
-        yield self.datastore._simple_insert(
+        yield self.datastore.db.simple_insert(
             table="tablename", values={"columname": "Value"}
         )
 
@@ -77,7 +81,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
     def test_insert_3cols(self):
         self.mock_txn.rowcount = 1
 
-        yield self.datastore._simple_insert(
+        yield self.datastore.db.simple_insert(
             table="tablename",
             # Use OrderedDict() so we can assert on the SQL generated
             values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]),
@@ -92,7 +96,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
         self.mock_txn.rowcount = 1
         self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)]))
 
-        value = yield self.datastore._simple_select_one_onecol(
+        value = yield self.datastore.db.simple_select_one_onecol(
             table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol"
         )
 
@@ -106,7 +110,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
         self.mock_txn.rowcount = 1
         self.mock_txn.fetchone.return_value = (1, 2, 3)
 
-        ret = yield self.datastore._simple_select_one(
+        ret = yield self.datastore.db.simple_select_one(
             table="tablename",
             keyvalues={"keycol": "TheKey"},
             retcols=["colA", "colB", "colC"],
@@ -122,7 +126,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
         self.mock_txn.rowcount = 0
         self.mock_txn.fetchone.return_value = None
 
-        ret = yield self.datastore._simple_select_one(
+        ret = yield self.datastore.db.simple_select_one(
             table="tablename",
             keyvalues={"keycol": "Not here"},
             retcols=["colA"],
@@ -137,7 +141,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
         self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)]))
         self.mock_txn.description = (("colA", None, None, None, None, None, None),)
 
-        ret = yield self.datastore._simple_select_list(
+        ret = yield self.datastore.db.simple_select_list(
             table="tablename", keyvalues={"keycol": "A set"}, retcols=["colA"]
         )
 
@@ -150,7 +154,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
     def test_update_one_1col(self):
         self.mock_txn.rowcount = 1
 
-        yield self.datastore._simple_update_one(
+        yield self.datastore.db.simple_update_one(
             table="tablename",
             keyvalues={"keycol": "TheKey"},
             updatevalues={"columnname": "New Value"},
@@ -165,7 +169,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
     def test_update_one_4cols(self):
         self.mock_txn.rowcount = 1
 
-        yield self.datastore._simple_update_one(
+        yield self.datastore.db.simple_update_one(
             table="tablename",
             keyvalues=OrderedDict([("colA", 1), ("colB", 2)]),
             updatevalues=OrderedDict([("colC", 3), ("colD", 4)]),
@@ -180,7 +184,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
     def test_delete_one(self):
         self.mock_txn.rowcount = 1
 
-        yield self.datastore._simple_delete_one(
+        yield self.datastore.db.simple_delete_one(
             table="tablename", keyvalues={"keycol": "Go away"}
         )
 
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index e9e2d5337c..43425c969a 100644
--- a/tests/storage/test_cleanup_extrems.py
+++ b/tests/storage/test_cleanup_extrems.py
@@ -14,7 +14,13 @@
 # limitations under the License.
 
 import os.path
+from unittest.mock import patch
 
+from mock import Mock
+
+import synapse.rest.admin
+from synapse.api.constants import EventTypes
+from synapse.rest.client.v1 import login, room
 from synapse.storage import prepare_database
 from synapse.types import Requester, UserID
 
@@ -33,17 +39,21 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
         # Create a test user and room
         self.user = UserID("alice", "test")
         self.requester = Requester(self.user, None, False, None, None)
-        info = self.get_success(self.room_creator.create_room(self.requester, {}))
+        info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
         self.room_id = info["room_id"]
 
     def run_background_update(self):
         """Re run the background update to clean up the extremities.
         """
         # Make sure we don't clash with in progress updates.
-        self.assertTrue(self.store._all_done, "Background updates are still ongoing")
+        self.assertTrue(
+            self.store.db.updates._all_done, "Background updates are still ongoing"
+        )
 
         schema_path = os.path.join(
             prepare_database.dir_path,
+            "data_stores",
+            "main",
             "schema",
             "delta",
             "54",
@@ -54,14 +64,20 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
             prepare_database.executescript(txn, schema_path)
 
         self.get_success(
-            self.store.runInteraction("test_delete_forward_extremities", run_delta_file)
+            self.store.db.runInteraction(
+                "test_delete_forward_extremities", run_delta_file
+            )
         )
 
         # Ugh, have to reset this flag
-        self.store._all_done = False
+        self.store.db.updates._all_done = False
 
-        while not self.get_success(self.store.has_completed_background_updates()):
-            self.get_success(self.store.do_next_background_update(100), by=0.1)
+        while not self.get_success(
+            self.store.db.updates.has_completed_background_updates()
+        ):
+            self.get_success(
+                self.store.db.updates.do_next_background_update(100), by=0.1
+            )
 
     def test_soft_failed_extremities_handled_correctly(self):
         """Test that extremities are correctly calculated in the presence of
@@ -118,7 +134,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
         latest_event_ids = self.get_success(
             self.store.get_latest_event_ids_in_room(self.room_id)
         )
-        self.assertEqual(set(latest_event_ids), set((event_id_a, event_id_b)))
+        self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b})
 
         # Run the background update and check it did the right thing
         self.run_background_update()
@@ -156,7 +172,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
         latest_event_ids = self.get_success(
             self.store.get_latest_event_ids_in_room(self.room_id)
         )
-        self.assertEqual(set(latest_event_ids), set((event_id_a, event_id_b)))
+        self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b})
 
         # Run the background update and check it did the right thing
         self.run_background_update()
@@ -211,9 +227,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
         latest_event_ids = self.get_success(
             self.store.get_latest_event_ids_in_room(self.room_id)
         )
-        self.assertEqual(
-            set(latest_event_ids), set((event_id_a, event_id_b, event_id_c))
-        )
+        self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b, event_id_c})
 
         # Run the background update and check it did the right thing
         self.run_background_update()
@@ -221,10 +235,18 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
         latest_event_ids = self.get_success(
             self.store.get_latest_event_ids_in_room(self.room_id)
         )
-        self.assertEqual(set(latest_event_ids), set([event_id_b, event_id_c]))
+        self.assertEqual(set(latest_event_ids), {event_id_b, event_id_c})
 
 
 class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
+    CONSENT_VERSION = "1"
+    EXTREMITIES_COUNT = 50
+    servlets = [
+        synapse.rest.admin.register_servlets_for_client_rest_resource,
+        login.register_servlets,
+        room.register_servlets,
+    ]
+
     def make_homeserver(self, reactor, clock):
         config = self.default_config()
         config["cleanup_extremities_with_dummy_events"] = True
@@ -233,28 +255,39 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
     def prepare(self, reactor, clock, homeserver):
         self.store = homeserver.get_datastore()
         self.room_creator = homeserver.get_room_creation_handler()
+        self.event_creator_handler = homeserver.get_event_creation_handler()
 
         # Create a test user and room
-        self.user = UserID("alice", "test")
+        self.user = UserID.from_string(self.register_user("user1", "password"))
+        self.token1 = self.login("user1", "password")
         self.requester = Requester(self.user, None, False, None, None)
-        info = self.get_success(self.room_creator.create_room(self.requester, {}))
+        info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
         self.room_id = info["room_id"]
+        self.event_creator = homeserver.get_event_creation_handler()
+        homeserver.config.user_consent_version = self.CONSENT_VERSION
 
     def test_send_dummy_event(self):
-        # Create a bushy graph with 50 extremities.
+        self._create_extremity_rich_graph()
 
-        event_id_start = self.create_and_send_event(self.room_id, self.user)
-
-        for _ in range(50):
-            self.create_and_send_event(
-                self.room_id, self.user, prev_event_ids=[event_id_start]
-            )
+        # Pump the reactor repeatedly so that the background updates have a
+        # chance to run.
+        self.pump(10 * 60)
 
         latest_event_ids = self.get_success(
             self.store.get_latest_event_ids_in_room(self.room_id)
         )
-        self.assertEqual(len(latest_event_ids), 50)
+        self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids))
 
+    @patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=0)
+    def test_send_dummy_events_when_insufficient_power(self):
+        self._create_extremity_rich_graph()
+        # Criple power levels
+        self.helper.send_state(
+            self.room_id,
+            EventTypes.PowerLevels,
+            body={"users": {str(self.user): -1}},
+            tok=self.token1,
+        )
         # Pump the reactor repeatedly so that the background updates have a
         # chance to run.
         self.pump(10 * 60)
@@ -262,4 +295,108 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
         latest_event_ids = self.get_success(
             self.store.get_latest_event_ids_in_room(self.room_id)
         )
+        # Check that the room has not been pruned
+        self.assertTrue(len(latest_event_ids) > 10)
+
+        # New user with regular levels
+        user2 = self.register_user("user2", "password")
+        token2 = self.login("user2", "password")
+        self.helper.join(self.room_id, user2, tok=token2)
+        self.pump(10 * 60)
+
+        latest_event_ids = self.get_success(
+            self.store.get_latest_event_ids_in_room(self.room_id)
+        )
         self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids))
+
+    @patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=0)
+    def test_send_dummy_event_without_consent(self):
+        self._create_extremity_rich_graph()
+        self._enable_consent_checking()
+
+        # Pump the reactor repeatedly so that the background updates have a
+        # chance to run. Attempt to add dummy event with user that has not consented
+        # Check that dummy event send fails.
+        self.pump(10 * 60)
+        latest_event_ids = self.get_success(
+            self.store.get_latest_event_ids_in_room(self.room_id)
+        )
+        self.assertTrue(len(latest_event_ids) == self.EXTREMITIES_COUNT)
+
+        # Create new user, and add consent
+        user2 = self.register_user("user2", "password")
+        token2 = self.login("user2", "password")
+        self.get_success(
+            self.store.user_set_consent_version(user2, self.CONSENT_VERSION)
+        )
+        self.helper.join(self.room_id, user2, tok=token2)
+
+        # Background updates should now cause a dummy event to be added to the graph
+        self.pump(10 * 60)
+
+        latest_event_ids = self.get_success(
+            self.store.get_latest_event_ids_in_room(self.room_id)
+        )
+        self.assertTrue(len(latest_event_ids) < 10, len(latest_event_ids))
+
+    @patch("synapse.handlers.message._DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY", new=250)
+    def test_expiry_logic(self):
+        """Simple test to ensure that _expire_rooms_to_exclude_from_dummy_event_insertion()
+        expires old entries correctly.
+        """
+        self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion[
+            "1"
+        ] = 100000
+        self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion[
+            "2"
+        ] = 200000
+        self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion[
+            "3"
+        ] = 300000
+        self.event_creator_handler._expire_rooms_to_exclude_from_dummy_event_insertion()
+        # All entries within time frame
+        self.assertEqual(
+            len(
+                self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion
+            ),
+            3,
+        )
+        # Oldest room to expire
+        self.pump(1)
+        self.event_creator_handler._expire_rooms_to_exclude_from_dummy_event_insertion()
+        self.assertEqual(
+            len(
+                self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion
+            ),
+            2,
+        )
+        # All rooms to expire
+        self.pump(2)
+        self.assertEqual(
+            len(
+                self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion
+            ),
+            0,
+        )
+
+    def _create_extremity_rich_graph(self):
+        """Helper method to create bushy graph on demand"""
+
+        event_id_start = self.create_and_send_event(self.room_id, self.user)
+
+        for _ in range(self.EXTREMITIES_COUNT):
+            self.create_and_send_event(
+                self.room_id, self.user, prev_event_ids=[event_id_start]
+            )
+
+        latest_event_ids = self.get_success(
+            self.store.get_latest_event_ids_in_room(self.room_id)
+        )
+        self.assertEqual(len(latest_event_ids), 50)
+
+    def _enable_consent_checking(self):
+        """Helper method to enable consent checking"""
+        self.event_creator._block_events_without_consent_error = "No consent from user"
+        consent_uri_builder = Mock()
+        consent_uri_builder.build_user_consent_uri.return_value = "http://example.com"
+        self.event_creator._consent_uri_builder = consent_uri_builder
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 09305c3bf1..3b483bc7f0 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -23,6 +23,7 @@ from synapse.http.site import XForwardedForRequest
 from synapse.rest.client.v1 import login
 
 from tests import unittest
+from tests.unittest import override_config
 
 
 class ClientIpStoreTestCase(unittest.HomeserverTestCase):
@@ -37,9 +38,13 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         self.reactor.advance(12345678)
 
         user_id = "@user:id"
+        device_id = "MY_DEVICE"
+
+        # Insert a user IP
+        self.get_success(self.store.store_device(user_id, device_id, "display name",))
         self.get_success(
             self.store.insert_client_ip(
-                user_id, "access_token", "ip", "user_agent", "device_id"
+                user_id, "access_token", "ip", "user_agent", device_id
             )
         )
 
@@ -47,15 +52,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         self.reactor.advance(10)
 
         result = self.get_success(
-            self.store.get_last_client_ip_by_device(user_id, "device_id")
+            self.store.get_last_client_ip_by_device(user_id, device_id)
         )
 
-        r = result[(user_id, "device_id")]
+        r = result[(user_id, device_id)]
         self.assertDictContainsSubset(
             {
                 "user_id": user_id,
-                "device_id": "device_id",
-                "access_token": "access_token",
+                "device_id": device_id,
                 "ip": "ip",
                 "user_agent": "user_agent",
                 "last_seen": 12345678000,
@@ -82,7 +86,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         self.pump(0)
 
         result = self.get_success(
-            self.store._simple_select_list(
+            self.store.db.simple_select_list(
                 table="user_ips",
                 keyvalues={"user_id": user_id},
                 retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
@@ -113,7 +117,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         self.pump(0)
 
         result = self.get_success(
-            self.store._simple_select_list(
+            self.store.db.simple_select_list(
                 table="user_ips",
                 keyvalues={"user_id": user_id},
                 retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
@@ -134,9 +138,8 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
             ],
         )
 
+    @override_config({"limit_usage_by_mau": False, "max_mau_value": 50})
     def test_disabled_monthly_active_user(self):
-        self.hs.config.limit_usage_by_mau = False
-        self.hs.config.max_mau_value = 50
         user_id = "@user:server"
         self.get_success(
             self.store.insert_client_ip(
@@ -146,9 +149,8 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
         self.assertFalse(active)
 
+    @override_config({"limit_usage_by_mau": True, "max_mau_value": 50})
     def test_adding_monthly_active_user_when_full(self):
-        self.hs.config.limit_usage_by_mau = True
-        self.hs.config.max_mau_value = 50
         lots_of_users = 100
         user_id = "@user:server"
 
@@ -163,9 +165,8 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
         self.assertFalse(active)
 
+    @override_config({"limit_usage_by_mau": True, "max_mau_value": 50})
     def test_adding_monthly_active_user_when_space(self):
-        self.hs.config.limit_usage_by_mau = True
-        self.hs.config.max_mau_value = 50
         user_id = "@user:server"
         active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
         self.assertFalse(active)
@@ -181,9 +182,8 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
         self.assertTrue(active)
 
+    @override_config({"limit_usage_by_mau": True, "max_mau_value": 50})
     def test_updating_monthly_active_user_when_space(self):
-        self.hs.config.limit_usage_by_mau = True
-        self.hs.config.max_mau_value = 50
         user_id = "@user:server"
         self.get_success(self.store.register_user(user_id=user_id, password_hash=None))
 
@@ -201,6 +201,173 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
         self.assertTrue(active)
 
+    def test_devices_last_seen_bg_update(self):
+        # First make sure we have completed all updates.
+        while not self.get_success(
+            self.store.db.updates.has_completed_background_updates()
+        ):
+            self.get_success(
+                self.store.db.updates.do_next_background_update(100), by=0.1
+            )
+
+        user_id = "@user:id"
+        device_id = "MY_DEVICE"
+
+        # Insert a user IP
+        self.get_success(self.store.store_device(user_id, device_id, "display name",))
+        self.get_success(
+            self.store.insert_client_ip(
+                user_id, "access_token", "ip", "user_agent", device_id
+            )
+        )
+        # Force persisting to disk
+        self.reactor.advance(200)
+
+        # But clear the associated entry in devices table
+        self.get_success(
+            self.store.db.simple_update(
+                table="devices",
+                keyvalues={"user_id": user_id, "device_id": device_id},
+                updatevalues={"last_seen": None, "ip": None, "user_agent": None},
+                desc="test_devices_last_seen_bg_update",
+            )
+        )
+
+        # We should now get nulls when querying
+        result = self.get_success(
+            self.store.get_last_client_ip_by_device(user_id, device_id)
+        )
+
+        r = result[(user_id, device_id)]
+        self.assertDictContainsSubset(
+            {
+                "user_id": user_id,
+                "device_id": device_id,
+                "ip": None,
+                "user_agent": None,
+                "last_seen": None,
+            },
+            r,
+        )
+
+        # Register the background update to run again.
+        self.get_success(
+            self.store.db.simple_insert(
+                table="background_updates",
+                values={
+                    "update_name": "devices_last_seen",
+                    "progress_json": "{}",
+                    "depends_on": None,
+                },
+            )
+        )
+
+        # ... and tell the DataStore that it hasn't finished all updates yet
+        self.store.db.updates._all_done = False
+
+        # Now let's actually drive the updates to completion
+        while not self.get_success(
+            self.store.db.updates.has_completed_background_updates()
+        ):
+            self.get_success(
+                self.store.db.updates.do_next_background_update(100), by=0.1
+            )
+
+        # We should now get the correct result again
+        result = self.get_success(
+            self.store.get_last_client_ip_by_device(user_id, device_id)
+        )
+
+        r = result[(user_id, device_id)]
+        self.assertDictContainsSubset(
+            {
+                "user_id": user_id,
+                "device_id": device_id,
+                "ip": "ip",
+                "user_agent": "user_agent",
+                "last_seen": 0,
+            },
+            r,
+        )
+
+    def test_old_user_ips_pruned(self):
+        # First make sure we have completed all updates.
+        while not self.get_success(
+            self.store.db.updates.has_completed_background_updates()
+        ):
+            self.get_success(
+                self.store.db.updates.do_next_background_update(100), by=0.1
+            )
+
+        user_id = "@user:id"
+        device_id = "MY_DEVICE"
+
+        # Insert a user IP
+        self.get_success(self.store.store_device(user_id, device_id, "display name",))
+        self.get_success(
+            self.store.insert_client_ip(
+                user_id, "access_token", "ip", "user_agent", device_id
+            )
+        )
+
+        # Force persisting to disk
+        self.reactor.advance(200)
+
+        # We should see that in the DB
+        result = self.get_success(
+            self.store.db.simple_select_list(
+                table="user_ips",
+                keyvalues={"user_id": user_id},
+                retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
+                desc="get_user_ip_and_agents",
+            )
+        )
+
+        self.assertEqual(
+            result,
+            [
+                {
+                    "access_token": "access_token",
+                    "ip": "ip",
+                    "user_agent": "user_agent",
+                    "device_id": device_id,
+                    "last_seen": 0,
+                }
+            ],
+        )
+
+        # Now advance by a couple of months
+        self.reactor.advance(60 * 24 * 60 * 60)
+
+        # We should get no results.
+        result = self.get_success(
+            self.store.db.simple_select_list(
+                table="user_ips",
+                keyvalues={"user_id": user_id},
+                retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
+                desc="get_user_ip_and_agents",
+            )
+        )
+
+        self.assertEqual(result, [])
+
+        # But we should still get the correct values for the device
+        result = self.get_success(
+            self.store.get_last_client_ip_by_device(user_id, device_id)
+        )
+
+        r = result[(user_id, device_id)]
+        self.assertDictContainsSubset(
+            {
+                "user_id": user_id,
+                "device_id": device_id,
+                "ip": "ip",
+                "user_agent": "user_agent",
+                "last_seen": 0,
+            },
+            r,
+        )
+
 
 class ClientIpAuthTestCase(unittest.HomeserverTestCase):
 
diff --git a/tests/storage/test_database.py b/tests/storage/test_database.py
new file mode 100644
index 0000000000..5a77c84962
--- /dev/null
+++ b/tests/storage/test_database.py
@@ -0,0 +1,52 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.storage.database import make_tuple_comparison_clause
+from synapse.storage.engines import BaseDatabaseEngine
+
+from tests import unittest
+
+
+def _stub_db_engine(**kwargs) -> BaseDatabaseEngine:
+    # returns a DatabaseEngine, circumventing the abc mechanism
+    # any kwargs are set as attributes on the class before instantiating it
+    t = type(
+        "TestBaseDatabaseEngine",
+        (BaseDatabaseEngine,),
+        dict(BaseDatabaseEngine.__dict__),
+    )
+    # defeat the abc mechanism
+    t.__abstractmethods__ = set()
+    for k, v in kwargs.items():
+        setattr(t, k, v)
+    return t(None, None)
+
+
+class TupleComparisonClauseTestCase(unittest.TestCase):
+    def test_native_tuple_comparison(self):
+        db_engine = _stub_db_engine(supports_tuple_comparison=True)
+        clause, args = make_tuple_comparison_clause(db_engine, [("a", 1), ("b", 2)])
+        self.assertEqual(clause, "(a,b) > (?,?)")
+        self.assertEqual(args, [1, 2])
+
+    def test_emulated_tuple_comparison(self):
+        db_engine = _stub_db_engine(supports_tuple_comparison=False)
+        clause, args = make_tuple_comparison_clause(
+            db_engine, [("a", 1), ("b", 2), ("c", 3)]
+        )
+        self.assertEqual(
+            clause, "(a >= ? AND (a > ? OR (b >= ? AND (b > ? OR c > ?))))"
+        )
+        self.assertEqual(args, [1, 1, 2, 2, 3])
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index 3cc18f9f1c..c2539b353a 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -72,7 +72,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
         )
 
     @defer.inlineCallbacks
-    def test_get_devices_by_remote(self):
+    def test_get_device_updates_by_remote(self):
         device_ids = ["device_id1", "device_id2"]
 
         # Add two device updates with a single stream_id
@@ -81,63 +81,20 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
         )
 
         # Get all device updates ever meant for this remote
-        now_stream_id, device_updates = yield self.store.get_devices_by_remote(
+        now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
             "somehost", -1, limit=100
         )
 
         # Check original device_ids are contained within these updates
         self._check_devices_in_updates(device_ids, device_updates)
 
-    @defer.inlineCallbacks
-    def test_get_devices_by_remote_limited(self):
-        # Test breaking the update limit in 1, 101, and 1 device_id segments
-
-        # first add one device
-        device_ids1 = ["device_id0"]
-        yield self.store.add_device_change_to_streams(
-            "user_id", device_ids1, ["someotherhost"]
-        )
-
-        # then add 101
-        device_ids2 = ["device_id" + str(i + 1) for i in range(101)]
-        yield self.store.add_device_change_to_streams(
-            "user_id", device_ids2, ["someotherhost"]
-        )
-
-        # then one more
-        device_ids3 = ["newdevice"]
-        yield self.store.add_device_change_to_streams(
-            "user_id", device_ids3, ["someotherhost"]
-        )
-
-        #
-        # now read them back.
-        #
-
-        # first we should get a single update
-        now_stream_id, device_updates = yield self.store.get_devices_by_remote(
-            "someotherhost", -1, limit=100
-        )
-        self._check_devices_in_updates(device_ids1, device_updates)
-
-        # Then we should get an empty list back as the 101 devices broke the limit
-        now_stream_id, device_updates = yield self.store.get_devices_by_remote(
-            "someotherhost", now_stream_id, limit=100
-        )
-        self.assertEqual(len(device_updates), 0)
-
-        # The 101 devices should've been cleared, so we should now just get one device
-        # update
-        now_stream_id, device_updates = yield self.store.get_devices_by_remote(
-            "someotherhost", now_stream_id, limit=100
-        )
-        self._check_devices_in_updates(device_ids3, device_updates)
-
     def _check_devices_in_updates(self, expected_device_ids, device_updates):
         """Check that an specific device ids exist in a list of device update EDUs"""
         self.assertEqual(len(device_updates), len(expected_device_ids))
 
-        received_device_ids = {update["device_id"] for update in device_updates}
+        received_device_ids = {
+            update["device_id"] for edu_type, update in device_updates
+        }
         self.assertEqual(received_device_ids, set(expected_device_ids))
 
     @defer.inlineCallbacks
diff --git a/tests/storage/test_e2e_room_keys.py b/tests/storage/test_e2e_room_keys.py
new file mode 100644
index 0000000000..35dafbb904
--- /dev/null
+++ b/tests/storage/test_e2e_room_keys.py
@@ -0,0 +1,75 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from tests import unittest
+
+# sample room_key data for use in the tests
+room_key = {
+    "first_message_index": 1,
+    "forwarded_count": 1,
+    "is_verified": False,
+    "session_data": "SSBBTSBBIEZJU0gK",
+}
+
+
+class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
+    def make_homeserver(self, reactor, clock):
+        hs = self.setup_test_homeserver("server", http_client=None)
+        self.store = hs.get_datastore()
+        return hs
+
+    def test_room_keys_version_delete(self):
+        # test that deleting a room key backup deletes the keys
+        version1 = self.get_success(
+            self.store.create_e2e_room_keys_version(
+                "user_id", {"algorithm": "rot13", "auth_data": {}}
+            )
+        )
+
+        self.get_success(
+            self.store.add_e2e_room_keys(
+                "user_id", version1, [("room", "session", room_key)]
+            )
+        )
+
+        version2 = self.get_success(
+            self.store.create_e2e_room_keys_version(
+                "user_id", {"algorithm": "rot13", "auth_data": {}}
+            )
+        )
+
+        self.get_success(
+            self.store.add_e2e_room_keys(
+                "user_id", version2, [("room", "session", room_key)]
+            )
+        )
+
+        # make sure the keys were stored properly
+        keys = self.get_success(self.store.get_e2e_room_keys("user_id", version1))
+        self.assertEqual(len(keys["rooms"]), 1)
+
+        keys = self.get_success(self.store.get_e2e_room_keys("user_id", version2))
+        self.assertEqual(len(keys["rooms"]), 1)
+
+        # delete version1
+        self.get_success(self.store.delete_e2e_room_keys_version("user_id", version1))
+
+        # make sure the key from version1 is gone, and the key from version2 is
+        # still there
+        keys = self.get_success(self.store.get_e2e_room_keys("user_id", version1))
+        self.assertEqual(len(keys["rooms"]), 0)
+
+        keys = self.get_success(self.store.get_e2e_room_keys("user_id", version2))
+        self.assertEqual(len(keys["rooms"]), 1)
diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py
index c8ece15284..398d546280 100644
--- a/tests/storage/test_end_to_end_keys.py
+++ b/tests/storage/test_end_to_end_keys.py
@@ -38,7 +38,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
         self.assertIn("user", res)
         self.assertIn("device", res["user"])
         dev = res["user"]["device"]
-        self.assertDictContainsSubset({"keys": json, "device_display_name": None}, dev)
+        self.assertDictContainsSubset(json, dev)
 
     @defer.inlineCallbacks
     def test_reupload_key(self):
@@ -68,7 +68,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
         self.assertIn("device", res["user"])
         dev = res["user"]["device"]
         self.assertDictContainsSubset(
-            {"keys": json, "device_display_name": "display_name"}, dev
+            {"key": "value", "unsigned": {"device_display_name": "display_name"}}, dev
         )
 
     @defer.inlineCallbacks
@@ -80,10 +80,10 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
         yield self.store.store_device("user2", "device1", None)
         yield self.store.store_device("user2", "device2", None)
 
-        yield self.store.set_e2e_device_keys("user1", "device1", now, "json11")
-        yield self.store.set_e2e_device_keys("user1", "device2", now, "json12")
-        yield self.store.set_e2e_device_keys("user2", "device1", now, "json21")
-        yield self.store.set_e2e_device_keys("user2", "device2", now, "json22")
+        yield self.store.set_e2e_device_keys("user1", "device1", now, {"key": "json11"})
+        yield self.store.set_e2e_device_keys("user1", "device2", now, {"key": "json12"})
+        yield self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"})
+        yield self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"})
 
         res = yield self.store.get_e2e_device_keys(
             (("user1", "device1"), ("user2", "device2"))
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 86c7ac350d..3aeec0dc0f 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -13,19 +13,14 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from twisted.internet import defer
-
 import tests.unittest
 import tests.utils
 
 
-class EventFederationWorkerStoreTestCase(tests.unittest.TestCase):
-    @defer.inlineCallbacks
-    def setUp(self):
-        hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
+class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
+    def prepare(self, reactor, clock, hs):
         self.store = hs.get_datastore()
 
-    @defer.inlineCallbacks
     def test_get_prev_events_for_room(self):
         room_id = "@ROOM:local"
 
@@ -57,21 +52,182 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase):
                     "(event_id, algorithm, hash) "
                     "VALUES (?, 'sha256', ?)"
                 ),
-                (event_id, b"ffff"),
+                (event_id, bytearray(b"ffff")),
             )
 
-        for i in range(0, 11):
-            yield self.store.runInteraction("insert", insert_event, i)
+        for i in range(0, 20):
+            self.get_success(self.store.db.runInteraction("insert", insert_event, i))
 
-        # this should get the last five and five others
-        r = yield self.store.get_prev_events_for_room(room_id)
+        # this should get the last ten
+        r = self.get_success(self.store.get_prev_events_for_room(room_id))
         self.assertEqual(10, len(r))
-        for i in range(0, 5):
-            el = r[i]
-            depth = el[2]
-            self.assertEqual(10 - i, depth)
-
-        for i in range(5, 5):
-            el = r[i]
-            depth = el[2]
-            self.assertLessEqual(5, depth)
+        for i in range(0, 10):
+            self.assertEqual("$event_%i:local" % (19 - i), r[i])
+
+    def test_get_rooms_with_many_extremities(self):
+        room1 = "#room1"
+        room2 = "#room2"
+        room3 = "#room3"
+
+        def insert_event(txn, i, room_id):
+            event_id = "$event_%i:local" % i
+            txn.execute(
+                (
+                    "INSERT INTO event_forward_extremities (room_id, event_id) "
+                    "VALUES (?, ?)"
+                ),
+                (room_id, event_id),
+            )
+
+        for i in range(0, 20):
+            self.get_success(
+                self.store.db.runInteraction("insert", insert_event, i, room1)
+            )
+            self.get_success(
+                self.store.db.runInteraction("insert", insert_event, i, room2)
+            )
+            self.get_success(
+                self.store.db.runInteraction("insert", insert_event, i, room3)
+            )
+
+        # Test simple case
+        r = self.get_success(self.store.get_rooms_with_many_extremities(5, 5, []))
+        self.assertEqual(len(r), 3)
+
+        # Does filter work?
+
+        r = self.get_success(self.store.get_rooms_with_many_extremities(5, 5, [room1]))
+        self.assertTrue(room2 in r)
+        self.assertTrue(room3 in r)
+        self.assertEqual(len(r), 2)
+
+        r = self.get_success(
+            self.store.get_rooms_with_many_extremities(5, 5, [room1, room2])
+        )
+        self.assertEqual(r, [room3])
+
+        # Does filter and limit work?
+
+        r = self.get_success(self.store.get_rooms_with_many_extremities(5, 1, [room1]))
+        self.assertTrue(r == [room2] or r == [room3])
+
+    def test_auth_difference(self):
+        room_id = "@ROOM:local"
+
+        # The silly auth graph we use to test the auth difference algorithm,
+        # where the top are the most recent events.
+        #
+        #   A   B
+        #    \ /
+        #  D  E
+        #  \  |
+        #   ` F   C
+        #     |  /|
+        #     G ´ |
+        #     | \ |
+        #     H   I
+        #     |   |
+        #     K   J
+
+        auth_graph = {
+            "a": ["e"],
+            "b": ["e"],
+            "c": ["g", "i"],
+            "d": ["f"],
+            "e": ["f"],
+            "f": ["g"],
+            "g": ["h", "i"],
+            "h": ["k"],
+            "i": ["j"],
+            "k": [],
+            "j": [],
+        }
+
+        depth_map = {
+            "a": 7,
+            "b": 7,
+            "c": 4,
+            "d": 6,
+            "e": 6,
+            "f": 5,
+            "g": 3,
+            "h": 2,
+            "i": 2,
+            "k": 1,
+            "j": 1,
+        }
+
+        # We rudely fiddle with the appropriate tables directly, as that's much
+        # easier than constructing events properly.
+
+        def insert_event(txn, event_id, stream_ordering):
+
+            depth = depth_map[event_id]
+
+            self.store.db.simple_insert_txn(
+                txn,
+                table="events",
+                values={
+                    "event_id": event_id,
+                    "room_id": room_id,
+                    "depth": depth,
+                    "topological_ordering": depth,
+                    "type": "m.test",
+                    "processed": True,
+                    "outlier": False,
+                    "stream_ordering": stream_ordering,
+                },
+            )
+
+            self.store.db.simple_insert_many_txn(
+                txn,
+                table="event_auth",
+                values=[
+                    {"event_id": event_id, "room_id": room_id, "auth_id": a}
+                    for a in auth_graph[event_id]
+                ],
+            )
+
+        next_stream_ordering = 0
+        for event_id in auth_graph:
+            next_stream_ordering += 1
+            self.get_success(
+                self.store.db.runInteraction(
+                    "insert", insert_event, event_id, next_stream_ordering
+                )
+            )
+
+        # Now actually test that various combinations give the right result:
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference([{"a"}, {"b"}])
+        )
+        self.assertSetEqual(difference, {"a", "b"})
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}])
+        )
+        self.assertSetEqual(difference, {"a", "b", "c", "e", "f"})
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference([{"a", "c"}, {"b"}])
+        )
+        self.assertSetEqual(difference, {"a", "b", "c"})
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference([{"a"}, {"b"}, {"d"}])
+        )
+        self.assertSetEqual(difference, {"a", "b", "d", "e"})
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}, {"d"}])
+        )
+        self.assertSetEqual(difference, {"a", "b", "c", "d", "e", "f"})
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference([{"a"}, {"b"}, {"e"}])
+        )
+        self.assertSetEqual(difference, {"a", "b"})
+
+        difference = self.get_success(self.store.get_auth_chain_difference([{"a"}]))
+        self.assertSetEqual(difference, set())
diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py
index f26ff57a18..a7b85004e5 100644
--- a/tests/storage/test_event_metrics.py
+++ b/tests/storage/test_event_metrics.py
@@ -33,7 +33,7 @@ class ExtremStatisticsTestCase(HomeserverTestCase):
         events = [(3, 2), (6, 2), (4, 6)]
 
         for event_count, extrems in events:
-            info = self.get_success(room_creator.create_room(requester, {}))
+            info, _ = self.get_success(room_creator.create_room(requester, {}))
             room_id = info["room_id"]
 
             last_event = None
@@ -59,24 +59,22 @@ class ExtremStatisticsTestCase(HomeserverTestCase):
             )
         )
 
-        expected = set(
-            [
-                b'synapse_forward_extremities_bucket{le="1.0"} 0.0',
-                b'synapse_forward_extremities_bucket{le="2.0"} 2.0',
-                b'synapse_forward_extremities_bucket{le="3.0"} 2.0',
-                b'synapse_forward_extremities_bucket{le="5.0"} 2.0',
-                b'synapse_forward_extremities_bucket{le="7.0"} 3.0',
-                b'synapse_forward_extremities_bucket{le="10.0"} 3.0',
-                b'synapse_forward_extremities_bucket{le="15.0"} 3.0',
-                b'synapse_forward_extremities_bucket{le="20.0"} 3.0',
-                b'synapse_forward_extremities_bucket{le="50.0"} 3.0',
-                b'synapse_forward_extremities_bucket{le="100.0"} 3.0',
-                b'synapse_forward_extremities_bucket{le="200.0"} 3.0',
-                b'synapse_forward_extremities_bucket{le="500.0"} 3.0',
-                b'synapse_forward_extremities_bucket{le="+Inf"} 3.0',
-                b"synapse_forward_extremities_count 3.0",
-                b"synapse_forward_extremities_sum 10.0",
-            ]
-        )
+        expected = {
+            b'synapse_forward_extremities_bucket{le="1.0"} 0.0',
+            b'synapse_forward_extremities_bucket{le="2.0"} 2.0',
+            b'synapse_forward_extremities_bucket{le="3.0"} 2.0',
+            b'synapse_forward_extremities_bucket{le="5.0"} 2.0',
+            b'synapse_forward_extremities_bucket{le="7.0"} 3.0',
+            b'synapse_forward_extremities_bucket{le="10.0"} 3.0',
+            b'synapse_forward_extremities_bucket{le="15.0"} 3.0',
+            b'synapse_forward_extremities_bucket{le="20.0"} 3.0',
+            b'synapse_forward_extremities_bucket{le="50.0"} 3.0',
+            b'synapse_forward_extremities_bucket{le="100.0"} 3.0',
+            b'synapse_forward_extremities_bucket{le="200.0"} 3.0',
+            b'synapse_forward_extremities_bucket{le="500.0"} 3.0',
+            b'synapse_forward_extremities_bucket{le="+Inf"} 3.0',
+            b"synapse_forward_extremities_count 3.0",
+            b"synapse_forward_extremities_sum 10.0",
+        }
 
         self.assertEqual(items, expected)
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index b114c6fb1d..b45bc9c115 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -35,6 +35,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
     def setUp(self):
         hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
         self.store = hs.get_datastore()
+        self.persist_events_store = hs.get_datastores().persist_events
 
     @defer.inlineCallbacks
     def test_get_unread_push_actions_for_user_in_range_for_http(self):
@@ -55,7 +56,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
 
         @defer.inlineCallbacks
         def _assert_counts(noitf_count, highlight_count):
-            counts = yield self.store.runInteraction(
+            counts = yield self.store.db.runInteraction(
                 "", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0
             )
             self.assertEquals(
@@ -74,20 +75,20 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
             yield self.store.add_push_actions_to_staging(
                 event.event_id, {user_id: action}
             )
-            yield self.store.runInteraction(
+            yield self.store.db.runInteraction(
                 "",
-                self.store._set_push_actions_for_event_and_users_txn,
+                self.persist_events_store._set_push_actions_for_event_and_users_txn,
                 [(event, None)],
                 [(event, None)],
             )
 
         def _rotate(stream):
-            return self.store.runInteraction(
+            return self.store.db.runInteraction(
                 "", self.store._rotate_notifs_before_txn, stream
             )
 
         def _mark_read(stream, depth):
-            return self.store.runInteraction(
+            return self.store.db.runInteraction(
                 "",
                 self.store._remove_old_push_actions_before_txn,
                 room_id,
@@ -116,7 +117,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
         yield _inject_actions(6, PlAIN_NOTIF)
         yield _rotate(7)
 
-        yield self.store._simple_delete(
+        yield self.store.db.simple_delete(
             table="event_push_actions", keyvalues={"1": 1}, desc=""
         )
 
@@ -135,7 +136,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
     @defer.inlineCallbacks
     def test_find_first_stream_ordering_after_ts(self):
         def add_event(so, ts):
-            return self.store._simple_insert(
+            return self.store.db.simple_insert(
                 "events",
                 {
                     "stream_ordering": so,
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
new file mode 100644
index 0000000000..55e9ecf264
--- /dev/null
+++ b/tests/storage/test_id_generators.py
@@ -0,0 +1,184 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from synapse.storage.database import Database
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
+
+from tests.unittest import HomeserverTestCase
+from tests.utils import USE_POSTGRES_FOR_TESTS
+
+
+class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
+    if not USE_POSTGRES_FOR_TESTS:
+        skip = "Requires Postgres"
+
+    def prepare(self, reactor, clock, hs):
+        self.store = hs.get_datastore()
+        self.db = self.store.db  # type: Database
+
+        self.get_success(self.db.runInteraction("_setup_db", self._setup_db))
+
+    def _setup_db(self, txn):
+        txn.execute("CREATE SEQUENCE foobar_seq")
+        txn.execute(
+            """
+            CREATE TABLE foobar (
+                stream_id BIGINT NOT NULL,
+                instance_name TEXT NOT NULL,
+                data TEXT
+            );
+            """
+        )
+
+    def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator:
+        def _create(conn):
+            return MultiWriterIdGenerator(
+                conn,
+                self.db,
+                instance_name=instance_name,
+                table="foobar",
+                instance_column="instance_name",
+                id_column="stream_id",
+                sequence_name="foobar_seq",
+            )
+
+        return self.get_success(self.db.runWithConnection(_create))
+
+    def _insert_rows(self, instance_name: str, number: int):
+        def _insert(txn):
+            for _ in range(number):
+                txn.execute(
+                    "INSERT INTO foobar VALUES (nextval('foobar_seq'), ?)",
+                    (instance_name,),
+                )
+
+        self.get_success(self.db.runInteraction("test_single_instance", _insert))
+
+    def test_empty(self):
+        """Test an ID generator against an empty database gives sensible
+        current positions.
+        """
+
+        id_gen = self._create_id_generator()
+
+        # The table is empty so we expect an empty map for positions
+        self.assertEqual(id_gen.get_positions(), {})
+
+    def test_single_instance(self):
+        """Test that reads and writes from a single process are handled
+        correctly.
+        """
+
+        # Prefill table with 7 rows written by 'master'
+        self._insert_rows("master", 7)
+
+        id_gen = self._create_id_generator()
+
+        self.assertEqual(id_gen.get_positions(), {"master": 7})
+        self.assertEqual(id_gen.get_current_token("master"), 7)
+
+        # Try allocating a new ID gen and check that we only see position
+        # advanced after we leave the context manager.
+
+        async def _get_next_async():
+            with await id_gen.get_next() as stream_id:
+                self.assertEqual(stream_id, 8)
+
+                self.assertEqual(id_gen.get_positions(), {"master": 7})
+                self.assertEqual(id_gen.get_current_token("master"), 7)
+
+        self.get_success(_get_next_async())
+
+        self.assertEqual(id_gen.get_positions(), {"master": 8})
+        self.assertEqual(id_gen.get_current_token("master"), 8)
+
+    def test_multi_instance(self):
+        """Test that reads and writes from multiple processes are handled
+        correctly.
+        """
+        self._insert_rows("first", 3)
+        self._insert_rows("second", 4)
+
+        first_id_gen = self._create_id_generator("first")
+        second_id_gen = self._create_id_generator("second")
+
+        self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
+        self.assertEqual(first_id_gen.get_current_token("first"), 3)
+        self.assertEqual(first_id_gen.get_current_token("second"), 7)
+
+        # Try allocating a new ID gen and check that we only see position
+        # advanced after we leave the context manager.
+
+        async def _get_next_async():
+            with await first_id_gen.get_next() as stream_id:
+                self.assertEqual(stream_id, 8)
+
+                self.assertEqual(
+                    first_id_gen.get_positions(), {"first": 3, "second": 7}
+                )
+
+        self.get_success(_get_next_async())
+
+        self.assertEqual(first_id_gen.get_positions(), {"first": 8, "second": 7})
+
+        # However the ID gen on the second instance won't have seen the update
+        self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
+
+        # ... but calling `get_next` on the second instance should give a unique
+        # stream ID
+
+        async def _get_next_async():
+            with await second_id_gen.get_next() as stream_id:
+                self.assertEqual(stream_id, 9)
+
+                self.assertEqual(
+                    second_id_gen.get_positions(), {"first": 3, "second": 7}
+                )
+
+        self.get_success(_get_next_async())
+
+        self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 9})
+
+        # If the second ID gen gets told about the first, it correctly updates
+        second_id_gen.advance("first", 8)
+        self.assertEqual(second_id_gen.get_positions(), {"first": 8, "second": 9})
+
+    def test_get_next_txn(self):
+        """Test that the `get_next_txn` function works correctly.
+        """
+
+        # Prefill table with 7 rows written by 'master'
+        self._insert_rows("master", 7)
+
+        id_gen = self._create_id_generator()
+
+        self.assertEqual(id_gen.get_positions(), {"master": 7})
+        self.assertEqual(id_gen.get_current_token("master"), 7)
+
+        # Try allocating a new ID gen and check that we only see position
+        # advanced after we leave the context manager.
+
+        def _get_next_txn(txn):
+            stream_id = id_gen.get_next_txn(txn)
+            self.assertEqual(stream_id, 8)
+
+            self.assertEqual(id_gen.get_positions(), {"master": 7})
+            self.assertEqual(id_gen.get_current_token("master"), 7)
+
+        self.get_success(self.db.runInteraction("test", _get_next_txn))
+
+        self.assertEqual(id_gen.get_positions(), {"master": 8})
+        self.assertEqual(id_gen.get_current_token("master"), 8)
diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py
index e07ff01201..95f309fbbc 100644
--- a/tests/storage/test_keys.py
+++ b/tests/storage/test_keys.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 import signedjson.key
+import unpaddedbase64
 
 from twisted.internet.defer import Deferred
 
@@ -21,11 +22,17 @@ from synapse.storage.keys import FetchKeyResult
 
 import tests.unittest
 
-KEY_1 = signedjson.key.decode_verify_key_base64(
-    "ed25519", "key1", "fP5l4JzpZPq/zdbBg5xx6lQGAAOM9/3w94cqiJ5jPrw"
+
+def decode_verify_key_base64(key_id: str, key_base64: str):
+    key_bytes = unpaddedbase64.decode_base64(key_base64)
+    return signedjson.key.decode_verify_key_bytes(key_id, key_bytes)
+
+
+KEY_1 = decode_verify_key_base64(
+    "ed25519:key1", "fP5l4JzpZPq/zdbBg5xx6lQGAAOM9/3w94cqiJ5jPrw"
 )
-KEY_2 = signedjson.key.decode_verify_key_base64(
-    "ed25519", "key2", "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw"
+KEY_2 = decode_verify_key_base64(
+    "ed25519:key2", "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw"
 )
 
 
diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py
new file mode 100644
index 0000000000..ab0df5ea93
--- /dev/null
+++ b/tests/storage/test_main.py
@@ -0,0 +1,46 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Awesome Technologies Innovationslabor GmbH
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from twisted.internet import defer
+
+from synapse.types import UserID
+
+from tests import unittest
+from tests.utils import setup_test_homeserver
+
+
+class DataStoreTestCase(unittest.TestCase):
+    @defer.inlineCallbacks
+    def setUp(self):
+        hs = yield setup_test_homeserver(self.addCleanup)
+
+        self.store = hs.get_datastore()
+
+        self.user = UserID.from_string("@abcde:test")
+        self.displayname = "Frank"
+
+    @defer.inlineCallbacks
+    def test_get_users_paginate(self):
+        yield self.store.register_user(self.user.to_string(), "pass")
+        yield self.store.create_profile(self.user.localpart)
+        yield self.store.set_profile_displayname(self.user.localpart, self.displayname)
+
+        users, total = yield self.store.get_users_paginate(
+            0, 10, name="bc", guests=False
+        )
+
+        self.assertEquals(1, total)
+        self.assertEquals(self.displayname, users.pop()["displayname"])
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index 1494650d10..9c04e92577 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -19,152 +19,222 @@ from twisted.internet import defer
 from synapse.api.constants import UserTypes
 
 from tests import unittest
+from tests.unittest import default_config, override_config
 
 FORTY_DAYS = 40 * 24 * 60 * 60
 
 
+def gen_3pids(count):
+    """Generate `count` threepids as a list."""
+    return [
+        {"medium": "email", "address": "user%i@matrix.org" % i} for i in range(count)
+    ]
+
+
 class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
-    def make_homeserver(self, reactor, clock):
+    def default_config(self):
+        config = default_config("test")
+
+        config.update({"limit_usage_by_mau": True, "max_mau_value": 50})
+
+        # apply any additional config which was specified via the override_config
+        # decorator.
+        if self._extra_config is not None:
+            config.update(self._extra_config)
 
-        hs = self.setup_test_homeserver()
-        self.store = hs.get_datastore()
-        hs.config.limit_usage_by_mau = True
-        hs.config.max_mau_value = 50
+        return config
 
+    def prepare(self, reactor, clock, homeserver):
+        self.store = homeserver.get_datastore()
         # Advance the clock a bit
         reactor.advance(FORTY_DAYS)
 
-        return hs
-
+    @override_config({"max_mau_value": 3, "mau_limit_reserved_threepids": gen_3pids(3)})
     def test_initialise_reserved_users(self):
-        self.hs.config.max_mau_value = 5
+        threepids = self.hs.config.mau_limits_reserved_threepids
+
+        # register three users, of which two have reserved 3pids, and a third
+        # which is a support user.
         user1 = "@user1:server"
-        user1_email = "user1@matrix.org"
+        user1_email = threepids[0]["address"]
         user2 = "@user2:server"
-        user2_email = "user2@matrix.org"
+        user2_email = threepids[1]["address"]
         user3 = "@user3:server"
-        user3_email = "user3@matrix.org"
-
-        threepids = [
-            {"medium": "email", "address": user1_email},
-            {"medium": "email", "address": user2_email},
-            {"medium": "email", "address": user3_email},
-        ]
-        # -1 because user3 is a support user and does not count
-        user_num = len(threepids) - 1
 
-        self.store.register_user(user_id=user1, password_hash=None)
-        self.store.register_user(user_id=user2, password_hash=None)
-        self.store.register_user(
-            user_id=user3, password_hash=None, user_type=UserTypes.SUPPORT
+        self.get_success(self.store.register_user(user_id=user1))
+        self.get_success(self.store.register_user(user_id=user2))
+        self.get_success(
+            self.store.register_user(user_id=user3, user_type=UserTypes.SUPPORT)
         )
-        self.pump()
 
         now = int(self.hs.get_clock().time_msec())
-        self.store.user_add_threepid(user1, "email", user1_email, now, now)
-        self.store.user_add_threepid(user2, "email", user2_email, now, now)
+        self.get_success(
+            self.store.user_add_threepid(user1, "email", user1_email, now, now)
+        )
+        self.get_success(
+            self.store.user_add_threepid(user2, "email", user2_email, now, now)
+        )
 
-        self.store.runInteraction(
-            "initialise", self.store._initialise_reserved_users, threepids
+        # XXX why are we doing this here? this function is only run at startup
+        # so it is odd to re-run it here.
+        self.get_success(
+            self.store.db.runInteraction(
+                "initialise", self.store._initialise_reserved_users, threepids
+            )
         )
-        self.pump()
 
-        active_count = self.store.get_monthly_active_count()
+        # the number of users we expect will be counted against the mau limit
+        # -1 because user3 is a support user and does not count
+        user_num = len(threepids) - 1
 
-        # Test total counts, ensure user3 (support user) is not counted
-        self.assertEquals(self.get_success(active_count), user_num)
+        # Check the number of active users. Ensure user3 (support user) is not counted
+        active_count = self.get_success(self.store.get_monthly_active_count())
+        self.assertEqual(active_count, user_num)
 
-        # Test user is marked as active
-        timestamp = self.store.user_last_seen_monthly_active(user1)
-        self.assertTrue(self.get_success(timestamp))
-        timestamp = self.store.user_last_seen_monthly_active(user2)
-        self.assertTrue(self.get_success(timestamp))
+        # Test each of the registered users is marked as active
+        timestamp = self.get_success(self.store.user_last_seen_monthly_active(user1))
+        self.assertGreater(timestamp, 0)
+        timestamp = self.get_success(self.store.user_last_seen_monthly_active(user2))
+        self.assertGreater(timestamp, 0)
 
-        # Test that users are never removed from the db.
+        # Test that users with reserved 3pids are not removed from the MAU table
+        # XXX some of this is redundant. poking things into the config shouldn't
+        # work, and in any case it's not obvious what we expect to happen when
+        # we advance the reactor.
         self.hs.config.max_mau_value = 0
-
         self.reactor.advance(FORTY_DAYS)
+        self.hs.config.max_mau_value = 5
 
-        self.store.reap_monthly_active_users()
-        self.pump()
+        self.get_success(self.store.reap_monthly_active_users())
 
-        active_count = self.store.get_monthly_active_count()
-        self.assertEquals(self.get_success(active_count), user_num)
+        active_count = self.get_success(self.store.get_monthly_active_count())
+        self.assertEqual(active_count, user_num)
 
-        # Test that regular users are removed from the db
+        # Add some more users and check they are counted as active
         ru_count = 2
-        self.store.upsert_monthly_active_user("@ru1:server")
-        self.store.upsert_monthly_active_user("@ru2:server")
-        self.pump()
 
-        active_count = self.store.get_monthly_active_count()
-        self.assertEqual(self.get_success(active_count), user_num + ru_count)
-        self.hs.config.max_mau_value = user_num
-        self.store.reap_monthly_active_users()
-        self.pump()
+        self.get_success(self.store.upsert_monthly_active_user("@ru1:server"))
+        self.get_success(self.store.upsert_monthly_active_user("@ru2:server"))
+
+        active_count = self.get_success(self.store.get_monthly_active_count())
+        self.assertEqual(active_count, user_num + ru_count)
 
-        active_count = self.store.get_monthly_active_count()
-        self.assertEquals(self.get_success(active_count), user_num)
+        # now run the reaper and check that the number of active users is reduced
+        # to max_mau_value
+        self.get_success(self.store.reap_monthly_active_users())
+
+        active_count = self.get_success(self.store.get_monthly_active_count())
+        self.assertEqual(active_count, 3)
 
     def test_can_insert_and_count_mau(self):
-        count = self.store.get_monthly_active_count()
-        self.assertEqual(0, self.get_success(count))
+        count = self.get_success(self.store.get_monthly_active_count())
+        self.assertEqual(count, 0)
 
-        self.store.upsert_monthly_active_user("@user:server")
-        self.pump()
+        d = self.store.upsert_monthly_active_user("@user:server")
+        self.get_success(d)
 
-        count = self.store.get_monthly_active_count()
-        self.assertEqual(1, self.get_success(count))
+        count = self.get_success(self.store.get_monthly_active_count())
+        self.assertEqual(count, 1)
 
     def test_user_last_seen_monthly_active(self):
         user_id1 = "@user1:server"
         user_id2 = "@user2:server"
         user_id3 = "@user3:server"
 
-        result = self.store.user_last_seen_monthly_active(user_id1)
-        self.assertFalse(self.get_success(result) == 0)
+        result = self.get_success(self.store.user_last_seen_monthly_active(user_id1))
+        self.assertNotEqual(result, 0)
 
-        self.store.upsert_monthly_active_user(user_id1)
-        self.store.upsert_monthly_active_user(user_id2)
-        self.pump()
+        self.get_success(self.store.upsert_monthly_active_user(user_id1))
+        self.get_success(self.store.upsert_monthly_active_user(user_id2))
 
-        result = self.store.user_last_seen_monthly_active(user_id1)
-        self.assertGreater(self.get_success(result), 0)
+        result = self.get_success(self.store.user_last_seen_monthly_active(user_id1))
+        self.assertGreater(result, 0)
 
-        result = self.store.user_last_seen_monthly_active(user_id3)
-        self.assertNotEqual(self.get_success(result), 0)
+        result = self.get_success(self.store.user_last_seen_monthly_active(user_id3))
+        self.assertNotEqual(result, 0)
 
+    @override_config({"max_mau_value": 5})
     def test_reap_monthly_active_users(self):
-        self.hs.config.max_mau_value = 5
         initial_users = 10
         for i in range(initial_users):
-            self.store.upsert_monthly_active_user("@user%d:server" % i)
-        self.pump()
+            self.get_success(
+                self.store.upsert_monthly_active_user("@user%d:server" % i)
+            )
 
-        count = self.store.get_monthly_active_count()
-        self.assertTrue(self.get_success(count), initial_users)
+        count = self.get_success(self.store.get_monthly_active_count())
+        self.assertEqual(count, initial_users)
 
-        self.store.reap_monthly_active_users()
-        self.pump()
-        count = self.store.get_monthly_active_count()
-        self.assertEquals(
-            self.get_success(count), initial_users - self.hs.config.max_mau_value
-        )
+        d = self.store.reap_monthly_active_users()
+        self.get_success(d)
+
+        count = self.get_success(self.store.get_monthly_active_count())
+        self.assertEqual(count, self.hs.config.max_mau_value)
 
         self.reactor.advance(FORTY_DAYS)
-        self.store.reap_monthly_active_users()
-        self.pump()
 
-        count = self.store.get_monthly_active_count()
-        self.assertEquals(self.get_success(count), 0)
+        d = self.store.reap_monthly_active_users()
+        self.get_success(d)
+
+        count = self.get_success(self.store.get_monthly_active_count())
+        self.assertEqual(count, 0)
+
+    # Note that below says mau_limit (no s), this is the name of the config
+    # value, although it gets stored on the config object as mau_limits.
+    @override_config({"max_mau_value": 5, "mau_limit_reserved_threepids": gen_3pids(5)})
+    def test_reap_monthly_active_users_reserved_users(self):
+        """ Tests that reaping correctly handles reaping where reserved users are
+        present"""
+        threepids = self.hs.config.mau_limits_reserved_threepids
+        initial_users = len(threepids)
+        reserved_user_number = initial_users - 1
+        for i in range(initial_users):
+            user = "@user%d:server" % i
+            email = "user%d@matrix.org" % i
+
+            self.get_success(self.store.upsert_monthly_active_user(user))
+
+            # Need to ensure that the most recent entries in the
+            # monthly_active_users table are reserved
+            now = int(self.hs.get_clock().time_msec())
+            if i != 0:
+                self.get_success(
+                    self.store.register_user(user_id=user, password_hash=None)
+                )
+                self.get_success(
+                    self.store.user_add_threepid(user, "email", email, now, now)
+                )
+
+        d = self.store.db.runInteraction(
+            "initialise", self.store._initialise_reserved_users, threepids
+        )
+        self.get_success(d)
+
+        count = self.get_success(self.store.get_monthly_active_count())
+        self.assertEqual(count, initial_users)
+
+        users = self.get_success(self.store.get_registered_reserved_users())
+        self.assertEqual(len(users), reserved_user_number)
+
+        d = self.store.reap_monthly_active_users()
+        self.get_success(d)
+
+        count = self.get_success(self.store.get_monthly_active_count())
+        self.assertEqual(count, self.hs.config.max_mau_value)
 
     def test_populate_monthly_users_is_guest(self):
         # Test that guest users are not added to mau list
         user_id = "@user_id:host"
-        self.store.register_user(user_id=user_id, password_hash=None, make_guest=True)
+
+        d = self.store.register_user(
+            user_id=user_id, password_hash=None, make_guest=True
+        )
+        self.get_success(d)
+
         self.store.upsert_monthly_active_user = Mock()
-        self.store.populate_monthly_active_users(user_id)
-        self.pump()
+
+        d = self.store.populate_monthly_active_users(user_id)
+        self.get_success(d)
+
         self.store.upsert_monthly_active_user.assert_not_called()
 
     def test_populate_monthly_users_should_update(self):
@@ -175,8 +245,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
         self.store.user_last_seen_monthly_active = Mock(
             return_value=defer.succeed(None)
         )
-        self.store.populate_monthly_active_users("user_id")
-        self.pump()
+        d = self.store.populate_monthly_active_users("user_id")
+        self.get_success(d)
+
         self.store.upsert_monthly_active_user.assert_called_once()
 
     def test_populate_monthly_users_should_not_update(self):
@@ -186,80 +257,132 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
         self.store.user_last_seen_monthly_active = Mock(
             return_value=defer.succeed(self.hs.get_clock().time_msec())
         )
-        self.store.populate_monthly_active_users("user_id")
-        self.pump()
+
+        d = self.store.populate_monthly_active_users("user_id")
+        self.get_success(d)
+
         self.store.upsert_monthly_active_user.assert_not_called()
 
     def test_get_reserved_real_user_account(self):
         # Test no reserved users, or reserved threepids
-        count = self.store.get_registered_reserved_users_count()
-        self.assertEquals(self.get_success(count), 0)
-        # Test reserved users but no registered users
+        users = self.get_success(self.store.get_registered_reserved_users())
+        self.assertEqual(len(users), 0)
 
+        # Test reserved users but no registered users
         user1 = "@user1:example.com"
         user2 = "@user2:example.com"
+
         user1_email = "user1@example.com"
         user2_email = "user2@example.com"
         threepids = [
             {"medium": "email", "address": user1_email},
             {"medium": "email", "address": user2_email},
         ]
+
         self.hs.config.mau_limits_reserved_threepids = threepids
-        self.store.runInteraction(
+        d = self.store.db.runInteraction(
             "initialise", self.store._initialise_reserved_users, threepids
         )
+        self.get_success(d)
 
-        self.pump()
-        count = self.store.get_registered_reserved_users_count()
-        self.assertEquals(self.get_success(count), 0)
+        users = self.get_success(self.store.get_registered_reserved_users())
+        self.assertEqual(len(users), 0)
 
-        # Test reserved registed users
-        self.store.register_user(user_id=user1, password_hash=None)
-        self.store.register_user(user_id=user2, password_hash=None)
-        self.pump()
+        # Test reserved registered users
+        self.get_success(self.store.register_user(user_id=user1, password_hash=None))
+        self.get_success(self.store.register_user(user_id=user2, password_hash=None))
 
         now = int(self.hs.get_clock().time_msec())
         self.store.user_add_threepid(user1, "email", user1_email, now, now)
         self.store.user_add_threepid(user2, "email", user2_email, now, now)
-        count = self.store.get_registered_reserved_users_count()
-        self.assertEquals(self.get_success(count), len(threepids))
+
+        users = self.get_success(self.store.get_registered_reserved_users())
+        self.assertEqual(len(users), len(threepids))
 
     def test_support_user_not_add_to_mau_limits(self):
         support_user_id = "@support:test"
-        count = self.store.get_monthly_active_count()
-        self.pump()
-        self.assertEqual(self.get_success(count), 0)
 
-        self.store.register_user(
+        count = self.get_success(self.store.get_monthly_active_count())
+        self.assertEqual(count, 0)
+
+        d = self.store.register_user(
             user_id=support_user_id, password_hash=None, user_type=UserTypes.SUPPORT
         )
+        self.get_success(d)
 
-        self.store.upsert_monthly_active_user(support_user_id)
-        count = self.store.get_monthly_active_count()
-        self.pump()
-        self.assertEqual(self.get_success(count), 0)
+        d = self.store.upsert_monthly_active_user(support_user_id)
+        self.get_success(d)
 
-    def test_track_monthly_users_without_cap(self):
-        self.hs.config.limit_usage_by_mau = False
-        self.hs.config.mau_stats_only = True
-        self.hs.config.max_mau_value = 1  # should not matter
+        d = self.store.get_monthly_active_count()
+        count = self.get_success(d)
+        self.assertEqual(count, 0)
 
-        count = self.store.get_monthly_active_count()
-        self.assertEqual(0, self.get_success(count))
+    # Note that the max_mau_value setting should not matter.
+    @override_config(
+        {"limit_usage_by_mau": False, "mau_stats_only": True, "max_mau_value": 1}
+    )
+    def test_track_monthly_users_without_cap(self):
+        count = self.get_success(self.store.get_monthly_active_count())
+        self.assertEqual(0, count)
 
-        self.store.upsert_monthly_active_user("@user1:server")
-        self.store.upsert_monthly_active_user("@user2:server")
-        self.pump()
+        self.get_success(self.store.upsert_monthly_active_user("@user1:server"))
+        self.get_success(self.store.upsert_monthly_active_user("@user2:server"))
 
-        count = self.store.get_monthly_active_count()
-        self.assertEqual(2, self.get_success(count))
+        count = self.get_success(self.store.get_monthly_active_count())
+        self.assertEqual(2, count)
 
+    @override_config({"limit_usage_by_mau": False, "mau_stats_only": False})
     def test_no_users_when_not_tracking(self):
-        self.hs.config.limit_usage_by_mau = False
-        self.hs.config.mau_stats_only = False
         self.store.upsert_monthly_active_user = Mock()
 
-        self.store.populate_monthly_active_users("@user:sever")
-        self.pump()
+        self.get_success(self.store.populate_monthly_active_users("@user:sever"))
 
         self.store.upsert_monthly_active_user.assert_not_called()
+
+    def test_get_monthly_active_count_by_service(self):
+        appservice1_user1 = "@appservice1_user1:example.com"
+        appservice1_user2 = "@appservice1_user2:example.com"
+
+        appservice2_user1 = "@appservice2_user1:example.com"
+        native_user1 = "@native_user1:example.com"
+
+        service1 = "service1"
+        service2 = "service2"
+        native = "native"
+
+        self.get_success(
+            self.store.register_user(
+                user_id=appservice1_user1, password_hash=None, appservice_id=service1
+            )
+        )
+        self.get_success(
+            self.store.register_user(
+                user_id=appservice1_user2, password_hash=None, appservice_id=service1
+            )
+        )
+        self.get_success(
+            self.store.register_user(
+                user_id=appservice2_user1, password_hash=None, appservice_id=service2
+            )
+        )
+        self.get_success(
+            self.store.register_user(user_id=native_user1, password_hash=None)
+        )
+
+        count = self.get_success(self.store.get_monthly_active_count_by_service())
+        self.assertEqual(count, {})
+
+        self.get_success(self.store.upsert_monthly_active_user(native_user1))
+        self.get_success(self.store.upsert_monthly_active_user(appservice1_user1))
+        self.get_success(self.store.upsert_monthly_active_user(appservice1_user2))
+        self.get_success(self.store.upsert_monthly_active_user(appservice2_user1))
+
+        count = self.get_success(self.store.get_monthly_active_count())
+        self.assertEqual(count, 4)
+
+        d = self.store.get_monthly_active_count_by_service()
+        result = self.get_success(d)
+
+        self.assertEqual(result[service1], 2)
+        self.assertEqual(result[service2], 1)
+        self.assertEqual(result[native], 1)
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index 45824bd3b2..9b6f7211ae 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -16,7 +16,6 @@
 
 from twisted.internet import defer
 
-from synapse.storage.profile import ProfileStore
 from synapse.types import UserID
 
 from tests import unittest
@@ -28,7 +27,7 @@ class ProfileStoreTestCase(unittest.TestCase):
     def setUp(self):
         hs = yield setup_test_homeserver(self.addCleanup)
 
-        self.store = ProfileStore(hs.get_db_conn(), hs)
+        self.store = hs.get_datastore()
 
         self.u_frank = UserID.from_string("@frank:test")
 
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index f671599cb8..b9fafaa1a6 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -40,23 +40,24 @@ class PurgeTests(HomeserverTestCase):
         third = self.helper.send(self.room_id, body="test3")
         last = self.helper.send(self.room_id, body="test4")
 
-        storage = self.hs.get_datastore()
+        store = self.hs.get_datastore()
+        storage = self.hs.get_storage()
 
         # Get the topological token
-        event = storage.get_topological_token_for_event(last["event_id"])
+        event = store.get_topological_token_for_event(last["event_id"])
         self.pump()
         event = self.successResultOf(event)
 
         # Purge everything before this topological token
-        purge = storage.purge_history(self.room_id, event, True)
+        purge = storage.purge_events.purge_history(self.room_id, event, True)
         self.pump()
         self.assertEqual(self.successResultOf(purge), None)
 
         # Try and get the events
-        get_first = storage.get_event(first["event_id"])
-        get_second = storage.get_event(second["event_id"])
-        get_third = storage.get_event(third["event_id"])
-        get_last = storage.get_event(last["event_id"])
+        get_first = store.get_event(first["event_id"])
+        get_second = store.get_event(second["event_id"])
+        get_third = store.get_event(third["event_id"])
+        get_last = store.get_event(last["event_id"])
         self.pump()
 
         # 1-3 should fail and last will succeed, meaning that 1-3 are deleted
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index deecfad9fb..db3667dc43 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -39,6 +39,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
 
     def prepare(self, reactor, clock, hs):
         self.store = hs.get_datastore()
+        self.storage = hs.get_storage()
         self.event_builder_factory = hs.get_event_builder_factory()
         self.event_creation_handler = hs.get_event_creation_handler()
 
@@ -73,7 +74,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
             self.event_creation_handler.create_new_client_event(builder)
         )
 
-        self.get_success(self.store.persist_event(event, context))
+        self.get_success(self.storage.persistence.persist_event(event, context))
 
         return event
 
@@ -95,7 +96,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
             self.event_creation_handler.create_new_client_event(builder)
         )
 
-        self.get_success(self.store.persist_event(event, context))
+        self.get_success(self.storage.persistence.persist_event(event, context))
 
         return event
 
@@ -116,7 +117,9 @@ class RedactionTestCase(unittest.HomeserverTestCase):
             self.event_creation_handler.create_new_client_event(builder)
         )
 
-        self.get_success(self.store.persist_event(event, context))
+        self.get_success(self.storage.persistence.persist_event(event, context))
+
+        return event
 
     def test_redact(self):
         self.get_success(
@@ -235,8 +238,11 @@ class RedactionTestCase(unittest.HomeserverTestCase):
             @defer.inlineCallbacks
             def build(self, prev_event_ids):
                 built_event = yield self._base_builder.build(prev_event_ids)
-                built_event.event_id = self._event_id
-                built_event._event_dict["event_id"] = self._event_id
+
+                built_event._event_id = self._event_id
+                built_event._dict["event_id"] = self._event_id
+                assert built_event.event_id == self._event_id
+
                 return built_event
 
             @property
@@ -261,7 +267,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        self.get_success(self.store.persist_event(event_1, context_1))
+        self.get_success(self.storage.persistence.persist_event(event_1, context_1))
 
         event_2, context_2 = self.get_success(
             self.event_creation_handler.create_new_client_event(
@@ -280,7 +286,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
                 )
             )
         )
-        self.get_success(self.store.persist_event(event_2, context_2))
+        self.get_success(self.storage.persistence.persist_event(event_2, context_2))
 
         # fetch one of the redactions
         fetched = self.get_success(self.store.get_event(redaction_event_id1))
@@ -335,7 +341,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
         )
 
         event_json = self.get_success(
-            self.store._simple_select_one_onecol(
+            self.store.db.simple_select_one_onecol(
                 table="event_json",
                 keyvalues={"event_id": msg_event.event_id},
                 retcol="json",
@@ -353,7 +359,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
         self.reactor.advance(60 * 60 * 2)
 
         event_json = self.get_success(
-            self.store._simple_select_one_onecol(
+            self.store.db.simple_select_one_onecol(
                 table="event_json",
                 keyvalues={"event_id": msg_event.event_id},
                 retcol="json",
@@ -361,3 +367,72 @@ class RedactionTestCase(unittest.HomeserverTestCase):
         )
 
         self.assert_dict({"content": {}}, json.loads(event_json))
+
+    def test_redact_redaction(self):
+        """Tests that we can redact a redaction and can fetch it again.
+        """
+
+        self.get_success(
+            self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
+        )
+
+        msg_event = self.get_success(self.inject_message(self.room1, self.u_alice, "t"))
+
+        first_redact_event = self.get_success(
+            self.inject_redaction(
+                self.room1, msg_event.event_id, self.u_alice, "Redacting message"
+            )
+        )
+
+        self.get_success(
+            self.inject_redaction(
+                self.room1,
+                first_redact_event.event_id,
+                self.u_alice,
+                "Redacting redaction",
+            )
+        )
+
+        # Now lets jump to the future where we have censored the redaction event
+        # in the DB.
+        self.reactor.advance(60 * 60 * 24 * 31)
+
+        # We just want to check that fetching the event doesn't raise an exception.
+        self.get_success(
+            self.store.get_event(first_redact_event.event_id, allow_none=True)
+        )
+
+    def test_store_redacted_redaction(self):
+        """Tests that we can store a redacted redaction.
+        """
+
+        self.get_success(
+            self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
+        )
+
+        builder = self.event_builder_factory.for_room_version(
+            RoomVersions.V1,
+            {
+                "type": EventTypes.Redaction,
+                "sender": self.u_alice.to_string(),
+                "room_id": self.room1.to_string(),
+                "content": {"reason": "foo"},
+            },
+        )
+
+        redaction_event, context = self.get_success(
+            self.event_creation_handler.create_new_client_event(builder)
+        )
+
+        self.get_success(
+            self.storage.persistence.persist_event(redaction_event, context)
+        )
+
+        # Now lets jump to the future where we have censored the redaction event
+        # in the DB.
+        self.reactor.advance(60 * 60 * 24 * 31)
+
+        # We just want to check that fetching the event doesn't raise an exception.
+        self.get_success(
+            self.store.get_event(redaction_event.event_id, allow_none=True)
+        )
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 4578cc3b60..71a40a0a49 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -26,7 +26,6 @@ class RegistrationStoreTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def setUp(self):
         hs = yield setup_test_homeserver(self.addCleanup)
-        self.db_pool = hs.get_db_pool()
 
         self.store = hs.get_datastore()
 
@@ -44,12 +43,14 @@ class RegistrationStoreTestCase(unittest.TestCase):
                 # TODO(paul): Surely this field should be 'user_id', not 'name'
                 "name": self.user_id,
                 "password_hash": self.pwhash,
+                "admin": 0,
                 "is_guest": 0,
                 "consent_version": None,
                 "consent_server_notice_sent": None,
                 "appservice_id": None,
                 "creation_ts": 1000,
                 "user_type": None,
+                "deactivated": 0,
             },
             (yield self.store.get_user_by_id(self.user_id)),
         )
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index 1bee45706f..3b78d48896 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -17,6 +17,7 @@
 from twisted.internet import defer
 
 from synapse.api.constants import EventTypes
+from synapse.api.room_versions import RoomVersions
 from synapse.types import RoomAlias, RoomID, UserID
 
 from tests import unittest
@@ -40,6 +41,7 @@ class RoomStoreTestCase(unittest.TestCase):
             self.room.to_string(),
             room_creator_user_id=self.u_creator.to_string(),
             is_public=True,
+            room_version=RoomVersions.V1,
         )
 
     @defer.inlineCallbacks
@@ -53,6 +55,17 @@ class RoomStoreTestCase(unittest.TestCase):
             (yield self.store.get_room(self.room.to_string())),
         )
 
+    @defer.inlineCallbacks
+    def test_get_room_with_stats(self):
+        self.assertDictContainsSubset(
+            {
+                "room_id": self.room.to_string(),
+                "creator": self.u_creator.to_string(),
+                "public": True,
+            },
+            (yield self.store.get_room_with_stats(self.room.to_string())),
+        )
+
 
 class RoomEventsStoreTestCase(unittest.TestCase):
     @defer.inlineCallbacks
@@ -62,17 +75,21 @@ class RoomEventsStoreTestCase(unittest.TestCase):
         # Room events need the full datastore, for persist_event() and
         # get_room_state()
         self.store = hs.get_datastore()
+        self.storage = hs.get_storage()
         self.event_factory = hs.get_event_factory()
 
         self.room = RoomID.from_string("!abcde:test")
 
         yield self.store.store_room(
-            self.room.to_string(), room_creator_user_id="@creator:text", is_public=True
+            self.room.to_string(),
+            room_creator_user_id="@creator:text",
+            is_public=True,
+            room_version=RoomVersions.V1,
         )
 
     @defer.inlineCallbacks
     def inject_room_event(self, **kwargs):
-        yield self.store.persist_event(
+        yield self.storage.persistence.persist_event(
             self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
         )
 
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 447a3c6ffb..5dd46005e6 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -16,13 +16,14 @@
 
 from unittest.mock import Mock
 
-from synapse.api.constants import EventTypes, Membership
-from synapse.api.room_versions import RoomVersions
+from synapse.api.constants import Membership
 from synapse.rest.admin import register_servlets_for_client_rest_resource
 from synapse.rest.client.v1 import login, room
 from synapse.types import Requester, UserID
 
 from tests import unittest
+from tests.test_utils import event_injection
+from tests.utils import TestHomeServer
 
 
 class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
@@ -39,13 +40,11 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
         )
         return hs
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor, clock, hs: TestHomeServer):
 
         # We can't test the RoomMemberStore on its own without the other event
         # storage logic
         self.store = hs.get_datastore()
-        self.event_builder_factory = hs.get_event_builder_factory()
-        self.event_creation_handler = hs.get_event_creation_handler()
 
         self.u_alice = self.register_user("alice", "pass")
         self.t_alice = self.login("alice", "pass")
@@ -54,33 +53,13 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
         # User elsewhere on another host
         self.u_charlie = UserID.from_string("@charlie:elsewhere")
 
-    def inject_room_member(self, room, user, membership, replaces_state=None):
-        builder = self.event_builder_factory.for_room_version(
-            RoomVersions.V1,
-            {
-                "type": EventTypes.Member,
-                "sender": user,
-                "state_key": user,
-                "room_id": room,
-                "content": {"membership": membership},
-            },
-        )
-
-        event, context = self.get_success(
-            self.event_creation_handler.create_new_client_event(builder)
-        )
-
-        self.get_success(self.store.persist_event(event, context))
-
-        return event
-
     def test_one_member(self):
 
         # Alice creates the room, and is automatically joined
         self.room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
 
         rooms_for_user = self.get_success(
-            self.store.get_rooms_for_user_where_membership_is(
+            self.store.get_rooms_for_local_user_where_membership_is(
                 self.u_alice, [Membership.JOIN]
             )
         )
@@ -137,6 +116,52 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
         # It now knows about Charlie's server.
         self.assertEqual(self.store._known_servers_count, 2)
 
+    def test_get_joined_users_from_context(self):
+        room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
+        bob_event = event_injection.inject_member_event(
+            self.hs, room, self.u_bob, Membership.JOIN
+        )
+
+        # first, create a regular event
+        event, context = event_injection.create_event(
+            self.hs,
+            room_id=room,
+            sender=self.u_alice,
+            prev_event_ids=[bob_event.event_id],
+            type="m.test.1",
+            content={},
+        )
+
+        users = self.get_success(
+            self.store.get_joined_users_from_context(event, context)
+        )
+        self.assertEqual(users.keys(), {self.u_alice, self.u_bob})
+
+        # Regression test for #7376: create a state event whose key matches bob's
+        # user_id, but which is *not* a membership event, and persist that; then check
+        # that `get_joined_users_from_context` returns the correct users for the next event.
+        non_member_event = event_injection.inject_event(
+            self.hs,
+            room_id=room,
+            sender=self.u_bob,
+            prev_event_ids=[bob_event.event_id],
+            type="m.test.2",
+            state_key=self.u_bob,
+            content={},
+        )
+        event, context = event_injection.create_event(
+            self.hs,
+            room_id=room,
+            sender=self.u_alice,
+            prev_event_ids=[non_member_event.event_id],
+            type="m.test.3",
+            content={},
+        )
+        users = self.get_success(
+            self.store.get_joined_users_from_context(event, context)
+        )
+        self.assertEqual(users.keys(), {self.u_alice, self.u_bob})
+
 
 class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
     def prepare(self, reactor, clock, homeserver):
@@ -145,8 +170,12 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
 
     def test_can_rerun_update(self):
         # First make sure we have completed all updates.
-        while not self.get_success(self.store.has_completed_background_updates()):
-            self.get_success(self.store.do_next_background_update(100), by=0.1)
+        while not self.get_success(
+            self.store.db.updates.has_completed_background_updates()
+        ):
+            self.get_success(
+                self.store.db.updates.do_next_background_update(100), by=0.1
+            )
 
         # Now let's create a room, which will insert a membership
         user = UserID("alice", "test")
@@ -155,7 +184,7 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
 
         # Register the background update to run again.
         self.get_success(
-            self.store._simple_insert(
+            self.store.db.simple_insert(
                 table="background_updates",
                 values={
                     "update_name": "current_state_events_membership",
@@ -166,8 +195,12 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
         )
 
         # ... and tell the DataStore that it hasn't finished all updates yet
-        self.store._all_done = False
+        self.store.db.updates._all_done = False
 
         # Now let's actually drive the updates to completion
-        while not self.get_success(self.store.has_completed_background_updates()):
-            self.get_success(self.store.do_next_background_update(100), by=0.1)
+        while not self.get_success(
+            self.store.db.updates.has_completed_background_updates()
+        ):
+            self.get_success(
+                self.store.db.updates.do_next_background_update(100), by=0.1
+            )
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 5c2cf3c2db..0b88308ff4 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -34,6 +34,8 @@ class StateStoreTestCase(tests.unittest.TestCase):
         hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
 
         self.store = hs.get_datastore()
+        self.storage = hs.get_storage()
+        self.state_datastore = self.storage.state.stores.state
         self.event_builder_factory = hs.get_event_builder_factory()
         self.event_creation_handler = hs.get_event_creation_handler()
 
@@ -43,7 +45,10 @@ class StateStoreTestCase(tests.unittest.TestCase):
         self.room = RoomID.from_string("!abc123:test")
 
         yield self.store.store_room(
-            self.room.to_string(), room_creator_user_id="@creator:text", is_public=True
+            self.room.to_string(),
+            room_creator_user_id="@creator:text",
+            is_public=True,
+            room_version=RoomVersions.V1,
         )
 
     @defer.inlineCallbacks
@@ -63,7 +68,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
             builder
         )
 
-        yield self.store.persist_event(event, context)
+        yield self.storage.persistence.persist_event(event, context)
 
         return event
 
@@ -82,7 +87,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
             self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
         )
 
-        state_group_map = yield self.store.get_state_groups_ids(
+        state_group_map = yield self.storage.state.get_state_groups_ids(
             self.room, [e2.event_id]
         )
         self.assertEqual(len(state_group_map), 1)
@@ -101,7 +106,9 @@ class StateStoreTestCase(tests.unittest.TestCase):
             self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
         )
 
-        state_group_map = yield self.store.get_state_groups(self.room, [e2.event_id])
+        state_group_map = yield self.storage.state.get_state_groups(
+            self.room, [e2.event_id]
+        )
         self.assertEqual(len(state_group_map), 1)
         state_list = list(state_group_map.values())[0]
 
@@ -141,7 +148,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
         )
 
         # check we get the full state as of the final event
-        state = yield self.store.get_state_for_event(e5.event_id)
+        state = yield self.storage.state.get_state_for_event(e5.event_id)
 
         self.assertIsNotNone(e4)
 
@@ -157,21 +164,21 @@ class StateStoreTestCase(tests.unittest.TestCase):
         )
 
         # check we can filter to the m.room.name event (with a '' state key)
-        state = yield self.store.get_state_for_event(
+        state = yield self.storage.state.get_state_for_event(
             e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
         )
 
         self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
 
         # check we can filter to the m.room.name event (with a wildcard None state key)
-        state = yield self.store.get_state_for_event(
+        state = yield self.storage.state.get_state_for_event(
             e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
         )
 
         self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
 
         # check we can grab the m.room.member events (with a wildcard None state key)
-        state = yield self.store.get_state_for_event(
+        state = yield self.storage.state.get_state_for_event(
             e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
         )
 
@@ -181,7 +188,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         # check we can grab a specific room member without filtering out the
         # other event types
-        state = yield self.store.get_state_for_event(
+        state = yield self.storage.state.get_state_for_event(
             e5.event_id,
             state_filter=StateFilter(
                 types={EventTypes.Member: {self.u_alice.to_string()}},
@@ -199,7 +206,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
         )
 
         # check that we can grab everything except members
-        state = yield self.store.get_state_for_event(
+        state = yield self.storage.state.get_state_for_event(
             e5.event_id,
             state_filter=StateFilter(
                 types={EventTypes.Member: set()}, include_others=True
@@ -215,13 +222,18 @@ class StateStoreTestCase(tests.unittest.TestCase):
         #######################################################
 
         room_id = self.room.to_string()
-        group_ids = yield self.store.get_state_groups_ids(room_id, [e5.event_id])
+        group_ids = yield self.storage.state.get_state_groups_ids(
+            room_id, [e5.event_id]
+        )
         group = list(group_ids.keys())[0]
 
         # test _get_state_for_group_using_cache correctly filters out members
         # with types=[]
-        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
-            self.store._state_group_cache,
+        (
+            state_dict,
+            is_all,
+        ) = yield self.state_datastore._get_state_for_group_using_cache(
+            self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
                 types={EventTypes.Member: set()}, include_others=True
@@ -237,8 +249,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
             state_dict,
         )
 
-        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
-            self.store._state_group_members_cache,
+        (
+            state_dict,
+            is_all,
+        ) = yield self.state_datastore._get_state_for_group_using_cache(
+            self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
                 types={EventTypes.Member: set()}, include_others=True
@@ -250,8 +265,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         # test _get_state_for_group_using_cache correctly filters in members
         # with wildcard types
-        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
-            self.store._state_group_cache,
+        (
+            state_dict,
+            is_all,
+        ) = yield self.state_datastore._get_state_for_group_using_cache(
+            self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
                 types={EventTypes.Member: None}, include_others=True
@@ -267,8 +285,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
             state_dict,
         )
 
-        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
-            self.store._state_group_members_cache,
+        (
+            state_dict,
+            is_all,
+        ) = yield self.state_datastore._get_state_for_group_using_cache(
+            self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
                 types={EventTypes.Member: None}, include_others=True
@@ -287,8 +308,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         # test _get_state_for_group_using_cache correctly filters in members
         # with specific types
-        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
-            self.store._state_group_cache,
+        (
+            state_dict,
+            is_all,
+        ) = yield self.state_datastore._get_state_for_group_using_cache(
+            self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
                 types={EventTypes.Member: {e5.state_key}}, include_others=True
@@ -304,8 +328,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
             state_dict,
         )
 
-        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
-            self.store._state_group_members_cache,
+        (
+            state_dict,
+            is_all,
+        ) = yield self.state_datastore._get_state_for_group_using_cache(
+            self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
                 types={EventTypes.Member: {e5.state_key}}, include_others=True
@@ -317,8 +344,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         # test _get_state_for_group_using_cache correctly filters in members
         # with specific types
-        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
-            self.store._state_group_members_cache,
+        (
+            state_dict,
+            is_all,
+        ) = yield self.state_datastore._get_state_for_group_using_cache(
+            self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
                 types={EventTypes.Member: {e5.state_key}}, include_others=False
@@ -331,9 +361,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
         #######################################################
         # deliberately remove e2 (room name) from the _state_group_cache
 
-        (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get(
-            group
-        )
+        (
+            is_all,
+            known_absent,
+            state_dict_ids,
+        ) = self.state_datastore._state_group_cache.get(group)
 
         self.assertEqual(is_all, True)
         self.assertEqual(known_absent, set())
@@ -346,21 +378,23 @@ class StateStoreTestCase(tests.unittest.TestCase):
         )
 
         state_dict_ids.pop((e2.type, e2.state_key))
-        self.store._state_group_cache.invalidate(group)
-        self.store._state_group_cache.update(
-            sequence=self.store._state_group_cache.sequence,
+        self.state_datastore._state_group_cache.invalidate(group)
+        self.state_datastore._state_group_cache.update(
+            sequence=self.state_datastore._state_group_cache.sequence,
             key=group,
             value=state_dict_ids,
             # list fetched keys so it knows it's partial
             fetched_keys=((e1.type, e1.state_key),),
         )
 
-        (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get(
-            group
-        )
+        (
+            is_all,
+            known_absent,
+            state_dict_ids,
+        ) = self.state_datastore._state_group_cache.get(group)
 
         self.assertEqual(is_all, False)
-        self.assertEqual(known_absent, set([(e1.type, e1.state_key)]))
+        self.assertEqual(known_absent, {(e1.type, e1.state_key)})
         self.assertDictEqual(state_dict_ids, {(e1.type, e1.state_key): e1.event_id})
 
         ############################################
@@ -369,8 +403,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
         # test _get_state_for_group_using_cache correctly filters out members
         # with types=[]
         room_id = self.room.to_string()
-        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
-            self.store._state_group_cache,
+        (
+            state_dict,
+            is_all,
+        ) = yield self.state_datastore._get_state_for_group_using_cache(
+            self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
                 types={EventTypes.Member: set()}, include_others=True
@@ -381,8 +418,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
         self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
 
         room_id = self.room.to_string()
-        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
-            self.store._state_group_members_cache,
+        (
+            state_dict,
+            is_all,
+        ) = yield self.state_datastore._get_state_for_group_using_cache(
+            self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
                 types={EventTypes.Member: set()}, include_others=True
@@ -394,8 +434,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         # test _get_state_for_group_using_cache correctly filters in members
         # wildcard types
-        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
-            self.store._state_group_cache,
+        (
+            state_dict,
+            is_all,
+        ) = yield self.state_datastore._get_state_for_group_using_cache(
+            self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
                 types={EventTypes.Member: None}, include_others=True
@@ -405,8 +448,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
         self.assertEqual(is_all, False)
         self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
 
-        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
-            self.store._state_group_members_cache,
+        (
+            state_dict,
+            is_all,
+        ) = yield self.state_datastore._get_state_for_group_using_cache(
+            self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
                 types={EventTypes.Member: None}, include_others=True
@@ -424,8 +470,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         # test _get_state_for_group_using_cache correctly filters in members
         # with specific types
-        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
-            self.store._state_group_cache,
+        (
+            state_dict,
+            is_all,
+        ) = yield self.state_datastore._get_state_for_group_using_cache(
+            self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
                 types={EventTypes.Member: {e5.state_key}}, include_others=True
@@ -435,8 +484,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
         self.assertEqual(is_all, False)
         self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
 
-        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
-            self.store._state_group_members_cache,
+        (
+            state_dict,
+            is_all,
+        ) = yield self.state_datastore._get_state_for_group_using_cache(
+            self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
                 types={EventTypes.Member: {e5.state_key}}, include_others=True
@@ -448,8 +500,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         # test _get_state_for_group_using_cache correctly filters in members
         # with specific types
-        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
-            self.store._state_group_cache,
+        (
+            state_dict,
+            is_all,
+        ) = yield self.state_datastore._get_state_for_group_using_cache(
+            self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
                 types={EventTypes.Member: {e5.state_key}}, include_others=False
@@ -459,8 +514,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
         self.assertEqual(is_all, False)
         self.assertDictEqual({}, state_dict)
 
-        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
-            self.store._state_group_members_cache,
+        (
+            state_dict,
+            is_all,
+        ) = yield self.state_datastore._get_state_for_group_using_cache(
+            self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
                 types={EventTypes.Member: {e5.state_key}}, include_others=False
diff --git a/tests/storage/test_transactions.py b/tests/storage/test_transactions.py
index a771d5af29..8e817e2c7f 100644
--- a/tests/storage/test_transactions.py
+++ b/tests/storage/test_transactions.py
@@ -13,6 +13,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from synapse.util.retryutils import MAX_RETRY_INTERVAL
+
 from tests.unittest import HomeserverTestCase
 
 
@@ -45,3 +47,12 @@ class TransactionStoreTestCase(HomeserverTestCase):
         """
         d = self.store.set_destination_retry_timings("example.com", 1000, 50, 100)
         self.get_success(d)
+
+    def test_large_destination_retry(self):
+        d = self.store.set_destination_retry_timings(
+            "example.com", MAX_RETRY_INTERVAL, MAX_RETRY_INTERVAL, MAX_RETRY_INTERVAL
+        )
+        self.get_success(d)
+
+        d = self.store.get_destination_retry_timings("example.com")
+        self.get_success(d)
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index d7d244ce97..6a545d2eb0 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -15,8 +15,6 @@
 
 from twisted.internet import defer
 
-from synapse.storage import UserDirectoryStore
-
 from tests import unittest
 from tests.utils import setup_test_homeserver
 
@@ -29,7 +27,7 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def setUp(self):
         self.hs = yield setup_test_homeserver(self.addCleanup)
-        self.store = UserDirectoryStore(self.hs.get_db_conn(), self.hs)
+        self.store = self.hs.get_datastore()
 
         # alice and bob are both in !room_id. bobby is not but shares
         # a homeserver with alice.