summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/handlers/test_e2e_keys.py8
-rw-r--r--tests/handlers/test_typing.py39
-rw-r--r--tests/push/test_email.py3
-rw-r--r--tests/push/test_http.py4
-rw-r--r--tests/replication/slave/storage/_base.py6
-rw-r--r--tests/rest/client/v1/test_profile.py2
-rw-r--r--tests/rest/client/v2_alpha/test_register.py3
-rw-r--r--tests/server.py55
-rw-r--r--tests/storage/test_appservice.py37
-rw-r--r--tests/storage/test_base.py14
-rw-r--r--tests/storage/test_client_ips.py49
-rw-r--r--tests/storage/test_registration.py1
-rw-r--r--tests/storage/test_state.py2
-rw-r--r--tests/test_federation.py14
-rw-r--r--tests/test_state.py28
-rw-r--r--tests/test_types.py19
-rw-r--r--tests/util/test_logcontext.py24
-rw-r--r--tests/util/test_snapshot_cache.py63
-rw-r--r--tests/utils.py45
19 files changed, 195 insertions, 221 deletions
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index fdfa2cbbc4..854eb6c024 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -183,10 +183,6 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         )
         self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]})
 
-    test_replace_master_key.skip = (
-        "Disabled waiting on #https://github.com/matrix-org/synapse/pull/6486"
-    )
-
     @defer.inlineCallbacks
     def test_reupload_signatures(self):
         """re-uploading a signature should not fail"""
@@ -507,7 +503,3 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
             ],
             other_master_key["signatures"][local_user]["ed25519:" + usersigning_pubkey],
         )
-
-    test_upload_signatures.skip = (
-        "Disabled waiting on #https://github.com/matrix-org/synapse/pull/6486"
-    )
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 92b8726093..596ddc6970 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -64,28 +64,29 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
         mock_federation_client = Mock(spec=["put_json"])
         mock_federation_client.put_json.return_value = defer.succeed((200, "OK"))
 
+        datastores = Mock()
+        datastores.main = Mock(
+            spec=[
+                # Bits that Federation needs
+                "prep_send_transaction",
+                "delivered_txn",
+                "get_received_txn_response",
+                "set_received_txn_response",
+                "get_destination_retry_timings",
+                "get_devices_by_remote",
+                # Bits that user_directory needs
+                "get_user_directory_stream_pos",
+                "get_current_state_deltas",
+                "get_device_updates_by_remote",
+            ]
+        )
+
         hs = self.setup_test_homeserver(
-            datastore=(
-                Mock(
-                    spec=[
-                        # Bits that Federation needs
-                        "prep_send_transaction",
-                        "delivered_txn",
-                        "get_received_txn_response",
-                        "set_received_txn_response",
-                        "get_destination_retry_timings",
-                        "get_device_updates_by_remote",
-                        # Bits that user_directory needs
-                        "get_user_directory_stream_pos",
-                        "get_current_state_deltas",
-                    ]
-                )
-            ),
-            notifier=Mock(),
-            http_client=mock_federation_client,
-            keyring=mock_keyring,
+            notifier=Mock(), http_client=mock_federation_client, keyring=mock_keyring
         )
 
+        hs.datastores = datastores
+
         return hs
 
     def prepare(self, reactor, clock, hs):
diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index 358b593cd4..80187406bc 100644
--- a/tests/push/test_email.py
+++ b/tests/push/test_email.py
@@ -165,6 +165,7 @@ class EmailPusherTests(HomeserverTestCase):
         pushers = self.get_success(
             self.hs.get_datastore().get_pushers_by(dict(user_name=self.user_id))
         )
+        pushers = list(pushers)
         self.assertEqual(len(pushers), 1)
         last_stream_ordering = pushers[0]["last_stream_ordering"]
 
@@ -175,6 +176,7 @@ class EmailPusherTests(HomeserverTestCase):
         pushers = self.get_success(
             self.hs.get_datastore().get_pushers_by(dict(user_name=self.user_id))
         )
+        pushers = list(pushers)
         self.assertEqual(len(pushers), 1)
         self.assertEqual(last_stream_ordering, pushers[0]["last_stream_ordering"])
 
@@ -192,5 +194,6 @@ class EmailPusherTests(HomeserverTestCase):
         pushers = self.get_success(
             self.hs.get_datastore().get_pushers_by(dict(user_name=self.user_id))
         )
