summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/api/test_auth.py12
-rw-r--r--tests/api/test_filtering.py5
-rw-r--r--tests/crypto/test_keyring.py6
-rw-r--r--tests/handlers/test_admin.py2
-rw-r--r--tests/handlers/test_auth.py11
-rw-r--r--tests/handlers/test_device.py82
-rw-r--r--tests/handlers/test_directory.py10
-rw-r--r--tests/handlers/test_e2e_keys.py87
-rw-r--r--tests/handlers/test_e2e_room_keys.py2
-rw-r--r--tests/handlers/test_federation.py2
-rw-r--r--tests/handlers/test_oidc.py10
-rw-r--r--tests/handlers/test_presence.py2
-rw-r--r--tests/handlers/test_profile.py7
-rw-r--r--tests/handlers/test_register.py24
-rw-r--r--tests/module_api/test_api.py153
-rw-r--r--tests/replication/_base.py224
-rw-r--r--tests/replication/test_federation_sender_shard.py2
-rw-r--r--tests/replication/test_sharded_event_persister.py102
-rw-r--r--tests/rest/client/test_shadow_banned.py2
-rw-r--r--tests/rest/client/test_third_party_rules.py170
-rw-r--r--tests/rest/client/third_party_rules.py79
-rw-r--r--tests/rest/client/v1/test_directory.py11
-rw-r--r--tests/rest/client/v1/test_events.py2
-rw-r--r--tests/rest/client/v1/test_rooms.py6
-rw-r--r--tests/rest/client/v1/test_typing.py2
-rw-r--r--tests/server.py42
-rw-r--r--tests/storage/test_appservice.py14
-rw-r--r--tests/storage/test_id_generators.py25
-rw-r--r--tests/test_federation.py2
-rw-r--r--tests/test_phone_home.py2
-rw-r--r--tests/unittest.py6
-rw-r--r--tests/utils.py124
32 files changed, 970 insertions, 260 deletions
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 8ab56ec94c..cb6f29d670 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -19,7 +19,6 @@ import pymacaroons
 
 from twisted.internet import defer
 
-import synapse.handlers.auth
 from synapse.api.auth import Auth
 from synapse.api.constants import UserTypes
 from synapse.api.errors import (
@@ -36,20 +35,15 @@ from tests import unittest
 from tests.utils import mock_getRawHeaders, setup_test_homeserver
 
 
-class TestHandlers:
-    def __init__(self, hs):
-        self.auth_handler = synapse.handlers.auth.AuthHandler(hs)
-
-
 class AuthTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def setUp(self):
         self.state_handler = Mock()
         self.store = Mock()
 
-        self.hs = yield setup_test_homeserver(self.addCleanup, handlers=None)
+        self.hs = yield setup_test_homeserver(self.addCleanup)
         self.hs.get_datastore = Mock(return_value=self.store)
-        self.hs.handlers = TestHandlers(self.hs)
+        self.hs.get_auth_handler().store = self.store
         self.auth = Auth(self.hs)
 
         # AuthBlocking reads from the hs' config on initialization. We need to
@@ -283,7 +277,7 @@ class AuthTestCase(unittest.TestCase):
         self.store.get_device = Mock(return_value=defer.succeed(None))
 
         token = yield defer.ensureDeferred(
-            self.hs.handlers.auth_handler.get_access_token_for_user_id(
+            self.hs.get_auth_handler().get_access_token_for_user_id(
                 USER_ID, "DEVICE", valid_until_ms=None
             )
         )
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index d2d535d23c..c98ae75974 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -50,10 +50,7 @@ class FilteringTestCase(unittest.TestCase):
         self.mock_http_client.put_json = DeferredMockCallable()
 
         hs = yield setup_test_homeserver(
-            self.addCleanup,
-            handlers=None,
-            http_client=self.mock_http_client,
-            keyring=Mock(),
+            self.addCleanup, http_client=self.mock_http_client, keyring=Mock(),
         )
 
         self.filtering = hs.get_filtering()
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 8ff1460c0d..697916a019 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -315,7 +315,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
 class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
     def make_homeserver(self, reactor, clock):
         self.http_client = Mock()
-        hs = self.setup_test_homeserver(handlers=None, http_client=self.http_client)
+        hs = self.setup_test_homeserver(http_client=self.http_client)
         return hs
 
     def test_get_keys_from_server(self):
@@ -395,9 +395,7 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
             }
         ]
 
-        return self.setup_test_homeserver(
-            handlers=None, http_client=self.http_client, config=config
-        )
+        return self.setup_test_homeserver(http_client=self.http_client, config=config)
 
     def build_perspectives_response(
         self, server_name: str, signing_key: SigningKey, valid_until_ts: int,
diff --git a/tests/handlers/test_admin.py b/tests/handlers/test_admin.py
index fc37c4328c..5c2b4de1a6 100644
--- a/tests/handlers/test_admin.py
+++ b/tests/handlers/test_admin.py
@@ -35,7 +35,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor, clock, hs):
-        self.admin_handler = hs.get_handlers().admin_handler
+        self.admin_handler = hs.get_admin_handler()
 
         self.user1 = self.register_user("user1", "password")
         self.token1 = self.login("user1", "password")
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index 97877c2e42..b5055e018c 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -21,24 +21,17 @@ from twisted.internet import defer
 import synapse
 import synapse.api.errors
 from synapse.api.errors import ResourceLimitError
-from synapse.handlers.auth import AuthHandler
 
 from tests import unittest
 from tests.test_utils import make_awaitable
 from tests.utils import setup_test_homeserver
 
 
-class AuthHandlers:
-    def __init__(self, hs):
-        self.auth_handler = AuthHandler(hs)
-
-
 class AuthTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def setUp(self):
-        self.hs = yield setup_test_homeserver(self.addCleanup, handlers=None)
-        self.hs.handlers = AuthHandlers(self.hs)
-        self.auth_handler = self.hs.handlers.auth_handler
+        self.hs = yield setup_test_homeserver(self.addCleanup)
+        self.auth_handler = self.hs.get_auth_handler()
         self.macaroon_generator = self.hs.get_macaroon_generator()
 
         # MAU tests
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 969d44c787..4512c51311 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -1,6 +1,7 @@
 # -*- coding: utf-8 -*-
 # Copyright 2016 OpenMarket Ltd
 # Copyright 2018 New Vector Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -224,3 +225,84 @@ class DeviceTestCase(unittest.HomeserverTestCase):
                 )
             )
             self.reactor.advance(1000)
+
+
+class DehydrationTestCase(unittest.HomeserverTestCase):
+    def make_homeserver(self, reactor, clock):
+        hs = self.setup_test_homeserver("server", http_client=None)
+        self.handler = hs.get_device_handler()
+        self.registration = hs.get_registration_handler()
+        self.auth = hs.get_auth()
+        self.store = hs.get_datastore()
+        return hs
+
+    def test_dehydrate_and_rehydrate_device(self):
+        user_id = "@boris:dehydration"
+
+        self.get_success(self.store.register_user(user_id, "foobar"))
+
+        # First check if we can store and fetch a dehydrated device
+        stored_dehydrated_device_id = self.get_success(
+            self.handler.store_dehydrated_device(
+                user_id=user_id,
+                device_data={"device_data": {"foo": "bar"}},
+                initial_device_display_name="dehydrated device",
+            )
+        )
+
+        retrieved_device_id, device_data = self.get_success(
+            self.handler.get_dehydrated_device(user_id=user_id)
+        )
+
+        self.assertEqual(retrieved_device_id, stored_dehydrated_device_id)
+        self.assertEqual(device_data, {"device_data": {"foo": "bar"}})
+
+        # Create a new login for the user and dehydrated the device
+        device_id, access_token = self.get_success(
+            self.registration.register_device(
+                user_id=user_id, device_id=None, initial_display_name="new device",
+            )
+        )
+
+        # Trying to claim a nonexistent device should throw an error
+        self.get_failure(
+            self.handler.rehydrate_device(
+                user_id=user_id,
+                access_token=access_token,
+                device_id="not the right device ID",
+            ),
+            synapse.api.errors.NotFoundError,
+        )
+
+        # dehydrating the right devices should succeed and change our device ID
+        # to the dehydrated device's ID
+        res = self.get_success(
+            self.handler.rehydrate_device(
+                user_id=user_id,
+                access_token=access_token,
+                device_id=retrieved_device_id,
+            )
+        )
+
+        self.assertEqual(res, {"success": True})
+
+        # make sure that our device ID has changed
+        user_info = self.get_success(self.auth.get_user_by_access_token(access_token))
+
+        self.assertEqual(user_info["device_id"], retrieved_device_id)
+
+        # make sure the device has the display name that was set from the login
+        res = self.get_success(self.handler.get_device(user_id, retrieved_device_id))
+
+        self.assertEqual(res["display_name"], "new device")
+
+        # make sure that the device ID that we were initially assigned no longer exists
+        self.get_failure(
+            self.handler.get_device(user_id, device_id),
+            synapse.api.errors.NotFoundError,
+        )
+
+        # make sure that there's no device available for dehydrating now
+        ret = self.get_success(self.handler.get_dehydrated_device(user_id=user_id))
+
+        self.assertIsNone(ret)
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index bc0c5aefdc..2ce6dc9528 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -48,7 +48,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
             federation_registry=self.mock_registry,
         )
 
-        self.handler = hs.get_handlers().directory_handler
+        self.handler = hs.get_directory_handler()
 
         self.store = hs.get_datastore()
 
@@ -110,7 +110,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
     ]
 
     def prepare(self, reactor, clock, hs):
-        self.handler = hs.get_handlers().directory_handler
+        self.handler = hs.get_directory_handler()
 
         # Create user
         self.admin_user = self.register_user("admin", "pass", admin=True)
@@ -173,7 +173,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
 
     def prepare(self, reactor, clock, hs):
         self.store = hs.get_datastore()
-        self.handler = hs.get_handlers().directory_handler
+        self.handler = hs.get_directory_handler()
         self.state_handler = hs.get_state_handler()
 
         # Create user
@@ -289,7 +289,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
 
     def prepare(self, reactor, clock, hs):
         self.store = hs.get_datastore()
-        self.handler = hs.get_handlers().directory_handler
+        self.handler = hs.get_directory_handler()
         self.state_handler = hs.get_state_handler()
 
         # Create user
@@ -442,7 +442,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
         self.assertEquals(200, channel.code, channel.result)
 
         self.room_list_handler = hs.get_room_list_handler()
-        self.directory_handler = hs.get_handlers().directory_handler
+        self.directory_handler = hs.get_directory_handler()
 
         return hs
 
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 366dcfb670..924f29f051 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -33,13 +33,15 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         super().__init__(*args, **kwargs)
         self.hs = None  # type: synapse.server.HomeServer
         self.handler = None  # type: synapse.handlers.e2e_keys.E2eKeysHandler
+        self.store = None  # type: synapse.storage.Storage
 
     @defer.inlineCallbacks
     def setUp(self):
         self.hs = yield utils.setup_test_homeserver(
-            self.addCleanup, handlers=None, federation_client=mock.Mock()
+            self.addCleanup, federation_client=mock.Mock()
         )
         self.handler = synapse.handlers.e2e_keys.E2eKeysHandler(self.hs)
+        self.store = self.hs.get_datastore()
 
     @defer.inlineCallbacks
     def test_query_local_devices_no_devices(self):
@@ -172,6 +174,89 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         )
 
     @defer.inlineCallbacks
+    def test_fallback_key(self):
+        local_user = "@boris:" + self.hs.hostname
+        device_id = "xyz"
+        fallback_key = {"alg1:k1": "key1"}
+        otk = {"alg1:k2": "key2"}
+
+        # we shouldn't have any unused fallback keys yet
+        res = yield defer.ensureDeferred(
+            self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
+        )
+        self.assertEqual(res, [])
+
+        yield defer.ensureDeferred(
+            self.handler.upload_keys_for_user(
+                local_user,
+                device_id,
+                {"org.matrix.msc2732.fallback_keys": fallback_key},
+            )
+        )
+
+        # we should now have an unused alg1 key
+        res = yield defer.ensureDeferred(
+            self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
+        )
+        self.assertEqual(res, ["alg1"])
+
+        # claiming an OTK when no OTKs are available should return the fallback
+        # key
+        res = yield defer.ensureDeferred(
+            self.handler.claim_one_time_keys(
+                {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+            )
+        )
+        self.assertEqual(
+            res,
+            {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
+        )
+
+        # we shouldn't have any unused fallback keys again
+        res = yield defer.ensureDeferred(
+            self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
+        )
+        self.assertEqual(res, [])
+
+        # claiming an OTK again should return the same fallback key
+        res = yield defer.ensureDeferred(
+            self.handler.claim_one_time_keys(
+                {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+            )
+        )
+        self.assertEqual(
+            res,
+            {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
+        )
+
+        # if the user uploads a one-time key, the next claim should fetch the
+        # one-time key, and then go back to the fallback
+        yield defer.ensureDeferred(
+            self.handler.upload_keys_for_user(
+                local_user, device_id, {"one_time_keys": otk}
+            )
+        )
+
+        res = yield defer.ensureDeferred(
+            self.handler.claim_one_time_keys(
+                {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+            )
+        )
+        self.assertEqual(
+            res, {"failures": {}, "one_time_keys": {local_user: {device_id: otk}}},
+        )
+
+        res = yield defer.ensureDeferred(
+            self.handler.claim_one_time_keys(
+                {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+            )
+        )
+        self.assertEqual(
+            res,
+            {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
+        )
+
+    @defer.inlineCallbacks
     def test_replace_master_key(self):
         """uploading a new signing key should make the old signing key unavailable"""
         local_user = "@boris:" + self.hs.hostname
diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py
index 7adde9b9de..45f201a399 100644
--- a/tests/handlers/test_e2e_room_keys.py
+++ b/tests/handlers/test_e2e_room_keys.py
@@ -54,7 +54,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def setUp(self):
         self.hs = yield utils.setup_test_homeserver(
-            self.addCleanup, handlers=None, replication_layer=mock.Mock()
+            self.addCleanup, replication_layer=mock.Mock()
         )
         self.handler = synapse.handlers.e2e_room_keys.E2eRoomKeysHandler(self.hs)
         self.local_user = "@boris:" + self.hs.hostname
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index 96fea58673..9ef80fe502 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -38,7 +38,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
 
     def make_homeserver(self, reactor, clock):
         hs = self.setup_test_homeserver(http_client=None)
-        self.handler = hs.get_handlers().federation_handler
+        self.handler = hs.get_federation_handler()
         self.store = hs.get_datastore()
         return hs
 
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index d5087e58be..b6f436c016 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -286,9 +286,15 @@ class OidcHandlerTestCase(HomeserverTestCase):
                 h._validate_metadata,
             )
 
-        # Tests for configs that the userinfo endpoint
+        # Tests for configs that require the userinfo endpoint
         self.assertFalse(h._uses_userinfo)
-        h._scopes = []  # do not request the openid scope
+        self.assertEqual(h._user_profile_method, "auto")
+        h._user_profile_method = "userinfo_endpoint"
+        self.assertTrue(h._uses_userinfo)
+
+        # Revert the profile method and do not request the "openid" scope.
+        h._user_profile_method = "auto"
+        h._scopes = []
         self.assertTrue(h._uses_userinfo)
         self.assertRaisesRegex(ValueError, "userinfo_endpoint", h._validate_metadata)
 
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 306dcfe944..914c82e7a8 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -470,7 +470,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
     def prepare(self, reactor, clock, hs):
         self.federation_sender = hs.get_federation_sender()
         self.event_builder_factory = hs.get_event_builder_factory()
-        self.federation_handler = hs.get_handlers().federation_handler
+        self.federation_handler = hs.get_federation_handler()
         self.presence_handler = hs.get_presence_handler()
 
         # self.event_builder_for_2 = EventBuilderFactory(hs)
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 8e95e53d9e..a69fa28b41 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -20,7 +20,6 @@ from twisted.internet import defer
 
 import synapse.types
 from synapse.api.errors import AuthError, SynapseError
-from synapse.handlers.profile import MasterProfileHandler
 from synapse.types import UserID
 
 from tests import unittest
@@ -28,11 +27,6 @@ from tests.test_utils import make_awaitable
 from tests.utils import setup_test_homeserver
 
 
-class ProfileHandlers:
-    def __init__(self, hs):
-        self.profile_handler = MasterProfileHandler(hs)
-
-
 class ProfileTestCase(unittest.TestCase):
     """ Tests profile management. """
 
@@ -51,7 +45,6 @@ class ProfileTestCase(unittest.TestCase):
         hs = yield setup_test_homeserver(
             self.addCleanup,
             http_client=None,
-            handlers=None,
             resource_for_federation=Mock(),
             federation_client=self.mock_federation,
             federation_server=Mock(),
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index cb7c0ed51a..bdf3d0a8a2 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -18,7 +18,6 @@ from mock import Mock
 from synapse.api.auth import Auth
 from synapse.api.constants import UserTypes
 from synapse.api.errors import Codes, ResourceLimitError, SynapseError
-from synapse.handlers.register import RegistrationHandler
 from synapse.spam_checker_api import RegistrationBehaviour
 from synapse.types import RoomAlias, UserID, create_requester
 
@@ -29,11 +28,6 @@ from tests.utils import mock_getRawHeaders
 from .. import unittest
 
 
-class RegistrationHandlers:
-    def __init__(self, hs):
-        self.registration_handler = RegistrationHandler(hs)
-
-
 class RegistrationTestCase(unittest.HomeserverTestCase):
     """ Tests the RegistrationHandler. """
 
@@ -154,7 +148,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
         room_alias_str = "#room:test"
         user_id = self.get_success(self.handler.register_user(localpart="jeff"))
         rooms = self.get_success(self.store.get_rooms_for_user(user_id))
-        directory_handler = self.hs.get_handlers().directory_handler
+        directory_handler = self.hs.get_directory_handler()
         room_alias = RoomAlias.from_string(room_alias_str)
         room_id = self.get_success(directory_handler.get_association(room_alias))
 
@@ -193,7 +187,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
         user_id = self.get_success(self.handler.register_user(localpart="support"))
         rooms = self.get_success(self.store.get_rooms_for_user(user_id))
         self.assertEqual(len(rooms), 0)
-        directory_handler = self.hs.get_handlers().directory_handler
+        directory_handler = self.hs.get_directory_handler()
         room_alias = RoomAlias.from_string(room_alias_str)
         self.get_failure(directory_handler.get_association(room_alias), SynapseError)
 
@@ -205,7 +199,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
         self.store.is_real_user = Mock(return_value=make_awaitable(True))
         user_id = self.get_success(self.handler.register_user(localpart="real"))
         rooms = self.get_success(self.store.get_rooms_for_user(user_id))
-        directory_handler = self.hs.get_handlers().directory_handler
+        directory_handler = self.hs.get_directory_handler()
         room_alias = RoomAlias.from_string(room_alias_str)
         room_id = self.get_success(directory_handler.get_association(room_alias))
 
@@ -237,7 +231,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
         user_id = self.get_success(self.handler.register_user(localpart="jeff"))
 
         # Ensure the room was created.
-        directory_handler = self.hs.get_handlers().directory_handler
+        directory_handler = self.hs.get_directory_handler()
         room_alias = RoomAlias.from_string(room_alias_str)
         room_id = self.get_success(directory_handler.get_association(room_alias))
 
@@ -266,7 +260,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
         user_id = self.get_success(self.handler.register_user(localpart="jeff"))
 
         # Ensure the room was created.
-        directory_handler = self.hs.get_handlers().directory_handler
+        directory_handler = self.hs.get_directory_handler()
         room_alias = RoomAlias.from_string(room_alias_str)
         room_id = self.get_success(directory_handler.get_association(room_alias))
 
@@ -304,7 +298,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
         user_id = self.get_success(self.handler.register_user(localpart="jeff"))
 
         # Ensure the room was created.
-        directory_handler = self.hs.get_handlers().directory_handler
+        directory_handler = self.hs.get_directory_handler()
         room_alias = RoomAlias.from_string(room_alias_str)
         room_id = self.get_success(directory_handler.get_association(room_alias))
 
@@ -347,7 +341,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
         )
 
         # Ensure the room was created.
-        directory_handler = self.hs.get_handlers().directory_handler
+        directory_handler = self.hs.get_directory_handler()
         room_alias = RoomAlias.from_string(room_alias_str)
         room_id = self.get_success(directory_handler.get_association(room_alias))
 
@@ -384,7 +378,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
         user_id = self.get_success(self.handler.register_user(localpart="jeff"))
 
         # Ensure the room was created.
-        directory_handler = self.hs.get_handlers().directory_handler
+        directory_handler = self.hs.get_directory_handler()
         room_alias = RoomAlias.from_string(room_alias_str)
         room_id = self.get_success(directory_handler.get_association(room_alias))
 
@@ -413,7 +407,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
             )
         )
         self.get_success(
-            event_creation_handler.send_nonmember_event(requester, event, context)
+            event_creation_handler.handle_new_client_event(requester, event, context)
         )
 
         # Register a second user, which won't be be in the room (or even have an invite)
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 04de0b9dbe..9b573ac24d 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -12,16 +12,27 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from mock import Mock
 
-from synapse.module_api import ModuleApi
+from synapse.events import EventBase
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+from synapse.types import create_requester
 
 from tests.unittest import HomeserverTestCase
 
 
 class ModuleApiTestCase(HomeserverTestCase):
+    servlets = [
+        admin.register_servlets,
+        login.register_servlets,
+        room.register_servlets,
+    ]
+
     def prepare(self, reactor, clock, homeserver):
         self.store = homeserver.get_datastore()
-        self.module_api = ModuleApi(homeserver, homeserver.get_auth_handler())
+        self.module_api = homeserver.get_module_api()
+        self.event_creation_handler = homeserver.get_event_creation_handler()
 
     def test_can_register_user(self):
         """Tests that an external module can register a user"""
@@ -52,3 +63,141 @@ class ModuleApiTestCase(HomeserverTestCase):
         # Check that the displayname was assigned
         displayname = self.get_success(self.store.get_profile_displayname("bob"))
         self.assertEqual(displayname, "Bobberino")
+
+    def test_sending_events_into_room(self):
+        """Tests that a module can send events into a room"""
+        # Mock out create_and_send_nonmember_event to check whether events are being sent
+        self.event_creation_handler.create_and_send_nonmember_event = Mock(
+            spec=[],
+            side_effect=self.event_creation_handler.create_and_send_nonmember_event,
+        )
+
+        # Create a user and room to play with
+        user_id = self.register_user("summer", "monkey")
+        tok = self.login("summer", "monkey")
+        room_id = self.helper.create_room_as(user_id, tok=tok)
+
+        # Create and send a non-state event
+        content = {"body": "I am a puppet", "msgtype": "m.text"}
+        event_dict = {
+            "room_id": room_id,
+            "type": "m.room.message",
+            "content": content,
+            "sender": user_id,
+        }
+        event = self.get_success(
+            self.module_api.create_and_send_event_into_room(event_dict)
+        )  # type: EventBase
+        self.assertEqual(event.sender, user_id)
+        self.assertEqual(event.type, "m.room.message")
+        self.assertEqual(event.room_id, room_id)
+        self.assertFalse(hasattr(event, "state_key"))
+        self.assertDictEqual(event.content, content)
+
+        # Check that the event was sent
+        self.event_creation_handler.create_and_send_nonmember_event.assert_called_with(
+            create_requester(user_id),
+            event_dict,
+            ratelimit=False,
+            ignore_shadow_ban=True,
+        )
+
+        # Create and send a state event
+        content = {
+            "events_default": 0,
+            "users": {user_id: 100},
+            "state_default": 50,
+            "users_default": 0,
+            "events": {"test.event.type": 25},
+        }
+        event_dict = {
+            "room_id": room_id,
+            "type": "m.room.power_levels",
+            "content": content,
+            "sender": user_id,
+            "state_key": "",
+        }
+        event = self.get_success(
+            self.module_api.create_and_send_event_into_room(event_dict)
+        )  # type: EventBase
+        self.assertEqual(event.sender, user_id)
+        self.assertEqual(event.type, "m.room.power_levels")
+        self.assertEqual(event.room_id, room_id)
+        self.assertEqual(event.state_key, "")
+        self.assertDictEqual(event.content, content)
+
+        # Check that the event was sent
+        self.event_creation_handler.create_and_send_nonmember_event.assert_called_with(
+            create_requester(user_id),
+            {
+                "type": "m.room.power_levels",
+                "content": content,
+                "room_id": room_id,
+                "sender": user_id,
+                "state_key": "",
+            },
+            ratelimit=False,
+            ignore_shadow_ban=True,
+        )
+
+        # Check that we can't send membership events
+        content = {
+            "membership": "leave",
+        }
+        event_dict = {
+            "room_id": room_id,
+            "type": "m.room.member",
+            "content": content,
+            "sender": user_id,
+            "state_key": user_id,
+        }
+        self.get_failure(
+            self.module_api.create_and_send_event_into_room(event_dict), Exception
+        )
+
+    def test_public_rooms(self):
+        """Tests that a room can be added and removed from the public rooms list,
+        as well as have its public rooms directory state queried.
+        """
+        # Create a user and room to play with
+        user_id = self.register_user("kermit", "monkey")
+        tok = self.login("kermit", "monkey")
+        room_id = self.helper.create_room_as(user_id, tok=tok)
+
+        # The room should not currently be in the public rooms directory
+        is_in_public_rooms = self.get_success(
+            self.module_api.public_room_list_manager.room_is_in_public_room_list(
+                room_id
+            )
+        )
+        self.assertFalse(is_in_public_rooms)
+
+        # Let's try adding it to the public rooms directory
+        self.get_success(
+            self.module_api.public_room_list_manager.add_room_to_public_room_list(
+                room_id
+            )
+        )
+
+        # And checking whether it's in there...
+        is_in_public_rooms = self.get_success(
+            self.module_api.public_room_list_manager.room_is_in_public_room_list(
+                room_id
+            )
+        )
+        self.assertTrue(is_in_public_rooms)
+
+        # Let's remove it again
+        self.get_success(
+            self.module_api.public_room_list_manager.remove_room_from_public_room_list(
+                room_id
+            )
+        )
+
+        # Should be gone
+        is_in_public_rooms = self.get_success(
+            self.module_api.public_room_list_manager.room_is_in_public_room_list(
+                room_id
+            )
+        )
+        self.assertFalse(is_in_public_rooms)
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index ae60874ec3..81ea985b9f 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -12,13 +12,14 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import logging
 from typing import Any, Callable, List, Optional, Tuple
 
 import attr
+import hiredis
 
 from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
+from twisted.internet.protocol import Protocol
 from twisted.internet.task import LoopingCall
 from twisted.web.http import HTTPChannel
 
@@ -27,7 +28,7 @@ from synapse.app.generic_worker import (
     GenericWorkerServer,
 )
 from synapse.http.server import JsonResource
-from synapse.http.site import SynapseRequest
+from synapse.http.site import SynapseRequest, SynapseSite
 from synapse.replication.http import ReplicationRestResource, streams
 from synapse.replication.tcp.handler import ReplicationCommandHandler
 from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
@@ -197,19 +198,37 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         self.server_factory = ReplicationStreamProtocolFactory(self.hs)
         self.streamer = self.hs.get_replication_streamer()
 
+        # Fake in memory Redis server that servers can connect to.
+        self._redis_server = FakeRedisPubSubServer()
+
         store = self.hs.get_datastore()
         self.database_pool = store.db_pool
 
         self.reactor.lookups["testserv"] = "1.2.3.4"
+        self.reactor.lookups["localhost"] = "127.0.0.1"
+
+        # A map from a HS instance to the associated HTTP Site to use for
+        # handling inbound HTTP requests to that instance.
+        self._hs_to_site = {self.hs: self.site}
+
+        if self.hs.config.redis.redis_enabled:
+            # Handle attempts to connect to fake redis server.
+            self.reactor.add_tcp_client_callback(
+                "localhost", 6379, self.connect_any_redis_attempts,
+            )
 
-        self._worker_hs_to_resource = {}
+            self.hs.get_tcp_replication().start_replication(self.hs)
 
         # When we see a connection attempt to the master replication listener we
         # automatically set up the connection. This is so that tests don't
         # manually have to go and explicitly set it up each time (plus sometimes
         # it is impossible to write the handling explicitly in the tests).
+        #
+        # Register the master replication listener:
         self.reactor.add_tcp_client_callback(
-            "1.2.3.4", 8765, self._handle_http_replication_attempt
+            "1.2.3.4",
+            8765,
+            lambda: self._handle_http_replication_attempt(self.hs, 8765),
         )
 
     def create_test_json_resource(self):
@@ -253,28 +272,63 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
             **kwargs
         )
 
+        # If the instance is in the `instance_map` config then workers may try
+        # and send HTTP requests to it, so we register it with
+        # `_handle_http_replication_attempt` like we do with the master HS.
+        instance_name = worker_hs.get_instance_name()
+        instance_loc = worker_hs.config.worker.instance_map.get(instance_name)
+        if instance_loc:
+            # Ensure the host is one that has a fake DNS entry.
+            if instance_loc.host not in self.reactor.lookups:
+                raise Exception(
+                    "Host does not have an IP for instance_map[%r].host = %r"
+                    % (instance_name, instance_loc.host,)
+                )
+
+            self.reactor.add_tcp_client_callback(
+                self.reactor.lookups[instance_loc.host],
+                instance_loc.port,
+                lambda: self._handle_http_replication_attempt(
+                    worker_hs, instance_loc.port
+                ),
+            )
+
         store = worker_hs.get_datastore()
         store.db_pool._db_pool = self.database_pool._db_pool
 
-        repl_handler = ReplicationCommandHandler(worker_hs)
-        client = ClientReplicationStreamProtocol(
-            worker_hs, "client", "test", self.clock, repl_handler,
-        )
-        server = self.server_factory.buildProtocol(None)
+        # Set up TCP replication between master and the new worker if we don't
+        # have Redis support enabled.
+        if not worker_hs.config.redis_enabled:
+            repl_handler = ReplicationCommandHandler(worker_hs)
+            client = ClientReplicationStreamProtocol(
+                worker_hs, "client", "test", self.clock, repl_handler,
+            )
+            server = self.server_factory.buildProtocol(None)
 
-        client_transport = FakeTransport(server, self.reactor)
-        client.makeConnection(client_transport)
+            client_transport = FakeTransport(server, self.reactor)
+            client.makeConnection(client_transport)
 
-        server_transport = FakeTransport(client, self.reactor)
-        server.makeConnection(server_transport)
+            server_transport = FakeTransport(client, self.reactor)
+            server.makeConnection(server_transport)
 
         # Set up a resource for the worker
-        resource = ReplicationRestResource(self.hs)
+        resource = ReplicationRestResource(worker_hs)
 
         for servlet in self.servlets:
             servlet(worker_hs, resource)
 
-        self._worker_hs_to_resource[worker_hs] = resource
+        self._hs_to_site[worker_hs] = SynapseSite(
+            logger_name="synapse.access.http.fake",
+            site_tag="{}-{}".format(
+                worker_hs.config.server.server_name, worker_hs.get_instance_name()
+            ),
+            config=worker_hs.config.server.listeners[0],
+            resource=resource,
+            server_version_string="1",
+        )
+
+        if worker_hs.config.redis.redis_enabled:
+            worker_hs.get_tcp_replication().start_replication(worker_hs)
 
         return worker_hs
 
@@ -285,7 +339,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         return config
 
     def render_on_worker(self, worker_hs: HomeServer, request: SynapseRequest):
-        render(request, self._worker_hs_to_resource[worker_hs], self.reactor)
+        render(request, self._hs_to_site[worker_hs].resource, self.reactor)
 
     def replicate(self):
         """Tell the master side of replication that something has happened, and then
@@ -294,9 +348,9 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         self.streamer.on_notifier_poke()
         self.pump()
 
-    def _handle_http_replication_attempt(self):
-        """Handles a connection attempt to the master replication HTTP
-        listener.
+    def _handle_http_replication_attempt(self, hs, repl_port):
+        """Handles a connection attempt to the given HS replication HTTP
+        listener on the given port.
         """
 
         # We should have at least one outbound connection attempt, where the
@@ -305,7 +359,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         self.assertGreaterEqual(len(clients), 1)
         (host, port, client_factory, _timeout, _bindAddress) = clients.pop()
         self.assertEqual(host, "1.2.3.4")
-        self.assertEqual(port, 8765)
+        self.assertEqual(port, repl_port)
 
         # Set up client side protocol
         client_protocol = client_factory.buildProtocol(None)
@@ -315,7 +369,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         # Set up the server side protocol
         channel = _PushHTTPChannel(self.reactor)
         channel.requestFactory = request_factory
-        channel.site = self.site
+        channel.site = self._hs_to_site[hs]
 
         # Connect client to server and vice versa.
         client_to_server_transport = FakeTransport(
@@ -333,6 +387,32 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         # inside `connecTCP` before the connection has been passed back to the
         # code that requested the TCP connection.
 
+    def connect_any_redis_attempts(self):
+        """If redis is enabled we need to deal with workers connecting to a
+        redis server. We don't want to use a real Redis server so we use a
+        fake one.
+        """
+        clients = self.reactor.tcpClients
+        self.assertEqual(len(clients), 1)
+        (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+        self.assertEqual(host, "localhost")
+        self.assertEqual(port, 6379)
+
+        client_protocol = client_factory.buildProtocol(None)
+        server_protocol = self._redis_server.buildProtocol(None)
+
+        client_to_server_transport = FakeTransport(
+            server_protocol, self.reactor, client_protocol
+        )
+        client_protocol.makeConnection(client_to_server_transport)
+
+        server_to_client_transport = FakeTransport(
+            client_protocol, self.reactor, server_protocol
+        )
+        server_protocol.makeConnection(server_to_client_transport)
+
+        return client_to_server_transport, server_to_client_transport
+
 
 class TestReplicationDataHandler(GenericWorkerReplicationHandler):
     """Drop-in for ReplicationDataHandler which just collects RDATA rows"""
@@ -467,3 +547,105 @@ class _PullToPushProducer:
                 pass
 
             self.stopProducing()
+
+
+class FakeRedisPubSubServer:
+    """A fake Redis server for pub/sub.
+    """
+
+    def __init__(self):
+        self._subscribers = set()
+
+    def add_subscriber(self, conn):
+        """A connection has called SUBSCRIBE
+        """
+        self._subscribers.add(conn)
+
+    def remove_subscriber(self, conn):
+        """A connection has called UNSUBSCRIBE
+        """
+        self._subscribers.discard(conn)
+
+    def publish(self, conn, channel, msg) -> int:
+        """A connection want to publish a message to subscribers.
+        """
+        for sub in self._subscribers:
+            sub.send(["message", channel, msg])
+
+        return len(self._subscribers)
+
+    def buildProtocol(self, addr):
+        return FakeRedisPubSubProtocol(self)
+
+
+class FakeRedisPubSubProtocol(Protocol):
+    """A connection from a client talking to the fake Redis server.
+    """
+
+    def __init__(self, server: FakeRedisPubSubServer):
+        self._server = server
+        self._reader = hiredis.Reader()
+
+    def dataReceived(self, data):
+        self._reader.feed(data)
+
+        # We might get multiple messages in one packet.
+        while True:
+            msg = self._reader.gets()
+
+            if msg is False:
+                # No more messages.
+                return
+
+            if not isinstance(msg, list):
+                # Inbound commands should always be a list
+                raise Exception("Expected redis list")
+
+            self.handle_command(msg[0], *msg[1:])
+
+    def handle_command(self, command, *args):
+        """Received a Redis command from the client.
+        """
+
+        # We currently only support pub/sub.
+        if command == b"PUBLISH":
+            channel, message = args
+            num_subscribers = self._server.publish(self, channel, message)
+            self.send(num_subscribers)
+        elif command == b"SUBSCRIBE":
+            (channel,) = args
+            self._server.add_subscriber(self)
+            self.send(["subscribe", channel, 1])
+        else:
+            raise Exception("Unknown command")
+
+    def send(self, msg):
+        """Send a message back to the client.
+        """
+        raw = self.encode(msg).encode("utf-8")
+
+        self.transport.write(raw)
+        self.transport.flush()
+
+    def encode(self, obj):
+        """Encode an object to its Redis format.
+
+        Supports: strings/bytes, integers and list/tuples.
+        """
+
+        if isinstance(obj, bytes):
+            # We assume bytes are just unicode strings.
+            obj = obj.decode("utf-8")
+
+        if isinstance(obj, str):
+            return "${len}\r\n{str}\r\n".format(len=len(obj), str=obj)
+        if isinstance(obj, int):
+            return ":{val}\r\n".format(val=obj)
+        if isinstance(obj, (list, tuple)):
+            items = "".join(self.encode(a) for a in obj)
+            return "*{len}\r\n{items}".format(len=len(obj), items=items)
+
+        raise Exception("Unrecognized type for encoding redis: %r: %r", type(obj), obj)
+
+    def connectionLost(self, reason):
+        self._server.remove_subscriber(self)
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index 1d7edee5ba..9c4a9c3563 100644
--- a/tests/replication/test_federation_sender_shard.py
+++ b/tests/replication/test_federation_sender_shard.py
@@ -207,7 +207,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
     def create_room_with_remote_server(self, user, token, remote_server="other_server"):
         room = self.helper.create_room_as(user, tok=token)
         store = self.hs.get_datastore()
-        federation = self.hs.get_handlers().federation_handler
+        federation = self.hs.get_federation_handler()
 
         prev_event_ids = self.get_success(store.get_latest_event_ids_in_room(room))
         room_version = self.get_success(store.get_room_version(room))
diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py
new file mode 100644
index 0000000000..6068d14905
--- /dev/null
+++ b/tests/replication/test_sharded_event_persister.py
@@ -0,0 +1,102 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+from tests.utils import USE_POSTGRES_FOR_TESTS
+
+logger = logging.getLogger(__name__)
+
+
+class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
+    """Checks event persisting sharding works
+    """
+
+    # Event persister sharding requires postgres (due to needing
+    # `MutliWriterIdGenerator`).
+    if not USE_POSTGRES_FOR_TESTS:
+        skip = "Requires Postgres"
+
+    servlets = [
+        admin.register_servlets_for_client_rest_resource,
+        room.register_servlets,
+        login.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, hs):
+        # Register a user who sends a message that we'll get notified about
+        self.other_user_id = self.register_user("otheruser", "pass")
+        self.other_access_token = self.login("otheruser", "pass")
+
+    def default_config(self):
+        conf = super().default_config()
+        conf["redis"] = {"enabled": "true"}
+        conf["stream_writers"] = {"events": ["worker1", "worker2"]}
+        conf["instance_map"] = {
+            "worker1": {"host": "testserv", "port": 1001},
+            "worker2": {"host": "testserv", "port": 1002},
+        }
+        return conf
+
+    def test_basic(self):
+        """Simple test to ensure that multiple rooms can be created and joined,
+        and that different rooms get handled by different instances.
+        """
+
+        self.make_worker_hs(
+            "synapse.app.generic_worker", {"worker_name": "worker1"},
+        )
+
+        self.make_worker_hs(
+            "synapse.app.generic_worker", {"worker_name": "worker2"},
+        )
+
+        persisted_on_1 = False
+        persisted_on_2 = False
+
+        store = self.hs.get_datastore()
+
+        user_id = self.register_user("user", "pass")
+        access_token = self.login("user", "pass")
+
+        # Keep making new rooms until we see rooms being persisted on both
+        # workers.
+        for _ in range(10):
+            # Create a room
+            room = self.helper.create_room_as(user_id, tok=access_token)
+
+            # The other user joins
+            self.helper.join(
+                room=room, user=self.other_user_id, tok=self.other_access_token
+            )
+
+            # The other user sends some messages
+            rseponse = self.helper.send(room, body="Hi!", tok=self.other_access_token)
+            event_id = rseponse["event_id"]
+
+            # The event position includes which instance persisted the event.
+            pos = self.get_success(store.get_position_for_event(event_id))
+
+            persisted_on_1 |= pos.instance_name == "worker1"
+            persisted_on_2 |= pos.instance_name == "worker2"
+
+            if persisted_on_1 and persisted_on_2:
+                break
+
+        self.assertTrue(persisted_on_1)
+        self.assertTrue(persisted_on_2)
diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py
index dfe4bf7762..6bb02b9630 100644
--- a/tests/rest/client/test_shadow_banned.py
+++ b/tests/rest/client/test_shadow_banned.py
@@ -78,7 +78,7 @@ class RoomTestCase(_ShadowBannedBase):
 
     def test_invite_3pid(self):
         """Ensure that a 3PID invite does not attempt to contact the identity server."""
-        identity_handler = self.hs.get_handlers().identity_handler
+        identity_handler = self.hs.get_identity_handler()
         identity_handler.lookup_3pid = Mock(
             side_effect=AssertionError("This should not get called")
         )
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
new file mode 100644
index 0000000000..d03e121664
--- /dev/null
+++ b/tests/rest/client/test_third_party_rules.py
@@ -0,0 +1,170 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import threading
+from typing import Dict
+
+from mock import Mock
+
+from synapse.events import EventBase
+from synapse.module_api import ModuleApi
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+from synapse.types import Requester, StateMap
+
+from tests import unittest
+
+thread_local = threading.local()
+
+
+class ThirdPartyRulesTestModule:
+    def __init__(self, config: Dict, module_api: ModuleApi):
+        # keep a record of the "current" rules module, so that the test can patch
+        # it if desired.
+        thread_local.rules_module = self
+        self.module_api = module_api
+
+    async def on_create_room(
+        self, requester: Requester, config: dict, is_requester_admin: bool
+    ):
+        return True
+
+    async def check_event_allowed(self, event: EventBase, state: StateMap[EventBase]):
+        return True
+
+    @staticmethod
+    def parse_config(config):
+        return config
+
+
+def current_rules_module() -> ThirdPartyRulesTestModule:
+    return thread_local.rules_module
+
+
+class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        admin.register_servlets,
+        login.register_servlets,
+        room.register_servlets,
+    ]
+
+    def default_config(self):
+        config = super().default_config()
+        config["third_party_event_rules"] = {
+            "module": __name__ + ".ThirdPartyRulesTestModule",
+            "config": {},
+        }
+        return config
+
+    def prepare(self, reactor, clock, homeserver):
+        # Create a user and room to play with during the tests
+        self.user_id = self.register_user("kermit", "monkey")
+        self.tok = self.login("kermit", "monkey")
+
+        self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
+
+    def test_third_party_rules(self):
+        """Tests that a forbidden event is forbidden from being sent, but an allowed one
+        can be sent.
+        """
+        # patch the rules module with a Mock which will return False for some event
+        # types
+        async def check(ev, state):
+            return ev.type != "foo.bar.forbidden"
+
+        callback = Mock(spec=[], side_effect=check)
+        current_rules_module().check_event_allowed = callback
+
+        request, channel = self.make_request(
+            "PUT",
+            "/_matrix/client/r0/rooms/%s/send/foo.bar.allowed/1" % self.room_id,
+            {},
+            access_token=self.tok,
+        )
+        self.render(request)
+        self.assertEquals(channel.result["code"], b"200", channel.result)
+
+        callback.assert_called_once()
+
+        # there should be various state events in the state arg: do some basic checks
+        state_arg = callback.call_args[0][1]
+        for k in (("m.room.create", ""), ("m.room.member", self.user_id)):
+            self.assertIn(k, state_arg)
+            ev = state_arg[k]
+            self.assertEqual(ev.type, k[0])
+            self.assertEqual(ev.state_key, k[1])
+
+        request, channel = self.make_request(
+            "PUT",
+            "/_matrix/client/r0/rooms/%s/send/foo.bar.forbidden/1" % self.room_id,
+            {},
+            access_token=self.tok,
+        )
+        self.render(request)
+        self.assertEquals(channel.result["code"], b"403", channel.result)
+
+    def test_modify_event(self):
+        """Tests that the module can successfully tweak an event before it is persisted.
+        """
+        # first patch the event checker so that it will modify the event
+        async def check(ev: EventBase, state):
+            ev.content = {"x": "y"}
+            return True
+
+        current_rules_module().check_event_allowed = check
+
+        # now send the event
+        request, channel = self.make_request(
+            "PUT",
+            "/_matrix/client/r0/rooms/%s/send/modifyme/1" % self.room_id,
+            {"x": "x"},
+            access_token=self.tok,
+        )
+        self.render(request)
+        self.assertEqual(channel.result["code"], b"200", channel.result)
+        event_id = channel.json_body["event_id"]
+
+        # ... and check that it got modified
+        request, channel = self.make_request(
+            "GET",
+            "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
+            access_token=self.tok,
+        )
+        self.render(request)
+        self.assertEqual(channel.result["code"], b"200", channel.result)
+        ev = channel.json_body
+        self.assertEqual(ev["content"]["x"], "y")
+
+    def test_send_event(self):
+        """Tests that the module can send an event into a room via the module api"""
+        content = {
+            "msgtype": "m.text",
+            "body": "Hello!",
+        }
+        event_dict = {
+            "room_id": self.room_id,
+            "type": "m.room.message",
+            "content": content,
+            "sender": self.user_id,
+        }
+        event = self.get_success(
+            current_rules_module().module_api.create_and_send_event_into_room(
+                event_dict
+            )
+        )  # type: EventBase
+
+        self.assertEquals(event.sender, self.user_id)
+        self.assertEquals(event.room_id, self.room_id)
+        self.assertEquals(event.type, "m.room.message")
+        self.assertEquals(event.content, content)
diff --git a/tests/rest/client/third_party_rules.py b/tests/rest/client/third_party_rules.py
deleted file mode 100644
index 8c24add530..0000000000
--- a/tests/rest/client/third_party_rules.py
+++ /dev/null
@@ -1,79 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2019 The Matrix.org Foundation C.I.C.
-#
-# Licensed under the Apache License, Version 2.0 (the 'License');
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an 'AS IS' BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from synapse.rest import admin
-from synapse.rest.client.v1 import login, room
-
-from tests import unittest
-
-
-class ThirdPartyRulesTestModule:
-    def __init__(self, config):
-        pass
-
-    def check_event_allowed(self, event, context):
-        if event.type == "foo.bar.forbidden":
-            return False
-        else:
-            return True
-
-    @staticmethod
-    def parse_config(config):
-        return config
-
-
-class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
-    servlets = [
-        admin.register_servlets,
-        login.register_servlets,
-        room.register_servlets,
-    ]
-
-    def make_homeserver(self, reactor, clock):
-        config = self.default_config()
-        config["third_party_event_rules"] = {
-            "module": "tests.rest.client.third_party_rules.ThirdPartyRulesTestModule",
-            "config": {},
-        }
-
-        self.hs = self.setup_test_homeserver(config=config)
-        return self.hs
-
-    def test_third_party_rules(self):
-        """Tests that a forbidden event is forbidden from being sent, but an allowed one
-        can be sent.
-        """
-        user_id = self.register_user("kermit", "monkey")
-        tok = self.login("kermit", "monkey")
-
-        room_id = self.helper.create_room_as(user_id, tok=tok)
-
-        request, channel = self.make_request(
-            "PUT",
-            "/_matrix/client/r0/rooms/%s/send/foo.bar.allowed/1" % room_id,
-            {},
-            access_token=tok,
-        )
-        self.render(request)
-        self.assertEquals(channel.result["code"], b"200", channel.result)
-
-        request, channel = self.make_request(
-            "PUT",
-            "/_matrix/client/r0/rooms/%s/send/foo.bar.forbidden/1" % room_id,
-            {},
-            access_token=tok,
-        )
-        self.render(request)
-        self.assertEquals(channel.result["code"], b"403", channel.result)
diff --git a/tests/rest/client/v1/test_directory.py b/tests/rest/client/v1/test_directory.py
index 633b7dbda0..ea5a7f3739 100644
--- a/tests/rest/client/v1/test_directory.py
+++ b/tests/rest/client/v1/test_directory.py
@@ -21,6 +21,7 @@ from synapse.types import RoomAlias
 from synapse.util.stringutils import random_string
 
 from tests import unittest
+from tests.unittest import override_config
 
 
 class DirectoryTestCase(unittest.HomeserverTestCase):
@@ -67,10 +68,18 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
         self.ensure_user_joined_room()
         self.set_alias_via_directory(400, alias_length=256)
 
-    def test_state_event_in_room(self):
+    @override_config({"default_room_version": 5})
+    def test_state_event_user_in_v5_room(self):
+        """Test that a regular user can add alias events before room v6"""
         self.ensure_user_joined_room()
         self.set_alias_via_state_event(200)
 
+    @override_config({"default_room_version": 6})
+    def test_state_event_v6_room(self):
+        """Test that a regular user can *not* add alias events from room v6"""
+        self.ensure_user_joined_room()
+        self.set_alias_via_state_event(403)
+
     def test_directory_in_room(self):
         self.ensure_user_joined_room()
         self.set_alias_via_directory(200)
diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py
index f75520877f..3397ba5579 100644
--- a/tests/rest/client/v1/test_events.py
+++ b/tests/rest/client/v1/test_events.py
@@ -42,7 +42,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
 
         hs = self.setup_test_homeserver(config=config)
 
-        hs.get_handlers().federation_handler = Mock()
+        hs.get_federation_handler = Mock()
 
         return hs
 
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 0d809d25d5..9ba5f9d943 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -32,6 +32,7 @@ from synapse.types import JsonDict, RoomAlias, UserID
 from synapse.util.stringutils import random_string
 
 from tests import unittest
+from tests.test_utils import make_awaitable
 
 PATH_PREFIX = b"/_matrix/client/api/v1"
 
@@ -47,7 +48,10 @@ class RoomBase(unittest.HomeserverTestCase):
             "red", http_client=None, federation_client=Mock(),
         )
 
-        self.hs.get_federation_handler = Mock(return_value=Mock())
+        self.hs.get_federation_handler = Mock()
+        self.hs.get_federation_handler.return_value.maybe_backfill = Mock(
+            return_value=make_awaitable(None)
+        )
 
         async def _insert_client_ip(*args, **kwargs):
             return None
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index 94d2bf2eb1..cd58ee7792 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -44,7 +44,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
 
         self.event_source = hs.get_event_sources().sources["typing"]
 
-        hs.get_handlers().federation_handler = Mock()
+        hs.get_federation_handler = Mock()
 
         async def get_user_by_access_token(token=None, allow_guest=False):
             return {
diff --git a/tests/server.py b/tests/server.py
index b404ad4e2a..4d33b84097 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -1,8 +1,11 @@
 import json
 import logging
+from collections import deque
 from io import SEEK_END, BytesIO
+from typing import Callable
 
 import attr
+from typing_extensions import Deque
 from zope.interface import implementer
 
 from twisted.internet import address, threads, udp
@@ -251,6 +254,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
         self._tcp_callbacks = {}
         self._udp = []
         lookups = self.lookups = {}
+        self._thread_callbacks = deque()  # type: Deque[Callable[[], None]]()
 
         @implementer(IResolverSimple)
         class FakeResolver:
@@ -272,10 +276,10 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
         """
         Make the callback fire in the next reactor iteration.
         """
-        d = Deferred()
-        d.addCallback(lambda x: callback(*args, **kwargs))
-        self.callLater(0, d.callback, True)
-        return d
+        cb = lambda: callback(*args, **kwargs)
+        # it's not safe to call callLater() here, so we append the callback to a
+        # separate queue.
+        self._thread_callbacks.append(cb)
 
     def getThreadPool(self):
         return self.threadpool
@@ -303,6 +307,30 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
 
         return conn
 
+    def advance(self, amount):
+        # first advance our reactor's time, and run any "callLater" callbacks that
+        # makes ready
+        super().advance(amount)
+
+        # now run any "callFromThread" callbacks
+        while True:
+            try:
+                callback = self._thread_callbacks.popleft()
+            except IndexError:
+                break
+            callback()
+
+            # check for more "callLater" callbacks added by the thread callback
+            # This isn't required in a regular reactor, but it ends up meaning that
+            # our database queries can complete in a single call to `advance` [1] which
+            # simplifies tests.
+            #
+            # [1]: we replace the threadpool backing the db connection pool with a
+            # mock ThreadPool which doesn't really use threads; but we still use
+            # reactor.callFromThread to feed results back from the db functions to the
+            # main thread.
+            super().advance(0)
+
 
 class ThreadPool:
     """
@@ -339,8 +367,6 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
     """
     server = _sth(cleanup_func, *args, **kwargs)
 
-    database = server.config.database.get_single_database()
-
     # Make the thread pool synchronous.
     clock = server.get_clock()
 
@@ -372,6 +398,10 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
         pool.threadpool = ThreadPool(clock._reactor)
         pool.running = True
 
+    # We've just changed the Databases to run DB transactions on the same
+    # thread, so we need to disable the dedicated thread behaviour.
+    server.get_datastores().main.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = False
+
     return server
 
 
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 46f94914ff..c905a38930 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -58,7 +58,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
         # must be done after inserts
         database = hs.get_datastores().databases[0]
         self.store = ApplicationServiceStore(
-            database, make_conn(database._database_config, database.engine), hs
+            database, make_conn(database._database_config, database.engine, "test"), hs
         )
 
     def tearDown(self):
@@ -132,7 +132,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
 
         db_config = hs.config.get_single_database()
         self.store = TestTransactionStore(
-            database, make_conn(db_config, self.engine), hs
+            database, make_conn(db_config, self.engine, "test"), hs
         )
 
     def _add_service(self, url, as_token, id):
@@ -448,7 +448,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
 
         database = hs.get_datastores().databases[0]
         ApplicationServiceStore(
-            database, make_conn(database._database_config, database.engine), hs
+            database, make_conn(database._database_config, database.engine, "test"), hs
         )
 
     @defer.inlineCallbacks
@@ -467,7 +467,9 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
         with self.assertRaises(ConfigError) as cm:
             database = hs.get_datastores().databases[0]
             ApplicationServiceStore(
-                database, make_conn(database._database_config, database.engine), hs
+                database,
+                make_conn(database._database_config, database.engine, "test"),
+                hs,
             )
 
         e = cm.exception
@@ -491,7 +493,9 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
         with self.assertRaises(ConfigError) as cm:
             database = hs.get_datastores().databases[0]
             ApplicationServiceStore(
-                database, make_conn(database._database_config, database.engine), hs
+                database,
+                make_conn(database._database_config, database.engine, "test"),
+                hs,
             )
 
         e = cm.exception
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 392b08832b..cc0612cf65 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -199,10 +199,17 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         first_id_gen = self._create_id_generator("first", writers=["first", "second"])
         second_id_gen = self._create_id_generator("second", writers=["first", "second"])
 
-        self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
-        self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3)
+        # The first ID gen will notice that it can advance its token to 7 as it
+        # has no in progress writes...
+        self.assertEqual(first_id_gen.get_positions(), {"first": 7, "second": 7})
+        self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
         self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
 
+        # ... but the second ID gen doesn't know that.
+        self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
+        self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 3)
+        self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
+
         # Try allocating a new ID gen and check that we only see position
         # advanced after we leave the context manager.
 
@@ -211,7 +218,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
                 self.assertEqual(stream_id, 8)
 
                 self.assertEqual(
-                    first_id_gen.get_positions(), {"first": 3, "second": 7}
+                    first_id_gen.get_positions(), {"first": 7, "second": 7}
                 )
 
         self.get_success(_get_next_async())
@@ -279,7 +286,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         self._insert_row_with_id("first", 3)
         self._insert_row_with_id("second", 5)
 
-        id_gen = self._create_id_generator("first", writers=["first", "second"])
+        id_gen = self._create_id_generator("worker", writers=["first", "second"])
 
         self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
 
@@ -319,14 +326,14 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
 
         id_gen = self._create_id_generator("first", writers=["first", "second"])
 
-        self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
+        self.assertEqual(id_gen.get_positions(), {"first": 5, "second": 5})
 
-        self.assertEqual(id_gen.get_persisted_upto_position(), 3)
+        self.assertEqual(id_gen.get_persisted_upto_position(), 5)
 
         async def _get_next_async():
             async with id_gen.get_next() as stream_id:
                 self.assertEqual(stream_id, 6)
-                self.assertEqual(id_gen.get_persisted_upto_position(), 3)
+                self.assertEqual(id_gen.get_persisted_upto_position(), 5)
 
         self.get_success(_get_next_async())
 
@@ -388,7 +395,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         self._insert_row_with_id("second", 5)
 
         # Initial config has two writers
-        id_gen = self._create_id_generator("first", writers=["first", "second"])
+        id_gen = self._create_id_generator("worker", writers=["first", "second"])
         self.assertEqual(id_gen.get_persisted_upto_position(), 3)
         self.assertEqual(id_gen.get_current_token_for_writer("first"), 3)
         self.assertEqual(id_gen.get_current_token_for_writer("second"), 5)
@@ -568,7 +575,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
 
         self.get_success(_get_next_async2())
 
-        self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -2})
+        self.assertEqual(id_gen_1.get_positions(), {"first": -2, "second": -2})
         self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2})
         self.assertEqual(id_gen_1.get_persisted_upto_position(), -2)
         self.assertEqual(id_gen_2.get_persisted_upto_position(), -2)
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 27a7fc9ed7..d39e792580 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -75,7 +75,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
             }
         )
 
-        self.handler = self.homeserver.get_handlers().federation_handler
+        self.handler = self.homeserver.get_federation_handler()
         self.handler.do_auth = lambda origin, event, context, auth_events: succeed(
             context
         )
diff --git a/tests/test_phone_home.py b/tests/test_phone_home.py
index 7657bddea5..e7aed092c2 100644
--- a/tests/test_phone_home.py
+++ b/tests/test_phone_home.py
@@ -17,7 +17,7 @@ import resource
 
 import mock
 
-from synapse.app.homeserver import phone_stats_home
+from synapse.app.phone_stats_home import phone_stats_home
 
 from tests.unittest import HomeserverTestCase
 
diff --git a/tests/unittest.py b/tests/unittest.py
index e654c0442d..5c87f6097e 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -241,7 +241,7 @@ class HomeserverTestCase(TestCase):
         # create a site to wrap the resource.
         self.site = SynapseSite(
             logger_name="synapse.access.http.fake",
-            site_tag="test",
+            site_tag=self.hs.config.server.server_name,
             config=self.hs.config.server.listeners[0],
             resource=self.resource,
             server_version_string="1",
@@ -608,7 +608,9 @@ class HomeserverTestCase(TestCase):
         if soft_failed:
             event.internal_metadata.soft_failed = True
 
-        self.get_success(event_creator.send_nonmember_event(requester, event, context))
+        self.get_success(
+            event_creator.handle_new_client_event(requester, event, context)
+        )
 
         return event.event_id
 
diff --git a/tests/utils.py b/tests/utils.py
index 4673872f88..0c09f5457f 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -38,6 +38,7 @@ from synapse.http.server import HttpServer
 from synapse.logging.context import current_context, set_current_context
 from synapse.server import HomeServer
 from synapse.storage import DataStore
+from synapse.storage.database import LoggingDatabaseConnection
 from synapse.storage.engines import PostgresEngine, create_engine
 from synapse.storage.prepare_database import prepare_database
 from synapse.util.ratelimitutils import FederationRateLimiter
@@ -88,6 +89,7 @@ def setupdb():
             host=POSTGRES_HOST,
             password=POSTGRES_PASSWORD,
         )
