summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--.travis.yml4
-rw-r--r--changelog.d/3740.misc1
-rw-r--r--tests/handlers/test_device.py143
-rw-r--r--tests/replication/slave/storage/_base.py95
-rw-r--r--tests/replication/slave/storage/test_account_data.py20
-rw-r--r--tests/replication/slave/storage/test_events.py110
-rw-r--r--tests/replication/slave/storage/test_receipts.py15
-rw-r--r--tests/server.py1
-rw-r--r--tests/storage/test_appservice.py131
-rw-r--r--tests/storage/test_directory.py3
-rw-r--r--tests/storage/test_event_federation.py6
-rw-r--r--tests/storage/test_monthly_active_users.py138
-rw-r--r--tests/storage/test_presence.py7
-rw-r--r--tests/storage/test_profile.py2
-rw-r--r--tests/storage/test_user_directory.py2
-rw-r--r--tests/test_visibility.py2
-rw-r--r--tests/unittest.py7
-rw-r--r--tests/utils.py10
18 files changed, 356 insertions, 341 deletions
diff --git a/.travis.yml b/.travis.yml
index 318701c9f8..11c76db2e5 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -35,10 +35,6 @@ matrix:
   - python: 3.6
     env: TOX_ENV=check-newsfragment
 
-  allow_failures:
-  - python: 2.7
-    env: TOX_ENV=py27-postgres TRIAL_FLAGS="-j 4"
-
 install:
   - pip install tox
 
diff --git a/changelog.d/3740.misc b/changelog.d/3740.misc
new file mode 100644
index 0000000000..4dcb7fb5de
--- /dev/null
+++ b/changelog.d/3740.misc
@@ -0,0 +1 @@
+The test suite now passes on PostgreSQL.
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 56e7acd37c..a3aa0a1cf2 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,79 +14,79 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from twisted.internet import defer
-
 import synapse.api.errors
 import synapse.handlers.device
 import synapse.storage
 
-from tests import unittest, utils
+from tests import unittest
 
 user1 = "@boris:aaa"
 user2 = "@theresa:bbb"
 
 
-class DeviceTestCase(unittest.TestCase):
-    def __init__(self, *args, **kwargs):
-        super(DeviceTestCase, self).__init__(*args, **kwargs)
-        self.store = None  # type: synapse.storage.DataStore
-        self.handler = None  # type: synapse.handlers.device.DeviceHandler
-        self.clock = None  # type: utils.MockClock
-
-    @defer.inlineCallbacks
-    def setUp(self):
-        hs = yield utils.setup_test_homeserver(self.addCleanup)
+class DeviceTestCase(unittest.HomeserverTestCase):
+    def make_homeserver(self, reactor, clock):
+        hs = self.setup_test_homeserver("server", http_client=None)
         self.handler = hs.get_device_handler()
         self.store = hs.get_datastore()
-        self.clock = hs.get_clock()
+        return hs
+
+    def prepare(self, reactor, clock, hs):
+        # These tests assume that it starts 1000 seconds in.
+        self.reactor.advance(1000)
 
-    @defer.inlineCallbacks
     def test_device_is_created_if_doesnt_exist(self):