+        pushers = list(pushers)
         self.assertEqual(len(pushers), 1)
         self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering)
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index af2327fb66..fe3441f081 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -104,6 +104,7 @@ class HTTPPusherTests(HomeserverTestCase):
         pushers = self.get_success(
             self.hs.get_datastore().get_pushers_by(dict(user_name=user_id))
         )
+        pushers = list(pushers)
         self.assertEqual(len(pushers), 1)
         last_stream_ordering = pushers[0]["last_stream_ordering"]
 
@@ -114,6 +115,7 @@ class HTTPPusherTests(HomeserverTestCase):
         pushers = self.get_success(
             self.hs.get_datastore().get_pushers_by(dict(user_name=user_id))
         )
+        pushers = list(pushers)
         self.assertEqual(len(pushers), 1)
         self.assertEqual(last_stream_ordering, pushers[0]["last_stream_ordering"])
 
@@ -132,6 +134,7 @@ class HTTPPusherTests(HomeserverTestCase):
         pushers = self.get_success(
             self.hs.get_datastore().get_pushers_by(dict(user_name=user_id))
         )
+        pushers = list(pushers)
         self.assertEqual(len(pushers), 1)
         self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering)
         last_stream_ordering = pushers[0]["last_stream_ordering"]
@@ -151,5 +154,6 @@ class HTTPPusherTests(HomeserverTestCase):
         pushers = self.get_success(
             self.hs.get_datastore().get_pushers_by(dict(user_name=user_id))
         )
+        pushers = list(pushers)
         self.assertEqual(len(pushers), 1)
         self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering)
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 3dae83c543..2a1e7c7166 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -20,7 +20,7 @@ from synapse.replication.tcp.client import (
     ReplicationClientHandler,
 )
 from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
-from synapse.storage.database import Database
+from synapse.storage.database import make_conn
 
 from tests import unittest
 from tests.server import FakeTransport
@@ -41,10 +41,12 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
 
     def prepare(self, reactor, clock, hs):
 
+        db_config = hs.config.database.get_single_database()
         self.master_store = self.hs.get_datastore()
         self.storage = hs.get_storage()
+        database = hs.get_datastores().databases[0]
         self.slaved_store = self.STORE_TYPE(
-            Database(hs), self.hs.get_db_conn(), self.hs
+            database, make_conn(db_config, database.engine), self.hs
         )
         self.event_id = 0
 
diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py
index 12c5e95cb5..8df58b4a63 100644
--- a/tests/rest/client/v1/test_profile.py
+++ b/tests/rest/client/v1/test_profile.py
@@ -237,6 +237,7 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase):
 
         config = self.default_config()
         config["require_auth_for_profile_requests"] = True
+        config["limit_profile_requests_to_users_who_share_rooms"] = True
         self.hs = self.setup_test_homeserver(config=config)
 
         return self.hs
@@ -309,6 +310,7 @@ class OwnProfileUnrestrictedTestCase(unittest.HomeserverTestCase):
     def make_homeserver(self, reactor, clock):
         config = self.default_config()
         config["require_auth_for_profile_requests"] = True
+        config["limit_profile_requests_to_users_who_share_rooms"] = True
         self.hs = self.setup_test_homeserver(config=config)
 
         return self.hs
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index c0d0d2b44e..d0c997e385 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -391,9 +391,8 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
         # Email config.
         self.email_attempts = []
 
-        def sendmail(*args, **kwargs):
+        async def sendmail(*args, **kwargs):
             self.email_attempts.append((args, kwargs))
-            return
 
         config["email"] = {
             "enable_notifs": True,
diff --git a/tests/server.py b/tests/server.py
index 2b7cf4242e..a554dfdd57 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -302,41 +302,42 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
     Set up a synchronous test server, driven by the reactor used by
     the homeserver.
     """
-    d = _sth(cleanup_func, *args, **kwargs).result
+    server = _sth(cleanup_func, *args, **kwargs)
 
-    if isinstance(d, Failure):
-        d.raiseException()
+    database = server.config.database.get_single_database()
 
     # Make the thread pool synchronous.
-    clock = d.get_clock()
-    pool = d.get_db_pool()
-
-    def runWithConnection(func, *args, **kwargs):
-        return threads.deferToThreadPool(
-            pool._reactor,
-            pool.threadpool,
-            pool._runWithConnection,
-            func,
-            *args,
-            **kwargs
-        )
-
-    def runInteraction(interaction, *args, **kwargs):
-        return threads.deferToThreadPool(
-            pool._reactor,
-            pool.threadpool,
-            pool._runInteraction,
-            interaction,
-            *args,
-            **kwargs
-        )
+    clock = server.get_clock()
+
+    for database in server.get_datastores().databases:
+        pool = database._db_pool
+
+        def runWithConnection(func, *args, **kwargs):
+            return threads.deferToThreadPool(
+                pool._reactor,
+                pool.threadpool,
+                pool._runWithConnection,
+                func,
+                *args,
+                **kwargs
+            )
+
+        def runInteraction(interaction, *args, **kwargs):
+            return threads.deferToThreadPool(
+                pool._reactor,
+                pool.threadpool,
+                pool._runInteraction,
+                interaction,
+                *args,
+                **kwargs
+            )
 
-    if pool:
         pool.runWithConnection = runWithConnection
         pool.runInteraction = runInteraction
         pool.threadpool = ThreadPool(clock._reactor)
         pool.running = True
-    return d
+
+    return server
 
 
 def get_clock():
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 2e521e9ab7..fd52512696 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -28,7 +28,7 @@ from synapse.storage.data_stores.main.appservice import (
     ApplicationServiceStore,
     ApplicationServiceTransactionStore,
 )
-from synapse.storage.database import Database
+from synapse.storage.database import Database, make_conn
 
 from tests import unittest
 from tests.utils import setup_test_homeserver
@@ -55,8 +55,10 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
         self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob")
         self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob")
         # must be done after inserts
-        database = Database(hs)
-        self.store = ApplicationServiceStore(database, hs.get_db_conn(), hs)
+        database = hs.get_datastores().databases[0]
+        self.store = ApplicationServiceStore(
+            database, make_conn(database._database_config, database.engine), hs
+        )
 
     def tearDown(self):
         # TODO: suboptimal that we need to create files for tests!
@@ -111,9 +113,6 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         hs.config.event_cache_size = 1
         hs.config.password_providers = []
 
-        self.db_pool = hs.get_db_pool()
-        self.engine = hs.database_engine
-
         self.as_list = [
             {"token": "token1", "url": "https://matrix-as.org", "id": "id_1"},
             {"token": "alpha_tok", "url": "https://alpha.com", "id": "id_alpha"},
@@ -125,8 +124,15 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
 
         self.as_yaml_files = []
 
-        database = Database(hs)
-        self.store = TestTransactionStore(database, hs.get_db_conn(), hs)
+        # We assume there is only one database in these tests
+        database = hs.get_datastores().databases[0]
+        self.db_pool = database._db_pool
+        self.engine = database.engine
+
+        db_config = hs.config.get_single_database()
+        self.store = TestTransactionStore(
+            database, make_conn(db_config, self.engine), hs
+        )
 
     def _add_service(self, url, as_token, id):
         as_yaml = dict(
@@ -419,7 +425,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
         hs.config.event_cache_size = 1
         hs.config.password_providers = []
 
-        ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs)
+        database = hs.get_datastores().databases[0]
+        ApplicationServiceStore(
+            database, make_conn(database._database_config, database.engine), hs
+        )
 
     @defer.inlineCallbacks
     def test_duplicate_ids(self):
@@ -435,7 +444,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
         hs.config.password_providers = []
 
         with self.assertRaises(ConfigError) as cm:
-            ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs)
+            database = hs.get_datastores().databases[0]
+            ApplicationServiceStore(
+                database, make_conn(database._database_config, database.engine), hs
+            )
 
         e = cm.exception
         self.assertIn(f1, str(e))
@@ -456,7 +468,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
         hs.config.password_providers = []
 
         with self.assertRaises(ConfigError) as cm:
-            ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs)
+            database = hs.get_datastores().databases[0]
+            ApplicationServiceStore(
+                database, make_conn(database._database_config, database.engine), hs
+            )
 
         e = cm.exception
         self.assertIn(f1, str(e))
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index 537cfe9f64..cdee0a9e60 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -52,15 +52,17 @@ class SQLBaseStoreTestCase(unittest.TestCase):
         config = Mock()
         config._disable_native_upserts = True
         config.event_cache_size = 1
-        config.database_config = {"name": "sqlite3"}
-        engine = create_engine(config.database_config)
+        hs = TestHomeServer("test", config=config)
+
+        sqlite_config = {"name": "sqlite3"}
+        engine = create_engine(sqlite_config)
         fake_engine = Mock(wraps=engine)
         fake_engine.can_native_upsert = False
-        hs = TestHomeServer(
-            "test", db_pool=self.db_pool, config=config, database_engine=fake_engine
-        )
 
-        self.datastore = SQLBaseStore(Database(hs), None, hs)
+        db = Database(Mock(), Mock(config=sqlite_config), fake_engine)
+        db._db_pool = self.db_pool
+
+        self.datastore = SQLBaseStore(db, None, hs)
 
     @defer.inlineCallbacks
     def test_insert_1col(self):
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index fc279340d4..bf674dd184 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -37,9 +37,13 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         self.reactor.advance(12345678)
 
         user_id = "@user:id"
+        device_id = "MY_DEVICE"
+
+        # Insert a user IP
+        self.get_success(self.store.store_device(user_id, device_id, "display name",))
         self.get_success(
             self.store.insert_client_ip(
-                user_id, "access_token", "ip", "user_agent", "device_id"
+                user_id, "access_token", "ip", "user_agent", device_id
             )
         )
 
@@ -47,14 +51,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         self.reactor.advance(10)
 
         result = self.get_success(
-            self.store.get_last_client_ip_by_device(user_id, "device_id")
+            self.store.get_last_client_ip_by_device(user_id, device_id)
         )
 
-        r = result[(user_id, "device_id")]
+        r = result[(user_id, device_id)]
         self.assertDictContainsSubset(
             {
                 "user_id": user_id,
-                "device_id": "device_id",
+                "device_id": device_id,
                 "ip": "ip",
                 "user_agent": "user_agent",
                 "last_seen": 12345678000,
@@ -209,14 +213,16 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
                 self.store.db.updates.do_next_background_update(100), by=0.1
             )
 
-        # Insert a user IP
         user_id = "@user:id"
+        device_id = "MY_DEVICE"
+
+        # Insert a user IP
+        self.get_success(self.store.store_device(user_id, device_id, "display name",))
         self.get_success(
             self.store.insert_client_ip(
-                user_id, "access_token", "ip", "user_agent", "device_id"
+                user_id, "access_token", "ip", "user_agent", device_id
             )
         )
-
         # Force persisting to disk
         self.reactor.advance(200)
 
@@ -224,7 +230,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         self.get_success(
             self.store.db.simple_update(
                 table="devices",
-                keyvalues={"user_id": user_id, "device_id": "device_id"},
+                keyvalues={"user_id": user_id, "device_id": device_id},
                 updatevalues={"last_seen": None, "ip": None, "user_agent": None},
                 desc="test_devices_last_seen_bg_update",
             )
@@ -232,14 +238,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
 
         # We should now get nulls when querying
         result = self.get_success(
-            self.store.get_last_client_ip_by_device(user_id, "device_id")
+            self.store.get_last_client_ip_by_device(user_id, device_id)
         )
 
-        r = result[(user_id, "device_id")]
+        r = result[(user_id, device_id)]
         self.assertDictContainsSubset(
             {
                 "user_id": user_id,
-                "device_id": "device_id",
+                "device_id": device_id,
                 "ip": None,
                 "user_agent": None,
                 "last_seen": None,
@@ -272,14 +278,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
 
         # We should now get the correct result again
         result = self.get_success(
-            self.store.get_last_client_ip_by_device(user_id, "device_id")
+            self.store.get_last_client_ip_by_device(user_id, device_id)
         )
 
-        r = result[(user_id, "device_id")]
+        r = result[(user_id, device_id)]
         self.assertDictContainsSubset(
             {
                 "user_id": user_id,
-                "device_id": "device_id",
+                "device_id": device_id,
                 "ip": "ip",
                 "user_agent": "user_agent",
                 "last_seen": 0,
@@ -296,11 +302,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
                 self.store.db.updates.do_next_background_update(100), by=0.1
             )
 
-        # Insert a user IP
         user_id = "@user:id"
+        device_id = "MY_DEVICE"
+
+        # Insert a user IP
+        self.get_success(self.store.store_device(user_id, device_id, "display name",))
         self.get_success(
             self.store.insert_client_ip(
-                user_id, "access_token", "ip", "user_agent", "device_id"
+                user_id, "access_token", "ip", "user_agent", device_id
             )
         )
 
@@ -324,7 +333,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
                     "access_token": "access_token",
                     "ip": "ip",
                     "user_agent": "user_agent",
-                    "device_id": "device_id",
+                    "device_id": device_id,
                     "last_seen": 0,
                 }
             ],
@@ -347,14 +356,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
 
         # But we should still get the correct values for the device
         result = self.get_success(
-            self.store.get_last_client_ip_by_device(user_id, "device_id")
+            self.store.get_last_client_ip_by_device(user_id, device_id)
         )
 
-        r = result[(user_id, "device_id")]
+        r = result[(user_id, device_id)]
         self.assertDictContainsSubset(
             {
                 "user_id": user_id,
-                "device_id": "device_id",
+                "device_id": device_id,
                 "ip": "ip",
                 "user_agent": "user_agent",
                 "last_seen": 0,
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 4578cc3b60..ed5786865a 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -26,7 +26,6 @@ class RegistrationStoreTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def setUp(self):
         hs = yield setup_test_homeserver(self.addCleanup)
-        self.db_pool = hs.get_db_pool()
 
         self.store = hs.get_datastore()
 
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 43200654f1..d6ecf102f8 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -35,7 +35,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         self.store = hs.get_datastore()
         self.storage = hs.get_storage()
-        self.state_datastore = self.store
+        self.state_datastore = self.storage.state.stores.state
         self.event_builder_factory = hs.get_event_builder_factory()
         self.event_creation_handler = hs.get_event_creation_handler()
 
diff --git a/tests/test_federation.py b/tests/test_federation.py
index ad165d7295..68684460c6 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -1,6 +1,6 @@
 from mock import Mock
 
-from twisted.internet.defer import maybeDeferred, succeed
+from twisted.internet.defer import ensureDeferred, maybeDeferred, succeed
 
 from synapse.events import FrozenEvent
 from synapse.logging.context import LoggingContext
@@ -70,8 +70,10 @@ class MessageAcceptTests(unittest.TestCase):
         )
 
         # Send the join, it should return None (which is not an error)
-        d = self.handler.on_receive_pdu(
-            "test.serv", join_event, sent_to_us_directly=True
+        d = ensureDeferred(
+            self.handler.on_receive_pdu(
+                "test.serv", join_event, sent_to_us_directly=True
+            )
         )
         self.reactor.advance(1)
         self.assertEqual(self.successResultOf(d), None)
@@ -119,8 +121,10 @@ class MessageAcceptTests(unittest.TestCase):
         )
 
         with LoggingContext(request="lying_event"):
-            d = self.handler.on_receive_pdu(
-                "test.serv", lying_event, sent_to_us_directly=True
+            d = ensureDeferred(
+                self.handler.on_receive_pdu(
+                    "test.serv", lying_event, sent_to_us_directly=True
+                )
             )
 
             # Step the reactor, so the database fetches come back
diff --git a/tests/test_state.py b/tests/test_state.py
index 176535947a..e0aae06be4 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -209,7 +209,7 @@ class StateTestCase(unittest.TestCase):
         ctx_c = context_store["C"]
         ctx_d = context_store["D"]
 
-        prev_state_ids = yield ctx_d.get_prev_state_ids(self.store)
+        prev_state_ids = yield ctx_d.get_prev_state_ids()
         self.assertEqual(2, len(prev_state_ids))
 
         self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
@@ -253,7 +253,7 @@ class StateTestCase(unittest.TestCase):
         ctx_c = context_store["C"]
         ctx_d = context_store["D"]
 
-        prev_state_ids = yield ctx_d.get_prev_state_ids(self.store)
+        prev_state_ids = yield ctx_d.get_prev_state_ids()
         self.assertSetEqual(
             {"START", "A", "C"}, {e_id for e_id in prev_state_ids.values()}
         )
@@ -312,7 +312,7 @@ class StateTestCase(unittest.TestCase):
         ctx_c = context_store["C"]
         ctx_e = context_store["E"]
 
-        prev_state_ids = yield ctx_e.get_prev_state_ids(self.store)
+        prev_state_ids = yield ctx_e.get_prev_state_ids()
         self.assertSetEqual(
             {"START", "A", "B", "C"}, {e for e in prev_state_ids.values()}
         )
@@ -387,7 +387,7 @@ class StateTestCase(unittest.TestCase):
         ctx_b = context_store["B"]
         ctx_d = context_store["D"]
 
-        prev_state_ids = yield ctx_d.get_prev_state_ids(self.store)
+        prev_state_ids = yield ctx_d.get_prev_state_ids()
         self.assertSetEqual(
             {"A1", "A2", "A3", "A5", "B"}, {e for e in prev_state_ids.values()}
         )
@@ -419,10 +419,10 @@ class StateTestCase(unittest.TestCase):
 
         context = yield self.state.compute_event_context(event, old_state=old_state)
 
-        prev_state_ids = yield context.get_prev_state_ids(self.store)
+        prev_state_ids = yield context.get_prev_state_ids()
         self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
 
-        current_state_ids = yield context.get_current_state_ids(self.store)
+        current_state_ids = yield context.get_current_state_ids()
         self.assertCountEqual(
             (e.event_id for e in old_state), current_state_ids.values()
         )
@@ -442,10 +442,10 @@ class StateTestCase(unittest.TestCase):
 
         context = yield self.state.compute_event_context(event, old_state=old_state)
 
-        prev_state_ids = yield context.get_prev_state_ids(self.store)
+        prev_state_ids = yield context.get_prev_state_ids()
         self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
 
-        current_state_ids = yield context.get_current_state_ids(self.store)
+        current_state_ids = yield context.get_current_state_ids()
         self.assertCountEqual(
             (e.event_id for e in old_state + [event]), current_state_ids.values()
         )
@@ -479,7 +479,7 @@ class StateTestCase(unittest.TestCase):
 
         context = yield self.state.compute_event_context(event)
 
-        current_state_ids = yield context.get_current_state_ids(self.store)
+        current_state_ids = yield context.get_current_state_ids()
 
         self.assertEqual(
             set([e.event_id for e in old_state]), set(current_state_ids.values())
@@ -511,7 +511,7 @@ class StateTestCase(unittest.TestCase):
 
         context = yield self.state.compute_event_context(event)
 
-        prev_state_ids = yield context.get_prev_state_ids(self.store)
+        prev_state_ids = yield context.get_prev_state_ids()
 
         self.assertEqual(
             set([e.event_id for e in old_state]), set(prev_state_ids.values())
@@ -552,7 +552,7 @@ class StateTestCase(unittest.TestCase):
             event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
         )
 
-        current_state_ids = yield context.get_current_state_ids(self.store)
+        current_state_ids = yield context.get_current_state_ids()
 
         self.assertEqual(len(current_state_ids), 6)
 
@@ -594,7 +594,7 @@ class StateTestCase(unittest.TestCase):
             event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
         )
 
-        current_state_ids = yield context.get_current_state_ids(self.store)
+        current_state_ids = yield context.get_current_state_ids()
 
         self.assertEqual(len(current_state_ids), 6)
 
@@ -649,7 +649,7 @@ class StateTestCase(unittest.TestCase):
             event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
         )
 
-        current_state_ids = yield context.get_current_state_ids(self.store)
+        current_state_ids = yield context.get_current_state_ids()
 
         self.assertEqual(old_state_2[3].event_id, current_state_ids[("test1", "1")])
 
@@ -677,7 +677,7 @@ class StateTestCase(unittest.TestCase):
             event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
         )
 
-        current_state_ids = yield context.get_current_state_ids(self.store)
+        current_state_ids = yield context.get_current_state_ids()
 
         self.assertEqual(old_state_1[3].event_id, current_state_ids[("test1", "1")])
 
diff --git a/tests/test_types.py b/tests/test_types.py
index 9ab5f829b0..8d97c751ea 100644
--- a/tests/test_types.py
+++ b/tests/test_types.py
@@ -17,18 +17,15 @@ from synapse.api.errors import SynapseError
 from synapse.types import GroupID, RoomAlias, UserID, map_username_to_mxid_localpart
 
 from tests import unittest
-from tests.utils import TestHomeServer
 
-mock_homeserver = TestHomeServer(hostname="my.domain")
 
-
-class UserIDTestCase(unittest.TestCase):
+class UserIDTestCase(unittest.HomeserverTestCase):
     def test_parse(self):
-        user = UserID.from_string("@1234abcd:my.domain")
+        user = UserID.from_string("@1234abcd:test")
 
         self.assertEquals("1234abcd", user.localpart)
-        self.assertEquals("my.domain", user.domain)
-        self.assertEquals(True, mock_homeserver.is_mine(user))
+        self.assertEquals("test", user.domain)
+        self.assertEquals(True, self.hs.is_mine(user))
 
     def test_pase_empty(self):
         with self.assertRaises(SynapseError):
@@ -48,13 +45,13 @@ class UserIDTestCase(unittest.TestCase):
         self.assertTrue(userA != userB)
 
 
-class RoomAliasTestCase(unittest.TestCase):
+class RoomAliasTestCase(unittest.HomeserverTestCase):
     def test_parse(self):
-        room = RoomAlias.from_string("#channel:my.domain")
+        room = RoomAlias.from_string("#channel:test")
 
         self.assertEquals("channel", room.localpart)
-        self.assertEquals("my.domain", room.domain)
-        self.assertEquals(True, mock_homeserver.is_mine(room))
+        self.assertEquals("test", room.domain)
+        self.assertEquals(True, self.hs.is_mine(room))
 
     def test_build(self):
         room = RoomAlias("channel", "my.domain")
diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py
index 8b8455c8b7..281b32c4b8 100644
--- a/tests/util/test_logcontext.py
+++ b/tests/util/test_logcontext.py
@@ -179,6 +179,30 @@ class LoggingContextTestCase(unittest.TestCase):
             nested_context = nested_logging_context(suffix="bar")
             self.assertEqual(nested_context.request, "foo-bar")
 
+    @defer.inlineCallbacks
+    def test_make_deferred_yieldable_with_await(self):
+        # an async function which retuns an incomplete coroutine, but doesn't
+        # follow the synapse rules.
+
+        async def blocking_function():
+            d = defer.Deferred()
+            reactor.callLater(0, d.callback, None)
+            await d
+
+        sentinel_context = LoggingContext.current_context()
+
+        with LoggingContext() as context_one:
+            context_one.request = "one"
+
+            d1 = make_deferred_yieldable(blocking_function())
+            # make sure that the context was reset by make_deferred_yieldable
+            self.assertIs(LoggingContext.current_context(), sentinel_context)
+
+            yield d1
+
+            # now it should be restored
+            self._check_test_key("one")
+
 
 # a function which returns a deferred which has been "called", but
 # which had a function which returned another incomplete deferred on
diff --git a/tests/util/test_snapshot_cache.py b/tests/util/test_snapshot_cache.py
deleted file mode 100644
index 1a44f72425..0000000000
--- a/tests/util/test_snapshot_cache.py
+++ /dev/null
@@ -1,63 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2015, 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-from twisted.internet.defer import Deferred
-
-from synapse.util.caches.snapshot_cache import SnapshotCache
-
-from .. import unittest
-
-
-class SnapshotCacheTestCase(unittest.TestCase):
-    def setUp(self):
-        self.cache = SnapshotCache()
-        self.cache.DURATION_MS = 1
-
-    def test_get_set(self):
-        # Check that getting a missing key returns None
-        self.assertEquals(self.cache.get(0, "key"), None)
-
-        # Check that setting a key with a deferred returns
-        # a deferred that resolves when the initial deferred does
-        d = Deferred()
-        set_result = self.cache.set(0, "key", d)
-        self.assertIsNotNone(set_result)
-        self.assertFalse(set_result.called)
-
-        # Check that getting the key before the deferred has resolved
-        # returns a deferred that resolves when the initial deferred does.
-        get_result_at_10 = self.cache.get(10, "key")
-        self.assertIsNotNone(get_result_at_10)
-        self.assertFalse(get_result_at_10.called)
-
-        # Check that the returned deferreds resolve when the initial deferred
-        # does.
-        d.callback("v")
-        self.assertTrue(set_result.called)
-        self.assertTrue(get_result_at_10.called)
-
-        # Check that getting the key after the deferred has resolved
-        # before the cache expires returns a resolved deferred.
-        get_result_at_11 = self.cache.get(11, "key")
-        self.assertIsNotNone(get_result_at_11)
-        if isinstance(get_result_at_11, Deferred):
-            # The cache may return the actual result rather than a deferred
-            self.assertTrue(get_result_at_11.called)
-
-        # Check that getting the key after the deferred has resolved
-        # after the cache expires returns None
-        get_result_at_12 = self.cache.get(12, "key")
-        self.assertIsNone(get_result_at_12)
diff --git a/tests/utils.py b/tests/utils.py
index c57da59191..e2e9cafd79 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -30,6 +30,7 @@ from twisted.internet import defer, reactor
 from synapse.api.constants import EventTypes
 from synapse.api.errors import CodeMessageException, cs_error
 from synapse.api.room_versions import RoomVersions
+from synapse.config.database import DatabaseConnectionConfig
 from synapse.config.homeserver import HomeServerConfig
 from synapse.config.server import DEFAULT_ROOM_VERSION
 from synapse.federation.transport import server as federation_server
@@ -177,7 +178,6 @@ class TestHomeServer(HomeServer):
     DATASTORE_CLASS = DataStore
 
 
-@defer.inlineCallbacks
 def setup_test_homeserver(
     cleanup_func,
     name="test",
@@ -214,7 +214,7 @@ def setup_test_homeserver(
     if USE_POSTGRES_FOR_TESTS:
         test_db = "synapse_test_%s" % uuid.uuid4().hex
 
-        config.database_config = {
+        database_config = {
             "name": "psycopg2",
             "args": {
                 "database": test_db,
@@ -226,12 +226,15 @@ def setup_test_homeserver(
             },
         }
     else:
-        config.database_config = {
+        database_config = {
             "name": "sqlite3",
             "args": {"database": ":memory:", "cp_min": 1, "cp_max": 1},
         }
 
-    db_engine = create_engine(config.database_config)
+    database = DatabaseConnectionConfig("master", database_config)
+    config.database.databases = [database]
+
+    db_engine = create_engine(database.config)
 
     # Create the database before we actually try and connect to it, based off
     # the template database we generate in setupdb()
@@ -251,39 +254,30 @@ def setup_test_homeserver(
         cur.close()
         db_conn.close()
 
-    # we need to configure the connection pool to run the on_new_connection
-    # function, so that we can test code that uses custom sqlite functions
-    # (like rank).
-    config.database_config["args"]["cp_openfun"] = db_engine.on_new_connection
-
     if datastore is None:
         hs = homeserverToUse(
             name,
             config=config,
-            db_config=config.database_config,
             version_string="Synapse/tests",
-            database_engine=db_engine,
             tls_server_context_factory=Mock(),
             tls_client_options_factory=Mock(),
             reactor=reactor,
             **kargs
         )
 
-        # Prepare the DB on SQLite -- PostgreSQL is a copy of an already up to
-        # date db
-        if not isinstance(db_engine, PostgresEngine):
-            db_conn = hs.get_db_conn()
-            yield prepare_database(db_conn, db_engine, config)
-            db_conn.commit()
-            db_conn.close()
+        hs.setup()
+        if homeserverToUse.__name__ == "TestHomeServer":
+            hs.setup_master()
+
+        if isinstance(db_engine, PostgresEngine):
+            database = hs.get_datastores().databases[0]
 
-        else:
             # We need to do cleanup on PostgreSQL
             def cleanup():
                 import psycopg2
 
                 # Close all the db pools
-                hs.get_db_pool().close()
+                database._db_pool.close()
 
                 dropped = False
 
@@ -322,23 +316,12 @@ def setup_test_homeserver(
                 # Register the cleanup hook
                 cleanup_func(cleanup)
 
-        hs.setup()
-        if homeserverToUse.__name__ == "TestHomeServer":
-            hs.setup_master()
     else:
-        # If we have been given an explicit datastore we probably want to mock
-        # out the DataStores somehow too. This all feels a bit wrong, but then
-        # mocking the stores feels wrong too.
-        datastores = Mock(datastore=datastore)
-
         hs = homeserverToUse(
             name,
-            db_pool=None,
             datastore=datastore,
-            datastores=datastores,
             config=config,
             version_string="Synapse/tests",
-            database_engine=db_engine,
             tls_server_context_factory=Mock(),
             tls_client_options_factory=Mock(),
             reactor=reactor,