summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/crypto/test_keyring.py4
-rw-r--r--tests/handlers/test_typing.py7
-rw-r--r--tests/http/__init__.py17
-rw-r--r--tests/http/federation/test_matrix_federation_agent.py11
-rw-r--r--tests/http/test_proxyagent.py334
-rw-r--r--tests/push/test_http.py2
-rw-r--r--tests/replication/slave/storage/_base.py1
-rw-r--r--tests/replication/slave/storage/test_events.py10
-rw-r--r--tests/rest/admin/test_admin.py78
-rw-r--r--tests/server.py30
-rw-r--r--tests/storage/test__base.py2
-rw-r--r--tests/storage/test_devices.py16
-rw-r--r--tests/storage/test_redaction.py11
-rw-r--r--tests/storage/test_room.py3
-rw-r--r--tests/storage/test_roommember.py3
-rw-r--r--tests/storage/test_state.py153
-rw-r--r--tests/test_federation.py15
-rw-r--r--tests/test_state.py3
-rw-r--r--tests/test_visibility.py18
-rw-r--r--tests/util/caches/test_descriptors.py4
-rw-r--r--tests/utils.py10
21 files changed, 640 insertions, 92 deletions
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index c4f0bbd3dd..8efd39c7f7 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -178,7 +178,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         kr = keyring.Keyring(self.hs)
 
         key1 = signedjson.key.generate_signing_key(1)