-        res = yield self.handler.check_device_registered(
-            user_id="@boris:foo",
-            device_id="fco",
-            initial_device_display_name="display name",
+        res = self.get_success(
+            self.handler.check_device_registered(
+                user_id="@boris:foo",
+                device_id="fco",
+                initial_device_display_name="display name",
+            )
         )
         self.assertEqual(res, "fco")
 
-        dev = yield self.handler.store.get_device("@boris:foo", "fco")
+        dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco"))
         self.assertEqual(dev["display_name"], "display name")
 
-    @defer.inlineCallbacks
     def test_device_is_preserved_if_exists(self):
-        res1 = yield self.handler.check_device_registered(
-            user_id="@boris:foo",
-            device_id="fco",
-            initial_device_display_name="display name",
+        res1 = self.get_success(
+            self.handler.check_device_registered(
+                user_id="@boris:foo",
+                device_id="fco",
+                initial_device_display_name="display name",
+            )
         )
         self.assertEqual(res1, "fco")
 
-        res2 = yield self.handler.check_device_registered(
-            user_id="@boris:foo",
-            device_id="fco",
-            initial_device_display_name="new display name",
+        res2 = self.get_success(
+            self.handler.check_device_registered(
+                user_id="@boris:foo",
+                device_id="fco",
+                initial_device_display_name="new display name",
+            )
         )
         self.assertEqual(res2, "fco")
 
-        dev = yield self.handler.store.get_device("@boris:foo", "fco")
+        dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco"))
         self.assertEqual(dev["display_name"], "display name")
 
-    @defer.inlineCallbacks
     def test_device_id_is_made_up_if_unspecified(self):
-        device_id = yield self.handler.check_device_registered(
-            user_id="@theresa:foo",
-            device_id=None,
-            initial_device_display_name="display",
+        device_id = self.get_success(
+            self.handler.check_device_registered(
+                user_id="@theresa:foo",
+                device_id=None,
+                initial_device_display_name="display",
+            )
         )
 
-        dev = yield self.handler.store.get_device("@theresa:foo", device_id)
+        dev = self.get_success(self.handler.store.get_device("@theresa:foo", device_id))
         self.assertEqual(dev["display_name"], "display")
 
-    @defer.inlineCallbacks
     def test_get_devices_by_user(self):
-        yield self._record_users()
+        self._record_users()
+
+        res = self.get_success(self.handler.get_devices_by_user(user1))
 
-        res = yield self.handler.get_devices_by_user(user1)
         self.assertEqual(3, len(res))
         device_map = {d["device_id"]: d for d in res}
         self.assertDictContainsSubset(
@@ -119,11 +120,10 @@ class DeviceTestCase(unittest.TestCase):
             device_map["abc"],
         )
 
-    @defer.inlineCallbacks
     def test_get_device(self):
-        yield self._record_users()
+        self._record_users()
 
-        res = yield self.handler.get_device(user1, "abc")
+        res = self.get_success(self.handler.get_device(user1, "abc"))
         self.assertDictContainsSubset(
             {
                 "user_id": user1,
@@ -135,59 +135,66 @@ class DeviceTestCase(unittest.TestCase):
             res,
         )
 
-    @defer.inlineCallbacks
     def test_delete_device(self):
-        yield self._record_users()
+        self._record_users()
 
         # delete the device
-        yield self.handler.delete_device(user1, "abc")
+        self.get_success(self.handler.delete_device(user1, "abc"))
 
         # check the device was deleted
-        with self.assertRaises(synapse.api.errors.NotFoundError):
-            yield self.handler.get_device(user1, "abc")
+        res = self.handler.get_device(user1, "abc")
+        self.pump()
+        self.assertIsInstance(
+            self.failureResultOf(res).value, synapse.api.errors.NotFoundError
+        )
 
         # we'd like to check the access token was invalidated, but that's a
         # bit of a PITA.
 
-    @defer.inlineCallbacks
     def test_update_device(self):
-        yield self._record_users()
+        self._record_users()
 
         update = {"display_name": "new display"}
-        yield self.handler.update_device(user1, "abc", update)
+        self.get_success(self.handler.update_device(user1, "abc", update))
 
-        res = yield self.handler.get_device(user1, "abc")
+        res = self.get_success(self.handler.get_device(user1, "abc"))
         self.assertEqual(res["display_name"], "new display")
 
-    @defer.inlineCallbacks
     def test_update_unknown_device(self):
         update = {"display_name": "new_display"}
-        with self.assertRaises(synapse.api.errors.NotFoundError):
-            yield self.handler.update_device("user_id", "unknown_device_id", update)
+        res = self.handler.update_device("user_id", "unknown_device_id", update)
+        self.pump()
+        self.assertIsInstance(
+            self.failureResultOf(res).value, synapse.api.errors.NotFoundError
+        )
 
-    @defer.inlineCallbacks
     def _record_users(self):
         # check this works for both devices which have a recorded client_ip,
         # and those which don't.
-        yield self._record_user(user1, "xyz", "display 0")
-        yield self._record_user(user1, "fco", "display 1", "token1", "ip1")
-        yield self._record_user(user1, "abc", "display 2", "token2", "ip2")
-        yield self._record_user(user1, "abc", "display 2", "token3", "ip3")
+        self._record_user(user1, "xyz", "display 0")
+        self._record_user(user1, "fco", "display 1", "token1", "ip1")
+        self._record_user(user1, "abc", "display 2", "token2", "ip2")
+        self._record_user(user1, "abc", "display 2", "token3", "ip3")
+
+        self._record_user(user2, "def", "dispkay", "token4", "ip4")
 
-        yield self._record_user(user2, "def", "dispkay", "token4", "ip4")
+        self.reactor.advance(10000)
 
-    @defer.inlineCallbacks
     def _record_user(
         self, user_id, device_id, display_name, access_token=None, ip=None
     ):
-        device_id = yield self.handler.check_device_registered(
-            user_id=user_id,
-            device_id=device_id,
-            initial_device_display_name=display_name,
+        device_id = self.get_success(
+            self.handler.check_device_registered(
+                user_id=user_id,
+                device_id=device_id,
+                initial_device_display_name=display_name,
+            )
         )
 
         if ip is not None:
-            yield self.store.insert_client_ip(
-                user_id, access_token, ip, "user_agent", device_id
+            self.get_success(
+                self.store.insert_client_ip(
+                    user_id, access_token, ip, "user_agent", device_id
+                )
             )
-            self.clock.advance_time(1000)
+            self.reactor.advance(1000)
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 65df116efc..089cecfbee 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -1,4 +1,5 @@
 # Copyright 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -11,89 +12,91 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import tempfile
 
 from mock import Mock, NonCallableMock
 
-from twisted.internet import defer, reactor
-from twisted.internet.defer import Deferred
+import attr
 
 from synapse.replication.tcp.client import (
     ReplicationClientFactory,
     ReplicationClientHandler,
 )
 from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
-from synapse.util.logcontext import PreserveLoggingContext, make_deferred_yieldable
 
 from tests import unittest
-from tests.utils import setup_test_homeserver
 
 
-class TestReplicationClientHandler(ReplicationClientHandler):
-    """Overrides on_rdata so that we can wait for it to happen"""
+class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
+    def make_homeserver(self, reactor, clock):
 
-    def __init__(self, store):
-        super(TestReplicationClientHandler, self).__init__(store)
-        self._rdata_awaiters = []
-
-    def await_replication(self):
-        d = Deferred()
-        self._rdata_awaiters.append(d)
-        return make_deferred_yieldable(d)
-
-    def on_rdata(self, stream_name, token, rows):
-        awaiters = self._rdata_awaiters
-        self._rdata_awaiters = []
-        super(TestReplicationClientHandler, self).on_rdata(stream_name, token, rows)
-        with PreserveLoggingContext():
-            for a in awaiters:
-                a.callback(None)
-
-
-class BaseSlavedStoreTestCase(unittest.TestCase):
-    @defer.inlineCallbacks
-    def setUp(self):
-        self.hs = yield setup_test_homeserver(
-            self.addCleanup,
+        hs = self.setup_test_homeserver(
             "blue",
-            http_client=None,
             federation_client=Mock(),
             ratelimiter=NonCallableMock(spec_set=["send_message"]),
         )
-        self.hs.get_ratelimiter().send_message.return_value = (True, 0)
+
+        hs.get_ratelimiter().send_message.return_value = (True, 0)
+
+        return hs
+
+    def prepare(self, reactor, clock, hs):
 
         self.master_store = self.hs.get_datastore()
         self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs)
         self.event_id = 0
 
         server_factory = ReplicationStreamProtocolFactory(self.hs)
-        # XXX: mktemp is unsafe and should never be used. but we're just a test.
-        path = tempfile.mktemp(prefix="base_slaved_store_test_case_socket")
-        listener = reactor.listenUNIX(path, server_factory)
-        self.addCleanup(listener.stopListening)
         self.streamer = server_factory.streamer
 
-        self.replication_handler = TestReplicationClientHandler(self.slaved_store)
+        self.replication_handler = ReplicationClientHandler(self.slaved_store)
         client_factory = ReplicationClientFactory(
             self.hs, "client_name", self.replication_handler
         )
-        client_connector = reactor.connectUNIX(path, client_factory)
-        self.addCleanup(client_factory.stopTrying)
-        self.addCleanup(client_connector.disconnect)
+
+        server = server_factory.buildProtocol(None)
+        client = client_factory.buildProtocol(None)
+
+        @attr.s
+        class FakeTransport(object):
+
+            other = attr.ib()
+            disconnecting = False
+            buffer = attr.ib(default=b'')
+
+            def registerProducer(self, producer, streaming):
+
+                self.producer = producer
+
+                def _produce():
+                    self.producer.resumeProducing()
+                    reactor.callLater(0.1, _produce)
+
+                reactor.callLater(0.0, _produce)
+
+            def write(self, byt):
+                self.buffer = self.buffer + byt
+
+                if getattr(self.other, "transport") is not None:
+                    self.other.dataReceived(self.buffer)
+                    self.buffer = b""
+
+            def writeSequence(self, seq):
+                for x in seq:
+                    self.write(x)
+
+        client.makeConnection(FakeTransport(server))
+        server.makeConnection(FakeTransport(client))
 
     def replicate(self):
         """Tell the master side of replication that something has happened, and then
         wait for the replication to occur.
         """
-        # xxx: should we be more specific in what we wait for?
-        d = self.replication_handler.await_replication()
         self.streamer.on_notifier_poke()
-        return d
+        self.pump(0.1)
 
-    @defer.inlineCallbacks
     def check(self, method, args, expected_result=None):
-        master_result = yield getattr(self.master_store, method)(*args)
-        slaved_result = yield getattr(self.slaved_store, method)(*args)
+        master_result = self.get_success(getattr(self.master_store, method)(*args))
+        slaved_result = self.get_success(getattr(self.slaved_store, method)(*args))
         if expected_result is not None:
             self.assertEqual(master_result, expected_result)
             self.assertEqual(slaved_result, expected_result)
diff --git a/tests/replication/slave/storage/test_account_data.py b/tests/replication/slave/storage/test_account_data.py
index 87cc2b2fba..43e3248703 100644
--- a/tests/replication/slave/storage/test_account_data.py
+++ b/tests/replication/slave/storage/test_account_data.py
@@ -12,9 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
-from twisted.internet import defer
-
 from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
 
 from ._base import BaseSlavedStoreTestCase
@@ -27,16 +24,19 @@ class SlavedAccountDataStoreTestCase(BaseSlavedStoreTestCase):
 
     STORE_TYPE = SlavedAccountDataStore
 
-    @defer.inlineCallbacks
     def test_user_account_data(self):
-        yield self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 1})
-        yield self.replicate()
-        yield self.check(
+        self.get_success(
+            self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 1})
+        )
+        self.replicate()
+        self.check(
             "get_global_account_data_by_type_for_user", [TYPE, USER_ID], {"a": 1}
         )
 
-        yield self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 2})
-        yield self.replicate()
-        yield self.check(
+        self.get_success(
+            self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 2})
+        )
+        self.replicate()
+        self.check(
             "get_global_account_data_by_type_for_user", [TYPE, USER_ID], {"a": 2}
         )
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 2ba80ccdcf..db44d33c68 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -12,8 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from twisted.internet import defer
-
 from synapse.events import FrozenEvent, _EventInternalMetadata
 from synapse.events.snapshot import EventContext
 from synapse.replication.slave.storage.events import SlavedEventStore
@@ -55,70 +53,66 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
     def tearDown(self):
         [unpatch() for unpatch in self.unpatches]
 
-    @defer.inlineCallbacks
     def test_get_latest_event_ids_in_room(self):
-        create = yield self.persist(type="m.room.create", key="", creator=USER_ID)
-        yield self.replicate()
-        yield self.check("get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id])
+        create = self.persist(type="m.room.create", key="", creator=USER_ID)
+        self.replicate()
+        self.check("get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id])
 
-        join = yield self.persist(
+        join = self.persist(
             type="m.room.member",
             key=USER_ID,
             membership="join",
             prev_events=[(create.event_id, {})],
         )
-        yield self.replicate()
-        yield self.check("get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id])
+        self.replicate()
+        self.check("get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id])
 
-    @defer.inlineCallbacks
     def test_redactions(self):
-        yield self.persist(type="m.room.create", key="", creator=USER_ID)
-        yield self.persist(type="m.room.member", key=USER_ID, membership="join")
+        self.persist(type="m.room.create", key="", creator=USER_ID)
+        self.persist(type="m.room.member", key=USER_ID, membership="join")
 
-        msg = yield self.persist(type="m.room.message", msgtype="m.text", body="Hello")
-        yield self.replicate()
-        yield self.check("get_event", [msg.event_id], msg)
+        msg = self.persist(type="m.room.message", msgtype="m.text", body="Hello")
+        self.replicate()
+        self.check("get_event", [msg.event_id], msg)
 
-        redaction = yield self.persist(type="m.room.redaction", redacts=msg.event_id)
-        yield self.replicate()
+        redaction = self.persist(type="m.room.redaction", redacts=msg.event_id)
+        self.replicate()
 
         msg_dict = msg.get_dict()
         msg_dict["content"] = {}
         msg_dict["unsigned"]["redacted_by"] = redaction.event_id
         msg_dict["unsigned"]["redacted_because"] = redaction
         redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict())
-        yield self.check("get_event", [msg.event_id], redacted)
+        self.check("get_event", [msg.event_id], redacted)
 
-    @defer.inlineCallbacks
     def test_backfilled_redactions(self):
-        yield self.persist(type="m.room.create", key="", creator=USER_ID)
-        yield self.persist(type="m.room.member", key=USER_ID, membership="join")
+        self.persist(type="m.room.create", key="", creator=USER_ID)
+        self.persist(type="m.room.member", key=USER_ID, membership="join")
 
-        msg = yield self.persist(type="m.room.message", msgtype="m.text", body="Hello")
-        yield self.replicate()
-        yield self.check("get_event", [msg.event_id], msg)
+        msg = self.persist(type="m.room.message", msgtype="m.text", body="Hello")
+        self.replicate()
+        self.check("get_event", [msg.event_id], msg)
 
-        redaction = yield self.persist(
+        redaction = self.persist(
             type="m.room.redaction", redacts=msg.event_id, backfill=True
         )
-        yield self.replicate()
+        self.replicate()
 
         msg_dict = msg.get_dict()
         msg_dict["content"] = {}
         msg_dict["unsigned"]["redacted_by"] = redaction.event_id
         msg_dict["unsigned"]["redacted_because"] = redaction
         redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict())
-        yield self.check("get_event", [msg.event_id], redacted)
+        self.check("get_event", [msg.event_id], redacted)
 
-    @defer.inlineCallbacks
     def test_invites(self):
-        yield self.persist(type="m.room.create", key="", creator=USER_ID)
-        yield self.check("get_invited_rooms_for_user", [USER_ID_2], [])
-        event = yield self.persist(
-            type="m.room.member", key=USER_ID_2, membership="invite"
-        )
-        yield self.replicate()
-        yield self.check(
+        self.persist(type="m.room.create", key="", creator=USER_ID)
+        self.check("get_invited_rooms_for_user", [USER_ID_2], [])
+        event = self.persist(type="m.room.member", key=USER_ID_2, membership="invite")
+
+        self.replicate()
+
+        self.check(
             "get_invited_rooms_for_user",
             [USER_ID_2],
             [
@@ -132,37 +126,34 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
             ],
         )
 
-    @defer.inlineCallbacks
     def test_push_actions_for_user(self):
-        yield self.persist(type="m.room.create", key="", creator=USER_ID)
-        yield self.persist(type="m.room.join", key=USER_ID, membership="join")
-        yield self.persist(
+        self.persist(type="m.room.create", key="", creator=USER_ID)
+        self.persist(type="m.room.join", key=USER_ID, membership="join")
+        self.persist(
             type="m.room.join", sender=USER_ID, key=USER_ID_2, membership="join"
         )
-        event1 = yield self.persist(
-            type="m.room.message", msgtype="m.text", body="hello"
-        )
-        yield self.replicate()
-        yield self.check(
+        event1 = self.persist(type="m.room.message", msgtype="m.text", body="hello")
+        self.replicate()
+        self.check(
             "get_unread_event_push_actions_by_room_for_user",
             [ROOM_ID, USER_ID_2, event1.event_id],
             {"highlight_count": 0, "notify_count": 0},
         )
 
-        yield self.persist(
+        self.persist(
             type="m.room.message",
             msgtype="m.text",
             body="world",
             push_actions=[(USER_ID_2, ["notify"])],
         )
-        yield self.replicate()
-        yield self.check(
+        self.replicate()
+        self.check(
             "get_unread_event_push_actions_by_room_for_user",
             [ROOM_ID, USER_ID_2, event1.event_id],
             {"highlight_count": 0, "notify_count": 1},
         )
 
-        yield self.persist(
+        self.persist(
             type="m.room.message",
             msgtype="m.text",
             body="world",
@@ -170,8 +161,8 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
                 (USER_ID_2, ["notify", {"set_tweak": "highlight", "value": True}])
             ],
         )
-        yield self.replicate()
-        yield self.check(
+        self.replicate()
+        self.check(
             "get_unread_event_push_actions_by_room_for_user",
             [ROOM_ID, USER_ID_2, event1.event_id],
             {"highlight_count": 1, "notify_count": 2},
@@ -179,7 +170,6 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
 
     event_id = 0
 
-    @defer.inlineCallbacks
     def persist(
         self,
         sender=USER_ID,
@@ -206,8 +196,8 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
             depth = self.event_id
 
         if not prev_events:
-            latest_event_ids = yield self.master_store.get_latest_event_ids_in_room(
-                room_id
+            latest_event_ids = self.get_success(
+                self.master_store.get_latest_event_ids_in_room(room_id)
             )
             prev_events = [(ev_id, {}) for ev_id in latest_event_ids]
 
@@ -240,19 +230,23 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
             )
         else:
             state_handler = self.hs.get_state_handler()
-            context = yield state_handler.compute_event_context(event)
+            context = self.get_success(state_handler.compute_event_context(event))
 
-        yield self.master_store.add_push_actions_to_staging(
+        self.master_store.add_push_actions_to_staging(
             event.event_id, {user_id: actions for user_id, actions in push_actions}
         )
 
         ordering = None
         if backfill:
-            yield self.master_store.persist_events([(event, context)], backfilled=True)
+            self.get_success(
+                self.master_store.persist_events([(event, context)], backfilled=True)
+            )
         else:
-            ordering, _ = yield self.master_store.persist_event(event, context)
+            ordering, _ = self.get_success(
+                self.master_store.persist_event(event, context)
+            )
 
         if ordering:
             event.internal_metadata.stream_ordering = ordering
 
-        defer.returnValue(event)
+        return event
diff --git a/tests/replication/slave/storage/test_receipts.py b/tests/replication/slave/storage/test_receipts.py
index ae1adeded1..f47d94f690 100644
--- a/tests/replication/slave/storage/test_receipts.py
+++ b/tests/replication/slave/storage/test_receipts.py
@@ -12,8 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from twisted.internet import defer
-
 from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
 
 from ._base import BaseSlavedStoreTestCase
@@ -27,13 +25,10 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
 
     STORE_TYPE = SlavedReceiptsStore
 
-    @defer.inlineCallbacks
     def test_receipt(self):
-        yield self.check("get_receipts_for_user", [USER_ID, "m.read"], {})
-        yield self.master_store.insert_receipt(
-            ROOM_ID, "m.read", USER_ID, [EVENT_ID], {}
-        )
-        yield self.replicate()
-        yield self.check(
-            "get_receipts_for_user", [USER_ID, "m.read"], {ROOM_ID: EVENT_ID}
+        self.check("get_receipts_for_user", [USER_ID, "m.read"], {})
+        self.get_success(
+            self.master_store.insert_receipt(ROOM_ID, "m.read", USER_ID, [EVENT_ID], {})
         )
+        self.replicate()
+        self.check("get_receipts_for_user", [USER_ID, "m.read"], {ROOM_ID: EVENT_ID})
diff --git a/tests/server.py b/tests/server.py
index 7dbdb7f8ea..615bba1b59 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -232,6 +232,7 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
 
     clock.threadpool = ThreadPool()
     pool.threadpool = ThreadPool()
+    pool.running = True
     return d
 
 
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index c893990454..3f0083831b 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -37,18 +37,14 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def setUp(self):
         self.as_yaml_files = []
-        config = Mock(
-            app_service_config_files=self.as_yaml_files,
-            event_cache_size=1,
-            password_providers=[],
-        )
         hs = yield setup_test_homeserver(
-            self.addCleanup,
-            config=config,
-            federation_sender=Mock(),
-            federation_client=Mock(),
+            self.addCleanup, federation_sender=Mock(), federation_client=Mock()
         )
 
+        hs.config.app_service_config_files = self.as_yaml_files
+        hs.config.event_cache_size = 1
+        hs.config.password_providers = []
+
         self.as_token = "token1"
         self.as_url = "some_url"
         self.as_id = "as1"
@@ -58,7 +54,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
         self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob")
         self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob")
         # must be done after inserts
-        self.store = ApplicationServiceStore(None, hs)
+        self.store = ApplicationServiceStore(hs.get_db_conn(), hs)
 
     def tearDown(self):
         # TODO: suboptimal that we need to create files for tests!
@@ -105,18 +101,16 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
     def setUp(self):
         self.as_yaml_files = []
 
-        config = Mock(
-            app_service_config_files=self.as_yaml_files,
-            event_cache_size=1,
-            password_providers=[],
-        )
         hs = yield setup_test_homeserver(
-            self.addCleanup,
-            config=config,
-            federation_sender=Mock(),
-            federation_client=Mock(),
+            self.addCleanup, federation_sender=Mock(), federation_client=Mock()
         )
+
+        hs.config.app_service_config_files = self.as_yaml_files
+        hs.config.event_cache_size = 1
+        hs.config.password_providers = []
+
         self.db_pool = hs.get_db_pool()
+        self.engine = hs.database_engine
 
         self.as_list = [
             {"token": "token1", "url": "https://matrix-as.org", "id": "id_1"},
@@ -129,7 +123,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
 
         self.as_yaml_files = []
 
-        self.store = TestTransactionStore(None, hs)
+        self.store = TestTransactionStore(hs.get_db_conn(), hs)
 
     def _add_service(self, url, as_token, id):
         as_yaml = dict(
@@ -146,29 +140,35 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
             self.as_yaml_files.append(as_token)
 
     def _set_state(self, id, state, txn=None):
-        return self.db_pool.runQuery(
-            "INSERT INTO application_services_state(as_id, state, last_txn) "
-            "VALUES(?,?,?)",
+        return self.db_pool.runOperation(
+            self.engine.convert_param_style(
+                "INSERT INTO application_services_state(as_id, state, last_txn) "
+                "VALUES(?,?,?)"
+            ),
             (id, state, txn),
         )
 
     def _insert_txn(self, as_id, txn_id, events):
-        return self.db_pool.runQuery(
-            "INSERT INTO application_services_txns(as_id, txn_id, event_ids) "
-            "VALUES(?,?,?)",
+        return self.db_pool.runOperation(
+            self.engine.convert_param_style(
+                "INSERT INTO application_services_txns(as_id, txn_id, event_ids) "
+                "VALUES(?,?,?)"
+            ),
             (as_id, txn_id, json.dumps([e.event_id for e in events])),
         )
 
     def _set_last_txn(self, as_id, txn_id):
-        return self.db_pool.runQuery(
-            "INSERT INTO application_services_state(as_id, last_txn, state) "
-            "VALUES(?,?,?)",
+        return self.db_pool.runOperation(
+            self.engine.convert_param_style(
+                "INSERT INTO application_services_state(as_id, last_txn, state) "
+                "VALUES(?,?,?)"
+            ),
             (as_id, txn_id, ApplicationServiceState.UP),
         )
 
     @defer.inlineCallbacks
     def test_get_appservice_state_none(self):
-        service = Mock(id=999)
+        service = Mock(id="999")
         state = yield self.store.get_appservice_state(service)
         self.assertEquals(None, state)
 
@@ -200,7 +200,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         service = Mock(id=self.as_list[1]["id"])
         yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
         rows = yield self.db_pool.runQuery(
-            "SELECT as_id FROM application_services_state WHERE state=?",
+            self.engine.convert_param_style(
+                "SELECT as_id FROM application_services_state WHERE state=?"
+            ),
             (ApplicationServiceState.DOWN,),
         )
         self.assertEquals(service.id, rows[0][0])
@@ -212,7 +214,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
         yield self.store.set_appservice_state(service, ApplicationServiceState.UP)
         rows = yield self.db_pool.runQuery(
-            "SELECT as_id FROM application_services_state WHERE state=?",
+            self.engine.convert_param_style(
+                "SELECT as_id FROM application_services_state WHERE state=?"
+            ),
             (ApplicationServiceState.UP,),
         )
         self.assertEquals(service.id, rows[0][0])
@@ -279,14 +283,19 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         yield self.store.complete_appservice_txn(txn_id=txn_id, service=service)
 
         res = yield self.db_pool.runQuery(
-            "SELECT last_txn FROM application_services_state WHERE as_id=?",
+            self.engine.convert_param_style(
+                "SELECT last_txn FROM application_services_state WHERE as_id=?"
+            ),
             (service.id,),
         )
         self.assertEquals(1, len(res))
         self.assertEquals(txn_id, res[0][0])
 
         res = yield self.db_pool.runQuery(
-            "SELECT * FROM application_services_txns WHERE txn_id=?", (txn_id,)
+            self.engine.convert_param_style(
+                "SELECT * FROM application_services_txns WHERE txn_id=?"
+            ),
+            (txn_id,),
         )
         self.assertEquals(0, len(res))
 
@@ -300,7 +309,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         yield self.store.complete_appservice_txn(txn_id=txn_id, service=service)
 
         res = yield self.db_pool.runQuery(
-            "SELECT last_txn, state FROM application_services_state WHERE " "as_id=?",
+            self.engine.convert_param_style(
+                "SELECT last_txn, state FROM application_services_state WHERE as_id=?"
+            ),
             (service.id,),
         )
         self.assertEquals(1, len(res))
@@ -308,7 +319,10 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         self.assertEquals(ApplicationServiceState.UP, res[0][1])
 
         res = yield self.db_pool.runQuery(
-            "SELECT * FROM application_services_txns WHERE txn_id=?", (txn_id,)
+            self.engine.convert_param_style(
+                "SELECT * FROM application_services_txns WHERE txn_id=?"
+            ),
+            (txn_id,),
         )
         self.assertEquals(0, len(res))
 
@@ -394,37 +408,31 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
         f1 = self._write_config(suffix="1")
         f2 = self._write_config(suffix="2")
 
-        config = Mock(
-            app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[]
-        )
         hs = yield setup_test_homeserver(
-            self.addCleanup,
-            config=config,
-            datastore=Mock(),
-            federation_sender=Mock(),
-            federation_client=Mock(),
+            self.addCleanup, federation_sender=Mock(), federation_client=Mock()
         )
 
-        ApplicationServiceStore(None, hs)
+        hs.config.app_service_config_files = [f1, f2]
+        hs.config.event_cache_size = 1
+        hs.config.password_providers = []
+
+        ApplicationServiceStore(hs.get_db_conn(), hs)
 
     @defer.inlineCallbacks
     def test_duplicate_ids(self):
         f1 = self._write_config(id="id", suffix="1")
         f2 = self._write_config(id="id", suffix="2")
 
-        config = Mock(
-            app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[]
-        )
         hs = yield setup_test_homeserver(
-            self.addCleanup,
-            config=config,
-            datastore=Mock(),
-            federation_sender=Mock(),
-            federation_client=Mock(),
+            self.addCleanup, federation_sender=Mock(), federation_client=Mock()
         )
 
+        hs.config.app_service_config_files = [f1, f2]
+        hs.config.event_cache_size = 1
+        hs.config.password_providers = []
+
         with self.assertRaises(ConfigError) as cm:
-            ApplicationServiceStore(None, hs)
+            ApplicationServiceStore(hs.get_db_conn(), hs)
 
         e = cm.exception
         self.assertIn(f1, str(e))
@@ -436,19 +444,16 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
         f1 = self._write_config(as_token="as_token", suffix="1")
         f2 = self._write_config(as_token="as_token", suffix="2")
 
-        config = Mock(
-            app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[]
-        )
         hs = yield setup_test_homeserver(
-            self.addCleanup,
-            config=config,
-            datastore=Mock(),
-            federation_sender=Mock(),
-            federation_client=Mock(),
+            self.addCleanup, federation_sender=Mock(), federation_client=Mock()
         )
 
+        hs.config.app_service_config_files = [f1, f2]
+        hs.config.event_cache_size = 1
+        hs.config.password_providers = []
+
         with self.assertRaises(ConfigError) as cm:
-            ApplicationServiceStore(None, hs)
+            ApplicationServiceStore(hs.get_db_conn(), hs)
 
         e = cm.exception
         self.assertIn(f1, str(e))
diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py
index b4510c1c8d..4e128e1047 100644
--- a/tests/storage/test_directory.py
+++ b/tests/storage/test_directory.py
@@ -16,7 +16,6 @@
 
 from twisted.internet import defer
 
-from synapse.storage.directory import DirectoryStore
 from synapse.types import RoomAlias, RoomID
 
 from tests import unittest
@@ -28,7 +27,7 @@ class DirectoryStoreTestCase(unittest.TestCase):
     def setUp(self):
         hs = yield setup_test_homeserver(self.addCleanup)
 
-        self.store = DirectoryStore(None, hs)
+        self.store = hs.get_datastore()
 
         self.room = RoomID.from_string("!abcde:test")
         self.alias = RoomAlias.from_string("#my-room:test")
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 2fdf34fdf6..0d4e74d637 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -37,10 +37,10 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase):
                 (
                     "INSERT INTO events ("
                     "   room_id, event_id, type, depth, topological_ordering,"
-                    "   content, processed, outlier) "
-                    "VALUES (?, ?, 'm.test', ?, ?, 'test', ?, ?)"
+                    "   content, processed, outlier, stream_ordering) "
+                    "VALUES (?, ?, 'm.test', ?, ?, 'test', ?, ?, ?)"
                 ),
-                (room_id, event_id, i, i, True, False),
+                (room_id, event_id, i, i, True, False, i),
             )
 
             txn.execute(
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index f2ed866ae7..2036287288 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -13,25 +13,22 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from twisted.internet import defer
-
-import tests.unittest
-import tests.utils
-from tests.utils import setup_test_homeserver
+from tests.unittest import HomeserverTestCase
 
 FORTY_DAYS = 40 * 24 * 60 * 60
 
 
-class MonthlyActiveUsersTestCase(tests.unittest.TestCase):
-    def __init__(self, *args, **kwargs):
-        super(MonthlyActiveUsersTestCase, self).__init__(*args, **kwargs)
+class MonthlyActiveUsersTestCase(HomeserverTestCase):
+    def make_homeserver(self, reactor, clock):
+
+        hs = self.setup_test_homeserver()
+        self.store = hs.get_datastore()
+
+        # Advance the clock a bit
+        reactor.advance(FORTY_DAYS)
 
-    @defer.inlineCallbacks
-    def setUp(self):
-        self.hs = yield setup_test_homeserver(self.addCleanup)
-        self.store = self.hs.get_datastore()
+        return hs
 
-    @defer.inlineCallbacks
     def test_initialise_reserved_users(self):
         self.hs.config.max_mau_value = 5
         user1 = "@user1:server"
@@ -44,88 +41,101 @@ class MonthlyActiveUsersTestCase(tests.unittest.TestCase):
         ]
         user_num = len(threepids)
 
-        yield self.store.register(user_id=user1, token="123", password_hash=None)
-
-        yield self.store.register(user_id=user2, token="456", password_hash=None)
+        self.store.register(user_id=user1, token="123", password_hash=None)
+        self.store.register(user_id=user2, token="456", password_hash=None)
+        self.pump()
 
         now = int(self.hs.get_clock().time_msec())
-        yield self.store.user_add_threepid(user1, "email", user1_email, now, now)
-        yield self.store.user_add_threepid(user2, "email", user2_email, now, now)
-        yield self.store.initialise_reserved_users(threepids)
+        self.store.user_add_threepid(user1, "email", user1_email, now, now)
+        self.store.user_add_threepid(user2, "email", user2_email, now, now)
+        self.store.initialise_reserved_users(threepids)
+        self.pump()
 
-        active_count = yield self.store.get_monthly_active_count()
+        active_count = self.store.get_monthly_active_count()
 
         # Test total counts
-        self.assertEquals(active_count, user_num)
+        self.assertEquals(self.get_success(active_count), user_num)
 
         # Test user is marked as active
-
-        timestamp = yield self.store.user_last_seen_monthly_active(user1)
-        self.assertTrue(timestamp)
-        timestamp = yield self.store.user_last_seen_monthly_active(user2)
-        self.assertTrue(timestamp)
+        timestamp = self.store.user_last_seen_monthly_active(user1)
+        self.assertTrue(self.get_success(timestamp))
+        timestamp = self.store.user_last_seen_monthly_active(user2)
+        self.assertTrue(self.get_success(timestamp))
 
         # Test that users are never removed from the db.
         self.hs.config.max_mau_value = 0
 
-        self.hs.get_clock().advance_time(FORTY_DAYS)
+        self.reactor.advance(FORTY_DAYS)
 
-        yield self.store.reap_monthly_active_users()
+        self.store.reap_monthly_active_users()
+        self.pump()
 
-        active_count = yield self.store.get_monthly_active_count()
-        self.assertEquals(active_count, user_num)
+        active_count = self.store.get_monthly_active_count()
+        self.assertEquals(self.get_success(active_count), user_num)
 
         # Test that regalar users are removed from the db
         ru_count = 2
-        yield self.store.upsert_monthly_active_user("@ru1:server")
-        yield self.store.upsert_monthly_active_user("@ru2:server")
-        active_count = yield self.store.get_monthly_active_count()
+        self.store.upsert_monthly_active_user("@ru1:server")
+        self.store.upsert_monthly_active_user("@ru2:server")
+        self.pump()
 
-        self.assertEqual(active_count, user_num + ru_count)
+        active_count = self.store.get_monthly_active_count()
+        self.assertEqual(self.get_success(active_count), user_num + ru_count)
         self.hs.config.max_mau_value = user_num
-        yield self.store.reap_monthly_active_users()
+        self.store.reap_monthly_active_users()
+        self.pump()
 
-        active_count = yield self.store.get_monthly_active_count()
-        self.assertEquals(active_count, user_num)
+        active_count = self.store.get_monthly_active_count()
+        self.assertEquals(self.get_success(active_count), user_num)
 
-    @defer.inlineCallbacks
     def test_can_insert_and_count_mau(self):
-        count = yield self.store.get_monthly_active_count()
-        self.assertEqual(0, count)
+        count = self.store.get_monthly_active_count()
+        self.assertEqual(0, self.get_success(count))
 
-        yield self.store.upsert_monthly_active_user("@user:server")
-        count = yield self.store.get_monthly_active_count()
+        self.store.upsert_monthly_active_user("@user:server")
+        self.pump()
 
-        self.assertEqual(1, count)
+        count = self.store.get_monthly_active_count()
+        self.assertEqual(1, self.get_success(count))
 
-    @defer.inlineCallbacks
     def test_user_last_seen_monthly_active(self):
         user_id1 = "@user1:server"
         user_id2 = "@user2:server"
         user_id3 = "@user3:server"
 
-        result = yield self.store.user_last_seen_monthly_active(user_id1)
-        self.assertFalse(result == 0)
-        yield self.store.upsert_monthly_active_user(user_id1)
-        yield self.store.upsert_monthly_active_user(user_id2)
-        result = yield self.store.user_last_seen_monthly_active(user_id1)
-        self.assertTrue(result > 0)
-        result = yield self.store.user_last_seen_monthly_active(user_id3)
-        self.assertFalse(result == 0)
+        result = self.store.user_last_seen_monthly_active(user_id1)
+        self.assertFalse(self.get_success(result) == 0)
+
+        self.store.upsert_monthly_active_user(user_id1)
+        self.store.upsert_monthly_active_user(user_id2)
+        self.pump()
+
+        result = self.store.user_last_seen_monthly_active(user_id1)
+        self.assertGreater(self.get_success(result), 0)
+
+        result = self.store.user_last_seen_monthly_active(user_id3)
+        self.assertNotEqual(self.get_success(result), 0)
 
-    @defer.inlineCallbacks
     def test_reap_monthly_active_users(self):
         self.hs.config.max_mau_value = 5
         initial_users = 10
         for i in range(initial_users):
-            yield self.store.upsert_monthly_active_user("@user%d:server" % i)
-        count = yield self.store.get_monthly_active_count()
-        self.assertTrue(count, initial_users)
-        yield self.store.reap_monthly_active_users()
-        count = yield self.store.get_monthly_active_count()
-        self.assertEquals(count, initial_users - self.hs.config.max_mau_value)
-
-        self.hs.get_clock().advance_time(FORTY_DAYS)
-        yield self.store.reap_monthly_active_users()
-        count = yield self.store.get_monthly_active_count()
-        self.assertEquals(count, 0)
+            self.store.upsert_monthly_active_user("@user%d:server" % i)
+        self.pump()
+
+        count = self.store.get_monthly_active_count()
+        self.assertTrue(self.get_success(count), initial_users)
+
+        self.store.reap_monthly_active_users()
+        self.pump()
+        count = self.store.get_monthly_active_count()
+        self.assertEquals(
+            self.get_success(count), initial_users - self.hs.config.max_mau_value
+        )
+
+        self.reactor.advance(FORTY_DAYS)
+        self.store.reap_monthly_active_users()
+        self.pump()
+
+        count = self.store.get_monthly_active_count()
+        self.assertEquals(self.get_success(count), 0)
diff --git a/tests/storage/test_presence.py b/tests/storage/test_presence.py
index b5b58ff660..c7a63f39b9 100644
--- a/tests/storage/test_presence.py
+++ b/tests/storage/test_presence.py
@@ -16,19 +16,18 @@
 
 from twisted.internet import defer
 
-from synapse.storage.presence import PresenceStore
 from synapse.types import UserID
 
 from tests import unittest
-from tests.utils import MockClock, setup_test_homeserver
+from tests.utils import setup_test_homeserver
 
 
 class PresenceStoreTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def setUp(self):
-        hs = yield setup_test_homeserver(self.addCleanup, clock=MockClock())
+        hs = yield setup_test_homeserver(self.addCleanup)
 
-        self.store = PresenceStore(None, hs)
+        self.store = hs.get_datastore()
 
         self.u_apple = UserID.from_string("@apple:test")
         self.u_banana = UserID.from_string("@banana:test")
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index a1f6618bf9..45824bd3b2 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -28,7 +28,7 @@ class ProfileStoreTestCase(unittest.TestCase):
     def setUp(self):
         hs = yield setup_test_homeserver(self.addCleanup)
 
-        self.store = ProfileStore(None, hs)
+        self.store = ProfileStore(hs.get_db_conn(), hs)
 
         self.u_frank = UserID.from_string("@frank:test")
 
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index b46e0ea7e2..0dde1ab2fe 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -30,7 +30,7 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def setUp(self):
         self.hs = yield setup_test_homeserver(self.addCleanup)
-        self.store = UserDirectoryStore(None, self.hs)
+        self.store = UserDirectoryStore(self.hs.get_db_conn(), self.hs)
 
         # alice and bob are both in !room_id. bobby is not but shares
         # a homeserver with alice.
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index 8d8ce0cab9..2eea3b098b 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -96,7 +96,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
         events_to_filter.append(evt)
 
         # the erasey user gets erased
-        self.hs.get_datastore().mark_user_erased("@erased:local_hs")
+        yield self.hs.get_datastore().mark_user_erased("@erased:local_hs")
 
         # ... and the filtering happens.
         filtered = yield filter_events_for_server(
diff --git a/tests/unittest.py b/tests/unittest.py
index 8b513bb32b..a3d39920db 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -22,6 +22,7 @@ from canonicaljson import json
 
 import twisted
 import twisted.logger
+from twisted.internet.defer import Deferred
 from twisted.trial import unittest
 
 from synapse.http.server import JsonResource
@@ -281,12 +282,14 @@ class HomeserverTestCase(TestCase):
         kwargs.update(self._hs_args)
         return setup_test_homeserver(self.addCleanup, *args, **kwargs)
 
-    def pump(self):
+    def pump(self, by=0.0):
         """
         Pump the reactor enough that Deferreds will fire.
         """
-        self.reactor.pump([0.0] * 100)
+        self.reactor.pump([by] * 100)
 
     def get_success(self, d):
+        if not isinstance(d, Deferred):
+            return d
         self.pump()
         return self.successResultOf(d)
diff --git a/tests/utils.py b/tests/utils.py
index 63e30dc6c0..114470d641 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -30,8 +30,8 @@ from synapse.config.server import ServerConfig
 from synapse.federation.transport import server
 from synapse.http.server import HttpServer
 from synapse.server import HomeServer
-from synapse.storage import DataStore, PostgresEngine
-from synapse.storage.engines import create_engine
+from synapse.storage import DataStore
+from synapse.storage.engines import PostgresEngine, create_engine
 from synapse.storage.prepare_database import (
     _get_or_create_schema_state,
     _setup_new_database,
@@ -42,6 +42,7 @@ from synapse.util.ratelimitutils import FederationRateLimiter
 
 # set this to True to run the tests against postgres instead of sqlite.
 USE_POSTGRES_FOR_TESTS = os.environ.get("SYNAPSE_POSTGRES", False)
+LEAVE_DB = os.environ.get("SYNAPSE_LEAVE_DB", False)
 POSTGRES_USER = os.environ.get("SYNAPSE_POSTGRES_USER", "postgres")
 POSTGRES_BASE_DB = "_synapse_unit_tests_base_%s" % (os.getpid(),)
 
@@ -244,8 +245,9 @@ def setup_test_homeserver(
                 cur.close()
                 db_conn.close()
 
-            # Register the cleanup hook
-            cleanup_func(cleanup)
+            if not LEAVE_DB:
+                # Register the cleanup hook
+                cleanup_func(cleanup)
 
         hs.setup()
     else: