summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/handlers/test_device.py82
-rw-r--r--tests/handlers/test_e2e_keys.py65
-rw-r--r--tests/handlers/test_oidc.py10
-rw-r--r--tests/handlers/test_register.py2
-rw-r--r--tests/module_api/test_api.py58
-rw-r--r--tests/replication/_base.py224
-rw-r--r--tests/replication/test_sharded_event_persister.py102
-rw-r--r--tests/rest/client/test_third_party_rules.py144
-rw-r--r--tests/rest/client/third_party_rules.py79
-rw-r--r--tests/rest/client/v1/test_directory.py11
-rw-r--r--tests/server.py4
-rw-r--r--tests/storage/test_appservice.py14
-rw-r--r--tests/test_phone_home.py2
-rw-r--r--tests/unittest.py6
-rw-r--r--tests/utils.py4
15 files changed, 692 insertions, 115 deletions
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 969d44c787..4512c51311 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -1,6 +1,7 @@
 # -*- coding: utf-8 -*-
 # Copyright 2016 OpenMarket Ltd
 # Copyright 2018 New Vector Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -224,3 +225,84 @@ class DeviceTestCase(unittest.HomeserverTestCase):
                 )
             )
             self.reactor.advance(1000)
+
+
+class DehydrationTestCase(unittest.HomeserverTestCase):
+    def make_homeserver(self, reactor, clock):
+        hs = self.setup_test_homeserver("server", http_client=None)
+        self.handler = hs.get_device_handler()
+        self.registration = hs.get_registration_handler()
+        self.auth = hs.get_auth()
+        self.store = hs.get_datastore()
+        return hs
+
+    def test_dehydrate_and_rehydrate_device(self):
+        user_id = "@boris:dehydration"
+
+        self.get_success(self.store.register_user(user_id, "foobar"))
+
+        # First check if we can store and fetch a dehydrated device
+        stored_dehydrated_device_id = self.get_success(
+            self.handler.store_dehydrated_device(
+                user_id=user_id,
+                device_data={"device_data": {"foo": "bar"}},
+                initial_device_display_name="dehydrated device",
+            )
+        )
+
+        retrieved_device_id, device_data = self.get_success(
+            self.handler.get_dehydrated_device(user_id=user_id)
+        )
+
+        self.assertEqual(retrieved_device_id, stored_dehydrated_device_id)
+        self.assertEqual(device_data, {"device_data": {"foo": "bar"}})
+
+        # Create a new login for the user and dehydrated the device
+        device_id, access_token = self.get_success(
+            self.registration.register_device(
+                user_id=user_id, device_id=None, initial_display_name="new device",
+            )
+        )
+
+        # Trying to claim a nonexistent device should throw an error
+        self.get_failure(
+            self.handler.rehydrate_device(
+                user_id=user_id,
+                access_token=access_token,
+                device_id="not the right device ID",
+            ),
+            synapse.api.errors.NotFoundError,
+        )
+
+        # dehydrating the right devices should succeed and change our device ID
+        # to the dehydrated device's ID
+        res = self.get_success(
+            self.handler.rehydrate_device(
+                user_id=user_id,
+                access_token=access_token,
+                device_id=retrieved_device_id,
+            )
+        )
+
+        self.assertEqual(res, {"success": True})
+
+        # make sure that our device ID has changed
+        user_info = self.get_success(self.auth.get_user_by_access_token(access_token))
+
+        self.assertEqual(user_info["device_id"], retrieved_device_id)
+
+        # make sure the device has the display name that was set from the login
+        res = self.get_success(self.handler.get_device(user_id, retrieved_device_id))
+
+        self.assertEqual(res["display_name"], "new device")
+
+        # make sure that the device ID that we were initially assigned no longer exists
+        self.get_failure(
+            self.handler.get_device(user_id, device_id),
+            synapse.api.errors.NotFoundError,
+        )
+
+        # make sure that there's no device available for dehydrating now
+        ret = self.get_success(self.handler.get_dehydrated_device(user_id=user_id))
+
+        self.assertIsNone(ret)
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 366dcfb670..4e9e3dcbc2 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -172,6 +172,71 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         )
 
     @defer.inlineCallbacks