-        r = self.hs.datastore.store_server_verify_keys(
+        r = self.hs.get_datastore().store_server_verify_keys(
             "server9",
             time.time() * 1000,
             [("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), 1000))],
@@ -209,7 +209,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         )
 
         key1 = signedjson.key.generate_signing_key(1)
-        r = self.hs.datastore.store_server_verify_keys(
+        r = self.hs.get_datastore().store_server_verify_keys(
             "server9",
             time.time() * 1000,
             [("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), None))],
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 67f1013051..5ec568f4e6 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -73,7 +73,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
                         "get_received_txn_response",
                         "set_received_txn_response",
                         "get_destination_retry_timings",
-                        "get_devices_by_remote",
+                        "get_device_updates_by_remote",
                         # Bits that user_directory needs
                         "get_user_directory_stream_pos",
                         "get_current_state_deltas",
@@ -109,7 +109,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
             retry_timings_res
         )
 
-        self.datastore.get_devices_by_remote.return_value = (0, [])
+        self.datastore.get_device_updates_by_remote.return_value = (0, [])
 
         def get_received_txn_response(*args):
             return defer.succeed(None)
@@ -144,6 +144,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
         self.datastore.get_to_device_stream_token = lambda: 0
         self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: ([], 0)
         self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None
+        self.datastore.set_received_txn_response = lambda *args, **kwargs: defer.succeed(
+            None
+        )
 
     def test_started_typing_local(self):
         self.room_members = [U_APPLE, U_BANANA]
diff --git a/tests/http/__init__.py b/tests/http/__init__.py
index 2d5dba6464..2096ba3c91 100644
--- a/tests/http/__init__.py
+++ b/tests/http/__init__.py
@@ -20,6 +20,23 @@ from zope.interface import implementer
 from OpenSSL import SSL
 from OpenSSL.SSL import Connection
 from twisted.internet.interfaces import IOpenSSLServerConnectionCreator
+from twisted.internet.ssl import Certificate, trustRootFromCertificates
+from twisted.web.client import BrowserLikePolicyForHTTPS  # noqa: F401
+from twisted.web.iweb import IPolicyForHTTPS  # noqa: F401
+
+
+def get_test_https_policy():
+    """Get a test IPolicyForHTTPS which trusts the test CA cert
+
+    Returns:
+        IPolicyForHTTPS
+    """
+    ca_file = get_test_ca_cert_file()
+    with open(ca_file) as stream:
+        content = stream.read()
+    cert = Certificate.loadPEM(content)
+    trust_root = trustRootFromCertificates([cert])
+    return BrowserLikePolicyForHTTPS(trustRoot=trust_root)
 
 
 def get_test_ca_cert_file():
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index 71d7025264..cfcd98ff7d 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -124,19 +124,24 @@ class MatrixFederationAgentTests(unittest.TestCase):
             FakeTransport(client_protocol, self.reactor, server_tls_protocol)
         )
 
+        # grab a hold of the TLS connection, in case it gets torn down
+        server_tls_connection = server_tls_protocol._tlsConnection
+
+        # fish the test server back out of the server-side TLS protocol.
+        http_protocol = server_tls_protocol.wrappedProtocol
+
         # give the reactor a pump to get the TLS juices flowing.
         self.reactor.pump((0.1,))
 
         # check the SNI
-        server_name = server_tls_protocol._tlsConnection.get_servername()
+        server_name = server_tls_connection.get_servername()
         self.assertEqual(
             server_name,
             expected_sni,
             "Expected SNI %s but got %s" % (expected_sni, server_name),
         )
 
-        # fish the test server back out of the server-side TLS protocol.
-        return server_tls_protocol.wrappedProtocol
+        return http_protocol
 
     @defer.inlineCallbacks
     def _make_get_request(self, uri):
diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py
new file mode 100644
index 0000000000..22abf76515
--- /dev/null
+++ b/tests/http/test_proxyagent.py
@@ -0,0 +1,334 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+import treq
+
+from twisted.internet import interfaces  # noqa: F401
+from twisted.internet.protocol import Factory
+from twisted.protocols.tls import TLSMemoryBIOFactory
+from twisted.web.http import HTTPChannel
+
+from synapse.http.proxyagent import ProxyAgent
+
+from tests.http import TestServerTLSConnectionFactory, get_test_https_policy
+from tests.server import FakeTransport, ThreadedMemoryReactorClock
+from tests.unittest import TestCase
+
+logger = logging.getLogger(__name__)
+
+HTTPFactory = Factory.forProtocol(HTTPChannel)
+
+
+class MatrixFederationAgentTests(TestCase):
+    def setUp(self):
+        self.reactor = ThreadedMemoryReactorClock()
+
+    def _make_connection(
+        self, client_factory, server_factory, ssl=False, expected_sni=None
+    ):
+        """Builds a test server, and completes the outgoing client connection
+
+        Args:
+            client_factory (interfaces.IProtocolFactory): the the factory that the
+                application is trying to use to make the outbound connection. We will
+                invoke it to build the client Protocol
+
+            server_factory (interfaces.IProtocolFactory): a factory to build the
+                server-side protocol
+
+            ssl (bool): If true, we will expect an ssl connection and wrap
+                server_factory with a TLSMemoryBIOFactory
+
+            expected_sni (bytes|None): the expected SNI value
+
+        Returns:
+            IProtocol: the server Protocol returned by server_factory
+        """
+        if ssl:
+            server_factory = _wrap_server_factory_for_tls(server_factory)
+
+        server_protocol = server_factory.buildProtocol(None)
+
+        # now, tell the client protocol factory to build the client protocol,
+        # and wire the output of said protocol up to the server via
+        # a FakeTransport.
+        #
+        # Normally this would be done by the TCP socket code in Twisted, but we are
+        # stubbing that out here.
+        client_protocol = client_factory.buildProtocol(None)
+        client_protocol.makeConnection(
+            FakeTransport(server_protocol, self.reactor, client_protocol)
+        )
+
+        # tell the server protocol to send its stuff back to the client, too
+        server_protocol.makeConnection(
+            FakeTransport(client_protocol, self.reactor, server_protocol)
+        )
+
+        if ssl:
+            http_protocol = server_protocol.wrappedProtocol
+            tls_connection = server_protocol._tlsConnection
+        else:
+            http_protocol = server_protocol
+            tls_connection = None
+
+        # give the reactor a pump to get the TLS juices flowing (if needed)
+        self.reactor.advance(0)
+
+        if expected_sni is not None:
+            server_name = tls_connection.get_servername()
+            self.assertEqual(
+                server_name,
+                expected_sni,
+                "Expected SNI %s but got %s" % (expected_sni, server_name),
+            )
+
+        return http_protocol
+
+    def test_http_request(self):
+        agent = ProxyAgent(self.reactor)
+
+        self.reactor.lookups["test.com"] = "1.2.3.4"
+        d = agent.request(b"GET", b"http://test.com")
+
+        # there should be a pending TCP connection
+        clients = self.reactor.tcpClients
+        self.assertEqual(len(clients), 1)
+        (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+        self.assertEqual(host, "1.2.3.4")
+        self.assertEqual(port, 80)
+
+        # make a test server, and wire up the client
+        http_server = self._make_connection(
+            client_factory, _get_test_protocol_factory()
+        )
+
+        # the FakeTransport is async, so we need to pump the reactor
+        self.reactor.advance(0)
+
+        # now there should be a pending request
+        self.assertEqual(len(http_server.requests), 1)
+
+        request = http_server.requests[0]
+        self.assertEqual(request.method, b"GET")
+        self.assertEqual(request.path, b"/")
+        self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
+        request.write(b"result")
+        request.finish()
+
+        self.reactor.advance(0)
+
+        resp = self.successResultOf(d)
+        body = self.successResultOf(treq.content(resp))
+        self.assertEqual(body, b"result")
+
+    def test_https_request(self):
+        agent = ProxyAgent(self.reactor, contextFactory=get_test_https_policy())
+
+        self.reactor.lookups["test.com"] = "1.2.3.4"
+        d = agent.request(b"GET", b"https://test.com/abc")
+
+        # there should be a pending TCP connection
+        clients = self.reactor.tcpClients
+        self.assertEqual(len(clients), 1)
+        (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+        self.assertEqual(host, "1.2.3.4")
+        self.assertEqual(port, 443)
+
+        # make a test server, and wire up the client
+        http_server = self._make_connection(
+            client_factory,
+            _get_test_protocol_factory(),
+            ssl=True,
+            expected_sni=b"test.com",
+        )
+
+        # the FakeTransport is async, so we need to pump the reactor
+        self.reactor.advance(0)
+
+        # now there should be a pending request
+        self.assertEqual(len(http_server.requests), 1)
+
+        request = http_server.requests[0]
+        self.assertEqual(request.method, b"GET")
+        self.assertEqual(request.path, b"/abc")
+        self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
+        request.write(b"result")
+        request.finish()
+
+        self.reactor.advance(0)
+
+        resp = self.successResultOf(d)
+        body = self.successResultOf(treq.content(resp))
+        self.assertEqual(body, b"result")
+
+    def test_http_request_via_proxy(self):
+        agent = ProxyAgent(self.reactor, http_proxy=b"proxy.com:8888")
+
+        self.reactor.lookups["proxy.com"] = "1.2.3.5"
+        d = agent.request(b"GET", b"http://test.com")
+
+        # there should be a pending TCP connection
+        clients = self.reactor.tcpClients
+        self.assertEqual(len(clients), 1)
+        (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+        self.assertEqual(host, "1.2.3.5")
+        self.assertEqual(port, 8888)
+
+        # make a test server, and wire up the client
+        http_server = self._make_connection(
+            client_factory, _get_test_protocol_factory()
+        )
+
+        # the FakeTransport is async, so we need to pump the reactor
+        self.reactor.advance(0)
+
+        # now there should be a pending request
+        self.assertEqual(len(http_server.requests), 1)
+
+        request = http_server.requests[0]
+        self.assertEqual(request.method, b"GET")
+        self.assertEqual(request.path, b"http://test.com")
+        self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
+        request.write(b"result")
+        request.finish()
+
+        self.reactor.advance(0)
+
+        resp = self.successResultOf(d)
+        body = self.successResultOf(treq.content(resp))
+        self.assertEqual(body, b"result")
+
+    def test_https_request_via_proxy(self):
+        agent = ProxyAgent(
+            self.reactor,
+            contextFactory=get_test_https_policy(),
+            https_proxy=b"proxy.com",
+        )
+
+        self.reactor.lookups["proxy.com"] = "1.2.3.5"
+        d = agent.request(b"GET", b"https://test.com/abc")
+
+        # there should be a pending TCP connection
+        clients = self.reactor.tcpClients
+        self.assertEqual(len(clients), 1)
+        (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+        self.assertEqual(host, "1.2.3.5")
+        self.assertEqual(port, 1080)
+
+        # make a test HTTP server, and wire up the client
+        proxy_server = self._make_connection(
+            client_factory, _get_test_protocol_factory()
+        )
+
+        # fish the transports back out so that we can do the old switcheroo
+        s2c_transport = proxy_server.transport
+        client_protocol = s2c_transport.other
+        c2s_transport = client_protocol.transport
+
+        # the FakeTransport is async, so we need to pump the reactor
+        self.reactor.advance(0)
+
+        # now there should be a pending CONNECT request
+        self.assertEqual(len(proxy_server.requests), 1)
+
+        request = proxy_server.requests[0]
+        self.assertEqual(request.method, b"CONNECT")
+        self.assertEqual(request.path, b"test.com:443")
+
+        # tell the proxy server not to close the connection
+        proxy_server.persistent = True
+
+        # this just stops the http Request trying to do a chunked response
+        # request.setHeader(b"Content-Length", b"0")
+        request.finish()
+
+        # now we can replace the proxy channel with a new, SSL-wrapped HTTP channel
+        ssl_factory = _wrap_server_factory_for_tls(_get_test_protocol_factory())
+        ssl_protocol = ssl_factory.buildProtocol(None)
+        http_server = ssl_protocol.wrappedProtocol
+
+        ssl_protocol.makeConnection(
+            FakeTransport(client_protocol, self.reactor, ssl_protocol)
+        )
+        c2s_transport.other = ssl_protocol
+
+        self.reactor.advance(0)
+
+        server_name = ssl_protocol._tlsConnection.get_servername()
+        expected_sni = b"test.com"
+        self.assertEqual(
+            server_name,
+            expected_sni,
+            "Expected SNI %s but got %s" % (expected_sni, server_name),
+        )
+
+        # now there should be a pending request
+        self.assertEqual(len(http_server.requests), 1)
+
+        request = http_server.requests[0]
+        self.assertEqual(request.method, b"GET")
+        self.assertEqual(request.path, b"/abc")
+        self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
+        request.write(b"result")
+        request.finish()
+
+        self.reactor.advance(0)
+
+        resp = self.successResultOf(d)
+        body = self.successResultOf(treq.content(resp))
+        self.assertEqual(body, b"result")
+
+
+def _wrap_server_factory_for_tls(factory, sanlist=None):
+    """Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
+
+    The resultant factory will create a TLS server which presents a certificate
+    signed by our test CA, valid for the domains in `sanlist`
+
+    Args:
+        factory (interfaces.IProtocolFactory): protocol factory to wrap
+        sanlist (iterable[bytes]): list of domains the cert should be valid for
+
+    Returns:
+        interfaces.IProtocolFactory
+    """
+    if sanlist is None:
+        sanlist = [b"DNS:test.com"]
+
+    connection_creator = TestServerTLSConnectionFactory(sanlist=sanlist)
+    return TLSMemoryBIOFactory(
+        connection_creator, isClient=False, wrappedFactory=factory
+    )
+
+
+def _get_test_protocol_factory():
+    """Get a protocol Factory which will build an HTTPChannel
+
+    Returns:
+        interfaces.IProtocolFactory
+    """
+    server_factory = Factory.forProtocol(HTTPChannel)
+
+    # Request.finish expects the factory to have a 'log' method.
+    server_factory.log = _log_request
+
+    return server_factory
+
+
+def _log_request(request):
+    """Implements Factory.log, which is expected by Request.finish"""
+    logger.info("Completed request %s", request)
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index 8ce6bb62da..af2327fb66 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -50,7 +50,7 @@ class HTTPPusherTests(HomeserverTestCase):
         config = self.default_config()
         config["start_pushers"] = True
 
-        hs = self.setup_test_homeserver(config=config, simple_http_client=m)
+        hs = self.setup_test_homeserver(config=config, proxied_http_client=m)
 
         return hs
 
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 104349cdbd..4f924ce451 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -41,6 +41,7 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
     def prepare(self, reactor, clock, hs):
 
         self.master_store = self.hs.get_datastore()
+        self.storage = hs.get_storage()
         self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs)
         self.event_id = 0
 
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index a368117b43..b68e9fe082 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -234,7 +234,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
             type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join"
         )
         msg, msgctx = self.build_event()
-        self.get_success(self.master_store.persist_events([(j2, j2ctx), (msg, msgctx)]))
+        self.get_success(
+            self.storage.persistence.persist_events([(j2, j2ctx), (msg, msgctx)])
+        )
         self.replicate()
 
         event_source = RoomEventSource(self.hs)
@@ -290,10 +292,12 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
 
         if backfill:
             self.get_success(
-                self.master_store.persist_events([(event, context)], backfilled=True)
+                self.storage.persistence.persist_events(
+                    [(event, context)], backfilled=True
+                )
             )
         else:
-            self.get_success(self.master_store.persist_event(event, context))
+            self.get_success(self.storage.persistence.persist_event(event, context))
 
         return event
 
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index d3a4f717f7..8e1ca8b738 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -561,3 +561,81 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase):
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
 
         return channel.json_body["groups"]
+
+
+class PurgeRoomTestCase(unittest.HomeserverTestCase):
+    """Test /purge_room admin API.
+    """
+
+    servlets = [
+        synapse.rest.admin.register_servlets,
+        login.register_servlets,
+        room.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, hs):
+        self.store = hs.get_datastore()
+
+        self.admin_user = self.register_user("admin", "pass", admin=True)
+        self.admin_user_tok = self.login("admin", "pass")
+
+    def test_purge_room(self):
+        room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+        # All users have to have left the room.
+        self.helper.leave(room_id, user=self.admin_user, tok=self.admin_user_tok)
+
+        url = "/_synapse/admin/v1/purge_room"
+        request, channel = self.make_request(
+            "POST",
+            url.encode("ascii"),
+            {"room_id": room_id},
+            access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+        # Test that the following tables have been purged of all rows related to the room.
+        for table in (
+            "current_state_events",
+            "event_backward_extremities",
+            "event_forward_extremities",
+            "event_json",
+            "event_push_actions",
+            "event_search",
+            "events",
+            "group_rooms",
+            "public_room_list_stream",
+            "receipts_graph",
+            "receipts_linearized",
+            "room_aliases",
+            "room_depth",
+            "room_memberships",
+            "room_stats_state",
+            "room_stats_current",
+            "room_stats_historical",
+            "room_stats_earliest_token",
+            "rooms",
+            "stream_ordering_to_exterm",
+            "users_in_public_rooms",
+            "users_who_share_private_rooms",
+            "appservice_room_list",
+            "e2e_room_keys",
+            "event_push_summary",
+            "pusher_throttle",
+            "group_summary_rooms",
+            "local_invites",
+            "room_account_data",
+            "room_tags",
+        ):
+            count = self.get_success(
+                self.store._simple_select_one_onecol(
+                    table="events",
+                    keyvalues={"room_id": room_id},
+                    retcol="COUNT(*)",
+                    desc="test_purge_room",
+                )
+            )
+
+            self.assertEqual(count, 0, msg="Rows not purged in {}".format(table))
diff --git a/tests/server.py b/tests/server.py
index e397ebe8fa..f878aeaada 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -161,7 +161,11 @@ def make_request(
         path = path.encode("ascii")
 
     # Decorate it to be the full path, if we're using shorthand
-    if shorthand and not path.startswith(b"/_matrix"):
+    if (
+        shorthand
+        and not path.startswith(b"/_matrix")
+        and not path.startswith(b"/_synapse")
+    ):
         path = b"/_matrix/client/r0/" + path
         path = path.replace(b"//", b"/")
 
@@ -391,11 +395,24 @@ class FakeTransport(object):
             self.disconnecting = True
             if self._protocol:
                 self._protocol.connectionLost(reason)
-            self.disconnected = True
+
+            # if we still have data to write, delay until that is done
+            if self.buffer:
+                logger.info(
+                    "FakeTransport: Delaying disconnect until buffer is flushed"
+                )
+            else:
+                self.disconnected = True
 
     def abortConnection(self):
         logger.info("FakeTransport: abortConnection()")
-        self.loseConnection()
+
+        if not self.disconnecting:
+            self.disconnecting = True
+            if self._protocol:
+                self._protocol.connectionLost(None)
+
+        self.disconnected = True
 
     def pauseProducing(self):
         if not self.producer:
@@ -426,6 +443,9 @@ class FakeTransport(object):
             self._reactor.callLater(0.0, _produce)
 
     def write(self, byt):
+        if self.disconnecting:
+            raise Exception("Writing to disconnecting FakeTransport")
+
         self.buffer = self.buffer + byt
 
         # always actually do the write asynchronously. Some protocols (notably the
@@ -470,6 +490,10 @@ class FakeTransport(object):
         if self.buffer and self.autoflush:
             self._reactor.callLater(0.0, self.flush)
 
+        if not self.buffer and self.disconnecting:
+            logger.info("FakeTransport: Buffer now empty, completing disconnect")
+            self.disconnected = True
+
 
 def connect_client(reactor: IReactorTCP, client_id: int) -> AccumulatingProtocol:
     """
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index dd49a14524..9b81b536f5 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -197,7 +197,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
 
         a.func.prefill(("foo",), ObservableDeferred(d))
 
-        self.assertEquals(a.func("foo"), d.result)
+        self.assertEquals(a.func("foo").result, d.result)
         self.assertEquals(callcount[0], 0)
 
     @defer.inlineCallbacks
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index 3cc18f9f1c..6f8d990959 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -72,7 +72,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
         )
 
     @defer.inlineCallbacks
-    def test_get_devices_by_remote(self):
+    def test_get_device_updates_by_remote(self):
         device_ids = ["device_id1", "device_id2"]
 
         # Add two device updates with a single stream_id
@@ -81,7 +81,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
         )
 
         # Get all device updates ever meant for this remote
-        now_stream_id, device_updates = yield self.store.get_devices_by_remote(
+        now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
             "somehost", -1, limit=100
         )
 
@@ -89,7 +89,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
         self._check_devices_in_updates(device_ids, device_updates)
 
     @defer.inlineCallbacks
-    def test_get_devices_by_remote_limited(self):
+    def test_get_device_updates_by_remote_limited(self):
         # Test breaking the update limit in 1, 101, and 1 device_id segments
 
         # first add one device
@@ -115,20 +115,20 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
         #
 
         # first we should get a single update
-        now_stream_id, device_updates = yield self.store.get_devices_by_remote(
+        now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
             "someotherhost", -1, limit=100
         )
         self._check_devices_in_updates(device_ids1, device_updates)
 
         # Then we should get an empty list back as the 101 devices broke the limit
-        now_stream_id, device_updates = yield self.store.get_devices_by_remote(
+        now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
             "someotherhost", now_stream_id, limit=100
         )
         self.assertEqual(len(device_updates), 0)
 
         # The 101 devices should've been cleared, so we should now just get one device
         # update
-        now_stream_id, device_updates = yield self.store.get_devices_by_remote(
+        now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
             "someotherhost", now_stream_id, limit=100
         )
         self._check_devices_in_updates(device_ids3, device_updates)
@@ -137,7 +137,9 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
         """Check that an specific device ids exist in a list of device update EDUs"""
         self.assertEqual(len(device_updates), len(expected_device_ids))
 
-        received_device_ids = {update["device_id"] for update in device_updates}
+        received_device_ids = {
+            update["device_id"] for edu_type, update in device_updates
+        }
         self.assertEqual(received_device_ids, set(expected_device_ids))
 
     @defer.inlineCallbacks
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index 427d3c49c5..4561c3e383 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -39,6 +39,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
 
     def prepare(self, reactor, clock, hs):
         self.store = hs.get_datastore()
+        self.storage = hs.get_storage()
         self.event_builder_factory = hs.get_event_builder_factory()
         self.event_creation_handler = hs.get_event_creation_handler()
 
@@ -73,7 +74,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
             self.event_creation_handler.create_new_client_event(builder)
         )
 
-        self.get_success(self.store.persist_event(event, context))
+        self.get_success(self.storage.persistence.persist_event(event, context))
 
         return event
 
@@ -95,7 +96,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
             self.event_creation_handler.create_new_client_event(builder)
         )
 
-        self.get_success(self.store.persist_event(event, context))
+        self.get_success(self.storage.persistence.persist_event(event, context))
 
         return event
 
@@ -116,7 +117,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
             self.event_creation_handler.create_new_client_event(builder)
         )
 
-        self.get_success(self.store.persist_event(event, context))
+        self.get_success(self.storage.persistence.persist_event(event, context))
 
         return event
 
@@ -263,7 +264,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        self.get_success(self.store.persist_event(event_1, context_1))
+        self.get_success(self.storage.persistence.persist_event(event_1, context_1))
 
         event_2, context_2 = self.get_success(
             self.event_creation_handler.create_new_client_event(
@@ -282,7 +283,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
                 )
             )
         )
-        self.get_success(self.store.persist_event(event_2, context_2))
+        self.get_success(self.storage.persistence.persist_event(event_2, context_2))
 
         # fetch one of the redactions
         fetched = self.get_success(self.store.get_event(redaction_event_id1))
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index 1bee45706f..3ddaa151fe 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -62,6 +62,7 @@ class RoomEventsStoreTestCase(unittest.TestCase):
         # Room events need the full datastore, for persist_event() and
         # get_room_state()
         self.store = hs.get_datastore()
+        self.storage = hs.get_storage()
         self.event_factory = hs.get_event_factory()
 
         self.room = RoomID.from_string("!abcde:test")
@@ -72,7 +73,7 @@ class RoomEventsStoreTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def inject_room_event(self, **kwargs):
-        yield self.store.persist_event(
+        yield self.storage.persistence.persist_event(
             self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
         )
 
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 447a3c6ffb..9ddd17f73d 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -44,6 +44,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
         # We can't test the RoomMemberStore on its own without the other event
         # storage logic
         self.store = hs.get_datastore()
+        self.storage = hs.get_storage()
         self.event_builder_factory = hs.get_event_builder_factory()
         self.event_creation_handler = hs.get_event_creation_handler()
 
@@ -70,7 +71,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
             self.event_creation_handler.create_new_client_event(builder)
         )
 
-        self.get_success(self.store.persist_event(event, context))
+        self.get_success(self.storage.persistence.persist_event(event, context))
 
         return event
 
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 5c2cf3c2db..43200654f1 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -34,6 +34,8 @@ class StateStoreTestCase(tests.unittest.TestCase):
         hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
 
         self.store = hs.get_datastore()
+        self.storage = hs.get_storage()
+        self.state_datastore = self.store
         self.event_builder_factory = hs.get_event_builder_factory()
         self.event_creation_handler = hs.get_event_creation_handler()
 
@@ -63,7 +65,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
             builder
         )
 
-        yield self.store.persist_event(event, context)
+        yield self.storage.persistence.persist_event(event, context)
 
         return event
 
@@ -82,7 +84,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
             self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
         )
 
-        state_group_map = yield self.store.get_state_groups_ids(
+        state_group_map = yield self.storage.state.get_state_groups_ids(
             self.room, [e2.event_id]
         )
         self.assertEqual(len(state_group_map), 1)
@@ -101,7 +103,9 @@ class StateStoreTestCase(tests.unittest.TestCase):
             self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
         )
 
-        state_group_map = yield self.store.get_state_groups(self.room, [e2.event_id])
+        state_group_map = yield self.storage.state.get_state_groups(
+            self.room, [e2.event_id]
+        )
         self.assertEqual(len(state_group_map), 1)
         state_list = list(state_group_map.values())[0]
 
@@ -141,7 +145,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
         )
 
         # check we get the full state as of the final event
-        state = yield self.store.get_state_for_event(e5.event_id)
+        state = yield self.storage.state.get_state_for_event(e5.event_id)
 
         self.assertIsNotNone(e4)
 
@@ -157,21 +161,21 @@ class StateStoreTestCase(tests.unittest.TestCase):
         )
 
         # check we can filter to the m.room.name event (with a '' state key)
-        state = yield self.store.get_state_for_event(
+        state = yield self.storage.state.get_state_for_event(
             e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
         )
 
         self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
 
         # check we can filter to the m.room.name event (with a wildcard None state key)
-        state = yield self.store.get_state_for_event(
+        state = yield self.storage.state.get_state_for_event(
             e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
         )
 
         self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
 
         # check we can grab the m.room.member events (with a wildcard None state key)
-        state = yield self.store.get_state_for_event(
+        state = yield self.storage.state.get_state_for_event(
             e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
         )
 
@@ -181,7 +185,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         # check we can grab a specific room member without filtering out the
         # other event types
-        state = yield self.store.get_state_for_event(
+        state = yield self.storage.state.get_state_for_event(
             e5.event_id,
             state_filter=StateFilter(
                 types={EventTypes.Member: {self.u_alice.to_string()}},
@@ -199,7 +203,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
         )
 
         # check that we can grab everything except members
-        state = yield self.store.get_state_for_event(
+        state = yield self.storage.state.get_state_for_event(
             e5.event_id,
             state_filter=StateFilter(
                 types={EventTypes.Member: set()}, include_others=True
@@ -215,13 +219,18 @@ class StateStoreTestCase(tests.unittest.TestCase):
         #######################################################
 
         room_id = self.room.to_string()
-        group_ids = yield self.store.get_state_groups_ids(room_id, [e5.event_id])
+        group_ids = yield self.storage.state.get_state_groups_ids(
+            room_id, [e5.event_id]
+        )
         group = list(group_ids.keys())[0]
 
         # test _get_state_for_group_using_cache correctly filters out members
         # with types=[]
-        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
-            self.store._state_group_cache,
+        (
+            state_dict,
+            is_all,
+        ) = yield self.state_datastore._get_state_for_group_using_cache(
+            self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
                 types={EventTypes.Member: set()}, include_others=True
@@ -237,8 +246,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
             state_dict,
         )
 
-        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
-            self.store._state_group_members_cache,
+        (
+            state_dict,
+            is_all,
+        ) = yield self.state_datastore._get_state_for_group_using_cache(
+            self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
                 types={EventTypes.Member: set()}, include_others=True
@@ -250,8 +262,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         # test _get_state_for_group_using_cache correctly filters in members
         # with wildcard types
-        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
-            self.store._state_group_cache,
+        (
+            state_dict,
+            is_all,
+        ) = yield self.state_datastore._get_state_for_group_using_cache(
+            self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
                 types={EventTypes.Member: None}, include_others=True
@@ -267,8 +282,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
             state_dict,
         )
 
-        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
-            self.store._state_group_members_cache,
+        (
+            state_dict,
+            is_all,
+        ) = yield self.state_datastore._get_state_for_group_using_cache(
+            self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
                 types={EventTypes.Member: None}, include_others=True
@@ -287,8 +305,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         # test _get_state_for_group_using_cache correctly filters in members
         # with specific types
-        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
-            self.store._state_group_cache,
+        (
+            state_dict,
+            is_all,
+        ) = yield self.state_datastore._get_state_for_group_using_cache(
+            self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
                 types={EventTypes.Member: {e5.state_key}}, include_others=True
@@ -304,8 +325,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
             state_dict,
         )
 
-        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
-            self.store._state_group_members_cache,
+        (
+            state_dict,
+            is_all,
+        ) = yield self.state_datastore._get_state_for_group_using_cache(
+            self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
                 types={EventTypes.Member: {e5.state_key}}, include_others=True
@@ -317,8 +341,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         # test _get_state_for_group_using_cache correctly filters in members
         # with specific types
-        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
-            self.store._state_group_members_cache,
+        (
+            state_dict,
+            is_all,
+        ) = yield self.state_datastore._get_state_for_group_using_cache(
+            self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
                 types={EventTypes.Member: {e5.state_key}}, include_others=False
@@ -331,9 +358,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
         #######################################################
         # deliberately remove e2 (room name) from the _state_group_cache
 
-        (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get(
-            group
-        )
+        (
+            is_all,
+            known_absent,
+            state_dict_ids,
+        ) = self.state_datastore._state_group_cache.get(group)
 
         self.assertEqual(is_all, True)
         self.assertEqual(known_absent, set())
@@ -346,18 +375,20 @@ class StateStoreTestCase(tests.unittest.TestCase):
         )
 
         state_dict_ids.pop((e2.type, e2.state_key))
-        self.store._state_group_cache.invalidate(group)
-        self.store._state_group_cache.update(
-            sequence=self.store._state_group_cache.sequence,
+        self.state_datastore._state_group_cache.invalidate(group)
+        self.state_datastore._state_group_cache.update(
+            sequence=self.state_datastore._state_group_cache.sequence,
             key=group,
             value=state_dict_ids,
             # list fetched keys so it knows it's partial
             fetched_keys=((e1.type, e1.state_key),),
         )
 
-        (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get(
-            group
-        )
+        (
+            is_all,
+            known_absent,
+            state_dict_ids,
+        ) = self.state_datastore._state_group_cache.get(group)
 
         self.assertEqual(is_all, False)
         self.assertEqual(known_absent, set([(e1.type, e1.state_key)]))
@@ -369,8 +400,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
         # test _get_state_for_group_using_cache correctly filters out members
         # with types=[]
         room_id = self.room.to_string()
-        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
-            self.store._state_group_cache,
+        (
+            state_dict,
+            is_all,
+        ) = yield self.state_datastore._get_state_for_group_using_cache(
+            self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
                 types={EventTypes.Member: set()}, include_others=True
@@ -381,8 +415,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
         self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
 
         room_id = self.room.to_string()
-        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
-            self.store._state_group_members_cache,
+        (
+            state_dict,
+            is_all,
+        ) = yield self.state_datastore._get_state_for_group_using_cache(
+            self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
                 types={EventTypes.Member: set()}, include_others=True
@@ -394,8 +431,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         # test _get_state_for_group_using_cache correctly filters in members
         # wildcard types
-        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
-            self.store._state_group_cache,
+        (
+            state_dict,
+            is_all,
+        ) = yield self.state_datastore._get_state_for_group_using_cache(
+            self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
                 types={EventTypes.Member: None}, include_others=True
@@ -405,8 +445,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
         self.assertEqual(is_all, False)
         self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
 
-        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
-            self.store._state_group_members_cache,
+        (
+            state_dict,
+            is_all,
+        ) = yield self.state_datastore._get_state_for_group_using_cache(
+            self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
                 types={EventTypes.Member: None}, include_others=True
@@ -424,8 +467,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         # test _get_state_for_group_using_cache correctly filters in members
         # with specific types
-        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
-            self.store._state_group_cache,
+        (
+            state_dict,
+            is_all,
+        ) = yield self.state_datastore._get_state_for_group_using_cache(
+            self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
                 types={EventTypes.Member: {e5.state_key}}, include_others=True
@@ -435,8 +481,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
         self.assertEqual(is_all, False)
         self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
 
-        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
-            self.store._state_group_members_cache,
+        (
+            state_dict,
+            is_all,
+        ) = yield self.state_datastore._get_state_for_group_using_cache(
+            self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
                 types={EventTypes.Member: {e5.state_key}}, include_others=True
@@ -448,8 +497,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         # test _get_state_for_group_using_cache correctly filters in members
         # with specific types
-        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
-            self.store._state_group_cache,
+        (
+            state_dict,
+            is_all,
+        ) = yield self.state_datastore._get_state_for_group_using_cache(
+            self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
                 types={EventTypes.Member: {e5.state_key}}, include_others=False
@@ -459,8 +511,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
         self.assertEqual(is_all, False)
         self.assertDictEqual({}, state_dict)
 
-        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
-            self.store._state_group_members_cache,
+        (
+            state_dict,
+            is_all,
+        ) = yield self.state_datastore._get_state_for_group_using_cache(
+            self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
                 types={EventTypes.Member: {e5.state_key}}, include_others=False
diff --git a/tests/test_federation.py b/tests/test_federation.py
index a73f18f88e..7d82b58466 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -36,7 +36,8 @@ class MessageAcceptTests(unittest.TestCase):
         # Figure out what the most recent event is
         most_recent = self.successResultOf(
             maybeDeferred(
-                self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id
+                self.homeserver.get_datastore().get_latest_event_ids_in_room,
+                self.room_id,
             )
         )[0]
 
@@ -58,7 +59,9 @@ class MessageAcceptTests(unittest.TestCase):
         )
 
         self.handler = self.homeserver.get_handlers().federation_handler
-        self.handler.do_auth = lambda *a, **b: succeed(True)
+        self.handler.do_auth = lambda origin, event, context, auth_events: succeed(
+            context
+        )
         self.client = self.homeserver.get_federation_client()
         self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed(
             pdus
@@ -75,7 +78,8 @@ class MessageAcceptTests(unittest.TestCase):
         self.assertEqual(
             self.successResultOf(
                 maybeDeferred(
-                    self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id
+                    self.homeserver.get_datastore().get_latest_event_ids_in_room,
+                    self.room_id,
                 )
             )[0],
             "$join:test.serv",
@@ -97,7 +101,8 @@ class MessageAcceptTests(unittest.TestCase):
         # Figure out what the most recent event is
         most_recent = self.successResultOf(
             maybeDeferred(
-                self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id
+                self.homeserver.get_datastore().get_latest_event_ids_in_room,
+                self.room_id,
             )
         )[0]
 
@@ -137,6 +142,6 @@ class MessageAcceptTests(unittest.TestCase):
 
         # Make sure the invalid event isn't there
         extrem = maybeDeferred(
-            self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id
+            self.homeserver.get_datastore().get_latest_event_ids_in_room, self.room_id
         )
         self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv")
diff --git a/tests/test_state.py b/tests/test_state.py
index 610ec9fb46..38246555bd 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -158,10 +158,12 @@ class Graph(object):
 class StateTestCase(unittest.TestCase):
     def setUp(self):
         self.store = StateGroupStore()
+        storage = Mock(main=self.store, state=self.store)
         hs = Mock(
             spec_set=[
                 "config",
                 "get_datastore",
+                "get_storage",
                 "get_auth",
                 "get_state_handler",
                 "get_clock",
@@ -174,6 +176,7 @@ class StateTestCase(unittest.TestCase):
         hs.get_clock.return_value = MockClock()
         hs.get_auth.return_value = Auth(hs)
         hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs)
+        hs.get_storage.return_value = storage
 
         self.state = StateHandler(hs)
         self.event_id = 0
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index 18f1a0035d..f7381b2885 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -14,6 +14,8 @@
 # limitations under the License.
 import logging
 
+from mock import Mock
+
 from twisted.internet import defer
 from twisted.internet.defer import succeed
 
@@ -36,6 +38,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
         self.event_creation_handler = self.hs.get_event_creation_handler()
         self.event_builder_factory = self.hs.get_event_builder_factory()
         self.store = self.hs.get_datastore()
+        self.storage = self.hs.get_storage()
 
         yield create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM")
 
@@ -62,7 +65,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
             events_to_filter.append(evt)
 
         filtered = yield filter_events_for_server(
-            self.store, "test_server", events_to_filter
+            self.storage, "test_server", events_to_filter
         )
 
         # the result should be 5 redacted events, and 5 unredacted events.
@@ -100,7 +103,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
 
         # ... and the filtering happens.
         filtered = yield filter_events_for_server(
-            self.store, "test_server", events_to_filter
+            self.storage, "test_server", events_to_filter
         )
 
         for i in range(0, len(events_to_filter)):
@@ -137,7 +140,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
         event, context = yield self.event_creation_handler.create_new_client_event(
             builder
         )
-        yield self.hs.get_datastore().persist_event(event, context)
+        yield self.storage.persistence.persist_event(event, context)
         return event
 
     @defer.inlineCallbacks
@@ -159,7 +162,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
             builder
         )
 
-        yield self.hs.get_datastore().persist_event(event, context)
+        yield self.storage.persistence.persist_event(event, context)
         return event
 
     @defer.inlineCallbacks
@@ -180,7 +183,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
             builder
         )
 
-        yield self.hs.get_datastore().persist_event(event, context)
+        yield self.storage.persistence.persist_event(event, context)
         return event
 
     @defer.inlineCallbacks
@@ -257,6 +260,11 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
 
         logger.info("Starting filtering")
         start = time.time()
+
+        storage = Mock()
+        storage.main = test_store
+        storage.state = test_store
+
         filtered = yield filter_events_for_server(
             test_store, "test_server", events_to_filter
         )
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index f907903511..39e360fe24 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -310,14 +310,14 @@ class DescriptorTestCase(unittest.TestCase):
 
         obj.mock.return_value = ["spam", "eggs"]
         r = obj.fn(1, 2)
-        self.assertEqual(r, ["spam", "eggs"])
+        self.assertEqual(r.result, ["spam", "eggs"])
         obj.mock.assert_called_once_with(1, 2)
         obj.mock.reset_mock()
 
         # a call with different params should call the mock again
         obj.mock.return_value = ["chips"]
         r = obj.fn(1, 3)
-        self.assertEqual(r, ["chips"])
+        self.assertEqual(r.result, ["chips"])
         obj.mock.assert_called_once_with(1, 3)
         obj.mock.reset_mock()
 
diff --git a/tests/utils.py b/tests/utils.py
index 8cced4b7e8..7dc9bdc505 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -325,10 +325,16 @@ def setup_test_homeserver(
         if homeserverToUse.__name__ == "TestHomeServer":
             hs.setup_master()
     else:
+        # If we have been given an explicit datastore we probably want to mock
+        # out the DataStores somehow too. This all feels a bit wrong, but then
+        # mocking the stores feels wrong too.
+        datastores = Mock(datastore=datastore)
+
         hs = homeserverToUse(
             name,
             db_pool=None,
             datastore=datastore,
+            datastores=datastores,
             config=config,
             version_string="Synapse/tests",
             database_engine=db_engine,
@@ -646,7 +652,7 @@ def create_room(hs, room_id, creator_id):
         creator_id (str)
     """
 
-    store = hs.get_datastore()
+    persistence_store = hs.get_storage().persistence
     event_builder_factory = hs.get_event_builder_factory()
     event_creation_handler = hs.get_event_creation_handler()
 
@@ -663,4 +669,4 @@ def create_room(hs, room_id, creator_id):
 
     event, context = yield event_creation_handler.create_new_client_event(builder)
 
-    yield store.persist_event(event, context)
+    yield persistence_store.persist_event(event, context)