+        db_conn = LoggingDatabaseConnection(db_conn, db_engine, "tests")
         prepare_database(db_conn, db_engine, None)
         db_conn.close()
 
@@ -190,7 +192,6 @@ class TestHomeServer(HomeServer):
 def setup_test_homeserver(
     cleanup_func,
     name="test",
-    datastore=None,
     config=None,
     reactor=None,
     homeserverToUse=TestHomeServer,
@@ -247,7 +248,7 @@ def setup_test_homeserver(
 
     # Create the database before we actually try and connect to it, based off
     # the template database we generate in setupdb()
-    if datastore is None and isinstance(db_engine, PostgresEngine):
+    if isinstance(db_engine, PostgresEngine):
         db_conn = db_engine.module.connect(
             database=POSTGRES_BASE_DB,
             user=POSTGRES_USER,
@@ -263,79 +264,66 @@ def setup_test_homeserver(
         cur.close()
         db_conn.close()
 
-    if datastore is None:
-        hs = homeserverToUse(
-            name,
-            config=config,
-            version_string="Synapse/tests",
-            tls_server_context_factory=Mock(),
-            tls_client_options_factory=Mock(),
-            reactor=reactor,
-            **kargs
-        )
+    hs = homeserverToUse(
+        name,
+        config=config,
+        version_string="Synapse/tests",
+        tls_server_context_factory=Mock(),
+        tls_client_options_factory=Mock(),
+        reactor=reactor,
+        **kargs
+    )
 
-        hs.setup()
-        if homeserverToUse.__name__ == "TestHomeServer":
-            hs.setup_master()
+    hs.setup()
+    if homeserverToUse.__name__ == "TestHomeServer":
+        hs.setup_background_tasks()
 
-        if isinstance(db_engine, PostgresEngine):
-            database = hs.get_datastores().databases[0]
+    if isinstance(db_engine, PostgresEngine):
+        database = hs.get_datastores().databases[0]
 
-            # We need to do cleanup on PostgreSQL
-            def cleanup():
-                import psycopg2
+        # We need to do cleanup on PostgreSQL
+        def cleanup():
+            import psycopg2
 
-                # Close all the db pools
-                database._db_pool.close()
+            # Close all the db pools
+            database._db_pool.close()
 
-                dropped = False
+            dropped = False
 
-                # Drop the test database
-                db_conn = db_engine.module.connect(
-                    database=POSTGRES_BASE_DB,
-                    user=POSTGRES_USER,
-                    host=POSTGRES_HOST,
-                    password=POSTGRES_PASSWORD,
-                )
-                db_conn.autocommit = True
-                cur = db_conn.cursor()
-
-                # Try a few times to drop the DB. Some things may hold on to the
-                # database for a few more seconds due to flakiness, preventing
-                # us from dropping it when the test is over. If we can't drop
-                # it, warn and move on.
-                for x in range(5):
-                    try:
-                        cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
-                        db_conn.commit()
-                        dropped = True
-                    except psycopg2.OperationalError as e:
-                        warnings.warn(
-                            "Couldn't drop old db: " + str(e), category=UserWarning
-                        )
-                        time.sleep(0.5)
-
-                cur.close()
-                db_conn.close()
-
-                if not dropped:
-                    warnings.warn("Failed to drop old DB.", category=UserWarning)
-
-            if not LEAVE_DB:
-                # Register the cleanup hook
-                cleanup_func(cleanup)
+            # Drop the test database
+            db_conn = db_engine.module.connect(
+                database=POSTGRES_BASE_DB,
+                user=POSTGRES_USER,
+                host=POSTGRES_HOST,
+                password=POSTGRES_PASSWORD,
+            )
+            db_conn.autocommit = True
+            cur = db_conn.cursor()
 
-    else:
-        hs = homeserverToUse(
-            name,
-            datastore=datastore,
-            config=config,
-            version_string="Synapse/tests",
-            tls_server_context_factory=Mock(),
-            tls_client_options_factory=Mock(),
-            reactor=reactor,
-            **kargs
-        )
+            # Try a few times to drop the DB. Some things may hold on to the
+            # database for a few more seconds due to flakiness, preventing
+            # us from dropping it when the test is over. If we can't drop
+            # it, warn and move on.
+            for x in range(5):
+                try:
+                    cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
+                    db_conn.commit()
+                    dropped = True
+                except psycopg2.OperationalError as e:
+                    warnings.warn(
+                        "Couldn't drop old db: " + str(e), category=UserWarning
+                    )
+                    time.sleep(0.5)
+
+            cur.close()
+            db_conn.close()
+
+            if not dropped:
+                warnings.warn("Failed to drop old DB.", category=UserWarning)
+
+        if not LEAVE_DB:
+            # Register the cleanup hook
+            cleanup_func(cleanup)
 
     # bcrypt is far too slow to be doing in unit tests
     # Need to let the HS build an auth handler and then mess with it