+    def test_fallback_key(self):
+        local_user = "@boris:" + self.hs.hostname
+        device_id = "xyz"
+        fallback_key = {"alg1:k1": "key1"}
+        otk = {"alg1:k2": "key2"}
+
+        yield defer.ensureDeferred(
+            self.handler.upload_keys_for_user(
+                local_user,
+                device_id,
+                {"org.matrix.msc2732.fallback_keys": fallback_key},
+            )
+        )
+
+        # claiming an OTK when no OTKs are available should return the fallback
+        # key
+        res = yield defer.ensureDeferred(
+            self.handler.claim_one_time_keys(
+                {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+            )
+        )
+        self.assertEqual(
+            res,
+            {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
+        )
+
+        # claiming an OTK again should return the same fallback key
+        res = yield defer.ensureDeferred(
+            self.handler.claim_one_time_keys(
+                {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+            )
+        )
+        self.assertEqual(
+            res,
+            {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
+        )
+
+        # if the user uploads a one-time key, the next claim should fetch the
+        # one-time key, and then go back to the fallback
+        yield defer.ensureDeferred(
+            self.handler.upload_keys_for_user(
+                local_user, device_id, {"one_time_keys": otk}
+            )
+        )
+
+        res = yield defer.ensureDeferred(
+            self.handler.claim_one_time_keys(
+                {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+            )
+        )
+        self.assertEqual(
+            res, {"failures": {}, "one_time_keys": {local_user: {device_id: otk}}},
+        )
+
+        res = yield defer.ensureDeferred(
+            self.handler.claim_one_time_keys(
+                {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+            )
+        )
+        self.assertEqual(
+            res,
+            {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
+        )
+
+    @defer.inlineCallbacks
     def test_replace_master_key(self):
         """uploading a new signing key should make the old signing key unavailable"""
         local_user = "@boris:" + self.hs.hostname
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index d5087e58be..b6f436c016 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -286,9 +286,15 @@ class OidcHandlerTestCase(HomeserverTestCase):
                 h._validate_metadata,
             )
 
-        # Tests for configs that the userinfo endpoint
+        # Tests for configs that require the userinfo endpoint
         self.assertFalse(h._uses_userinfo)
-        h._scopes = []  # do not request the openid scope
+        self.assertEqual(h._user_profile_method, "auto")
+        h._user_profile_method = "userinfo_endpoint"
+        self.assertTrue(h._uses_userinfo)
+
+        # Revert the profile method and do not request the "openid" scope.
+        h._user_profile_method = "auto"
+        h._scopes = []
         self.assertTrue(h._uses_userinfo)
         self.assertRaisesRegex(ValueError, "userinfo_endpoint", h._validate_metadata)
 
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index cb7c0ed51a..702c6aa089 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -413,7 +413,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
             )
         )
         self.get_success(
-            event_creation_handler.send_nonmember_event(requester, event, context)
+            event_creation_handler.handle_new_client_event(requester, event, context)
         )
 
         # Register a second user, which won't be be in the room (or even have an invite)
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 04de0b9dbe..7c790bee7d 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -13,15 +13,22 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.module_api import ModuleApi
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
 
 from tests.unittest import HomeserverTestCase
 
 
 class ModuleApiTestCase(HomeserverTestCase):
+    servlets = [
+        admin.register_servlets,
+        login.register_servlets,
+        room.register_servlets,
+    ]
+
     def prepare(self, reactor, clock, homeserver):
         self.store = homeserver.get_datastore()
-        self.module_api = ModuleApi(homeserver, homeserver.get_auth_handler())
+        self.module_api = homeserver.get_module_api()
 
     def test_can_register_user(self):
         """Tests that an external module can register a user"""
@@ -52,3 +59,50 @@ class ModuleApiTestCase(HomeserverTestCase):
         # Check that the displayname was assigned
         displayname = self.get_success(self.store.get_profile_displayname("bob"))
         self.assertEqual(displayname, "Bobberino")
+
+    def test_public_rooms(self):
+        """Tests that a room can be added and removed from the public rooms list,
+        as well as have its public rooms directory state queried.
+        """
+        # Create a user and room to play with
+        user_id = self.register_user("kermit", "monkey")
+        tok = self.login("kermit", "monkey")
+        room_id = self.helper.create_room_as(user_id, tok=tok)
+
+        # The room should not currently be in the public rooms directory
+        is_in_public_rooms = self.get_success(
+            self.module_api.public_room_list_manager.room_is_in_public_room_list(
+                room_id
+            )
+        )
+        self.assertFalse(is_in_public_rooms)
+
+        # Let's try adding it to the public rooms directory
+        self.get_success(
+            self.module_api.public_room_list_manager.add_room_to_public_room_list(
+                room_id
+            )
+        )
+
+        # And checking whether it's in there...
+        is_in_public_rooms = self.get_success(
+            self.module_api.public_room_list_manager.room_is_in_public_room_list(
+                room_id
+            )
+        )
+        self.assertTrue(is_in_public_rooms)
+
+        # Let's remove it again
+        self.get_success(
+            self.module_api.public_room_list_manager.remove_room_from_public_room_list(
+                room_id
+            )
+        )
+
+        # Should be gone
+        is_in_public_rooms = self.get_success(
+            self.module_api.public_room_list_manager.room_is_in_public_room_list(
+                room_id
+            )
+        )
+        self.assertFalse(is_in_public_rooms)
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index ae60874ec3..81ea985b9f 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -12,13 +12,14 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import logging
 from typing import Any, Callable, List, Optional, Tuple
 
 import attr
+import hiredis
 
 from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
+from twisted.internet.protocol import Protocol
 from twisted.internet.task import LoopingCall
 from twisted.web.http import HTTPChannel
 
@@ -27,7 +28,7 @@ from synapse.app.generic_worker import (
     GenericWorkerServer,
 )
 from synapse.http.server import JsonResource
-from synapse.http.site import SynapseRequest
+from synapse.http.site import SynapseRequest, SynapseSite
 from synapse.replication.http import ReplicationRestResource, streams
 from synapse.replication.tcp.handler import ReplicationCommandHandler
 from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
@@ -197,19 +198,37 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         self.server_factory = ReplicationStreamProtocolFactory(self.hs)
         self.streamer = self.hs.get_replication_streamer()
 
+        # Fake in memory Redis server that servers can connect to.
+        self._redis_server = FakeRedisPubSubServer()
+
         store = self.hs.get_datastore()
         self.database_pool = store.db_pool
 
         self.reactor.lookups["testserv"] = "1.2.3.4"
+        self.reactor.lookups["localhost"] = "127.0.0.1"
+
+        # A map from a HS instance to the associated HTTP Site to use for
+        # handling inbound HTTP requests to that instance.
+        self._hs_to_site = {self.hs: self.site}
+
+        if self.hs.config.redis.redis_enabled:
+            # Handle attempts to connect to fake redis server.
+            self.reactor.add_tcp_client_callback(
+                "localhost", 6379, self.connect_any_redis_attempts,
+            )
 
-        self._worker_hs_to_resource = {}
+            self.hs.get_tcp_replication().start_replication(self.hs)
 
         # When we see a connection attempt to the master replication listener we
         # automatically set up the connection. This is so that tests don't
         # manually have to go and explicitly set it up each time (plus sometimes
         # it is impossible to write the handling explicitly in the tests).
+        #
+        # Register the master replication listener:
         self.reactor.add_tcp_client_callback(
-            "1.2.3.4", 8765, self._handle_http_replication_attempt
+            "1.2.3.4",
+            8765,
+            lambda: self._handle_http_replication_attempt(self.hs, 8765),
         )
 
     def create_test_json_resource(self):
@@ -253,28 +272,63 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
             **kwargs
         )
 
+        # If the instance is in the `instance_map` config then workers may try
+        # and send HTTP requests to it, so we register it with
+        # `_handle_http_replication_attempt` like we do with the master HS.
+        instance_name = worker_hs.get_instance_name()
+        instance_loc = worker_hs.config.worker.instance_map.get(instance_name)
+        if instance_loc:
+            # Ensure the host is one that has a fake DNS entry.
+            if instance_loc.host not in self.reactor.lookups:
+                raise Exception(
+                    "Host does not have an IP for instance_map[%r].host = %r"
+                    % (instance_name, instance_loc.host,)
+                )
+
+            self.reactor.add_tcp_client_callback(
+                self.reactor.lookups[instance_loc.host],
+                instance_loc.port,
+                lambda: self._handle_http_replication_attempt(
+                    worker_hs, instance_loc.port
+                ),
+            )
+
         store = worker_hs.get_datastore()
         store.db_pool._db_pool = self.database_pool._db_pool
 
-        repl_handler = ReplicationCommandHandler(worker_hs)
-        client = ClientReplicationStreamProtocol(
-            worker_hs, "client", "test", self.clock, repl_handler,
-        )
-        server = self.server_factory.buildProtocol(None)
+        # Set up TCP replication between master and the new worker if we don't
+        # have Redis support enabled.
+        if not worker_hs.config.redis_enabled:
+            repl_handler = ReplicationCommandHandler(worker_hs)
+            client = ClientReplicationStreamProtocol(
+                worker_hs, "client", "test", self.clock, repl_handler,
+            )
+            server = self.server_factory.buildProtocol(None)
 
-        client_transport = FakeTransport(server, self.reactor)
-        client.makeConnection(client_transport)
+            client_transport = FakeTransport(server, self.reactor)
+            client.makeConnection(client_transport)
 
-        server_transport = FakeTransport(client, self.reactor)
-        server.makeConnection(server_transport)
+            server_transport = FakeTransport(client, self.reactor)
+            server.makeConnection(server_transport)
 
         # Set up a resource for the worker
-        resource = ReplicationRestResource(self.hs)
+        resource = ReplicationRestResource(worker_hs)
 
         for servlet in self.servlets:
             servlet(worker_hs, resource)
 
-        self._worker_hs_to_resource[worker_hs] = resource
+        self._hs_to_site[worker_hs] = SynapseSite(
+            logger_name="synapse.access.http.fake",
+            site_tag="{}-{}".format(
+                worker_hs.config.server.server_name, worker_hs.get_instance_name()
+            ),
+            config=worker_hs.config.server.listeners[0],
+            resource=resource,
+            server_version_string="1",
+        )
+
+        if worker_hs.config.redis.redis_enabled:
+            worker_hs.get_tcp_replication().start_replication(worker_hs)
 
         return worker_hs
 
@@ -285,7 +339,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         return config
 
     def render_on_worker(self, worker_hs: HomeServer, request: SynapseRequest):
-        render(request, self._worker_hs_to_resource[worker_hs], self.reactor)
+        render(request, self._hs_to_site[worker_hs].resource, self.reactor)
 
     def replicate(self):
         """Tell the master side of replication that something has happened, and then
@@ -294,9 +348,9 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         self.streamer.on_notifier_poke()
         self.pump()
 
-    def _handle_http_replication_attempt(self):
-        """Handles a connection attempt to the master replication HTTP
-        listener.
+    def _handle_http_replication_attempt(self, hs, repl_port):
+        """Handles a connection attempt to the given HS replication HTTP
+        listener on the given port.
         """
 
         # We should have at least one outbound connection attempt, where the
@@ -305,7 +359,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         self.assertGreaterEqual(len(clients), 1)
         (host, port, client_factory, _timeout, _bindAddress) = clients.pop()
         self.assertEqual(host, "1.2.3.4")
-        self.assertEqual(port, 8765)
+        self.assertEqual(port, repl_port)
 
         # Set up client side protocol
         client_protocol = client_factory.buildProtocol(None)
@@ -315,7 +369,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         # Set up the server side protocol
         channel = _PushHTTPChannel(self.reactor)
         channel.requestFactory = request_factory
-        channel.site = self.site
+        channel.site = self._hs_to_site[hs]
 
         # Connect client to server and vice versa.
         client_to_server_transport = FakeTransport(
@@ -333,6 +387,32 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         # inside `connecTCP` before the connection has been passed back to the
         # code that requested the TCP connection.
 
+    def connect_any_redis_attempts(self):
+        """If redis is enabled we need to deal with workers connecting to a
+        redis server. We don't want to use a real Redis server so we use a
+        fake one.
+        """
+        clients = self.reactor.tcpClients
+        self.assertEqual(len(clients), 1)
+        (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+        self.assertEqual(host, "localhost")
+        self.assertEqual(port, 6379)
+
+        client_protocol = client_factory.buildProtocol(None)
+        server_protocol = self._redis_server.buildProtocol(None)
+
+        client_to_server_transport = FakeTransport(
+            server_protocol, self.reactor, client_protocol
+        )
+        client_protocol.makeConnection(client_to_server_transport)
+
+        server_to_client_transport = FakeTransport(
+            client_protocol, self.reactor, server_protocol
+        )
+        server_protocol.makeConnection(server_to_client_transport)
+
+        return client_to_server_transport, server_to_client_transport
+
 
 class TestReplicationDataHandler(GenericWorkerReplicationHandler):
     """Drop-in for ReplicationDataHandler which just collects RDATA rows"""
@@ -467,3 +547,105 @@ class _PullToPushProducer:
                 pass
 
             self.stopProducing()
+
+
+class FakeRedisPubSubServer:
+    """A fake Redis server for pub/sub.
+    """
+
+    def __init__(self):
+        self._subscribers = set()
+
+    def add_subscriber(self, conn):
+        """A connection has called SUBSCRIBE
+        """
+        self._subscribers.add(conn)
+
+    def remove_subscriber(self, conn):
+        """A connection has called UNSUBSCRIBE
+        """
+        self._subscribers.discard(conn)
+
+    def publish(self, conn, channel, msg) -> int:
+        """A connection want to publish a message to subscribers.
+        """
+        for sub in self._subscribers:
+            sub.send(["message", channel, msg])
+
+        return len(self._subscribers)
+
+    def buildProtocol(self, addr):
+        return FakeRedisPubSubProtocol(self)
+
+
+class FakeRedisPubSubProtocol(Protocol):
+    """A connection from a client talking to the fake Redis server.
+    """
+
+    def __init__(self, server: FakeRedisPubSubServer):
+        self._server = server
+        self._reader = hiredis.Reader()
+
+    def dataReceived(self, data):
+        self._reader.feed(data)
+
+        # We might get multiple messages in one packet.
+        while True:
+            msg = self._reader.gets()
+
+            if msg is False:
+                # No more messages.
+                return
+
+            if not isinstance(msg, list):
+                # Inbound commands should always be a list
+                raise Exception("Expected redis list")
+
+            self.handle_command(msg[0], *msg[1:])
+
+    def handle_command(self, command, *args):
+        """Received a Redis command from the client.
+        """
+
+        # We currently only support pub/sub.
+        if command == b"PUBLISH":
+            channel, message = args
+            num_subscribers = self._server.publish(self, channel, message)
+            self.send(num_subscribers)
+        elif command == b"SUBSCRIBE":
+            (channel,) = args
+            self._server.add_subscriber(self)
+            self.send(["subscribe", channel, 1])
+        else:
+            raise Exception("Unknown command")
+
+    def send(self, msg):
+        """Send a message back to the client.
+        """
+        raw = self.encode(msg).encode("utf-8")
+
+        self.transport.write(raw)
+        self.transport.flush()
+
+    def encode(self, obj):
+        """Encode an object to its Redis format.
+
+        Supports: strings/bytes, integers and list/tuples.
+        """
+
+        if isinstance(obj, bytes):
+            # We assume bytes are just unicode strings.
+            obj = obj.decode("utf-8")
+
+        if isinstance(obj, str):
+            return "${len}\r\n{str}\r\n".format(len=len(obj), str=obj)
+        if isinstance(obj, int):
+            return ":{val}\r\n".format(val=obj)
+        if isinstance(obj, (list, tuple)):
+            items = "".join(self.encode(a) for a in obj)
+            return "*{len}\r\n{items}".format(len=len(obj), items=items)
+
+        raise Exception("Unrecognized type for encoding redis: %r: %r", type(obj), obj)
+
+    def connectionLost(self, reason):
+        self._server.remove_subscriber(self)
diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py
new file mode 100644
index 0000000000..6068d14905
--- /dev/null
+++ b/tests/replication/test_sharded_event_persister.py
@@ -0,0 +1,102 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+from tests.utils import USE_POSTGRES_FOR_TESTS
+
+logger = logging.getLogger(__name__)
+
+
+class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
+    """Checks event persisting sharding works
+    """
+
+    # Event persister sharding requires postgres (due to needing
+    # `MutliWriterIdGenerator`).
+    if not USE_POSTGRES_FOR_TESTS:
+        skip = "Requires Postgres"
+
+    servlets = [
+        admin.register_servlets_for_client_rest_resource,
+        room.register_servlets,
+        login.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, hs):
+        # Register a user who sends a message that we'll get notified about
+        self.other_user_id = self.register_user("otheruser", "pass")
+        self.other_access_token = self.login("otheruser", "pass")
+
+    def default_config(self):
+        conf = super().default_config()
+        conf["redis"] = {"enabled": "true"}
+        conf["stream_writers"] = {"events": ["worker1", "worker2"]}
+        conf["instance_map"] = {
+            "worker1": {"host": "testserv", "port": 1001},
+            "worker2": {"host": "testserv", "port": 1002},
+        }
+        return conf
+
+    def test_basic(self):
+        """Simple test to ensure that multiple rooms can be created and joined,
+        and that different rooms get handled by different instances.
+        """
+
+        self.make_worker_hs(
+            "synapse.app.generic_worker", {"worker_name": "worker1"},
+        )
+
+        self.make_worker_hs(
+            "synapse.app.generic_worker", {"worker_name": "worker2"},
+        )
+
+        persisted_on_1 = False
+        persisted_on_2 = False
+
+        store = self.hs.get_datastore()
+
+        user_id = self.register_user("user", "pass")
+        access_token = self.login("user", "pass")
+
+        # Keep making new rooms until we see rooms being persisted on both
+        # workers.
+        for _ in range(10):
+            # Create a room
+            room = self.helper.create_room_as(user_id, tok=access_token)
+
+            # The other user joins
+            self.helper.join(
+                room=room, user=self.other_user_id, tok=self.other_access_token
+            )
+
+            # The other user sends some messages
+            rseponse = self.helper.send(room, body="Hi!", tok=self.other_access_token)
+            event_id = rseponse["event_id"]
+
+            # The event position includes which instance persisted the event.
+            pos = self.get_success(store.get_position_for_event(event_id))
+
+            persisted_on_1 |= pos.instance_name == "worker1"
+            persisted_on_2 |= pos.instance_name == "worker2"
+
+            if persisted_on_1 and persisted_on_2:
+                break
+
+        self.assertTrue(persisted_on_1)
+        self.assertTrue(persisted_on_2)
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
new file mode 100644
index 0000000000..c12518c931
--- /dev/null
+++ b/tests/rest/client/test_third_party_rules.py
@@ -0,0 +1,144 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import threading
+
+from mock import Mock
+
+from synapse.events import EventBase
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+from synapse.types import Requester, StateMap
+
+from tests import unittest
+
+thread_local = threading.local()
+
+
+class ThirdPartyRulesTestModule:
+    def __init__(self, config, module_api):
+        # keep a record of the "current" rules module, so that the test can patch
+        # it if desired.
+        thread_local.rules_module = self
+
+    async def on_create_room(
+        self, requester: Requester, config: dict, is_requester_admin: bool
+    ):
+        return True
+
+    async def check_event_allowed(self, event: EventBase, state: StateMap[EventBase]):
+        return True
+
+    @staticmethod
+    def parse_config(config):
+        return config
+
+
+def current_rules_module() -> ThirdPartyRulesTestModule:
+    return thread_local.rules_module
+
+
+class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        admin.register_servlets,
+        login.register_servlets,
+        room.register_servlets,
+    ]
+
+    def default_config(self):
+        config = super().default_config()
+        config["third_party_event_rules"] = {
+            "module": __name__ + ".ThirdPartyRulesTestModule",
+            "config": {},
+        }
+        return config
+
+    def prepare(self, reactor, clock, homeserver):
+        # Create a user and room to play with during the tests
+        self.user_id = self.register_user("kermit", "monkey")
+        self.tok = self.login("kermit", "monkey")
+
+        self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
+
+    def test_third_party_rules(self):
+        """Tests that a forbidden event is forbidden from being sent, but an allowed one
+        can be sent.
+        """
+        # patch the rules module with a Mock which will return False for some event
+        # types
+        async def check(ev, state):
+            return ev.type != "foo.bar.forbidden"
+
+        callback = Mock(spec=[], side_effect=check)
+        current_rules_module().check_event_allowed = callback
+
+        request, channel = self.make_request(
+            "PUT",
+            "/_matrix/client/r0/rooms/%s/send/foo.bar.allowed/1" % self.room_id,
+            {},
+            access_token=self.tok,
+        )
+        self.render(request)
+        self.assertEquals(channel.result["code"], b"200", channel.result)
+
+        callback.assert_called_once()
+
+        # there should be various state events in the state arg: do some basic checks
+        state_arg = callback.call_args[0][1]
+        for k in (("m.room.create", ""), ("m.room.member", self.user_id)):
+            self.assertIn(k, state_arg)
+            ev = state_arg[k]
+            self.assertEqual(ev.type, k[0])
+            self.assertEqual(ev.state_key, k[1])
+
+        request, channel = self.make_request(
+            "PUT",
+            "/_matrix/client/r0/rooms/%s/send/foo.bar.forbidden/1" % self.room_id,
+            {},
+            access_token=self.tok,
+        )
+        self.render(request)
+        self.assertEquals(channel.result["code"], b"403", channel.result)
+
+    def test_modify_event(self):
+        """Tests that the module can successfully tweak an event before it is persisted.
+        """
+        # first patch the event checker so that it will modify the event
+        async def check(ev: EventBase, state):
+            ev.content = {"x": "y"}
+            return True
+
+        current_rules_module().check_event_allowed = check
+
+        # now send the event
+        request, channel = self.make_request(
+            "PUT",
+            "/_matrix/client/r0/rooms/%s/send/modifyme/1" % self.room_id,
+            {"x": "x"},
+            access_token=self.tok,
+        )
+        self.render(request)
+        self.assertEqual(channel.result["code"], b"200", channel.result)
+        event_id = channel.json_body["event_id"]
+
+        # ... and check that it got modified
+        request, channel = self.make_request(
+            "GET",
+            "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
+            access_token=self.tok,
+        )
+        self.render(request)
+        self.assertEqual(channel.result["code"], b"200", channel.result)
+        ev = channel.json_body
+        self.assertEqual(ev["content"]["x"], "y")
diff --git a/tests/rest/client/third_party_rules.py b/tests/rest/client/third_party_rules.py
deleted file mode 100644
index 8c24add530..0000000000
--- a/tests/rest/client/third_party_rules.py
+++ /dev/null
@@ -1,79 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2019 The Matrix.org Foundation C.I.C.
-#
-# Licensed under the Apache License, Version 2.0 (the 'License');
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an 'AS IS' BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from synapse.rest import admin
-from synapse.rest.client.v1 import login, room
-
-from tests import unittest
-
-
-class ThirdPartyRulesTestModule:
-    def __init__(self, config):
-        pass
-
-    def check_event_allowed(self, event, context):
-        if event.type == "foo.bar.forbidden":
-            return False
-        else:
-            return True
-
-    @staticmethod
-    def parse_config(config):
-        return config
-
-
-class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
-    servlets = [
-        admin.register_servlets,
-        login.register_servlets,
-        room.register_servlets,
-    ]
-
-    def make_homeserver(self, reactor, clock):
-        config = self.default_config()
-        config["third_party_event_rules"] = {
-            "module": "tests.rest.client.third_party_rules.ThirdPartyRulesTestModule",
-            "config": {},
-        }
-
-        self.hs = self.setup_test_homeserver(config=config)
-        return self.hs
-
-    def test_third_party_rules(self):
-        """Tests that a forbidden event is forbidden from being sent, but an allowed one
-        can be sent.
-        """
-        user_id = self.register_user("kermit", "monkey")
-        tok = self.login("kermit", "monkey")
-
-        room_id = self.helper.create_room_as(user_id, tok=tok)
-
-        request, channel = self.make_request(
-            "PUT",
-            "/_matrix/client/r0/rooms/%s/send/foo.bar.allowed/1" % room_id,
-            {},
-            access_token=tok,
-        )
-        self.render(request)
-        self.assertEquals(channel.result["code"], b"200", channel.result)
-
-        request, channel = self.make_request(
-            "PUT",
-            "/_matrix/client/r0/rooms/%s/send/foo.bar.forbidden/1" % room_id,
-            {},
-            access_token=tok,
-        )
-        self.render(request)
-        self.assertEquals(channel.result["code"], b"403", channel.result)
diff --git a/tests/rest/client/v1/test_directory.py b/tests/rest/client/v1/test_directory.py
index 633b7dbda0..ea5a7f3739 100644
--- a/tests/rest/client/v1/test_directory.py
+++ b/tests/rest/client/v1/test_directory.py
@@ -21,6 +21,7 @@ from synapse.types import RoomAlias
 from synapse.util.stringutils import random_string
 
 from tests import unittest
+from tests.unittest import override_config
 
 
 class DirectoryTestCase(unittest.HomeserverTestCase):
@@ -67,10 +68,18 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
         self.ensure_user_joined_room()
         self.set_alias_via_directory(400, alias_length=256)
 
-    def test_state_event_in_room(self):
+    @override_config({"default_room_version": 5})
+    def test_state_event_user_in_v5_room(self):
+        """Test that a regular user can add alias events before room v6"""
         self.ensure_user_joined_room()
         self.set_alias_via_state_event(200)
 
+    @override_config({"default_room_version": 6})
+    def test_state_event_v6_room(self):
+        """Test that a regular user can *not* add alias events from room v6"""
+        self.ensure_user_joined_room()
+        self.set_alias_via_state_event(403)
+
     def test_directory_in_room(self):
         self.ensure_user_joined_room()
         self.set_alias_via_directory(200)
diff --git a/tests/server.py b/tests/server.py
index b404ad4e2a..f7f5276b21 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -372,6 +372,10 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
         pool.threadpool = ThreadPool(clock._reactor)
         pool.running = True
 
+    # We've just changed the Databases to run DB transactions on the same
+    # thread, so we need to disable the dedicated thread behaviour.
+    server.get_datastores().main.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = False
+
     return server
 
 
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 46f94914ff..c905a38930 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -58,7 +58,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
         # must be done after inserts
         database = hs.get_datastores().databases[0]
         self.store = ApplicationServiceStore(
-            database, make_conn(database._database_config, database.engine), hs
+            database, make_conn(database._database_config, database.engine, "test"), hs
         )
 
     def tearDown(self):
@@ -132,7 +132,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
 
         db_config = hs.config.get_single_database()
         self.store = TestTransactionStore(
-            database, make_conn(db_config, self.engine), hs
+            database, make_conn(db_config, self.engine, "test"), hs
         )
 
     def _add_service(self, url, as_token, id):
@@ -448,7 +448,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
 
         database = hs.get_datastores().databases[0]
         ApplicationServiceStore(
-            database, make_conn(database._database_config, database.engine), hs
+            database, make_conn(database._database_config, database.engine, "test"), hs
         )
 
     @defer.inlineCallbacks
@@ -467,7 +467,9 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
         with self.assertRaises(ConfigError) as cm:
             database = hs.get_datastores().databases[0]
             ApplicationServiceStore(
-                database, make_conn(database._database_config, database.engine), hs
+                database,
+                make_conn(database._database_config, database.engine, "test"),
+                hs,
             )
 
         e = cm.exception
@@ -491,7 +493,9 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
         with self.assertRaises(ConfigError) as cm:
             database = hs.get_datastores().databases[0]
             ApplicationServiceStore(
-                database, make_conn(database._database_config, database.engine), hs
+                database,
+                make_conn(database._database_config, database.engine, "test"),
+                hs,
             )
 
         e = cm.exception
diff --git a/tests/test_phone_home.py b/tests/test_phone_home.py
index 7657bddea5..e7aed092c2 100644
--- a/tests/test_phone_home.py
+++ b/tests/test_phone_home.py
@@ -17,7 +17,7 @@ import resource
 
 import mock
 
-from synapse.app.homeserver import phone_stats_home
+from synapse.app.phone_stats_home import phone_stats_home
 
 from tests.unittest import HomeserverTestCase
 
diff --git a/tests/unittest.py b/tests/unittest.py
index e654c0442d..5c87f6097e 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -241,7 +241,7 @@ class HomeserverTestCase(TestCase):
         # create a site to wrap the resource.
         self.site = SynapseSite(
             logger_name="synapse.access.http.fake",
-            site_tag="test",
+            site_tag=self.hs.config.server.server_name,
             config=self.hs.config.server.listeners[0],
             resource=self.resource,
             server_version_string="1",
@@ -608,7 +608,9 @@ class HomeserverTestCase(TestCase):
         if soft_failed:
             event.internal_metadata.soft_failed = True
 
-        self.get_success(event_creator.send_nonmember_event(requester, event, context))
+        self.get_success(
+            event_creator.handle_new_client_event(requester, event, context)
+        )
 
         return event.event_id
 
diff --git a/tests/utils.py b/tests/utils.py
index 4673872f88..af563ffe0f 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -38,6 +38,7 @@ from synapse.http.server import HttpServer
 from synapse.logging.context import current_context, set_current_context
 from synapse.server import HomeServer
 from synapse.storage import DataStore
+from synapse.storage.database import LoggingDatabaseConnection
 from synapse.storage.engines import PostgresEngine, create_engine
 from synapse.storage.prepare_database import prepare_database
 from synapse.util.ratelimitutils import FederationRateLimiter
@@ -88,6 +89,7 @@ def setupdb():
             host=POSTGRES_HOST,
             password=POSTGRES_PASSWORD,
         )
+        db_conn = LoggingDatabaseConnection(db_conn, db_engine, "tests")
         prepare_database(db_conn, db_engine, None)
         db_conn.close()
 
@@ -276,7 +278,7 @@ def setup_test_homeserver(
 
         hs.setup()
         if homeserverToUse.__name__ == "TestHomeServer":
-            hs.setup_master()
+            hs.setup_background_tasks()
 
         if isinstance(db_engine, PostgresEngine):
             database = hs.get_datastores().databases[0]