summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/appservice/test_appservice.py89
-rw-r--r--tests/appservice/test_scheduler.py19
-rw-r--r--tests/crypto/test_keyring.py50
-rw-r--r--tests/federation/test_complexity.py118
-rw-r--r--tests/federation/test_federation_sender.py10
-rw-r--r--tests/handlers/test_appservice.py5
-rw-r--r--tests/handlers/test_directory.py5
-rw-r--r--tests/handlers/test_profile.py3
-rw-r--r--tests/handlers/test_stats.py108
-rw-r--r--tests/handlers/test_user_directory.py22
-rw-r--r--tests/http/test_fedclient.py50
-rw-r--r--tests/replication/_base.py6
-rw-r--r--tests/replication/slave/storage/test_events.py6
-rw-r--r--tests/replication/test_federation_sender_shard.py13
-rw-r--r--tests/rest/admin/test_admin.py4
-rw-r--r--tests/rest/admin/test_room.py2947
-rw-r--r--tests/rest/client/v1/utils.py44
-rw-r--r--tests/rest/client/v2_alpha/test_sync.py157
-rw-r--r--tests/rest/key/v2/test_remote_key_resource.py4
-rw-r--r--tests/rest/media/v1/test_url_preview.py142
-rw-r--r--tests/server_notices/test_resource_limits_server_notices.py2
-rw-r--r--tests/storage/test__base.py16
-rw-r--r--tests/storage/test_appservice.py6
-rw-r--r--tests/storage/test_background_update.py8
-rw-r--r--tests/storage/test_base.py22
-rw-r--r--tests/storage/test_cleanup_extrems.py12
-rw-r--r--tests/storage/test_client_ips.py26
-rw-r--r--tests/storage/test_event_federation.py16
-rw-r--r--tests/storage/test_event_push_actions.py30
-rw-r--r--tests/storage/test_id_generators.py14
-rw-r--r--tests/storage/test_monthly_active_users.py6
-rw-r--r--tests/storage/test_purge.py8
-rw-r--r--tests/storage/test_redaction.py8
-rw-r--r--tests/storage/test_room.py30
-rw-r--r--tests/storage/test_roommember.py12
-rw-r--r--tests/storage/test_state.py76
-rw-r--r--tests/test_federation.py2
-rw-r--r--tests/test_server.py45
-rw-r--r--tests/test_state.py14
-rw-r--r--tests/test_visibility.py26
-rw-r--r--tests/unittest.py6
-rw-r--r--tests/utils.py20
42 files changed, 2427 insertions, 1780 deletions
diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py
index 4003869ed6..236b608d58 100644
--- a/tests/appservice/test_appservice.py
+++ b/tests/appservice/test_appservice.py
@@ -50,13 +50,17 @@ class ApplicationServiceTestCase(unittest.TestCase):
     def test_regex_user_id_prefix_match(self):
         self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
         self.event.sender = "@irc_foobar:matrix.org"
-        self.assertTrue((yield self.service.is_interested(self.event)))
+        self.assertTrue(
+            (yield defer.ensureDeferred(self.service.is_interested(self.event)))
+        )
 
     @defer.inlineCallbacks
     def test_regex_user_id_prefix_no_match(self):
         self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
         self.event.sender = "@someone_else:matrix.org"
-        self.assertFalse((yield self.service.is_interested(self.event)))
+        self.assertFalse(
+            (yield defer.ensureDeferred(self.service.is_interested(self.event)))
+        )
 
     @defer.inlineCallbacks
     def test_regex_room_member_is_checked(self):
@@ -64,7 +68,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
         self.event.sender = "@someone_else:matrix.org"
         self.event.type = "m.room.member"
         self.event.state_key = "@irc_foobar:matrix.org"
-        self.assertTrue((yield self.service.is_interested(self.event)))
+        self.assertTrue(
+            (yield defer.ensureDeferred(self.service.is_interested(self.event)))
+        )
 
     @defer.inlineCallbacks
     def test_regex_room_id_match(self):
@@ -72,7 +78,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
             _regex("!some_prefix.*some_suffix:matrix.org")
         )
         self.event.room_id = "!some_prefixs0m3th1nGsome_suffix:matrix.org"
-        self.assertTrue((yield self.service.is_interested(self.event)))
+        self.assertTrue(
+            (yield defer.ensureDeferred(self.service.is_interested(self.event)))
+        )
 
     @defer.inlineCallbacks
     def test_regex_room_id_no_match(self):
@@ -80,19 +88,26 @@ class ApplicationServiceTestCase(unittest.TestCase):
             _regex("!some_prefix.*some_suffix:matrix.org")
         )
         self.event.room_id = "!XqBunHwQIXUiqCaoxq:matrix.org"
-        self.assertFalse((yield self.service.is_interested(self.event)))
+        self.assertFalse(
+            (yield defer.ensureDeferred(self.service.is_interested(self.event)))
+        )
 
     @defer.inlineCallbacks
     def test_regex_alias_match(self):
         self.service.namespaces[ApplicationService.NS_ALIASES].append(
             _regex("#irc_.*:matrix.org")
         )
-        self.store.get_aliases_for_room.return_value = [
-            "#irc_foobar:matrix.org",
-            "#athing:matrix.org",
-        ]
-        self.store.get_users_in_room.return_value = []
-        self.assertTrue((yield self.service.is_interested(self.event, self.store)))
+        self.store.get_aliases_for_room.return_value = defer.succeed(
+            ["#irc_foobar:matrix.org", "#athing:matrix.org"]
+        )
+        self.store.get_users_in_room.return_value = defer.succeed([])
+        self.assertTrue(
+            (
+                yield defer.ensureDeferred(
+                    self.service.is_interested(self.event, self.store)
+                )
+            )
+        )
 
     def test_non_exclusive_alias(self):
         self.service.namespaces[ApplicationService.NS_ALIASES].append(
@@ -135,12 +150,17 @@ class ApplicationServiceTestCase(unittest.TestCase):
         self.service.namespaces[ApplicationService.NS_ALIASES].append(
             _regex("#irc_.*:matrix.org")
         )
-        self.store.get_aliases_for_room.return_value = [
-            "#xmpp_foobar:matrix.org",
-            "#athing:matrix.org",
-        ]
-        self.store.get_users_in_room.return_value = []
-        self.assertFalse((yield self.service.is_interested(self.event, self.store)))
+        self.store.get_aliases_for_room.return_value = defer.succeed(
+            ["#xmpp_foobar:matrix.org", "#athing:matrix.org"]
+        )
+        self.store.get_users_in_room.return_value = defer.succeed([])
+        self.assertFalse(
+            (
+                yield defer.ensureDeferred(
+                    self.service.is_interested(self.event, self.store)
+                )
+            )
+        )
 
     @defer.inlineCallbacks
     def test_regex_multiple_matches(self):
@@ -149,9 +169,17 @@ class ApplicationServiceTestCase(unittest.TestCase):
         )
         self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
         self.event.sender = "@irc_foobar:matrix.org"
-        self.store.get_aliases_for_room.return_value = ["#irc_barfoo:matrix.org"]
-        self.store.get_users_in_room.return_value = []
-        self.assertTrue((yield self.service.is_interested(self.event, self.store)))
+        self.store.get_aliases_for_room.return_value = defer.succeed(
+            ["#irc_barfoo:matrix.org"]
+        )
+        self.store.get_users_in_room.return_value = defer.succeed([])
+        self.assertTrue(
+            (
+                yield defer.ensureDeferred(
+                    self.service.is_interested(self.event, self.store)
+                )
+            )
+        )
 
     @defer.inlineCallbacks
     def test_interested_in_self(self):
@@ -161,19 +189,24 @@ class ApplicationServiceTestCase(unittest.TestCase):
         self.event.type = "m.room.member"
         self.event.content = {"membership": "invite"}
         self.event.state_key = self.service.sender
-        self.assertTrue((yield self.service.is_interested(self.event)))
+        self.assertTrue(
+            (yield defer.ensureDeferred(self.service.is_interested(self.event)))
+        )
 
     @defer.inlineCallbacks
     def test_member_list_match(self):
         self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
-        self.store.get_users_in_room.return_value = [
-            "@alice:here",
-            "@irc_fo:here",  # AS user
-            "@bob:here",
-        ]
-        self.store.get_aliases_for_room.return_value = []
+        # Note that @irc_fo:here is the AS user.
+        self.store.get_users_in_room.return_value = defer.succeed(
+            ["@alice:here", "@irc_fo:here", "@bob:here"]
+        )
+        self.store.get_aliases_for_room.return_value = defer.succeed([])
 
         self.event.sender = "@xmpp_foobar:matrix.org"
         self.assertTrue(
-            (yield self.service.is_interested(event=self.event, store=self.store))
+            (
+                yield defer.ensureDeferred(
+                    self.service.is_interested(event=self.event, store=self.store)
+                )
+            )
         )
diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py
index 52f89d3f83..68a4caabbf 100644
--- a/tests/appservice/test_scheduler.py
+++ b/tests/appservice/test_scheduler.py
@@ -25,6 +25,7 @@ from synapse.appservice.scheduler import (
 from synapse.logging.context import make_deferred_yieldable
 
 from tests import unittest
+from tests.test_utils import make_awaitable
 
 from ..utils import MockClock
 
@@ -52,11 +53,11 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
         self.store.get_appservice_state = Mock(
             return_value=defer.succeed(ApplicationServiceState.UP)
         )
-        txn.send = Mock(return_value=defer.succeed(True))
+        txn.send = Mock(return_value=make_awaitable(True))
         self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn))
 
         # actual call
-        self.txnctrl.send(service, events)
+        self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
 
         self.store.create_appservice_txn.assert_called_once_with(
             service=service, events=events  # txn made and saved
@@ -77,7 +78,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
         self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn))
 
         # actual call
-        self.txnctrl.send(service, events)
+        self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
 
         self.store.create_appservice_txn.assert_called_once_with(
             service=service, events=events  # txn made and saved
@@ -98,11 +99,11 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
             return_value=defer.succeed(ApplicationServiceState.UP)
         )
         self.store.set_appservice_state = Mock(return_value=defer.succeed(True))
-        txn.send = Mock(return_value=defer.succeed(False))  # fails to send
+        txn.send = Mock(return_value=make_awaitable(False))  # fails to send
         self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn))
 
         # actual call
-        self.txnctrl.send(service, events)
+        self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
 
         self.store.create_appservice_txn.assert_called_once_with(
             service=service, events=events
@@ -144,7 +145,8 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
         self.recoverer.recover()
         # shouldn't have called anything prior to waiting for exp backoff
         self.assertEquals(0, self.store.get_oldest_unsent_txn.call_count)
-        txn.send = Mock(return_value=True)
+        txn.send = Mock(return_value=make_awaitable(True))
+        txn.complete.return_value = make_awaitable(None)
         # wait for exp backoff
         self.clock.advance_time(2)
         self.assertEquals(1, txn.send.call_count)
@@ -169,7 +171,8 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
 
         self.recoverer.recover()
         self.assertEquals(0, self.store.get_oldest_unsent_txn.call_count)
-        txn.send = Mock(return_value=False)
+        txn.send = Mock(return_value=make_awaitable(False))
+        txn.complete.return_value = make_awaitable(None)
         self.clock.advance_time(2)
         self.assertEquals(1, txn.send.call_count)
         self.assertEquals(0, txn.complete.call_count)
@@ -182,7 +185,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
         self.assertEquals(3, txn.send.call_count)
         self.assertEquals(0, txn.complete.call_count)
         self.assertEquals(0, self.callback.call_count)
-        txn.send = Mock(return_value=True)  # successfully send the txn
+        txn.send = Mock(return_value=make_awaitable(True))  # successfully send the txn
         pop_txn = True  # returns the txn the first time, then no more.
         self.clock.advance_time(16)
         self.assertEquals(1, txn.send.call_count)  # new mock reset call count
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index f9ce609923..0d4b05304b 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -40,6 +40,7 @@ from synapse.logging.context import (
 from synapse.storage.keys import FetchKeyResult
 
 from tests import unittest
+from tests.test_utils import make_awaitable
 
 
 class MockPerspectiveServer(object):
@@ -102,11 +103,10 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         }
         persp_deferred = defer.Deferred()
 
-        @defer.inlineCallbacks
-        def get_perspectives(**kwargs):
+        async def get_perspectives(**kwargs):
             self.assertEquals(current_context().request, "11")
             with PreserveLoggingContext():
-                yield persp_deferred
+                await persp_deferred
             return persp_resp
 
         self.http_client.post_json.side_effect = get_perspectives
@@ -202,7 +202,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         with a null `ts_valid_until_ms`
         """
         mock_fetcher = keyring.KeyFetcher()
-        mock_fetcher.get_keys = Mock(return_value=defer.succeed({}))
+        mock_fetcher.get_keys = Mock(return_value=make_awaitable({}))
 
         kr = keyring.Keyring(
             self.hs, key_fetchers=(StoreKeyFetcher(self.hs), mock_fetcher)
@@ -245,17 +245,15 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         """Two requests for the same key should be deduped."""
         key1 = signedjson.key.generate_signing_key(1)
 
-        def get_keys(keys_to_fetch):
+        async def get_keys(keys_to_fetch):
             # there should only be one request object (with the max validity)
             self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
 
-            return defer.succeed(
-                {
-                    "server1": {
-                        get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
-                    }
+            return {
+                "server1": {
+                    get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
                 }
-            )
+            }
 
         mock_fetcher = keyring.KeyFetcher()
         mock_fetcher.get_keys = Mock(side_effect=get_keys)
@@ -282,25 +280,19 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         """If the first fetcher cannot provide a recent enough key, we fall back"""
         key1 = signedjson.key.generate_signing_key(1)
 
-        def get_keys1(keys_to_fetch):
+        async def get_keys1(keys_to_fetch):
             self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
-            return defer.succeed(
-                {
-                    "server1": {
-                        get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800)
-                    }
-                }
-            )
+            return {
+                "server1": {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800)}
+            }
 
-        def get_keys2(keys_to_fetch):
+        async def get_keys2(keys_to_fetch):
             self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
-            return defer.succeed(
-                {
-                    "server1": {
-                        get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
-                    }
+            return {
+                "server1": {
+                    get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
                 }
-            )
+            }
 
         mock_fetcher1 = keyring.KeyFetcher()
         mock_fetcher1.get_keys = Mock(side_effect=get_keys1)
@@ -355,7 +347,7 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
         }
         signedjson.sign.sign_json(response, SERVER_NAME, testkey)
 
-        def get_json(destination, path, **kwargs):
+        async def get_json(destination, path, **kwargs):
             self.assertEqual(destination, SERVER_NAME)
             self.assertEqual(path, "/_matrix/key/v2/server/key1")
             return response
@@ -444,7 +436,7 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
         Tell the mock http client to expect a perspectives-server key query
         """
 
-        def post_json(destination, path, data, **kwargs):
+        async def post_json(destination, path, data, **kwargs):
             self.assertEqual(destination, self.mock_perspective_server.server_name)
             self.assertEqual(path, "/_matrix/key/v2/query")
 
@@ -580,14 +572,12 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
         # remove the perspectives server's signature
         response = build_response()
         del response["signatures"][self.mock_perspective_server.server_name]
-        self.http_client.post_json.return_value = {"server_keys": [response]}
         keys = get_key_from_perspectives(response)
         self.assertEqual(keys, {}, "Expected empty dict with missing persp server sig")
 
         # remove the origin server's signature
         response = build_response()
         del response["signatures"][SERVER_NAME]
-        self.http_client.post_json.return_value = {"server_keys": [response]}
         keys = get_key_from_perspectives(response)
         self.assertEqual(keys, {}, "Expected empty dict with missing origin server sig")
 
diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py
index 0c9987be54..b8ca118716 100644
--- a/tests/federation/test_complexity.py
+++ b/tests/federation/test_complexity.py
@@ -23,6 +23,7 @@ from synapse.rest.client.v1 import login, room
 from synapse.types import UserID
 
 from tests import unittest
+from tests.test_utils import make_awaitable
 
 
 class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
@@ -78,9 +79,40 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
         fed_transport = self.hs.get_federation_transport_client()
 
         # Mock out some things, because we don't want to test the whole join
-        fed_transport.client.get_json = Mock(return_value=defer.succeed({"v1": 9999}))
+        fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
         handler.federation_handler.do_invite_join = Mock(
-            return_value=defer.succeed(("", 1))
+            return_value=make_awaitable(("", 1))
+        )
+
+        d = handler._remote_join(
+            None,
+            ["other.example.com"],
+            "roomid",
+            UserID.from_string(u1),
+            {"membership": "join"},
+        )
+
+        self.pump()
+
+        # The request failed with a SynapseError saying the resource limit was
+        # exceeded.
+        f = self.get_failure(d, SynapseError)
+        self.assertEqual(f.value.code, 400, f.value)
+        self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+
+    def test_join_too_large_admin(self):
+        # Check whether an admin can join if option "admins_can_join" is undefined,
+        # this option defaults to false, so the join should fail.
+
+        u1 = self.register_user("u1", "pass", admin=True)
+
+        handler = self.hs.get_room_member_handler()
+        fed_transport = self.hs.get_federation_transport_client()
+
+        # Mock out some things, because we don't want to test the whole join
+        fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
+        handler.federation_handler.do_invite_join = Mock(
+            return_value=make_awaitable(("", 1))
         )
 
         d = handler._remote_join(
@@ -116,9 +148,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
         fed_transport = self.hs.get_federation_transport_client()
 
         # Mock out some things, because we don't want to test the whole join
-        fed_transport.client.get_json = Mock(return_value=defer.succeed(None))
+        fed_transport.client.get_json = Mock(return_value=make_awaitable(None))
         handler.federation_handler.do_invite_join = Mock(
-            return_value=defer.succeed(("", 1))
+            return_value=make_awaitable(("", 1))
         )
 
         # Artificially raise the complexity
@@ -141,3 +173,81 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
         f = self.get_failure(d, SynapseError)
         self.assertEqual(f.value.code, 400)
         self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+
+
+class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
+    # Test the behavior of joining rooms which exceed the complexity if option
+    # limit_remote_rooms.admins_can_join is True.
+
+    servlets = [
+        admin.register_servlets,
+        room.register_servlets,
+        login.register_servlets,
+    ]
+
+    def default_config(self):
+        config = super().default_config()
+        config["limit_remote_rooms"] = {
+            "enabled": True,
+            "complexity": 0.05,
+            "admins_can_join": True,
+        }
+        return config
+
+    def test_join_too_large_no_admin(self):
+        # A user which is not an admin should not be able to join a remote room
+        # which is too complex.
+
+        u1 = self.register_user("u1", "pass")
+
+        handler = self.hs.get_room_member_handler()
+        fed_transport = self.hs.get_federation_transport_client()
+
+        # Mock out some things, because we don't want to test the whole join
+        fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
+        handler.federation_handler.do_invite_join = Mock(
+            return_value=make_awaitable(("", 1))
+        )
+
+        d = handler._remote_join(
+            None,
+            ["other.example.com"],
+            "roomid",
+            UserID.from_string(u1),
+            {"membership": "join"},
+        )
+
+        self.pump()
+
+        # The request failed with a SynapseError saying the resource limit was
+        # exceeded.
+        f = self.get_failure(d, SynapseError)
+        self.assertEqual(f.value.code, 400, f.value)
+        self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+
+    def test_join_too_large_admin(self):
+        # An admin should be able to join rooms where a complexity check fails.
+
+        u1 = self.register_user("u1", "pass", admin=True)
+
+        handler = self.hs.get_room_member_handler()
+        fed_transport = self.hs.get_federation_transport_client()
+
+        # Mock out some things, because we don't want to test the whole join
+        fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
+        handler.federation_handler.do_invite_join = Mock(
+            return_value=make_awaitable(("", 1))
+        )
+
+        d = handler._remote_join(
+            None,
+            ["other.example.com"],
+            "roomid",
+            UserID.from_string(u1),
+            {"membership": "join"},
+        )
+
+        self.pump()
+
+        # The request success since the user is an admin
+        self.get_success(d)
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index d1bd18da39..5f512ff8bf 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -47,13 +47,13 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
         mock_send_transaction = (
             self.hs.get_federation_transport_client().send_transaction
         )
-        mock_send_transaction.return_value = defer.succeed({})
+        mock_send_transaction.return_value = make_awaitable({})
 
         sender = self.hs.get_federation_sender()
         receipt = ReadReceipt(
             "room_id", "m.read", "user_id", ["event_id"], {"ts": 1234}
         )
-        self.successResultOf(sender.send_read_receipt(receipt))
+        self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt)))
 
         self.pump()
 
@@ -87,13 +87,13 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
         mock_send_transaction = (
             self.hs.get_federation_transport_client().send_transaction
         )
-        mock_send_transaction.return_value = defer.succeed({})
+        mock_send_transaction.return_value = make_awaitable({})
 
         sender = self.hs.get_federation_sender()
         receipt = ReadReceipt(
             "room_id", "m.read", "user_id", ["event_id"], {"ts": 1234}
         )
-        self.successResultOf(sender.send_read_receipt(receipt))
+        self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt)))
 
         self.pump()
 
@@ -125,7 +125,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
         receipt = ReadReceipt(
             "room_id", "m.read", "user_id", ["other_id"], {"ts": 1234}
         )
-        self.successResultOf(sender.send_read_receipt(receipt))
+        self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt)))
         self.pump()
         mock_send_transaction.assert_not_called()
 
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index ebabe9a7d6..628f7d8db0 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -19,6 +19,7 @@ from twisted.internet import defer
 
 from synapse.handlers.appservice import ApplicationServicesHandler
 
+from tests.test_utils import make_awaitable
 from tests.utils import MockClock
 
 from .. import unittest
@@ -117,7 +118,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
             self._mkservice_alias(is_interested_in_alias=False),
         ]
 
-        self.mock_as_api.query_alias.return_value = defer.succeed(True)
+        self.mock_as_api.query_alias.return_value = make_awaitable(True)
         self.mock_store.get_app_services.return_value = services
         self.mock_store.get_association_from_room_alias.return_value = defer.succeed(
             Mock(room_id=room_id, servers=servers)
@@ -135,7 +136,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
 
     def _mkservice(self, is_interested):
         service = Mock()
-        service.is_interested.return_value = defer.succeed(is_interested)
+        service.is_interested.return_value = make_awaitable(is_interested)
         service.token = "mock_service_token"
         service.url = "mock_service_url"
         return service
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index 00bb776271..bc0c5aefdc 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -16,8 +16,6 @@
 
 from mock import Mock
 
-from twisted.internet import defer
-
 import synapse
 import synapse.api.errors
 from synapse.api.constants import EventTypes
@@ -26,6 +24,7 @@ from synapse.rest.client.v1 import directory, login, room
 from synapse.types import RoomAlias, create_requester
 
 from tests import unittest
+from tests.test_utils import make_awaitable
 
 
 class DirectoryTestCase(unittest.HomeserverTestCase):
@@ -71,7 +70,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
         self.assertEquals({"room_id": "!8765qwer:test", "servers": ["test"]}, result)
 
     def test_get_remote_association(self):
-        self.mock_federation.make_query.return_value = defer.succeed(
+        self.mock_federation.make_query.return_value = make_awaitable(
             {"room_id": "!8765qwer:test", "servers": ["test", "remote"]}
         )
 
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 4f1347cd25..d70e1fc608 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -24,6 +24,7 @@ from synapse.handlers.profile import MasterProfileHandler
 from synapse.types import UserID
 
 from tests import unittest
+from tests.test_utils import make_awaitable
 from tests.utils import setup_test_homeserver
 
 
@@ -138,7 +139,7 @@ class ProfileTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_get_other_name(self):
-        self.mock_federation.make_query.return_value = defer.succeed(
+        self.mock_federation.make_query.return_value = make_awaitable(
             {"displayname": "Alice"}
         )
 
diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py
index d9d312f0fb..0e666492f6 100644
--- a/tests/handlers/test_stats.py
+++ b/tests/handlers/test_stats.py
@@ -15,7 +15,7 @@
 
 from synapse.rest import admin
 from synapse.rest.client.v1 import login, room
-from synapse.storage.data_stores.main import stats
+from synapse.storage.databases.main import stats
 
 from tests import unittest
 
@@ -42,36 +42,36 @@ class StatsRoomTests(unittest.HomeserverTestCase):
         Add the background updates we need to run.
         """
         # Ugh, have to reset this flag
-        self.store.db.updates._all_done = False
+        self.store.db_pool.updates._all_done = False
 
         self.get_success(
-            self.store.db.simple_insert(
+            self.store.db_pool.simple_insert(
                 "background_updates",
                 {"update_name": "populate_stats_prepare", "progress_json": "{}"},
             )
         )
         self.get_success(
-            self.store.db.simple_insert(
+            self.store.db_pool.simple_insert(
                 "background_updates",
                 {
-                    "update_name": "populate_stats_process_rooms",
+                    "update_name": "populate_stats_process_rooms_2",
                     "progress_json": "{}",
                     "depends_on": "populate_stats_prepare",
                 },
             )
         )
         self.get_success(
-            self.store.db.simple_insert(
+            self.store.db_pool.simple_insert(
                 "background_updates",
                 {
                     "update_name": "populate_stats_process_users",
                     "progress_json": "{}",
-                    "depends_on": "populate_stats_process_rooms",
+                    "depends_on": "populate_stats_process_rooms_2",
                 },
             )
         )
         self.get_success(
-            self.store.db.simple_insert(
+            self.store.db_pool.simple_insert(
                 "background_updates",
                 {
                     "update_name": "populate_stats_cleanup",
@@ -82,7 +82,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
         )
 
     def get_all_room_state(self):
-        return self.store.db.simple_select_list(
+        return self.store.db_pool.simple_select_list(
             "room_stats_state", None, retcols=("name", "topic", "canonical_alias")
         )
 
@@ -96,7 +96,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
         end_ts = self.store.quantise_stats_time(self.reactor.seconds() * 1000)
 
         return self.get_success(
-            self.store.db.simple_select_one(
+            self.store.db_pool.simple_select_one(
                 table + "_historical",
                 {id_col: stat_id, end_ts: end_ts},
                 cols,
@@ -109,10 +109,10 @@ class StatsRoomTests(unittest.HomeserverTestCase):
         self._add_background_updates()
 
         while not self.get_success(
-            self.store.db.updates.has_completed_background_updates()
+            self.store.db_pool.updates.has_completed_background_updates()
         ):
             self.get_success(
-                self.store.db.updates.do_next_background_update(100), by=0.1
+                self.store.db_pool.updates.do_next_background_update(100), by=0.1
             )
 
     def test_initial_room(self):
@@ -146,10 +146,10 @@ class StatsRoomTests(unittest.HomeserverTestCase):
         self._add_background_updates()
 
         while not self.get_success(
-            self.store.db.updates.has_completed_background_updates()
+            self.store.db_pool.updates.has_completed_background_updates()
         ):
             self.get_success(
-                self.store.db.updates.do_next_background_update(100), by=0.1
+                self.store.db_pool.updates.do_next_background_update(100), by=0.1
             )
 
         r = self.get_success(self.get_all_room_state())
@@ -186,9 +186,9 @@ class StatsRoomTests(unittest.HomeserverTestCase):
         # the position that the deltas should begin at, once they take over.
         self.hs.config.stats_enabled = True
         self.handler.stats_enabled = True
-        self.store.db.updates._all_done = False
+        self.store.db_pool.updates._all_done = False
         self.get_success(
-            self.store.db.simple_update_one(
+            self.store.db_pool.simple_update_one(
                 table="stats_incremental_position",
                 keyvalues={},
                 updatevalues={"stream_id": 0},
@@ -196,17 +196,17 @@ class StatsRoomTests(unittest.HomeserverTestCase):
         )
 
         self.get_success(
-            self.store.db.simple_insert(
+            self.store.db_pool.simple_insert(
                 "background_updates",
                 {"update_name": "populate_stats_prepare", "progress_json": "{}"},
             )
         )
 
         while not self.get_success(
-            self.store.db.updates.has_completed_background_updates()
+            self.store.db_pool.updates.has_completed_background_updates()
         ):
             self.get_success(
-                self.store.db.updates.do_next_background_update(100), by=0.1
+                self.store.db_pool.updates.do_next_background_update(100), by=0.1
             )
 
         # Now, before the table is actually ingested, add some more events.
@@ -217,28 +217,31 @@ class StatsRoomTests(unittest.HomeserverTestCase):
 
         # Now do the initial ingestion.
         self.get_success(
-            self.store.db.simple_insert(
+            self.store.db_pool.simple_insert(
                 "background_updates",
-                {"update_name": "populate_stats_process_rooms", "progress_json": "{}"},
+                {
+                    "update_name": "populate_stats_process_rooms_2",
+                    "progress_json": "{}",
+                },
             )
         )
         self.get_success(
-            self.store.db.simple_insert(
+            self.store.db_pool.simple_insert(
                 "background_updates",
                 {
                     "update_name": "populate_stats_cleanup",
                     "progress_json": "{}",
-                    "depends_on": "populate_stats_process_rooms",
+                    "depends_on": "populate_stats_process_rooms_2",
                 },
             )
         )
 
-        self.store.db.updates._all_done = False
+        self.store.db_pool.updates._all_done = False
         while not self.get_success(
-            self.store.db.updates.has_completed_background_updates()
+            self.store.db_pool.updates.has_completed_background_updates()
         ):
             self.get_success(
-                self.store.db.updates.do_next_background_update(100), by=0.1
+                self.store.db_pool.updates.do_next_background_update(100), by=0.1
             )
 
         self.reactor.advance(86401)
@@ -346,6 +349,37 @@ class StatsRoomTests(unittest.HomeserverTestCase):
 
         self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1)
 
+    def test_updating_profile_information_does_not_increase_joined_members_count(self):
+        """
+        Check that the joined_members count does not increase when a user changes their
+        profile information (which is done by sending another join membership event into
+        the room.
+        """
+        self._perform_background_initial_update()
+
+        # Create a user and room
+        u1 = self.register_user("u1", "pass")
+        u1token = self.login("u1", "pass")
+        r1 = self.helper.create_room_as(u1, tok=u1token)
+
+        # Get the current room stats
+        r1stats_ante = self._get_current_stats("room", r1)
+
+        # Send a profile update into the room
+        new_profile = {"displayname": "bob"}
+        self.helper.change_membership(
+            r1, u1, u1, "join", extra_data=new_profile, tok=u1token
+        )
+
+        # Get the new room stats
+        r1stats_post = self._get_current_stats("room", r1)
+
+        # Ensure that the user count did not changed
+        self.assertEqual(r1stats_post["joined_members"], r1stats_ante["joined_members"])
+        self.assertEqual(
+            r1stats_post["local_users_in_room"], r1stats_ante["local_users_in_room"]
+        )
+
     def test_send_state_event_nonoverwriting(self):
         """
         When we send a non-overwriting state event, it increments total_events AND current_state_events
@@ -669,15 +703,15 @@ class StatsRoomTests(unittest.HomeserverTestCase):
 
         # preparation stage of the initial background update
         # Ugh, have to reset this flag
-        self.store.db.updates._all_done = False
+        self.store.db_pool.updates._all_done = False
 
         self.get_success(
-            self.store.db.simple_delete(
+            self.store.db_pool.simple_delete(
                 "room_stats_current", {"1": 1}, "test_delete_stats"
             )
         )
         self.get_success(
-            self.store.db.simple_delete(
+            self.store.db_pool.simple_delete(
                 "user_stats_current", {"1": 1}, "test_delete_stats"
             )
         )
@@ -689,29 +723,29 @@ class StatsRoomTests(unittest.HomeserverTestCase):
 
         # now do the background updates
 
-        self.store.db.updates._all_done = False
+        self.store.db_pool.updates._all_done = False
         self.get_success(
-            self.store.db.simple_insert(
+            self.store.db_pool.simple_insert(
                 "background_updates",
                 {
-                    "update_name": "populate_stats_process_rooms",
+                    "update_name": "populate_stats_process_rooms_2",
                     "progress_json": "{}",
                     "depends_on": "populate_stats_prepare",
                 },
             )
         )
         self.get_success(
-            self.store.db.simple_insert(
+            self.store.db_pool.simple_insert(
                 "background_updates",
                 {
                     "update_name": "populate_stats_process_users",
                     "progress_json": "{}",
-                    "depends_on": "populate_stats_process_rooms",
+                    "depends_on": "populate_stats_process_rooms_2",
                 },
             )
         )
         self.get_success(
-            self.store.db.simple_insert(
+            self.store.db_pool.simple_insert(
                 "background_updates",
                 {
                     "update_name": "populate_stats_cleanup",
@@ -722,10 +756,10 @@ class StatsRoomTests(unittest.HomeserverTestCase):
         )
 
         while not self.get_success(
-            self.store.db.updates.has_completed_background_updates()
+            self.store.db_pool.updates.has_completed_background_updates()
         ):
             self.get_success(
-                self.store.db.updates.do_next_background_update(100), by=0.1
+                self.store.db_pool.updates.do_next_background_update(100), by=0.1
             )
 
         r1stats_complete = self._get_current_stats("room", r1)
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 23fcc372dd..31ed89a5cd 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -339,7 +339,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
 
     def get_users_in_public_rooms(self):
         r = self.get_success(
-            self.store.db.simple_select_list(
+            self.store.db_pool.simple_select_list(
                 "users_in_public_rooms", None, ("user_id", "room_id")
             )
         )
@@ -350,7 +350,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
 
     def get_users_who_share_private_rooms(self):
         return self.get_success(
-            self.store.db.simple_select_list(
+            self.store.db_pool.simple_select_list(
                 "users_who_share_private_rooms",
                 None,
                 ["user_id", "other_user_id", "room_id"],
@@ -362,10 +362,10 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         Add the background updates we need to run.
         """
         # Ugh, have to reset this flag
-        self.store.db.updates._all_done = False
+        self.store.db_pool.updates._all_done = False
 
         self.get_success(
-            self.store.db.simple_insert(
+            self.store.db_pool.simple_insert(
                 "background_updates",
                 {
                     "update_name": "populate_user_directory_createtables",
@@ -374,7 +374,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
             )
         )
         self.get_success(
-            self.store.db.simple_insert(
+            self.store.db_pool.simple_insert(
                 "background_updates",
                 {
                     "update_name": "populate_user_directory_process_rooms",
@@ -384,7 +384,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
             )
         )
         self.get_success(
-            self.store.db.simple_insert(
+            self.store.db_pool.simple_insert(
                 "background_updates",
                 {
                     "update_name": "populate_user_directory_process_users",
@@ -394,7 +394,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
             )
         )
         self.get_success(
-            self.store.db.simple_insert(
+            self.store.db_pool.simple_insert(
                 "background_updates",
                 {
                     "update_name": "populate_user_directory_cleanup",
@@ -437,10 +437,10 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         self._add_background_updates()
 
         while not self.get_success(
-            self.store.db.updates.has_completed_background_updates()
+            self.store.db_pool.updates.has_completed_background_updates()
         ):
             self.get_success(
-                self.store.db.updates.do_next_background_update(100), by=0.1
+                self.store.db_pool.updates.do_next_background_update(100), by=0.1
             )
 
         shares_private = self.get_users_who_share_private_rooms()
@@ -476,10 +476,10 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         self._add_background_updates()
 
         while not self.get_success(
-            self.store.db.updates.has_completed_background_updates()
+            self.store.db_pool.updates.has_completed_background_updates()
         ):
             self.get_success(
-                self.store.db.updates.do_next_background_update(100), by=0.1
+                self.store.db_pool.updates.do_next_background_update(100), by=0.1
             )
 
         shares_private = self.get_users_who_share_private_rooms()
diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py
index fff4f0cbf4..ac598249e4 100644
--- a/tests/http/test_fedclient.py
+++ b/tests/http/test_fedclient.py
@@ -58,7 +58,9 @@ class FederationClientTests(HomeserverTestCase):
         @defer.inlineCallbacks
         def do_request():
             with LoggingContext("one") as context:
-                fetch_d = self.cl.get_json("testserv:8008", "foo/bar")
+                fetch_d = defer.ensureDeferred(
+                    self.cl.get_json("testserv:8008", "foo/bar")
+                )
 
                 # Nothing happened yet
                 self.assertNoResult(fetch_d)
@@ -120,7 +122,9 @@ class FederationClientTests(HomeserverTestCase):
         """
         If the DNS lookup returns an error, it will bubble up.
         """
-        d = self.cl.get_json("testserv2:8008", "foo/bar", timeout=10000)
+        d = defer.ensureDeferred(
+            self.cl.get_json("testserv2:8008", "foo/bar", timeout=10000)
+        )
         self.pump()
 
         f = self.failureResultOf(d)
@@ -128,7 +132,9 @@ class FederationClientTests(HomeserverTestCase):
         self.assertIsInstance(f.value.inner_exception, DNSLookupError)
 
     def test_client_connection_refused(self):
-        d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
+        d = defer.ensureDeferred(
+            self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
+        )
 
         self.pump()
 
@@ -154,7 +160,9 @@ class FederationClientTests(HomeserverTestCase):
         If the HTTP request is not connected and is timed out, it'll give a
         ConnectingCancelledError or TimeoutError.
         """
-        d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
+        d = defer.ensureDeferred(
+            self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
+        )
 
         self.pump()
 
@@ -184,7 +192,9 @@ class FederationClientTests(HomeserverTestCase):
         If the HTTP request is connected, but gets no response before being
         timed out, it'll give a ResponseNeverReceived.
         """
-        d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
+        d = defer.ensureDeferred(
+            self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
+        )
 
         self.pump()
 
@@ -226,7 +236,7 @@ class FederationClientTests(HomeserverTestCase):
         # Try making a GET request to a blacklisted IPv4 address
         # ------------------------------------------------------
         # Make the request
-        d = cl.get_json("internal:8008", "foo/bar", timeout=10000)
+        d = defer.ensureDeferred(cl.get_json("internal:8008", "foo/bar", timeout=10000))
 
         # Nothing happened yet
         self.assertNoResult(d)
@@ -244,7 +254,9 @@ class FederationClientTests(HomeserverTestCase):
         # Try making a POST request to a blacklisted IPv6 address
         # -------------------------------------------------------
         # Make the request
-        d = cl.post_json("internalv6:8008", "foo/bar", timeout=10000)
+        d = defer.ensureDeferred(
+            cl.post_json("internalv6:8008", "foo/bar", timeout=10000)
+        )
 
         # Nothing has happened yet
         self.assertNoResult(d)
@@ -263,7 +275,7 @@ class FederationClientTests(HomeserverTestCase):
         # Try making a GET request to a non-blacklisted IPv4 address
         # ----------------------------------------------------------
         # Make the request
-        d = cl.post_json("fine:8008", "foo/bar", timeout=10000)
+        d = defer.ensureDeferred(cl.post_json("fine:8008", "foo/bar", timeout=10000))
 
         # Nothing has happened yet
         self.assertNoResult(d)
@@ -286,7 +298,7 @@ class FederationClientTests(HomeserverTestCase):
         request = MatrixFederationRequest(
             method="GET", destination="testserv:8008", path="foo/bar"
         )
-        d = self.cl._send_request(request, timeout=10000)
+        d = defer.ensureDeferred(self.cl._send_request(request, timeout=10000))
 
         self.pump()
 
@@ -310,7 +322,9 @@ class FederationClientTests(HomeserverTestCase):
         If the HTTP request is connected, but gets no response before being
         timed out, it'll give a ResponseNeverReceived.
         """
-        d = self.cl.post_json("testserv:8008", "foo/bar", timeout=10000)
+        d = defer.ensureDeferred(
+            self.cl.post_json("testserv:8008", "foo/bar", timeout=10000)
+        )
 
         self.pump()
 
@@ -342,7 +356,9 @@ class FederationClientTests(HomeserverTestCase):
         requiring a trailing slash. We need to retry the request with a
         trailing slash. Workaround for Synapse <= v0.99.3, explained in #3622.
         """
-        d = self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True)
+        d = defer.ensureDeferred(
+            self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True)
+        )
 
         # Send the request
         self.pump()
@@ -395,7 +411,9 @@ class FederationClientTests(HomeserverTestCase):
 
         See test_client_requires_trailing_slashes() for context.
         """
-        d = self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True)
+        d = defer.ensureDeferred(
+            self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True)
+        )
 
         # Send the request
         self.pump()
@@ -432,7 +450,11 @@ class FederationClientTests(HomeserverTestCase):
         self.failureResultOf(d)
 
     def test_client_sends_body(self):
-        self.cl.post_json("testserv:8008", "foo/bar", timeout=10000, data={"a": "b"})
+        defer.ensureDeferred(
+            self.cl.post_json(
+                "testserv:8008", "foo/bar", timeout=10000, data={"a": "b"}
+            )
+        )
 
         self.pump()
 
@@ -453,7 +475,7 @@ class FederationClientTests(HomeserverTestCase):
 
     def test_closes_connection(self):
         """Check that the client closes unused HTTP connections"""
-        d = self.cl.get_json("testserv:8008", "foo/bar")
+        d = defer.ensureDeferred(self.cl.get_json("testserv:8008", "foo/bar"))
 
         self.pump()
 
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 06575ba0a6..ae60874ec3 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -65,7 +65,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
 
         # Since we use sqlite in memory databases we need to make sure the
         # databases objects are the same.
-        self.worker_hs.get_datastore().db = hs.get_datastore().db
+        self.worker_hs.get_datastore().db_pool = hs.get_datastore().db_pool
 
         self.test_handler = self._build_replication_data_handler()
         self.worker_hs.replication_data_handler = self.test_handler
@@ -198,7 +198,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         self.streamer = self.hs.get_replication_streamer()
 
         store = self.hs.get_datastore()
-        self.database = store.db
+        self.database_pool = store.db_pool
 
         self.reactor.lookups["testserv"] = "1.2.3.4"
 
@@ -254,7 +254,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         )
 
         store = worker_hs.get_datastore()
-        store.db._db_pool = self.database._db_pool
+        store.db_pool._db_pool = self.database_pool._db_pool
 
         repl_handler = ReplicationCommandHandler(worker_hs)
         client = ClientReplicationStreamProtocol(
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 1a88c7fb80..0b5204654c 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -366,7 +366,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
         state_handler = self.hs.get_state_handler()
         context = self.get_success(state_handler.compute_event_context(event))
 
-        self.master_store.add_push_actions_to_staging(
-            event.event_id, {user_id: actions for user_id, actions in push_actions}
+        self.get_success(
+            self.master_store.add_push_actions_to_staging(
+                event.event_id, {user_id: actions for user_id, actions in push_actions}
+            )
         )
         return event, context
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index 8d4dbf232e..83f9aa291c 100644
--- a/tests/replication/test_federation_sender_shard.py
+++ b/tests/replication/test_federation_sender_shard.py
@@ -16,8 +16,6 @@ import logging
 
 from mock import Mock
 
-from twisted.internet import defer
-
 from synapse.api.constants import EventTypes, Membership
 from synapse.events.builder import EventBuilderFactory
 from synapse.rest.admin import register_servlets_for_client_rest_resource
@@ -25,6 +23,7 @@ from synapse.rest.client.v1 import login, room
 from synapse.types import UserID
 
 from tests.replication._base import BaseMultiWorkerStreamTestCase
+from tests.test_utils import make_awaitable
 
 logger = logging.getLogger(__name__)
 
@@ -46,7 +45,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
         new event.
         """
         mock_client = Mock(spec=["put_json"])
-        mock_client.put_json.side_effect = lambda *_, **__: defer.succeed({})
+        mock_client.put_json.side_effect = lambda *_, **__: make_awaitable({})
 
         self.make_worker_hs(
             "synapse.app.federation_sender",
@@ -74,7 +73,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
         new events.
         """
         mock_client1 = Mock(spec=["put_json"])
-        mock_client1.put_json.side_effect = lambda *_, **__: defer.succeed({})
+        mock_client1.put_json.side_effect = lambda *_, **__: make_awaitable({})
         self.make_worker_hs(
             "synapse.app.federation_sender",
             {
@@ -86,7 +85,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
         )
 
         mock_client2 = Mock(spec=["put_json"])
-        mock_client2.put_json.side_effect = lambda *_, **__: defer.succeed({})
+        mock_client2.put_json.side_effect = lambda *_, **__: make_awaitable({})
         self.make_worker_hs(
             "synapse.app.federation_sender",
             {
@@ -137,7 +136,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
         new typing EDUs.
         """
         mock_client1 = Mock(spec=["put_json"])
-        mock_client1.put_json.side_effect = lambda *_, **__: defer.succeed({})
+        mock_client1.put_json.side_effect = lambda *_, **__: make_awaitable({})
         self.make_worker_hs(
             "synapse.app.federation_sender",
             {
@@ -149,7 +148,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
         )
 
         mock_client2 = Mock(spec=["put_json"])
-        mock_client2.put_json.side_effect = lambda *_, **__: defer.succeed({})
+        mock_client2.put_json.side_effect = lambda *_, **__: make_awaitable({})
         self.make_worker_hs(
             "synapse.app.federation_sender",
             {
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index b1a4decced..0f1144fe1e 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -178,7 +178,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
 
         self.fetches = []
 
-        def get_file(destination, path, output_stream, args=None, max_size=None):
+        async def get_file(destination, path, output_stream, args=None, max_size=None):
             """
             Returns tuple[int,dict,str,int] of file length, response headers,
             absolute URI, and response code.
@@ -192,7 +192,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
             d = Deferred()
             d.addCallback(write_to)
             self.fetches.append((d, destination, path, args))
-            return make_deferred_yieldable(d)
+            return await make_deferred_yieldable(d)
 
         client = Mock()
         client.get_file = get_file
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 946f06d151..408c568a27 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -1,1447 +1,1500 @@
-# -*- coding: utf-8 -*-

-# Copyright 2020 Dirk Klimpel

-#

-# Licensed under the Apache License, Version 2.0 (the "License");

-# you may not use this file except in compliance with the License.

-# You may obtain a copy of the License at

-#

-#     http://www.apache.org/licenses/LICENSE-2.0

-#

-# Unless required by applicable law or agreed to in writing, software

-# distributed under the License is distributed on an "AS IS" BASIS,

-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

-# See the License for the specific language governing permissions and

-# limitations under the License.

-

-import json

-import urllib.parse

-from typing import List, Optional

-

-from mock import Mock

-

-import synapse.rest.admin

-from synapse.api.errors import Codes

-from synapse.rest.client.v1 import directory, events, login, room

-

-from tests import unittest

-

-"""Tests admin REST events for /rooms paths."""

-

-

-class ShutdownRoomTestCase(unittest.HomeserverTestCase):

-    servlets = [

-        synapse.rest.admin.register_servlets_for_client_rest_resource,

-        login.register_servlets,

-        events.register_servlets,

-        room.register_servlets,

-        room.register_deprecated_servlets,

-    ]

-

-    def prepare(self, reactor, clock, hs):

-        self.event_creation_handler = hs.get_event_creation_handler()

-        hs.config.user_consent_version = "1"

-

-        consent_uri_builder = Mock()

-        consent_uri_builder.build_user_consent_uri.return_value = "http://example.com"

-        self.event_creation_handler._consent_uri_builder = consent_uri_builder

-

-        self.store = hs.get_datastore()

-

-        self.admin_user = self.register_user("admin", "pass", admin=True)

-        self.admin_user_tok = self.login("admin", "pass")

-

-        self.other_user = self.register_user("user", "pass")

-        self.other_user_token = self.login("user", "pass")

-

-        # Mark the admin user as having consented

-        self.get_success(self.store.user_set_consent_version(self.admin_user, "1"))

-

-    def test_shutdown_room_consent(self):

-        """Test that we can shutdown rooms with local users who have not

-        yet accepted the privacy policy. This used to fail when we tried to

-        force part the user from the old room.

-        """

-        self.event_creation_handler._block_events_without_consent_error = None

-

-        room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token)

-

-        # Assert one user in room

-        users_in_room = self.get_success(self.store.get_users_in_room(room_id))

-        self.assertEqual([self.other_user], users_in_room)

-

-        # Enable require consent to send events

-        self.event_creation_handler._block_events_without_consent_error = "Error"

-

-        # Assert that the user is getting consent error

-        self.helper.send(

-            room_id, body="foo", tok=self.other_user_token, expect_code=403

-        )

-

-        # Test that the admin can still send shutdown

-        url = "admin/shutdown_room/" + room_id

-        request, channel = self.make_request(

-            "POST",

-            url.encode("ascii"),

-            json.dumps({"new_room_user_id": self.admin_user}),

-            access_token=self.admin_user_tok,

-        )

-        self.render(request)

-

-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])

-

-        # Assert there is now no longer anyone in the room

-        users_in_room = self.get_success(self.store.get_users_in_room(room_id))

-        self.assertEqual([], users_in_room)

-

-    def test_shutdown_room_block_peek(self):

-        """Test that a world_readable room can no longer be peeked into after

-        it has been shut down.

-        """

-

-        self.event_creation_handler._block_events_without_consent_error = None

-

-        room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token)

-

-        # Enable world readable

-        url = "rooms/%s/state/m.room.history_visibility" % (room_id,)

-        request, channel = self.make_request(

-            "PUT",

-            url.encode("ascii"),

-            json.dumps({"history_visibility": "world_readable"}),

-            access_token=self.other_user_token,

-        )

-        self.render(request)

-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])

-

-        # Test that the admin can still send shutdown

-        url = "admin/shutdown_room/" + room_id

-        request, channel = self.make_request(

-            "POST",

-            url.encode("ascii"),

-            json.dumps({"new_room_user_id": self.admin_user}),

-            access_token=self.admin_user_tok,

-        )

-        self.render(request)

-

-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])

-

-        # Assert we can no longer peek into the room

-        self._assert_peek(room_id, expect_code=403)

-

-    def _assert_peek(self, room_id, expect_code):

-        """Assert that the admin user can (or cannot) peek into the room.

-        """

-

-        url = "rooms/%s/initialSync" % (room_id,)

-        request, channel = self.make_request(

-            "GET", url.encode("ascii"), access_token=self.admin_user_tok

-        )

-        self.render(request)

-        self.assertEqual(

-            expect_code, int(channel.result["code"]), msg=channel.result["body"]

-        )

-

-        url = "events?timeout=0&room_id=" + room_id

-        request, channel = self.make_request(

-            "GET", url.encode("ascii"), access_token=self.admin_user_tok

-        )

-        self.render(request)

-        self.assertEqual(

-            expect_code, int(channel.result["code"]), msg=channel.result["body"]

-        )

-

-

-class DeleteRoomTestCase(unittest.HomeserverTestCase):

-    servlets = [

-        synapse.rest.admin.register_servlets,

-        login.register_servlets,

-        events.register_servlets,

-        room.register_servlets,

-        room.register_deprecated_servlets,

-    ]

-

-    def prepare(self, reactor, clock, hs):

-        self.event_creation_handler = hs.get_event_creation_handler()

-        hs.config.user_consent_version = "1"

-

-        consent_uri_builder = Mock()

-        consent_uri_builder.build_user_consent_uri.return_value = "http://example.com"

-        self.event_creation_handler._consent_uri_builder = consent_uri_builder

-

-        self.store = hs.get_datastore()

-

-        self.admin_user = self.register_user("admin", "pass", admin=True)

-        self.admin_user_tok = self.login("admin", "pass")

-

-        self.other_user = self.register_user("user", "pass")

-        self.other_user_tok = self.login("user", "pass")

-

-        # Mark the admin user as having consented

-        self.get_success(self.store.user_set_consent_version(self.admin_user, "1"))

-

-        self.room_id = self.helper.create_room_as(

-            self.other_user, tok=self.other_user_tok

-        )

-        self.url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id

-

-    def test_requester_is_no_admin(self):

-        """

-        If the user is not a server admin, an error 403 is returned.

-        """

-

-        request, channel = self.make_request(

-            "POST", self.url, json.dumps({}), access_token=self.other_user_tok,

-        )

-        self.render(request)

-

-        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])

-        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])

-

-    def test_room_does_not_exist(self):

-        """

-        Check that unknown rooms/server return error 404.

-        """

-        url = "/_synapse/admin/v1/rooms/!unknown:test/delete"

-

-        request, channel = self.make_request(

-            "POST", url, json.dumps({}), access_token=self.admin_user_tok,

-        )

-        self.render(request)

-

-        self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])

-        self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])

-

-    def test_room_is_not_valid(self):

-        """

-        Check that invalid room names, return an error 400.

-        """

-        url = "/_synapse/admin/v1/rooms/invalidroom/delete"

-

-        request, channel = self.make_request(

-            "POST", url, json.dumps({}), access_token=self.admin_user_tok,

-        )

-        self.render(request)

-

-        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])

-        self.assertEqual(

-            "invalidroom is not a legal room ID", channel.json_body["error"],

-        )

-

-    def test_new_room_user_does_not_exist(self):

-        """

-        Tests that the user ID must be from local server but it does not have to exist.

-        """

-        body = json.dumps({"new_room_user_id": "@unknown:test"})

-

-        request, channel = self.make_request(

-            "POST",

-            self.url,

-            content=body.encode(encoding="utf_8"),

-            access_token=self.admin_user_tok,

-        )

-        self.render(request)

-

-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])

-        self.assertIn("new_room_id", channel.json_body)

-        self.assertIn("kicked_users", channel.json_body)

-        self.assertIn("failed_to_kick_users", channel.json_body)

-        self.assertIn("local_aliases", channel.json_body)

-

-    def test_new_room_user_is_not_local(self):

-        """

-        Check that only local users can create new room to move members.

-        """

-        body = json.dumps({"new_room_user_id": "@not:exist.bla"})

-

-        request, channel = self.make_request(

-            "POST",

-            self.url,

-            content=body.encode(encoding="utf_8"),

-            access_token=self.admin_user_tok,

-        )

-        self.render(request)

-

-        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])

-        self.assertEqual(

-            "User must be our own: @not:exist.bla", channel.json_body["error"],

-        )

-

-    def test_block_is_not_bool(self):

-        """

-        If parameter `block` is not boolean, return an error

-        """

-        body = json.dumps({"block": "NotBool"})

-

-        request, channel = self.make_request(

-            "POST",

-            self.url,

-            content=body.encode(encoding="utf_8"),

-            access_token=self.admin_user_tok,

-        )

-        self.render(request)

-

-        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])

-        self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])

-

-    def test_purge_room_and_block(self):

-        """Test to purge a room and block it.

-        Members will not be moved to a new room and will not receive a message.

-        """

-        # Test that room is not purged

-        with self.assertRaises(AssertionError):

-            self._is_purged(self.room_id)

-

-        # Test that room is not blocked

-        self._is_blocked(self.room_id, expect=False)

-

-        # Assert one user in room

-        self._is_member(room_id=self.room_id, user_id=self.other_user)

-

-        body = json.dumps({"block": True})

-

-        request, channel = self.make_request(

-            "POST",

-            self.url.encode("ascii"),

-            content=body.encode(encoding="utf_8"),

-            access_token=self.admin_user_tok,

-        )

-        self.render(request)

-

-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])

-        self.assertEqual(None, channel.json_body["new_room_id"])

-        self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])

-        self.assertIn("failed_to_kick_users", channel.json_body)

-        self.assertIn("local_aliases", channel.json_body)

-

-        self._is_purged(self.room_id)

-        self._is_blocked(self.room_id, expect=True)

-        self._has_no_members(self.room_id)

-

-    def test_purge_room_and_not_block(self):

-        """Test to purge a room and do not block it.

-        Members will not be moved to a new room and will not receive a message.

-        """

-        # Test that room is not purged

-        with self.assertRaises(AssertionError):

-            self._is_purged(self.room_id)

-

-        # Test that room is not blocked

-        self._is_blocked(self.room_id, expect=False)

-

-        # Assert one user in room

-        self._is_member(room_id=self.room_id, user_id=self.other_user)

-

-        body = json.dumps({"block": False})

-

-        request, channel = self.make_request(

-            "POST",

-            self.url.encode("ascii"),

-            content=body.encode(encoding="utf_8"),

-            access_token=self.admin_user_tok,

-        )

-        self.render(request)

-

-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])

-        self.assertEqual(None, channel.json_body["new_room_id"])

-        self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])

-        self.assertIn("failed_to_kick_users", channel.json_body)

-        self.assertIn("local_aliases", channel.json_body)

-

-        self._is_purged(self.room_id)

-        self._is_blocked(self.room_id, expect=False)

-        self._has_no_members(self.room_id)

-

-    def test_shutdown_room_consent(self):

-        """Test that we can shutdown rooms with local users who have not

-        yet accepted the privacy policy. This used to fail when we tried to

-        force part the user from the old room.

-        Members will be moved to a new room and will receive a message.

-        """

-        self.event_creation_handler._block_events_without_consent_error = None

-

-        # Assert one user in room

-        users_in_room = self.get_success(self.store.get_users_in_room(self.room_id))

-        self.assertEqual([self.other_user], users_in_room)

-

-        # Enable require consent to send events

-        self.event_creation_handler._block_events_without_consent_error = "Error"

-

-        # Assert that the user is getting consent error

-        self.helper.send(

-            self.room_id, body="foo", tok=self.other_user_tok, expect_code=403

-        )

-

-        # Test that room is not purged

-        with self.assertRaises(AssertionError):

-            self._is_purged(self.room_id)

-

-        # Assert one user in room

-        self._is_member(room_id=self.room_id, user_id=self.other_user)

-

-        # Test that the admin can still send shutdown

-        url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id

-        request, channel = self.make_request(

-            "POST",

-            url.encode("ascii"),

-            json.dumps({"new_room_user_id": self.admin_user}),

-            access_token=self.admin_user_tok,

-        )

-        self.render(request)

-

-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])

-        self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])

-        self.assertIn("new_room_id", channel.json_body)

-        self.assertIn("failed_to_kick_users", channel.json_body)

-        self.assertIn("local_aliases", channel.json_body)

-

-        # Test that member has moved to new room

-        self._is_member(

-            room_id=channel.json_body["new_room_id"], user_id=self.other_user

-        )

-

-        self._is_purged(self.room_id)

-        self._has_no_members(self.room_id)

-

-    def test_shutdown_room_block_peek(self):

-        """Test that a world_readable room can no longer be peeked into after

-        it has been shut down.

-        Members will be moved to a new room and will receive a message.

-        """

-        self.event_creation_handler._block_events_without_consent_error = None

-

-        # Enable world readable

-        url = "rooms/%s/state/m.room.history_visibility" % (self.room_id,)

-        request, channel = self.make_request(

-            "PUT",

-            url.encode("ascii"),

-            json.dumps({"history_visibility": "world_readable"}),

-            access_token=self.other_user_tok,

-        )

-        self.render(request)

-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])

-

-        # Test that room is not purged

-        with self.assertRaises(AssertionError):

-            self._is_purged(self.room_id)

-

-        # Assert one user in room

-        self._is_member(room_id=self.room_id, user_id=self.other_user)

-

-        # Test that the admin can still send shutdown

-        url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id

-        request, channel = self.make_request(

-            "POST",

-            url.encode("ascii"),

-            json.dumps({"new_room_user_id": self.admin_user}),

-            access_token=self.admin_user_tok,

-        )

-        self.render(request)

-

-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])

-        self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])

-        self.assertIn("new_room_id", channel.json_body)

-        self.assertIn("failed_to_kick_users", channel.json_body)

-        self.assertIn("local_aliases", channel.json_body)

-

-        # Test that member has moved to new room

-        self._is_member(

-            room_id=channel.json_body["new_room_id"], user_id=self.other_user

-        )

-

-        self._is_purged(self.room_id)

-        self._has_no_members(self.room_id)

-

-        # Assert we can no longer peek into the room

-        self._assert_peek(self.room_id, expect_code=403)

-

-    def _is_blocked(self, room_id, expect=True):

-        """Assert that the room is blocked or not

-        """

-        d = self.store.is_room_blocked(room_id)

-        if expect:

-            self.assertTrue(self.get_success(d))

-        else:

-            self.assertIsNone(self.get_success(d))

-

-    def _has_no_members(self, room_id):

-        """Assert there is now no longer anyone in the room

-        """

-        users_in_room = self.get_success(self.store.get_users_in_room(room_id))

-        self.assertEqual([], users_in_room)

-

-    def _is_member(self, room_id, user_id):

-        """Test that user is member of the room

-        """

-        users_in_room = self.get_success(self.store.get_users_in_room(room_id))

-        self.assertIn(user_id, users_in_room)

-

-    def _is_purged(self, room_id):

-        """Test that the following tables have been purged of all rows related to the room.

-        """

-        for table in (

-            "current_state_events",

-            "event_backward_extremities",

-            "event_forward_extremities",

-            "event_json",

-            "event_push_actions",

-            "event_search",

-            "events",

-            "group_rooms",

-            "public_room_list_stream",

-            "receipts_graph",

-            "receipts_linearized",

-            "room_aliases",

-            "room_depth",

-            "room_memberships",

-            "room_stats_state",

-            "room_stats_current",

-            "room_stats_historical",

-            "room_stats_earliest_token",

-            "rooms",

-            "stream_ordering_to_exterm",

-            "users_in_public_rooms",

-            "users_who_share_private_rooms",

-            "appservice_room_list",

-            "e2e_room_keys",

-            "event_push_summary",

-            "pusher_throttle",

-            "group_summary_rooms",

-            "local_invites",

-            "room_account_data",

-            "room_tags",

-            # "state_groups",  # Current impl leaves orphaned state groups around.

-            "state_groups_state",

-        ):

-            count = self.get_success(

-                self.store.db.simple_select_one_onecol(

-                    table=table,

-                    keyvalues={"room_id": room_id},

-                    retcol="COUNT(*)",

-                    desc="test_purge_room",

-                )

-            )

-

-            self.assertEqual(count, 0, msg="Rows not purged in {}".format(table))

-

-    def _assert_peek(self, room_id, expect_code):

-        """Assert that the admin user can (or cannot) peek into the room.

-        """

-

-        url = "rooms/%s/initialSync" % (room_id,)

-        request, channel = self.make_request(

-            "GET", url.encode("ascii"), access_token=self.admin_user_tok

-        )

-        self.render(request)

-        self.assertEqual(

-            expect_code, int(channel.result["code"]), msg=channel.result["body"]

-        )

-

-        url = "events?timeout=0&room_id=" + room_id

-        request, channel = self.make_request(

-            "GET", url.encode("ascii"), access_token=self.admin_user_tok

-        )

-        self.render(request)

-        self.assertEqual(

-            expect_code, int(channel.result["code"]), msg=channel.result["body"]

-        )

-

-

-class PurgeRoomTestCase(unittest.HomeserverTestCase):

-    """Test /purge_room admin API.

-    """

-

-    servlets = [

-        synapse.rest.admin.register_servlets,

-        login.register_servlets,

-        room.register_servlets,

-    ]

-

-    def prepare(self, reactor, clock, hs):

-        self.store = hs.get_datastore()

-

-        self.admin_user = self.register_user("admin", "pass", admin=True)

-        self.admin_user_tok = self.login("admin", "pass")

-

-    def test_purge_room(self):

-        room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)

-

-        # All users have to have left the room.

-        self.helper.leave(room_id, user=self.admin_user, tok=self.admin_user_tok)

-

-        url = "/_synapse/admin/v1/purge_room"

-        request, channel = self.make_request(

-            "POST",

-            url.encode("ascii"),

-            {"room_id": room_id},

-            access_token=self.admin_user_tok,

-        )

-        self.render(request)

-

-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])

-

-        # Test that the following tables have been purged of all rows related to the room.

-        for table in (

-            "current_state_events",

-            "event_backward_extremities",

-            "event_forward_extremities",

-            "event_json",

-            "event_push_actions",

-            "event_search",

-            "events",

-            "group_rooms",

-            "public_room_list_stream",

-            "receipts_graph",

-            "receipts_linearized",

-            "room_aliases",

-            "room_depth",

-            "room_memberships",

-            "room_stats_state",

-            "room_stats_current",

-            "room_stats_historical",

-            "room_stats_earliest_token",

-            "rooms",

-            "stream_ordering_to_exterm",

-            "users_in_public_rooms",

-            "users_who_share_private_rooms",

-            "appservice_room_list",

-            "e2e_room_keys",

-            "event_push_summary",

-            "pusher_throttle",

-            "group_summary_rooms",

-            "room_account_data",

-            "room_tags",

-            # "state_groups",  # Current impl leaves orphaned state groups around.

-            "state_groups_state",

-        ):

-            count = self.get_success(

-                self.store.db.simple_select_one_onecol(

-                    table=table,

-                    keyvalues={"room_id": room_id},

-                    retcol="COUNT(*)",

-                    desc="test_purge_room",

-                )

-            )

-

-            self.assertEqual(count, 0, msg="Rows not purged in {}".format(table))

-

-

-class RoomTestCase(unittest.HomeserverTestCase):

-    """Test /room admin API.

-    """

-

-    servlets = [

-        synapse.rest.admin.register_servlets,

-        login.register_servlets,

-        room.register_servlets,

-        directory.register_servlets,

-    ]

-

-    def prepare(self, reactor, clock, hs):

-        self.store = hs.get_datastore()

-

-        # Create user

-        self.admin_user = self.register_user("admin", "pass", admin=True)

-        self.admin_user_tok = self.login("admin", "pass")

-

-    def test_list_rooms(self):

-        """Test that we can list rooms"""

-        # Create 3 test rooms

-        total_rooms = 3

-        room_ids = []

-        for x in range(total_rooms):

-            room_id = self.helper.create_room_as(

-                self.admin_user, tok=self.admin_user_tok

-            )

-            room_ids.append(room_id)

-

-        # Request the list of rooms

-        url = "/_synapse/admin/v1/rooms"

-        request, channel = self.make_request(

-            "GET", url.encode("ascii"), access_token=self.admin_user_tok,

-        )

-        self.render(request)

-

-        # Check request completed successfully

-        self.assertEqual(200, int(channel.code), msg=channel.json_body)

-

-        # Check that response json body contains a "rooms" key

-        self.assertTrue(

-            "rooms" in channel.json_body,

-            msg="Response body does not " "contain a 'rooms' key",

-        )

-

-        # Check that 3 rooms were returned

-        self.assertEqual(3, len(channel.json_body["rooms"]), msg=channel.json_body)

-

-        # Check their room_ids match

-        returned_room_ids = [room["room_id"] for room in channel.json_body["rooms"]]

-        self.assertEqual(room_ids, returned_room_ids)

-

-        # Check that all fields are available

-        for r in channel.json_body["rooms"]:

-            self.assertIn("name", r)

-            self.assertIn("canonical_alias", r)

-            self.assertIn("joined_members", r)

-            self.assertIn("joined_local_members", r)

-            self.assertIn("version", r)

-            self.assertIn("creator", r)

-            self.assertIn("encryption", r)

-            self.assertIn("federatable", r)

-            self.assertIn("public", r)

-            self.assertIn("join_rules", r)

-            self.assertIn("guest_access", r)

-            self.assertIn("history_visibility", r)

-            self.assertIn("state_events", r)

-

-        # Check that the correct number of total rooms was returned

-        self.assertEqual(channel.json_body["total_rooms"], total_rooms)

-

-        # Check that the offset is correct

-        # Should be 0 as we aren't paginating

-        self.assertEqual(channel.json_body["offset"], 0)

-

-        # Check that the prev_batch parameter is not present

-        self.assertNotIn("prev_batch", channel.json_body)

-

-        # We shouldn't receive a next token here as there's no further rooms to show

-        self.assertNotIn("next_batch", channel.json_body)

-

-    def test_list_rooms_pagination(self):

-        """Test that we can get a full list of rooms through pagination"""

-        # Create 5 test rooms

-        total_rooms = 5

-        room_ids = []

-        for x in range(total_rooms):

-            room_id = self.helper.create_room_as(

-                self.admin_user, tok=self.admin_user_tok

-            )

-            room_ids.append(room_id)

-

-        # Set the name of the rooms so we get a consistent returned ordering

-        for idx, room_id in enumerate(room_ids):

-            self.helper.send_state(

-                room_id, "m.room.name", {"name": str(idx)}, tok=self.admin_user_tok,

-            )

-

-        # Request the list of rooms

-        returned_room_ids = []

-        start = 0

-        limit = 2

-

-        run_count = 0

-        should_repeat = True

-        while should_repeat:

-            run_count += 1

-

-            url = "/_synapse/admin/v1/rooms?from=%d&limit=%d&order_by=%s" % (

-                start,

-                limit,

-                "name",

-            )

-            request, channel = self.make_request(

-                "GET", url.encode("ascii"), access_token=self.admin_user_tok,

-            )

-            self.render(request)

-            self.assertEqual(

-                200, int(channel.result["code"]), msg=channel.result["body"]

-            )

-

-            self.assertTrue("rooms" in channel.json_body)

-            for r in channel.json_body["rooms"]:

-                returned_room_ids.append(r["room_id"])

-

-            # Check that the correct number of total rooms was returned

-            self.assertEqual(channel.json_body["total_rooms"], total_rooms)

-

-            # Check that the offset is correct

-            # We're only getting 2 rooms each page, so should be 2 * last run_count

-            self.assertEqual(channel.json_body["offset"], 2 * (run_count - 1))

-

-            if run_count > 1:

-                # Check the value of prev_batch is correct

-                self.assertEqual(channel.json_body["prev_batch"], 2 * (run_count - 2))

-

-            if "next_batch" not in channel.json_body:

-                # We have reached the end of the list

-                should_repeat = False

-            else:

-                # Make another query with an updated start value

-                start = channel.json_body["next_batch"]

-

-        # We should've queried the endpoint 3 times

-        self.assertEqual(

-            run_count,

-            3,

-            msg="Should've queried 3 times for 5 rooms with limit 2 per query",

-        )

-

-        # Check that we received all of the room ids

-        self.assertEqual(room_ids, returned_room_ids)

-

-        url = "/_synapse/admin/v1/rooms?from=%d&limit=%d" % (start, limit)

-        request, channel = self.make_request(

-            "GET", url.encode("ascii"), access_token=self.admin_user_tok,

-        )

-        self.render(request)

-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])

-

-    def test_correct_room_attributes(self):

-        """Test the correct attributes for a room are returned"""

-        # Create a test room

-        room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)

-

-        test_alias = "#test:test"

-        test_room_name = "something"

-

-        # Have another user join the room

-        user_2 = self.register_user("user4", "pass")

-        user_tok_2 = self.login("user4", "pass")

-        self.helper.join(room_id, user_2, tok=user_tok_2)

-

-        # Create a new alias to this room

-        url = "/_matrix/client/r0/directory/room/%s" % (urllib.parse.quote(test_alias),)

-        request, channel = self.make_request(

-            "PUT",

-            url.encode("ascii"),

-            {"room_id": room_id},

-            access_token=self.admin_user_tok,

-        )

-        self.render(request)

-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])

-

-        # Set this new alias as the canonical alias for this room

-        self.helper.send_state(

-            room_id,

-            "m.room.aliases",

-            {"aliases": [test_alias]},

-            tok=self.admin_user_tok,

-            state_key="test",

-        )

-        self.helper.send_state(

-            room_id,

-            "m.room.canonical_alias",

-            {"alias": test_alias},

-            tok=self.admin_user_tok,

-        )

-

-        # Set a name for the room

-        self.helper.send_state(

-            room_id, "m.room.name", {"name": test_room_name}, tok=self.admin_user_tok,

-        )

-

-        # Request the list of rooms

-        url = "/_synapse/admin/v1/rooms"

-        request, channel = self.make_request(

-            "GET", url.encode("ascii"), access_token=self.admin_user_tok,

-        )

-        self.render(request)

-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])

-

-        # Check that rooms were returned

-        self.assertTrue("rooms" in channel.json_body)

-        rooms = channel.json_body["rooms"]

-

-        # Check that only one room was returned

-        self.assertEqual(len(rooms), 1)

-

-        # And that the value of the total_rooms key was correct

-        self.assertEqual(channel.json_body["total_rooms"], 1)

-

-        # Check that the offset is correct

-        # We're not paginating, so should be 0

-        self.assertEqual(channel.json_body["offset"], 0)

-

-        # Check that there is no `prev_batch`

-        self.assertNotIn("prev_batch", channel.json_body)

-

-        # Check that there is no `next_batch`

-        self.assertNotIn("next_batch", channel.json_body)

-

-        # Check that all provided attributes are set

-        r = rooms[0]

-        self.assertEqual(room_id, r["room_id"])

-        self.assertEqual(test_room_name, r["name"])

-        self.assertEqual(test_alias, r["canonical_alias"])

-

-    def test_room_list_sort_order(self):

-        """Test room list sort ordering. alphabetical name versus number of members,

-        reversing the order, etc.

-        """

-

-        def _set_canonical_alias(room_id: str, test_alias: str, admin_user_tok: str):

-            # Create a new alias to this room

-            url = "/_matrix/client/r0/directory/room/%s" % (

-                urllib.parse.quote(test_alias),

-            )

-            request, channel = self.make_request(

-                "PUT",

-                url.encode("ascii"),

-                {"room_id": room_id},

-                access_token=admin_user_tok,

-            )

-            self.render(request)

-            self.assertEqual(

-                200, int(channel.result["code"]), msg=channel.result["body"]

-            )

-

-            # Set this new alias as the canonical alias for this room

-            self.helper.send_state(

-                room_id,

-                "m.room.aliases",

-                {"aliases": [test_alias]},

-                tok=admin_user_tok,

-                state_key="test",

-            )

-            self.helper.send_state(

-                room_id,

-                "m.room.canonical_alias",

-                {"alias": test_alias},

-                tok=admin_user_tok,

-            )

-

-        def _order_test(

-            order_type: str, expected_room_list: List[str], reverse: bool = False,

-        ):

-            """Request the list of rooms in a certain order. Assert that order is what

-            we expect

-

-            Args:

-                order_type: The type of ordering to give the server

-                expected_room_list: The list of room_ids in the order we expect to get

-                    back from the server

-            """

-            # Request the list of rooms in the given order

-            url = "/_synapse/admin/v1/rooms?order_by=%s" % (order_type,)

-            if reverse:

-                url += "&dir=b"

-            request, channel = self.make_request(

-                "GET", url.encode("ascii"), access_token=self.admin_user_tok,

-            )

-            self.render(request)

-            self.assertEqual(200, channel.code, msg=channel.json_body)

-

-            # Check that rooms were returned

-            self.assertTrue("rooms" in channel.json_body)

-            rooms = channel.json_body["rooms"]

-

-            # Check for the correct total_rooms value

-            self.assertEqual(channel.json_body["total_rooms"], 3)

-

-            # Check that the offset is correct

-            # We're not paginating, so should be 0

-            self.assertEqual(channel.json_body["offset"], 0)

-

-            # Check that there is no `prev_batch`

-            self.assertNotIn("prev_batch", channel.json_body)

-

-            # Check that there is no `next_batch`

-            self.assertNotIn("next_batch", channel.json_body)

-

-            # Check that rooms were returned in alphabetical order

-            returned_order = [r["room_id"] for r in rooms]

-            self.assertListEqual(expected_room_list, returned_order)  # order is checked

-

-        # Create 3 test rooms

-        room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)

-        room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)

-        room_id_3 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)

-

-        # Set room names in alphabetical order. room 1 -> A, 2 -> B, 3 -> C

-        self.helper.send_state(

-            room_id_1, "m.room.name", {"name": "A"}, tok=self.admin_user_tok,

-        )

-        self.helper.send_state(

-            room_id_2, "m.room.name", {"name": "B"}, tok=self.admin_user_tok,

-        )

-        self.helper.send_state(

-            room_id_3, "m.room.name", {"name": "C"}, tok=self.admin_user_tok,

-        )

-

-        # Set room canonical room aliases

-        _set_canonical_alias(room_id_1, "#A_alias:test", self.admin_user_tok)

-        _set_canonical_alias(room_id_2, "#B_alias:test", self.admin_user_tok)

-        _set_canonical_alias(room_id_3, "#C_alias:test", self.admin_user_tok)

-

-        # Set room member size in the reverse order. room 1 -> 1 member, 2 -> 2, 3 -> 3

-        user_1 = self.register_user("bob1", "pass")

-        user_1_tok = self.login("bob1", "pass")

-        self.helper.join(room_id_2, user_1, tok=user_1_tok)

-

-        user_2 = self.register_user("bob2", "pass")

-        user_2_tok = self.login("bob2", "pass")

-        self.helper.join(room_id_3, user_2, tok=user_2_tok)

-

-        user_3 = self.register_user("bob3", "pass")

-        user_3_tok = self.login("bob3", "pass")

-        self.helper.join(room_id_3, user_3, tok=user_3_tok)

-

-        # Test different sort orders, with forward and reverse directions

-        _order_test("name", [room_id_1, room_id_2, room_id_3])

-        _order_test("name", [room_id_3, room_id_2, room_id_1], reverse=True)

-

-        _order_test("canonical_alias", [room_id_1, room_id_2, room_id_3])

-        _order_test("canonical_alias", [room_id_3, room_id_2, room_id_1], reverse=True)

-

-        _order_test("joined_members", [room_id_3, room_id_2, room_id_1])

-        _order_test("joined_members", [room_id_1, room_id_2, room_id_3], reverse=True)

-

-        _order_test("joined_local_members", [room_id_3, room_id_2, room_id_1])

-        _order_test(

-            "joined_local_members", [room_id_1, room_id_2, room_id_3], reverse=True

-        )

-

-        _order_test("version", [room_id_1, room_id_2, room_id_3])

-        _order_test("version", [room_id_1, room_id_2, room_id_3], reverse=True)

-

-        _order_test("creator", [room_id_1, room_id_2, room_id_3])

-        _order_test("creator", [room_id_1, room_id_2, room_id_3], reverse=True)

-

-        _order_test("encryption", [room_id_1, room_id_2, room_id_3])

-        _order_test("encryption", [room_id_1, room_id_2, room_id_3], reverse=True)

-

-        _order_test("federatable", [room_id_1, room_id_2, room_id_3])

-        _order_test("federatable", [room_id_1, room_id_2, room_id_3], reverse=True)

-

-        _order_test("public", [room_id_1, room_id_2, room_id_3])

-        # Different sort order of SQlite and PostreSQL

-        # _order_test("public", [room_id_3, room_id_2, room_id_1], reverse=True)

-

-        _order_test("join_rules", [room_id_1, room_id_2, room_id_3])

-        _order_test("join_rules", [room_id_1, room_id_2, room_id_3], reverse=True)

-

-        _order_test("guest_access", [room_id_1, room_id_2, room_id_3])

-        _order_test("guest_access", [room_id_1, room_id_2, room_id_3], reverse=True)

-

-        _order_test("history_visibility", [room_id_1, room_id_2, room_id_3])

-        _order_test(

-            "history_visibility", [room_id_1, room_id_2, room_id_3], reverse=True

-        )

-

-        _order_test("state_events", [room_id_3, room_id_2, room_id_1])

-        _order_test("state_events", [room_id_1, room_id_2, room_id_3], reverse=True)

-

-    def test_search_term(self):

-        """Test that searching for a room works correctly"""

-        # Create two test rooms

-        room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)

-        room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)

-

-        room_name_1 = "something"

-        room_name_2 = "else"

-

-        # Set the name for each room

-        self.helper.send_state(

-            room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok,

-        )

-        self.helper.send_state(

-            room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok,

-        )

-

-        def _search_test(

-            expected_room_id: Optional[str],

-            search_term: str,

-            expected_http_code: int = 200,

-        ):

-            """Search for a room and check that the returned room's id is a match

-

-            Args:

-                expected_room_id: The room_id expected to be returned by the API. Set

-                    to None to expect zero results for the search

-                search_term: The term to search for room names with

-                expected_http_code: The expected http code for the request

-            """

-            url = "/_synapse/admin/v1/rooms?search_term=%s" % (search_term,)

-            request, channel = self.make_request(

-                "GET", url.encode("ascii"), access_token=self.admin_user_tok,

-            )

-            self.render(request)

-            self.assertEqual(expected_http_code, channel.code, msg=channel.json_body)

-

-            if expected_http_code != 200:

-                return

-

-            # Check that rooms were returned

-            self.assertTrue("rooms" in channel.json_body)

-            rooms = channel.json_body["rooms"]

-

-            # Check that the expected number of rooms were returned

-            expected_room_count = 1 if expected_room_id else 0

-            self.assertEqual(len(rooms), expected_room_count)

-            self.assertEqual(channel.json_body["total_rooms"], expected_room_count)

-

-            # Check that the offset is correct

-            # We're not paginating, so should be 0

-            self.assertEqual(channel.json_body["offset"], 0)

-

-            # Check that there is no `prev_batch`

-            self.assertNotIn("prev_batch", channel.json_body)

-

-            # Check that there is no `next_batch`

-            self.assertNotIn("next_batch", channel.json_body)

-

-            if expected_room_id:

-                # Check that the first returned room id is correct

-                r = rooms[0]

-                self.assertEqual(expected_room_id, r["room_id"])

-

-        # Perform search tests

-        _search_test(room_id_1, "something")

-        _search_test(room_id_1, "thing")

-

-        _search_test(room_id_2, "else")

-        _search_test(room_id_2, "se")

-

-        _search_test(None, "foo")

-        _search_test(None, "bar")

-        _search_test(None, "", expected_http_code=400)

-

-    def test_single_room(self):

-        """Test that a single room can be requested correctly"""

-        # Create two test rooms

-        room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)

-        room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)

-

-        room_name_1 = "something"

-        room_name_2 = "else"

-

-        # Set the name for each room

-        self.helper.send_state(

-            room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok,

-        )

-        self.helper.send_state(

-            room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok,

-        )

-

-        url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,)

-        request, channel = self.make_request(

-            "GET", url.encode("ascii"), access_token=self.admin_user_tok,

-        )

-        self.render(request)

-        self.assertEqual(200, channel.code, msg=channel.json_body)

-

-        self.assertIn("room_id", channel.json_body)

-        self.assertIn("name", channel.json_body)

-        self.assertIn("canonical_alias", channel.json_body)

-        self.assertIn("joined_members", channel.json_body)

-        self.assertIn("joined_local_members", channel.json_body)

-        self.assertIn("version", channel.json_body)

-        self.assertIn("creator", channel.json_body)

-        self.assertIn("encryption", channel.json_body)

-        self.assertIn("federatable", channel.json_body)

-        self.assertIn("public", channel.json_body)

-        self.assertIn("join_rules", channel.json_body)

-        self.assertIn("guest_access", channel.json_body)

-        self.assertIn("history_visibility", channel.json_body)

-        self.assertIn("state_events", channel.json_body)

-

-        self.assertEqual(room_id_1, channel.json_body["room_id"])

-

-    def test_room_members(self):

-        """Test that room members can be requested correctly"""

-        # Create two test rooms

-        room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)

-        room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)

-

-        # Have another user join the room

-        user_1 = self.register_user("foo", "pass")

-        user_tok_1 = self.login("foo", "pass")

-        self.helper.join(room_id_1, user_1, tok=user_tok_1)

-

-        # Have another user join the room

-        user_2 = self.register_user("bar", "pass")

-        user_tok_2 = self.login("bar", "pass")

-        self.helper.join(room_id_1, user_2, tok=user_tok_2)

-        self.helper.join(room_id_2, user_2, tok=user_tok_2)

-

-        # Have another user join the room

-        user_3 = self.register_user("foobar", "pass")

-        user_tok_3 = self.login("foobar", "pass")

-        self.helper.join(room_id_2, user_3, tok=user_tok_3)

-

-        url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_1,)

-        request, channel = self.make_request(

-            "GET", url.encode("ascii"), access_token=self.admin_user_tok,

-        )

-        self.render(request)

-        self.assertEqual(200, channel.code, msg=channel.json_body)

-

-        self.assertCountEqual(

-            ["@admin:test", "@foo:test", "@bar:test"], channel.json_body["members"]

-        )

-        self.assertEqual(channel.json_body["total"], 3)

-

-        url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_2,)

-        request, channel = self.make_request(

-            "GET", url.encode("ascii"), access_token=self.admin_user_tok,

-        )

-        self.render(request)

-        self.assertEqual(200, channel.code, msg=channel.json_body)

-

-        self.assertCountEqual(

-            ["@admin:test", "@bar:test", "@foobar:test"], channel.json_body["members"]

-        )

-        self.assertEqual(channel.json_body["total"], 3)

-

-

-class JoinAliasRoomTestCase(unittest.HomeserverTestCase):

-

-    servlets = [

-        synapse.rest.admin.register_servlets,

-        room.register_servlets,

-        login.register_servlets,

-    ]

-

-    def prepare(self, reactor, clock, homeserver):

-        self.admin_user = self.register_user("admin", "pass", admin=True)

-        self.admin_user_tok = self.login("admin", "pass")

-

-        self.creator = self.register_user("creator", "test")

-        self.creator_tok = self.login("creator", "test")

-

-        self.second_user_id = self.register_user("second", "test")

-        self.second_tok = self.login("second", "test")

-

-        self.public_room_id = self.helper.create_room_as(

-            self.creator, tok=self.creator_tok, is_public=True

-        )

-        self.url = "/_synapse/admin/v1/join/{}".format(self.public_room_id)

-

-    def test_requester_is_no_admin(self):

-        """

-        If the user is not a server admin, an error 403 is returned.

-        """

-        body = json.dumps({"user_id": self.second_user_id})

-

-        request, channel = self.make_request(

-            "POST",

-            self.url,

-            content=body.encode(encoding="utf_8"),

-            access_token=self.second_tok,

-        )

-        self.render(request)

-

-        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])

-        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])

-

-    def test_invalid_parameter(self):

-        """

-        If a parameter is missing, return an error

-        """

-        body = json.dumps({"unknown_parameter": "@unknown:test"})

-

-        request, channel = self.make_request(

-            "POST",

-            self.url,

-            content=body.encode(encoding="utf_8"),

-            access_token=self.admin_user_tok,

-        )

-        self.render(request)

-

-        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])

-        self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])

-

-    def test_local_user_does_not_exist(self):

-        """

-        Tests that a lookup for a user that does not exist returns a 404

-        """

-        body = json.dumps({"user_id": "@unknown:test"})

-

-        request, channel = self.make_request(

-            "POST",

-            self.url,

-            content=body.encode(encoding="utf_8"),

-            access_token=self.admin_user_tok,

-        )

-        self.render(request)

-

-        self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])

-        self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])

-

-    def test_remote_user(self):

-        """

-        Check that only local user can join rooms.

-        """

-        body = json.dumps({"user_id": "@not:exist.bla"})

-

-        request, channel = self.make_request(

-            "POST",

-            self.url,

-            content=body.encode(encoding="utf_8"),

-            access_token=self.admin_user_tok,

-        )

-        self.render(request)

-

-        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])

-        self.assertEqual(

-            "This endpoint can only be used with local users",

-            channel.json_body["error"],

-        )

-

-    def test_room_does_not_exist(self):

-        """

-        Check that unknown rooms/server return error 404.

-        """

-        body = json.dumps({"user_id": self.second_user_id})

-        url = "/_synapse/admin/v1/join/!unknown:test"

-

-        request, channel = self.make_request(

-            "POST",

-            url,

-            content=body.encode(encoding="utf_8"),

-            access_token=self.admin_user_tok,

-        )

-        self.render(request)

-

-        self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])

-        self.assertEqual("No known servers", channel.json_body["error"])

-

-    def test_room_is_not_valid(self):

-        """

-        Check that invalid room names, return an error 400.

-        """

-        body = json.dumps({"user_id": self.second_user_id})

-        url = "/_synapse/admin/v1/join/invalidroom"

-

-        request, channel = self.make_request(

-            "POST",

-            url,

-            content=body.encode(encoding="utf_8"),

-            access_token=self.admin_user_tok,

-        )

-        self.render(request)

-

-        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])

-        self.assertEqual(

-            "invalidroom was not legal room ID or room alias",

-            channel.json_body["error"],

-        )

-

-    def test_join_public_room(self):

-        """

-        Test joining a local user to a public room with "JoinRules.PUBLIC"

-        """

-        body = json.dumps({"user_id": self.second_user_id})

-

-        request, channel = self.make_request(

-            "POST",

-            self.url,

-            content=body.encode(encoding="utf_8"),

-            access_token=self.admin_user_tok,

-        )

-        self.render(request)

-

-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])

-        self.assertEqual(self.public_room_id, channel.json_body["room_id"])

-

-        # Validate if user is a member of the room

-

-        request, channel = self.make_request(

-            "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,

-        )

-        self.render(request)

-        self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])

-        self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0])

-

-    def test_join_private_room_if_not_member(self):

-        """

-        Test joining a local user to a private room with "JoinRules.INVITE"

-        when server admin is not member of this room.

-        """

-        private_room_id = self.helper.create_room_as(

-            self.creator, tok=self.creator_tok, is_public=False

-        )

-        url = "/_synapse/admin/v1/join/{}".format(private_room_id)

-        body = json.dumps({"user_id": self.second_user_id})

-

-        request, channel = self.make_request(

-            "POST",

-            url,

-            content=body.encode(encoding="utf_8"),

-            access_token=self.admin_user_tok,

-        )

-        self.render(request)

-

-        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])

-        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])

-

-    def test_join_private_room_if_member(self):

-        """

-        Test joining a local user to a private room with "JoinRules.INVITE",

-        when server admin is member of this room.

-        """

-        private_room_id = self.helper.create_room_as(

-            self.creator, tok=self.creator_tok, is_public=False

-        )

-        self.helper.invite(

-            room=private_room_id,

-            src=self.creator,

-            targ=self.admin_user,

-            tok=self.creator_tok,

-        )

-        self.helper.join(

-            room=private_room_id, user=self.admin_user, tok=self.admin_user_tok

-        )

-

-        # Validate if server admin is a member of the room

-

-        request, channel = self.make_request(

-            "GET", "/_matrix/client/r0/joined_rooms", access_token=self.admin_user_tok,

-        )

-        self.render(request)

-        self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])

-        self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])

-

-        # Join user to room.

-

-        url = "/_synapse/admin/v1/join/{}".format(private_room_id)

-        body = json.dumps({"user_id": self.second_user_id})

-

-        request, channel = self.make_request(

-            "POST",

-            url,

-            content=body.encode(encoding="utf_8"),

-            access_token=self.admin_user_tok,

-        )

-        self.render(request)

-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])

-        self.assertEqual(private_room_id, channel.json_body["room_id"])

-

-        # Validate if user is a member of the room

-

-        request, channel = self.make_request(

-            "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,

-        )

-        self.render(request)

-        self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])

-        self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])

-

-    def test_join_private_room_if_owner(self):

-        """

-        Test joining a local user to a private room with "JoinRules.INVITE",

-        when server admin is owner of this room.

-        """

-        private_room_id = self.helper.create_room_as(

-            self.admin_user, tok=self.admin_user_tok, is_public=False

-        )

-        url = "/_synapse/admin/v1/join/{}".format(private_room_id)

-        body = json.dumps({"user_id": self.second_user_id})

-

-        request, channel = self.make_request(

-            "POST",

-            url,

-            content=body.encode(encoding="utf_8"),

-            access_token=self.admin_user_tok,

-        )

-        self.render(request)

-

-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])

-        self.assertEqual(private_room_id, channel.json_body["room_id"])

-

-        # Validate if user is a member of the room

-

-        request, channel = self.make_request(

-            "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,

-        )

-        self.render(request)

-        self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])

-        self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])

+# -*- coding: utf-8 -*-
+# Copyright 2020 Dirk Klimpel
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import urllib.parse
+from typing import List, Optional
+
+from mock import Mock
+
+import synapse.rest.admin
+from synapse.api.errors import Codes
+from synapse.rest.client.v1 import directory, events, login, room
+
+from tests import unittest
+
+"""Tests admin REST events for /rooms paths."""
+
+
+class ShutdownRoomTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        synapse.rest.admin.register_servlets_for_client_rest_resource,
+        login.register_servlets,
+        events.register_servlets,
+        room.register_servlets,
+        room.register_deprecated_servlets,
+    ]
+
+    def prepare(self, reactor, clock, hs):
+        self.event_creation_handler = hs.get_event_creation_handler()
+        hs.config.user_consent_version = "1"
+
+        consent_uri_builder = Mock()
+        consent_uri_builder.build_user_consent_uri.return_value = "http://example.com"
+        self.event_creation_handler._consent_uri_builder = consent_uri_builder
+
+        self.store = hs.get_datastore()
+
+        self.admin_user = self.register_user("admin", "pass", admin=True)
+        self.admin_user_tok = self.login("admin", "pass")
+
+        self.other_user = self.register_user("user", "pass")
+        self.other_user_token = self.login("user", "pass")
+
+        # Mark the admin user as having consented
+        self.get_success(self.store.user_set_consent_version(self.admin_user, "1"))
+
+    def test_shutdown_room_consent(self):
+        """Test that we can shutdown rooms with local users who have not
+        yet accepted the privacy policy. This used to fail when we tried to
+        force part the user from the old room.
+        """
+        self.event_creation_handler._block_events_without_consent_error = None
+
+        room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token)
+
+        # Assert one user in room
+        users_in_room = self.get_success(self.store.get_users_in_room(room_id))
+        self.assertEqual([self.other_user], users_in_room)
+
+        # Enable require consent to send events
+        self.event_creation_handler._block_events_without_consent_error = "Error"
+
+        # Assert that the user is getting consent error
+        self.helper.send(
+            room_id, body="foo", tok=self.other_user_token, expect_code=403
+        )
+
+        # Test that the admin can still send shutdown
+        url = "admin/shutdown_room/" + room_id
+        request, channel = self.make_request(
+            "POST",
+            url.encode("ascii"),
+            json.dumps({"new_room_user_id": self.admin_user}),
+            access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+        # Assert there is now no longer anyone in the room
+        users_in_room = self.get_success(self.store.get_users_in_room(room_id))
+        self.assertEqual([], users_in_room)
+
+    def test_shutdown_room_block_peek(self):
+        """Test that a world_readable room can no longer be peeked into after
+        it has been shut down.
+        """
+
+        self.event_creation_handler._block_events_without_consent_error = None
+
+        room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token)
+
+        # Enable world readable
+        url = "rooms/%s/state/m.room.history_visibility" % (room_id,)
+        request, channel = self.make_request(
+            "PUT",
+            url.encode("ascii"),
+            json.dumps({"history_visibility": "world_readable"}),
+            access_token=self.other_user_token,
+        )
+        self.render(request)
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+        # Test that the admin can still send shutdown
+        url = "admin/shutdown_room/" + room_id
+        request, channel = self.make_request(
+            "POST",
+            url.encode("ascii"),
+            json.dumps({"new_room_user_id": self.admin_user}),
+            access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+        # Assert we can no longer peek into the room
+        self._assert_peek(room_id, expect_code=403)
+
+    def _assert_peek(self, room_id, expect_code):
+        """Assert that the admin user can (or cannot) peek into the room.
+        """
+
+        url = "rooms/%s/initialSync" % (room_id,)
+        request, channel = self.make_request(
+            "GET", url.encode("ascii"), access_token=self.admin_user_tok
+        )
+        self.render(request)
+        self.assertEqual(
+            expect_code, int(channel.result["code"]), msg=channel.result["body"]
+        )
+
+        url = "events?timeout=0&room_id=" + room_id
+        request, channel = self.make_request(
+            "GET", url.encode("ascii"), access_token=self.admin_user_tok
+        )
+        self.render(request)
+        self.assertEqual(
+            expect_code, int(channel.result["code"]), msg=channel.result["body"]
+        )
+
+
+class DeleteRoomTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        synapse.rest.admin.register_servlets,
+        login.register_servlets,
+        events.register_servlets,
+        room.register_servlets,
+        room.register_deprecated_servlets,
+    ]
+
+    def prepare(self, reactor, clock, hs):
+        self.event_creation_handler = hs.get_event_creation_handler()
+        hs.config.user_consent_version = "1"
+
+        consent_uri_builder = Mock()
+        consent_uri_builder.build_user_consent_uri.return_value = "http://example.com"
+        self.event_creation_handler._consent_uri_builder = consent_uri_builder
+
+        self.store = hs.get_datastore()
+
+        self.admin_user = self.register_user("admin", "pass", admin=True)
+        self.admin_user_tok = self.login("admin", "pass")
+
+        self.other_user = self.register_user("user", "pass")
+        self.other_user_tok = self.login("user", "pass")
+
+        # Mark the admin user as having consented
+        self.get_success(self.store.user_set_consent_version(self.admin_user, "1"))
+
+        self.room_id = self.helper.create_room_as(
+            self.other_user, tok=self.other_user_tok
+        )
+        self.url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id
+
+    def test_requester_is_no_admin(self):
+        """
+        If the user is not a server admin, an error 403 is returned.
+        """
+
+        request, channel = self.make_request(
+            "POST", self.url, json.dumps({}), access_token=self.other_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+    def test_room_does_not_exist(self):
+        """
+        Check that unknown rooms/server return error 404.
+        """
+        url = "/_synapse/admin/v1/rooms/!unknown:test/delete"
+
+        request, channel = self.make_request(
+            "POST", url, json.dumps({}), access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+    def test_room_is_not_valid(self):
+        """
+        Check that invalid room names, return an error 400.
+        """
+        url = "/_synapse/admin/v1/rooms/invalidroom/delete"
+
+        request, channel = self.make_request(
+            "POST", url, json.dumps({}), access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(
+            "invalidroom is not a legal room ID", channel.json_body["error"],
+        )
+
+    def test_new_room_user_does_not_exist(self):
+        """
+        Tests that the user ID must be from local server but it does not have to exist.
+        """
+        body = json.dumps({"new_room_user_id": "@unknown:test"})
+
+        request, channel = self.make_request(
+            "POST",
+            self.url,
+            content=body.encode(encoding="utf_8"),
+            access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertIn("new_room_id", channel.json_body)
+        self.assertIn("kicked_users", channel.json_body)
+        self.assertIn("failed_to_kick_users", channel.json_body)
+        self.assertIn("local_aliases", channel.json_body)
+
+    def test_new_room_user_is_not_local(self):
+        """
+        Check that only local users can create new room to move members.
+        """
+        body = json.dumps({"new_room_user_id": "@not:exist.bla"})
+
+        request, channel = self.make_request(
+            "POST",
+            self.url,
+            content=body.encode(encoding="utf_8"),
+            access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(
+            "User must be our own: @not:exist.bla", channel.json_body["error"],
+        )
+
+    def test_block_is_not_bool(self):
+        """
+        If parameter `block` is not boolean, return an error
+        """
+        body = json.dumps({"block": "NotBool"})
+
+        request, channel = self.make_request(
+            "POST",
+            self.url,
+            content=body.encode(encoding="utf_8"),
+            access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
+
+    def test_purge_is_not_bool(self):
+        """
+        If parameter `purge` is not boolean, return an error
+        """
+        body = json.dumps({"purge": "NotBool"})
+
+        request, channel = self.make_request(
+            "POST",
+            self.url,
+            content=body.encode(encoding="utf_8"),
+            access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
+
+    def test_purge_room_and_block(self):
+        """Test to purge a room and block it.
+        Members will not be moved to a new room and will not receive a message.
+        """
+        # Test that room is not purged
+        with self.assertRaises(AssertionError):
+            self._is_purged(self.room_id)
+
+        # Test that room is not blocked
+        self._is_blocked(self.room_id, expect=False)
+
+        # Assert one user in room
+        self._is_member(room_id=self.room_id, user_id=self.other_user)
+
+        body = json.dumps({"block": True, "purge": True})
+
+        request, channel = self.make_request(
+            "POST",
+            self.url.encode("ascii"),
+            content=body.encode(encoding="utf_8"),
+            access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(None, channel.json_body["new_room_id"])
+        self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
+        self.assertIn("failed_to_kick_users", channel.json_body)
+        self.assertIn("local_aliases", channel.json_body)
+
+        self._is_purged(self.room_id)
+        self._is_blocked(self.room_id, expect=True)
+        self._has_no_members(self.room_id)
+
+    def test_purge_room_and_not_block(self):
+        """Test to purge a room and do not block it.
+        Members will not be moved to a new room and will not receive a message.
+        """
+        # Test that room is not purged
+        with self.assertRaises(AssertionError):
+            self._is_purged(self.room_id)
+
+        # Test that room is not blocked
+        self._is_blocked(self.room_id, expect=False)
+
+        # Assert one user in room
+        self._is_member(room_id=self.room_id, user_id=self.other_user)
+
+        body = json.dumps({"block": False, "purge": True})
+
+        request, channel = self.make_request(
+            "POST",
+            self.url.encode("ascii"),
+            content=body.encode(encoding="utf_8"),
+            access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(None, channel.json_body["new_room_id"])
+        self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
+        self.assertIn("failed_to_kick_users", channel.json_body)
+        self.assertIn("local_aliases", channel.json_body)
+
+        self._is_purged(self.room_id)
+        self._is_blocked(self.room_id, expect=False)
+        self._has_no_members(self.room_id)
+
+    def test_block_room_and_not_purge(self):
+        """Test to block a room without purging it.
+        Members will not be moved to a new room and will not receive a message.
+        The room will not be purged.
+        """
+        # Test that room is not purged
+        with self.assertRaises(AssertionError):
+            self._is_purged(self.room_id)
+
+        # Test that room is not blocked
+        self._is_blocked(self.room_id, expect=False)
+
+        # Assert one user in room
+        self._is_member(room_id=self.room_id, user_id=self.other_user)
+
+        body = json.dumps({"block": False, "purge": False})
+
+        request, channel = self.make_request(
+            "POST",
+            self.url.encode("ascii"),
+            content=body.encode(encoding="utf_8"),
+            access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(None, channel.json_body["new_room_id"])
+        self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
+        self.assertIn("failed_to_kick_users", channel.json_body)
+        self.assertIn("local_aliases", channel.json_body)
+
+        with self.assertRaises(AssertionError):
+            self._is_purged(self.room_id)
+        self._is_blocked(self.room_id, expect=False)
+        self._has_no_members(self.room_id)
+
+    def test_shutdown_room_consent(self):
+        """Test that we can shutdown rooms with local users who have not
+        yet accepted the privacy policy. This used to fail when we tried to
+        force part the user from the old room.
+        Members will be moved to a new room and will receive a message.
+        """
+        self.event_creation_handler._block_events_without_consent_error = None
+
+        # Assert one user in room
+        users_in_room = self.get_success(self.store.get_users_in_room(self.room_id))
+        self.assertEqual([self.other_user], users_in_room)
+
+        # Enable require consent to send events
+        self.event_creation_handler._block_events_without_consent_error = "Error"
+
+        # Assert that the user is getting consent error
+        self.helper.send(
+            self.room_id, body="foo", tok=self.other_user_tok, expect_code=403
+        )
+
+        # Test that room is not purged
+        with self.assertRaises(AssertionError):
+            self._is_purged(self.room_id)
+
+        # Assert one user in room
+        self._is_member(room_id=self.room_id, user_id=self.other_user)
+
+        # Test that the admin can still send shutdown
+        url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id
+        request, channel = self.make_request(
+            "POST",
+            url.encode("ascii"),
+            json.dumps({"new_room_user_id": self.admin_user}),
+            access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
+        self.assertIn("new_room_id", channel.json_body)
+        self.assertIn("failed_to_kick_users", channel.json_body)
+        self.assertIn("local_aliases", channel.json_body)
+
+        # Test that member has moved to new room
+        self._is_member(
+            room_id=channel.json_body["new_room_id"], user_id=self.other_user
+        )
+
+        self._is_purged(self.room_id)
+        self._has_no_members(self.room_id)
+
+    def test_shutdown_room_block_peek(self):
+        """Test that a world_readable room can no longer be peeked into after
+        it has been shut down.
+        Members will be moved to a new room and will receive a message.
+        """
+        self.event_creation_handler._block_events_without_consent_error = None
+
+        # Enable world readable
+        url = "rooms/%s/state/m.room.history_visibility" % (self.room_id,)
+        request, channel = self.make_request(
+            "PUT",
+            url.encode("ascii"),
+            json.dumps({"history_visibility": "world_readable"}),
+            access_token=self.other_user_tok,
+        )
+        self.render(request)
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+        # Test that room is not purged
+        with self.assertRaises(AssertionError):
+            self._is_purged(self.room_id)
+
+        # Assert one user in room
+        self._is_member(room_id=self.room_id, user_id=self.other_user)
+
+        # Test that the admin can still send shutdown
+        url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id
+        request, channel = self.make_request(
+            "POST",
+            url.encode("ascii"),
+            json.dumps({"new_room_user_id": self.admin_user}),
+            access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
+        self.assertIn("new_room_id", channel.json_body)
+        self.assertIn("failed_to_kick_users", channel.json_body)
+        self.assertIn("local_aliases", channel.json_body)
+
+        # Test that member has moved to new room
+        self._is_member(
+            room_id=channel.json_body["new_room_id"], user_id=self.other_user
+        )
+
+        self._is_purged(self.room_id)
+        self._has_no_members(self.room_id)
+
+        # Assert we can no longer peek into the room
+        self._assert_peek(self.room_id, expect_code=403)
+
+    def _is_blocked(self, room_id, expect=True):
+        """Assert that the room is blocked or not
+        """
+        d = self.store.is_room_blocked(room_id)
+        if expect:
+            self.assertTrue(self.get_success(d))
+        else:
+            self.assertIsNone(self.get_success(d))
+
+    def _has_no_members(self, room_id):
+        """Assert there is now no longer anyone in the room
+        """
+        users_in_room = self.get_success(self.store.get_users_in_room(room_id))
+        self.assertEqual([], users_in_room)
+
+    def _is_member(self, room_id, user_id):
+        """Test that user is member of the room
+        """
+        users_in_room = self.get_success(self.store.get_users_in_room(room_id))
+        self.assertIn(user_id, users_in_room)
+
+    def _is_purged(self, room_id):
+        """Test that the following tables have been purged of all rows related to the room.
+        """
+        for table in (
+            "current_state_events",
+            "event_backward_extremities",
+            "event_forward_extremities",
+            "event_json",
+            "event_push_actions",
+            "event_search",
+            "events",
+            "group_rooms",
+            "public_room_list_stream",
+            "receipts_graph",
+            "receipts_linearized",
+            "room_aliases",
+            "room_depth",
+            "room_memberships",
+            "room_stats_state",
+            "room_stats_current",
+            "room_stats_historical",
+            "room_stats_earliest_token",
+            "rooms",
+            "stream_ordering_to_exterm",
+            "users_in_public_rooms",
+            "users_who_share_private_rooms",
+            "appservice_room_list",
+            "e2e_room_keys",
+            "event_push_summary",
+            "pusher_throttle",
+            "group_summary_rooms",
+            "local_invites",
+            "room_account_data",
+            "room_tags",
+            # "state_groups",  # Current impl leaves orphaned state groups around.
+            "state_groups_state",
+        ):
+            count = self.get_success(
+                self.store.db_pool.simple_select_one_onecol(
+                    table=table,
+                    keyvalues={"room_id": room_id},
+                    retcol="COUNT(*)",
+                    desc="test_purge_room",
+                )
+            )
+
+            self.assertEqual(count, 0, msg="Rows not purged in {}".format(table))
+
+    def _assert_peek(self, room_id, expect_code):
+        """Assert that the admin user can (or cannot) peek into the room.
+        """
+
+        url = "rooms/%s/initialSync" % (room_id,)
+        request, channel = self.make_request(
+            "GET", url.encode("ascii"), access_token=self.admin_user_tok
+        )
+        self.render(request)
+        self.assertEqual(
+            expect_code, int(channel.result["code"]), msg=channel.result["body"]
+        )
+
+        url = "events?timeout=0&room_id=" + room_id
+        request, channel = self.make_request(
+            "GET", url.encode("ascii"), access_token=self.admin_user_tok
+        )
+        self.render(request)
+        self.assertEqual(
+            expect_code, int(channel.result["code"]), msg=channel.result["body"]
+        )
+
+
+class PurgeRoomTestCase(unittest.HomeserverTestCase):
+    """Test /purge_room admin API.
+    """
+
+    servlets = [
+        synapse.rest.admin.register_servlets,
+        login.register_servlets,
+        room.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, hs):
+        self.store = hs.get_datastore()
+
+        self.admin_user = self.register_user("admin", "pass", admin=True)
+        self.admin_user_tok = self.login("admin", "pass")
+
+    def test_purge_room(self):
+        room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+        # All users have to have left the room.
+        self.helper.leave(room_id, user=self.admin_user, tok=self.admin_user_tok)
+
+        url = "/_synapse/admin/v1/purge_room"
+        request, channel = self.make_request(
+            "POST",
+            url.encode("ascii"),
+            {"room_id": room_id},
+            access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+        # Test that the following tables have been purged of all rows related to the room.
+        for table in (
+            "current_state_events",
+            "event_backward_extremities",
+            "event_forward_extremities",
+            "event_json",
+            "event_push_actions",
+            "event_search",
+            "events",
+            "group_rooms",
+            "public_room_list_stream",
+            "receipts_graph",
+            "receipts_linearized",
+            "room_aliases",
+            "room_depth",
+            "room_memberships",
+            "room_stats_state",
+            "room_stats_current",
+            "room_stats_historical",
+            "room_stats_earliest_token",
+            "rooms",
+            "stream_ordering_to_exterm",
+            "users_in_public_rooms",
+            "users_who_share_private_rooms",
+            "appservice_room_list",
+            "e2e_room_keys",
+            "event_push_summary",
+            "pusher_throttle",
+            "group_summary_rooms",
+            "room_account_data",
+            "room_tags",
+            # "state_groups",  # Current impl leaves orphaned state groups around.
+            "state_groups_state",
+        ):
+            count = self.get_success(
+                self.store.db_pool.simple_select_one_onecol(
+                    table=table,
+                    keyvalues={"room_id": room_id},
+                    retcol="COUNT(*)",
+                    desc="test_purge_room",
+                )
+            )
+
+            self.assertEqual(count, 0, msg="Rows not purged in {}".format(table))
+
+
+class RoomTestCase(unittest.HomeserverTestCase):
+    """Test /room admin API.
+    """
+
+    servlets = [
+        synapse.rest.admin.register_servlets,
+        login.register_servlets,
+        room.register_servlets,
+        directory.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, hs):
+        self.store = hs.get_datastore()
+
+        # Create user
+        self.admin_user = self.register_user("admin", "pass", admin=True)
+        self.admin_user_tok = self.login("admin", "pass")
+
+    def test_list_rooms(self):
+        """Test that we can list rooms"""
+        # Create 3 test rooms
+        total_rooms = 3
+        room_ids = []
+        for x in range(total_rooms):
+            room_id = self.helper.create_room_as(
+                self.admin_user, tok=self.admin_user_tok
+            )
+            room_ids.append(room_id)
+
+        # Request the list of rooms
+        url = "/_synapse/admin/v1/rooms"
+        request, channel = self.make_request(
+            "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        # Check request completed successfully
+        self.assertEqual(200, int(channel.code), msg=channel.json_body)
+
+        # Check that response json body contains a "rooms" key
+        self.assertTrue(
+            "rooms" in channel.json_body,
+            msg="Response body does not " "contain a 'rooms' key",
+        )
+
+        # Check that 3 rooms were returned
+        self.assertEqual(3, len(channel.json_body["rooms"]), msg=channel.json_body)
+
+        # Check their room_ids match
+        returned_room_ids = [room["room_id"] for room in channel.json_body["rooms"]]
+        self.assertEqual(room_ids, returned_room_ids)
+
+        # Check that all fields are available
+        for r in channel.json_body["rooms"]:
+            self.assertIn("name", r)
+            self.assertIn("canonical_alias", r)
+            self.assertIn("joined_members", r)
+            self.assertIn("joined_local_members", r)
+            self.assertIn("version", r)
+            self.assertIn("creator", r)
+            self.assertIn("encryption", r)
+            self.assertIn("federatable", r)
+            self.assertIn("public", r)
+            self.assertIn("join_rules", r)
+            self.assertIn("guest_access", r)
+            self.assertIn("history_visibility", r)
+            self.assertIn("state_events", r)
+
+        # Check that the correct number of total rooms was returned
+        self.assertEqual(channel.json_body["total_rooms"], total_rooms)
+
+        # Check that the offset is correct
+        # Should be 0 as we aren't paginating
+        self.assertEqual(channel.json_body["offset"], 0)
+
+        # Check that the prev_batch parameter is not present
+        self.assertNotIn("prev_batch", channel.json_body)
+
+        # We shouldn't receive a next token here as there's no further rooms to show
+        self.assertNotIn("next_batch", channel.json_body)
+
+    def test_list_rooms_pagination(self):
+        """Test that we can get a full list of rooms through pagination"""
+        # Create 5 test rooms
+        total_rooms = 5
+        room_ids = []
+        for x in range(total_rooms):
+            room_id = self.helper.create_room_as(
+                self.admin_user, tok=self.admin_user_tok
+            )
+            room_ids.append(room_id)
+
+        # Set the name of the rooms so we get a consistent returned ordering
+        for idx, room_id in enumerate(room_ids):
+            self.helper.send_state(
+                room_id, "m.room.name", {"name": str(idx)}, tok=self.admin_user_tok,
+            )
+
+        # Request the list of rooms
+        returned_room_ids = []
+        start = 0
+        limit = 2
+
+        run_count = 0
+        should_repeat = True
+        while should_repeat:
+            run_count += 1
+
+            url = "/_synapse/admin/v1/rooms?from=%d&limit=%d&order_by=%s" % (
+                start,
+                limit,
+                "name",
+            )
+            request, channel = self.make_request(
+                "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+            )
+            self.render(request)
+            self.assertEqual(
+                200, int(channel.result["code"]), msg=channel.result["body"]
+            )
+
+            self.assertTrue("rooms" in channel.json_body)
+            for r in channel.json_body["rooms"]:
+                returned_room_ids.append(r["room_id"])
+
+            # Check that the correct number of total rooms was returned
+            self.assertEqual(channel.json_body["total_rooms"], total_rooms)
+
+            # Check that the offset is correct
+            # We're only getting 2 rooms each page, so should be 2 * last run_count
+            self.assertEqual(channel.json_body["offset"], 2 * (run_count - 1))
+
+            if run_count > 1:
+                # Check the value of prev_batch is correct
+                self.assertEqual(channel.json_body["prev_batch"], 2 * (run_count - 2))
+
+            if "next_batch" not in channel.json_body:
+                # We have reached the end of the list
+                should_repeat = False
+            else:
+                # Make another query with an updated start value
+                start = channel.json_body["next_batch"]
+
+        # We should've queried the endpoint 3 times
+        self.assertEqual(
+            run_count,
+            3,
+            msg="Should've queried 3 times for 5 rooms with limit 2 per query",
+        )
+
+        # Check that we received all of the room ids
+        self.assertEqual(room_ids, returned_room_ids)
+
+        url = "/_synapse/admin/v1/rooms?from=%d&limit=%d" % (start, limit)
+        request, channel = self.make_request(
+            "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+        )
+        self.render(request)
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+    def test_correct_room_attributes(self):
+        """Test the correct attributes for a room are returned"""
+        # Create a test room
+        room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+        test_alias = "#test:test"
+        test_room_name = "something"
+
+        # Have another user join the room
+        user_2 = self.register_user("user4", "pass")
+        user_tok_2 = self.login("user4", "pass")
+        self.helper.join(room_id, user_2, tok=user_tok_2)
+
+        # Create a new alias to this room
+        url = "/_matrix/client/r0/directory/room/%s" % (urllib.parse.quote(test_alias),)
+        request, channel = self.make_request(
+            "PUT",
+            url.encode("ascii"),
+            {"room_id": room_id},
+            access_token=self.admin_user_tok,
+        )
+        self.render(request)
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+        # Set this new alias as the canonical alias for this room
+        self.helper.send_state(
+            room_id,
+            "m.room.aliases",
+            {"aliases": [test_alias]},
+            tok=self.admin_user_tok,
+            state_key="test",
+        )
+        self.helper.send_state(
+            room_id,
+            "m.room.canonical_alias",
+            {"alias": test_alias},
+            tok=self.admin_user_tok,
+        )
+
+        # Set a name for the room
+        self.helper.send_state(
+            room_id, "m.room.name", {"name": test_room_name}, tok=self.admin_user_tok,
+        )
+
+        # Request the list of rooms
+        url = "/_synapse/admin/v1/rooms"
+        request, channel = self.make_request(
+            "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+        )
+        self.render(request)
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+        # Check that rooms were returned
+        self.assertTrue("rooms" in channel.json_body)
+        rooms = channel.json_body["rooms"]
+
+        # Check that only one room was returned
+        self.assertEqual(len(rooms), 1)
+
+        # And that the value of the total_rooms key was correct
+        self.assertEqual(channel.json_body["total_rooms"], 1)
+
+        # Check that the offset is correct
+        # We're not paginating, so should be 0
+        self.assertEqual(channel.json_body["offset"], 0)
+
+        # Check that there is no `prev_batch`
+        self.assertNotIn("prev_batch", channel.json_body)
+
+        # Check that there is no `next_batch`
+        self.assertNotIn("next_batch", channel.json_body)
+
+        # Check that all provided attributes are set
+        r = rooms[0]
+        self.assertEqual(room_id, r["room_id"])
+        self.assertEqual(test_room_name, r["name"])
+        self.assertEqual(test_alias, r["canonical_alias"])
+
+    def test_room_list_sort_order(self):
+        """Test room list sort ordering. alphabetical name versus number of members,
+        reversing the order, etc.
+        """
+
+        def _set_canonical_alias(room_id: str, test_alias: str, admin_user_tok: str):
+            # Create a new alias to this room
+            url = "/_matrix/client/r0/directory/room/%s" % (
+                urllib.parse.quote(test_alias),
+            )
+            request, channel = self.make_request(
+                "PUT",
+                url.encode("ascii"),
+                {"room_id": room_id},
+                access_token=admin_user_tok,
+            )
+            self.render(request)
+            self.assertEqual(
+                200, int(channel.result["code"]), msg=channel.result["body"]
+            )
+
+            # Set this new alias as the canonical alias for this room
+            self.helper.send_state(
+                room_id,
+                "m.room.aliases",
+                {"aliases": [test_alias]},
+                tok=admin_user_tok,
+                state_key="test",
+            )
+            self.helper.send_state(
+                room_id,
+                "m.room.canonical_alias",
+                {"alias": test_alias},
+                tok=admin_user_tok,
+            )
+
+        def _order_test(
+            order_type: str, expected_room_list: List[str], reverse: bool = False,
+        ):
+            """Request the list of rooms in a certain order. Assert that order is what
+            we expect
+
+            Args:
+                order_type: The type of ordering to give the server
+                expected_room_list: The list of room_ids in the order we expect to get
+                    back from the server
+            """
+            # Request the list of rooms in the given order
+            url = "/_synapse/admin/v1/rooms?order_by=%s" % (order_type,)
+            if reverse:
+                url += "&dir=b"
+            request, channel = self.make_request(
+                "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+            )
+            self.render(request)
+            self.assertEqual(200, channel.code, msg=channel.json_body)
+
+            # Check that rooms were returned
+            self.assertTrue("rooms" in channel.json_body)
+            rooms = channel.json_body["rooms"]
+
+            # Check for the correct total_rooms value
+            self.assertEqual(channel.json_body["total_rooms"], 3)
+
+            # Check that the offset is correct
+            # We're not paginating, so should be 0
+            self.assertEqual(channel.json_body["offset"], 0)
+
+            # Check that there is no `prev_batch`
+            self.assertNotIn("prev_batch", channel.json_body)
+
+            # Check that there is no `next_batch`
+            self.assertNotIn("next_batch", channel.json_body)
+
+            # Check that rooms were returned in alphabetical order
+            returned_order = [r["room_id"] for r in rooms]
+            self.assertListEqual(expected_room_list, returned_order)  # order is checked
+
+        # Create 3 test rooms
+        room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+        room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+        room_id_3 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+        # Set room names in alphabetical order. room 1 -> A, 2 -> B, 3 -> C
+        self.helper.send_state(
+            room_id_1, "m.room.name", {"name": "A"}, tok=self.admin_user_tok,
+        )
+        self.helper.send_state(
+            room_id_2, "m.room.name", {"name": "B"}, tok=self.admin_user_tok,
+        )
+        self.helper.send_state(
+            room_id_3, "m.room.name", {"name": "C"}, tok=self.admin_user_tok,
+        )
+
+        # Set room canonical room aliases
+        _set_canonical_alias(room_id_1, "#A_alias:test", self.admin_user_tok)
+        _set_canonical_alias(room_id_2, "#B_alias:test", self.admin_user_tok)
+        _set_canonical_alias(room_id_3, "#C_alias:test", self.admin_user_tok)
+
+        # Set room member size in the reverse order. room 1 -> 1 member, 2 -> 2, 3 -> 3
+        user_1 = self.register_user("bob1", "pass")
+        user_1_tok = self.login("bob1", "pass")
+        self.helper.join(room_id_2, user_1, tok=user_1_tok)
+
+        user_2 = self.register_user("bob2", "pass")
+        user_2_tok = self.login("bob2", "pass")
+        self.helper.join(room_id_3, user_2, tok=user_2_tok)
+
+        user_3 = self.register_user("bob3", "pass")
+        user_3_tok = self.login("bob3", "pass")
+        self.helper.join(room_id_3, user_3, tok=user_3_tok)
+
+        # Test different sort orders, with forward and reverse directions
+        _order_test("name", [room_id_1, room_id_2, room_id_3])
+        _order_test("name", [room_id_3, room_id_2, room_id_1], reverse=True)
+
+        _order_test("canonical_alias", [room_id_1, room_id_2, room_id_3])
+        _order_test("canonical_alias", [room_id_3, room_id_2, room_id_1], reverse=True)
+
+        _order_test("joined_members", [room_id_3, room_id_2, room_id_1])
+        _order_test("joined_members", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+        _order_test("joined_local_members", [room_id_3, room_id_2, room_id_1])
+        _order_test(
+            "joined_local_members", [room_id_1, room_id_2, room_id_3], reverse=True
+        )
+
+        _order_test("version", [room_id_1, room_id_2, room_id_3])
+        _order_test("version", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+        _order_test("creator", [room_id_1, room_id_2, room_id_3])
+        _order_test("creator", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+        _order_test("encryption", [room_id_1, room_id_2, room_id_3])
+        _order_test("encryption", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+        _order_test("federatable", [room_id_1, room_id_2, room_id_3])
+        _order_test("federatable", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+        _order_test("public", [room_id_1, room_id_2, room_id_3])
+        # Different sort order of SQlite and PostreSQL
+        # _order_test("public", [room_id_3, room_id_2, room_id_1], reverse=True)
+
+        _order_test("join_rules", [room_id_1, room_id_2, room_id_3])
+        _order_test("join_rules", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+        _order_test("guest_access", [room_id_1, room_id_2, room_id_3])
+        _order_test("guest_access", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+        _order_test("history_visibility", [room_id_1, room_id_2, room_id_3])
+        _order_test(
+            "history_visibility", [room_id_1, room_id_2, room_id_3], reverse=True
+        )
+
+        _order_test("state_events", [room_id_3, room_id_2, room_id_1])
+        _order_test("state_events", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+    def test_search_term(self):
+        """Test that searching for a room works correctly"""
+        # Create two test rooms
+        room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+        room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+        room_name_1 = "something"
+        room_name_2 = "else"
+
+        # Set the name for each room
+        self.helper.send_state(
+            room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok,
+        )
+        self.helper.send_state(
+            room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok,
+        )
+
+        def _search_test(
+            expected_room_id: Optional[str],
+            search_term: str,
+            expected_http_code: int = 200,
+        ):
+            """Search for a room and check that the returned room's id is a match
+
+            Args:
+                expected_room_id: The room_id expected to be returned by the API. Set
+                    to None to expect zero results for the search
+                search_term: The term to search for room names with
+                expected_http_code: The expected http code for the request
+            """
+            url = "/_synapse/admin/v1/rooms?search_term=%s" % (search_term,)
+            request, channel = self.make_request(
+                "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+            )
+            self.render(request)
+            self.assertEqual(expected_http_code, channel.code, msg=channel.json_body)
+
+            if expected_http_code != 200:
+                return
+
+            # Check that rooms were returned
+            self.assertTrue("rooms" in channel.json_body)
+            rooms = channel.json_body["rooms"]
+
+            # Check that the expected number of rooms were returned
+            expected_room_count = 1 if expected_room_id else 0
+            self.assertEqual(len(rooms), expected_room_count)
+            self.assertEqual(channel.json_body["total_rooms"], expected_room_count)
+
+            # Check that the offset is correct
+            # We're not paginating, so should be 0
+            self.assertEqual(channel.json_body["offset"], 0)
+
+            # Check that there is no `prev_batch`
+            self.assertNotIn("prev_batch", channel.json_body)
+
+            # Check that there is no `next_batch`
+            self.assertNotIn("next_batch", channel.json_body)
+
+            if expected_room_id:
+                # Check that the first returned room id is correct
+                r = rooms[0]
+                self.assertEqual(expected_room_id, r["room_id"])
+
+        # Perform search tests
+        _search_test(room_id_1, "something")
+        _search_test(room_id_1, "thing")
+
+        _search_test(room_id_2, "else")
+        _search_test(room_id_2, "se")
+
+        _search_test(None, "foo")
+        _search_test(None, "bar")
+        _search_test(None, "", expected_http_code=400)
+
+    def test_single_room(self):
+        """Test that a single room can be requested correctly"""
+        # Create two test rooms
+        room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+        room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+        room_name_1 = "something"
+        room_name_2 = "else"
+
+        # Set the name for each room
+        self.helper.send_state(
+            room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok,
+        )
+        self.helper.send_state(
+            room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok,
+        )
+
+        url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,)
+        request, channel = self.make_request(
+            "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+        )
+        self.render(request)
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+
+        self.assertIn("room_id", channel.json_body)
+        self.assertIn("name", channel.json_body)
+        self.assertIn("canonical_alias", channel.json_body)
+        self.assertIn("joined_members", channel.json_body)
+        self.assertIn("joined_local_members", channel.json_body)
+        self.assertIn("version", channel.json_body)
+        self.assertIn("creator", channel.json_body)
+        self.assertIn("encryption", channel.json_body)
+        self.assertIn("federatable", channel.json_body)
+        self.assertIn("public", channel.json_body)
+        self.assertIn("join_rules", channel.json_body)
+        self.assertIn("guest_access", channel.json_body)
+        self.assertIn("history_visibility", channel.json_body)
+        self.assertIn("state_events", channel.json_body)
+
+        self.assertEqual(room_id_1, channel.json_body["room_id"])
+
+    def test_room_members(self):
+        """Test that room members can be requested correctly"""
+        # Create two test rooms
+        room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+        room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+        # Have another user join the room
+        user_1 = self.register_user("foo", "pass")
+        user_tok_1 = self.login("foo", "pass")
+        self.helper.join(room_id_1, user_1, tok=user_tok_1)
+
+        # Have another user join the room
+        user_2 = self.register_user("bar", "pass")
+        user_tok_2 = self.login("bar", "pass")
+        self.helper.join(room_id_1, user_2, tok=user_tok_2)
+        self.helper.join(room_id_2, user_2, tok=user_tok_2)
+
+        # Have another user join the room
+        user_3 = self.register_user("foobar", "pass")
+        user_tok_3 = self.login("foobar", "pass")
+        self.helper.join(room_id_2, user_3, tok=user_tok_3)
+
+        url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_1,)
+        request, channel = self.make_request(
+            "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+        )
+        self.render(request)
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+
+        self.assertCountEqual(
+            ["@admin:test", "@foo:test", "@bar:test"], channel.json_body["members"]
+        )
+        self.assertEqual(channel.json_body["total"], 3)
+
+        url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_2,)
+        request, channel = self.make_request(
+            "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+        )
+        self.render(request)
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+
+        self.assertCountEqual(
+            ["@admin:test", "@bar:test", "@foobar:test"], channel.json_body["members"]
+        )
+        self.assertEqual(channel.json_body["total"], 3)
+
+
+class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
+
+    servlets = [
+        synapse.rest.admin.register_servlets,
+        room.register_servlets,
+        login.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, homeserver):
+        self.admin_user = self.register_user("admin", "pass", admin=True)
+        self.admin_user_tok = self.login("admin", "pass")
+
+        self.creator = self.register_user("creator", "test")
+        self.creator_tok = self.login("creator", "test")
+
+        self.second_user_id = self.register_user("second", "test")
+        self.second_tok = self.login("second", "test")
+
+        self.public_room_id = self.helper.create_room_as(
+            self.creator, tok=self.creator_tok, is_public=True
+        )
+        self.url = "/_synapse/admin/v1/join/{}".format(self.public_room_id)
+
+    def test_requester_is_no_admin(self):
+        """
+        If the user is not a server admin, an error 403 is returned.
+        """
+        body = json.dumps({"user_id": self.second_user_id})
+
+        request, channel = self.make_request(
+            "POST",
+            self.url,
+            content=body.encode(encoding="utf_8"),
+            access_token=self.second_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+    def test_invalid_parameter(self):
+        """
+        If a parameter is missing, return an error
+        """
+        body = json.dumps({"unknown_parameter": "@unknown:test"})
+
+        request, channel = self.make_request(
+            "POST",
+            self.url,
+            content=body.encode(encoding="utf_8"),
+            access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
+
+    def test_local_user_does_not_exist(self):
+        """
+        Tests that a lookup for a user that does not exist returns a 404
+        """
+        body = json.dumps({"user_id": "@unknown:test"})
+
+        request, channel = self.make_request(
+            "POST",
+            self.url,
+            content=body.encode(encoding="utf_8"),
+            access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+    def test_remote_user(self):
+        """
+        Check that only local user can join rooms.
+        """
+        body = json.dumps({"user_id": "@not:exist.bla"})
+
+        request, channel = self.make_request(
+            "POST",
+            self.url,
+            content=body.encode(encoding="utf_8"),
+            access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(
+            "This endpoint can only be used with local users",
+            channel.json_body["error"],
+        )
+
+    def test_room_does_not_exist(self):
+        """
+        Check that unknown rooms/server return error 404.
+        """
+        body = json.dumps({"user_id": self.second_user_id})
+        url = "/_synapse/admin/v1/join/!unknown:test"
+
+        request, channel = self.make_request(
+            "POST",
+            url,
+            content=body.encode(encoding="utf_8"),
+            access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual("No known servers", channel.json_body["error"])
+
+    def test_room_is_not_valid(self):
+        """
+        Check that invalid room names, return an error 400.
+        """
+        body = json.dumps({"user_id": self.second_user_id})
+        url = "/_synapse/admin/v1/join/invalidroom"
+
+        request, channel = self.make_request(
+            "POST",
+            url,
+            content=body.encode(encoding="utf_8"),
+            access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(
+            "invalidroom was not legal room ID or room alias",
+            channel.json_body["error"],
+        )
+
+    def test_join_public_room(self):
+        """
+        Test joining a local user to a public room with "JoinRules.PUBLIC"
+        """
+        body = json.dumps({"user_id": self.second_user_id})
+
+        request, channel = self.make_request(
+            "POST",
+            self.url,
+            content=body.encode(encoding="utf_8"),
+            access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(self.public_room_id, channel.json_body["room_id"])
+
+        # Validate if user is a member of the room
+
+        request, channel = self.make_request(
+            "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
+        )
+        self.render(request)
+        self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0])
+
+    def test_join_private_room_if_not_member(self):
+        """
+        Test joining a local user to a private room with "JoinRules.INVITE"
+        when server admin is not member of this room.
+        """
+        private_room_id = self.helper.create_room_as(
+            self.creator, tok=self.creator_tok, is_public=False
+        )
+        url = "/_synapse/admin/v1/join/{}".format(private_room_id)
+        body = json.dumps({"user_id": self.second_user_id})
+
+        request, channel = self.make_request(
+            "POST",
+            url,
+            content=body.encode(encoding="utf_8"),
+            access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+    def test_join_private_room_if_member(self):
+        """
+        Test joining a local user to a private room with "JoinRules.INVITE",
+        when server admin is member of this room.
+        """
+        private_room_id = self.helper.create_room_as(
+            self.creator, tok=self.creator_tok, is_public=False
+        )
+        self.helper.invite(
+            room=private_room_id,
+            src=self.creator,
+            targ=self.admin_user,
+            tok=self.creator_tok,
+        )
+        self.helper.join(
+            room=private_room_id, user=self.admin_user, tok=self.admin_user_tok
+        )
+
+        # Validate if server admin is a member of the room
+
+        request, channel = self.make_request(
+            "GET", "/_matrix/client/r0/joined_rooms", access_token=self.admin_user_tok,
+        )
+        self.render(request)
+        self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
+
+        # Join user to room.
+
+        url = "/_synapse/admin/v1/join/{}".format(private_room_id)
+        body = json.dumps({"user_id": self.second_user_id})
+
+        request, channel = self.make_request(
+            "POST",
+            url,
+            content=body.encode(encoding="utf_8"),
+            access_token=self.admin_user_tok,
+        )
+        self.render(request)
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(private_room_id, channel.json_body["room_id"])
+
+        # Validate if user is a member of the room
+
+        request, channel = self.make_request(
+            "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
+        )
+        self.render(request)
+        self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
+
+    def test_join_private_room_if_owner(self):
+        """
+        Test joining a local user to a private room with "JoinRules.INVITE",
+        when server admin is owner of this room.
+        """
+        private_room_id = self.helper.create_room_as(
+            self.admin_user, tok=self.admin_user_tok, is_public=False
+        )
+        url = "/_synapse/admin/v1/join/{}".format(private_room_id)
+        body = json.dumps({"user_id": self.second_user_id})
+
+        request, channel = self.make_request(
+            "POST",
+            url,
+            content=body.encode(encoding="utf_8"),
+            access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(private_room_id, channel.json_body["room_id"])
+
+        # Validate if user is a member of the room
+
+        request, channel = self.make_request(
+            "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
+        )
+        self.render(request)
+        self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 22d734e763..51941f99f9 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -88,7 +88,28 @@ class RestHelper(object):
             expect_code=expect_code,
         )
 
-    def change_membership(self, room, src, targ, membership, tok=None, expect_code=200):
+    def change_membership(
+        self,
+        room: str,
+        src: str,
+        targ: str,
+        membership: str,
+        extra_data: dict = {},
+        tok: Optional[str] = None,
+        expect_code: int = 200,
+    ) -> None:
+        """
+        Send a membership state event into a room.
+
+        Args:
+            room: The ID of the room to send to
+            src: The mxid of the event sender
+            targ: The mxid of the event's target. The state key
+            membership: The type of membership event
+            extra_data: Extra information to include in the content of the event
+            tok: The user access token to use
+            expect_code: The expected HTTP response code
+        """
         temp_id = self.auth_user_id
         self.auth_user_id = src
 
@@ -97,6 +118,7 @@ class RestHelper(object):
             path = path + "?access_token=%s" % tok
 
         data = {"membership": membership}
+        data.update(extra_data)
 
         request, channel = make_request(
             self.hs.get_reactor(), "PUT", path, json.dumps(data).encode("utf8")
@@ -143,6 +165,26 @@ class RestHelper(object):
 
         return channel.json_body
 
+    def redact(self, room_id, event_id, txn_id=None, tok=None, expect_code=200):
+        if txn_id is None:
+            txn_id = "m%s" % (str(time.time()))
+
+        path = "/_matrix/client/r0/rooms/%s/redact/%s/%s" % (room_id, event_id, txn_id)
+        if tok:
+            path = path + "?access_token=%s" % tok
+
+        request, channel = make_request(
+            self.hs.get_reactor(), "PUT", path, json.dumps({}).encode("utf8")
+        )
+        render(request, self.resource, self.hs.get_reactor())
+
+        assert int(channel.result["code"]) == expect_code, (
+            "Expected: %d, got: %d, resp: %r"
+            % (expect_code, int(channel.result["code"]), channel.result["body"])
+        )
+
+        return channel.json_body
+
     def _read_write_state(
         self,
         room_id: str,
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py
index fa3a3ec1bd..a31e44c97e 100644
--- a/tests/rest/client/v2_alpha/test_sync.py
+++ b/tests/rest/client/v2_alpha/test_sync.py
@@ -16,9 +16,9 @@
 import json
 
 import synapse.rest.admin
-from synapse.api.constants import EventContentFields, EventTypes
+from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
 from synapse.rest.client.v1 import login, room
-from synapse.rest.client.v2_alpha import sync
+from synapse.rest.client.v2_alpha import read_marker, sync
 
 from tests import unittest
 from tests.server import TimedOutException
@@ -324,3 +324,156 @@ class SyncTypingTests(unittest.HomeserverTestCase):
             "GET", sync_url % (access_token, next_batch)
         )
         self.assertRaises(TimedOutException, self.render, request)
+
+
+class UnreadMessagesTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        synapse.rest.admin.register_servlets,
+        login.register_servlets,
+        read_marker.register_servlets,
+        room.register_servlets,
+        sync.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, hs):
+        self.url = "/sync?since=%s"
+        self.next_batch = "s0"
+
+        # Register the first user (used to check the unread counts).
+        self.user_id = self.register_user("kermit", "monkey")
+        self.tok = self.login("kermit", "monkey")
+
+        # Create the room we'll check unread counts for.
+        self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
+
+        # Register the second user (used to send events to the room).
+        self.user2 = self.register_user("kermit2", "monkey")
+        self.tok2 = self.login("kermit2", "monkey")
+
+        # Change the power levels of the room so that the second user can send state
+        # events.
+        self.helper.send_state(
+            self.room_id,
+            EventTypes.PowerLevels,
+            {
+                "users": {self.user_id: 100, self.user2: 100},
+                "users_default": 0,
+                "events": {
+                    "m.room.name": 50,
+                    "m.room.power_levels": 100,
+                    "m.room.history_visibility": 100,
+                    "m.room.canonical_alias": 50,
+                    "m.room.avatar": 50,
+                    "m.room.tombstone": 100,
+                    "m.room.server_acl": 100,
+                    "m.room.encryption": 100,
+                },
+                "events_default": 0,
+                "state_default": 50,
+                "ban": 50,
+                "kick": 50,
+                "redact": 50,
+                "invite": 0,
+            },
+            tok=self.tok,
+        )
+
+    def test_unread_counts(self):
+        """Tests that /sync returns the right value for the unread count (MSC2654)."""
+
+        # Check that our own messages don't increase the unread count.
+        self.helper.send(self.room_id, "hello", tok=self.tok)
+        self._check_unread_count(0)
+
+        # Join the new user and check that this doesn't increase the unread count.
+        self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2)
+        self._check_unread_count(0)
+
+        # Check that the new user sending a message increases our unread count.
+        res = self.helper.send(self.room_id, "hello", tok=self.tok2)
+        self._check_unread_count(1)
+
+        # Send a read receipt to tell the server we've read the latest event.
+        body = json.dumps({"m.read": res["event_id"]}).encode("utf8")
+        request, channel = self.make_request(
+            "POST",
+            "/rooms/%s/read_markers" % self.room_id,
+            body,
+            access_token=self.tok,
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        # Check that the unread counter is back to 0.
+        self._check_unread_count(0)
+
+        # Check that room name changes increase the unread counter.
+        self.helper.send_state(
+            self.room_id, "m.room.name", {"name": "my super room"}, tok=self.tok2,
+        )
+        self._check_unread_count(1)
+
+        # Check that room topic changes increase the unread counter.
+        self.helper.send_state(
+            self.room_id, "m.room.topic", {"topic": "welcome!!!"}, tok=self.tok2,
+        )
+        self._check_unread_count(2)
+
+        # Check that encrypted messages increase the unread counter.
+        self.helper.send_event(self.room_id, EventTypes.Encrypted, {}, tok=self.tok2)
+        self._check_unread_count(3)
+
+        # Check that custom events with a body increase the unread counter.
+        self.helper.send_event(
+            self.room_id, "org.matrix.custom_type", {"body": "hello"}, tok=self.tok2,
+        )
+        self._check_unread_count(4)
+
+        # Check that edits don't increase the unread counter.
+        self.helper.send_event(
+            room_id=self.room_id,
+            type=EventTypes.Message,
+            content={
+                "body": "hello",
+                "msgtype": "m.text",
+                "m.relates_to": {"rel_type": RelationTypes.REPLACE},
+            },
+            tok=self.tok2,
+        )
+        self._check_unread_count(4)
+
+        # Check that notices don't increase the unread counter.
+        self.helper.send_event(
+            room_id=self.room_id,
+            type=EventTypes.Message,
+            content={"body": "hello", "msgtype": "m.notice"},
+            tok=self.tok2,
+        )
+        self._check_unread_count(4)
+
+        # Check that tombstone events changes increase the unread counter.
+        self.helper.send_state(
+            self.room_id,
+            EventTypes.Tombstone,
+            {"replacement_room": "!someroom:test"},
+            tok=self.tok2,
+        )
+        self._check_unread_count(5)
+
+    def _check_unread_count(self, expected_count: True):
+        """Syncs and compares the unread count with the expected value."""
+
+        request, channel = self.make_request(
+            "GET", self.url % self.next_batch, access_token=self.tok,
+        )
+        self.render(request)
+
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        room_entry = channel.json_body["rooms"]["join"][self.room_id]
+        self.assertEqual(
+            room_entry["org.matrix.msc2654.unread_count"], expected_count, room_entry,
+        )
+
+        # Store the next batch for the next request.
+        self.next_batch = channel.json_body["next_batch"]
diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py
index 99eb477149..6850c666be 100644
--- a/tests/rest/key/v2/test_remote_key_resource.py
+++ b/tests/rest/key/v2/test_remote_key_resource.py
@@ -53,7 +53,7 @@ class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase):
         Tell the mock http client to expect an outgoing GET request for the given key
         """
 
-        def get_json(destination, path, ignore_backoff=False, **kwargs):
+        async def get_json(destination, path, ignore_backoff=False, **kwargs):
             self.assertTrue(ignore_backoff)
             self.assertEqual(destination, server_name)
             key_id = "%s:%s" % (signing_key.alg, signing_key.version)
@@ -177,7 +177,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
 
         # wire up outbound POST /key/v2/query requests from hs2 so that they
         # will be forwarded to hs1
-        def post_json(destination, path, data):
+        async def post_json(destination, path, data):
             self.assertEqual(destination, self.hs.hostname)
             self.assertEqual(
                 path, "/_matrix/key/v2/query",
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index 2826211f32..74765a582b 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -12,8 +12,11 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
+import json
 import os
+import re
+
+from mock import patch
 
 import attr
 
@@ -131,7 +134,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         self.reactor.nameResolver = Resolver()
 
     def test_cache_returns_correct_type(self):
-        self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")]
+        self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
 
         request, channel = self.make_request(
             "GET", "url_preview?url=http://matrix.org", shorthand=False
@@ -187,7 +190,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         )
 
     def test_non_ascii_preview_httpequiv(self):
-        self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")]
+        self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
 
         end_content = (
             b"<html><head>"
@@ -221,7 +224,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430")
 
     def test_non_ascii_preview_content_type(self):
-        self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")]
+        self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
 
         end_content = (
             b"<html><head>"
@@ -254,7 +257,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430")
 
     def test_overlong_title(self):
-        self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")]
+        self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
 
         end_content = (
             b"<html><head>"
@@ -292,7 +295,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         """
         IP addresses can be previewed directly.
         """
-        self.lookups["example.com"] = [(IPv4Address, "8.8.8.8")]
+        self.lookups["example.com"] = [(IPv4Address, "10.1.2.3")]
 
         request, channel = self.make_request(
             "GET", "url_preview?url=http://example.com", shorthand=False
@@ -439,7 +442,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         # Hardcode the URL resolving to the IP we want.
         self.lookups["example.com"] = [
             (IPv4Address, "1.1.1.2"),
-            (IPv4Address, "8.8.8.8"),
+            (IPv4Address, "10.1.2.3"),
         ]
 
         request, channel = self.make_request(
@@ -518,7 +521,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         """
         Accept-Language header is sent to the remote server
         """
-        self.lookups["example.com"] = [(IPv4Address, "8.8.8.8")]
+        self.lookups["example.com"] = [(IPv4Address, "10.1.2.3")]
 
         # Build and make a request to the server
         request, channel = self.make_request(
@@ -562,3 +565,126 @@ class URLPreviewTests(unittest.HomeserverTestCase):
             ),
             server.data,
         )
+
+    def test_oembed_photo(self):
+        """Test an oEmbed endpoint which returns a 'photo' type which redirects the preview to a new URL."""
+        # Route the HTTP version to an HTTP endpoint so that the tests work.
+        with patch.dict(
+            "synapse.rest.media.v1.preview_url_resource._oembed_patterns",
+            {
+                re.compile(
+                    r"http://twitter\.com/.+/status/.+"
+                ): "http://publish.twitter.com/oembed",
+            },
+            clear=True,
+        ):
+
+            self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")]
+            self.lookups["cdn.twitter.com"] = [(IPv4Address, "10.1.2.3")]
+
+            result = {
+                "version": "1.0",
+                "type": "photo",
+                "url": "http://cdn.twitter.com/matrixdotorg",
+            }
+            oembed_content = json.dumps(result).encode("utf-8")
+
+            end_content = (
+                b"<html><head>"
+                b"<title>Some Title</title>"
+                b'<meta property="og:description" content="hi" />'
+                b"</head></html>"
+            )
+
+            request, channel = self.make_request(
+                "GET",
+                "url_preview?url=http://twitter.com/matrixdotorg/status/12345",
+                shorthand=False,
+            )
+            request.render(self.preview_url)
+            self.pump()
+
+            client = self.reactor.tcpClients[0][2].buildProtocol(None)
+            server = AccumulatingProtocol()
+            server.makeConnection(FakeTransport(client, self.reactor))
+            client.makeConnection(FakeTransport(server, self.reactor))
+            client.dataReceived(
+                (
+                    b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+                    b'Content-Type: application/json; charset="utf8"\r\n\r\n'
+                )
+                % (len(oembed_content),)
+                + oembed_content
+            )
+
+            self.pump()
+
+            client = self.reactor.tcpClients[1][2].buildProtocol(None)
+            server = AccumulatingProtocol()
+            server.makeConnection(FakeTransport(client, self.reactor))
+            client.makeConnection(FakeTransport(server, self.reactor))
+            client.dataReceived(
+                (
+                    b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+                    b'Content-Type: text/html; charset="utf8"\r\n\r\n'
+                )
+                % (len(end_content),)
+                + end_content
+            )
+
+            self.pump()
+
+            self.assertEqual(channel.code, 200)
+            self.assertEqual(
+                channel.json_body, {"og:title": "Some Title", "og:description": "hi"}
+            )
+
+    def test_oembed_rich(self):
+        """Test an oEmbed endpoint which returns HTML content via the 'rich' type."""
+        # Route the HTTP version to an HTTP endpoint so that the tests work.
+        with patch.dict(
+            "synapse.rest.media.v1.preview_url_resource._oembed_patterns",
+            {
+                re.compile(
+                    r"http://twitter\.com/.+/status/.+"
+                ): "http://publish.twitter.com/oembed",
+            },
+            clear=True,
+        ):
+
+            self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")]
+
+            result = {
+                "version": "1.0",
+                "type": "rich",
+                "html": "<div>Content Preview</div>",
+            }
+            end_content = json.dumps(result).encode("utf-8")
+
+            request, channel = self.make_request(
+                "GET",
+                "url_preview?url=http://twitter.com/matrixdotorg/status/12345",
+                shorthand=False,
+            )
+            request.render(self.preview_url)
+            self.pump()
+
+            client = self.reactor.tcpClients[0][2].buildProtocol(None)
+            server = AccumulatingProtocol()
+            server.makeConnection(FakeTransport(client, self.reactor))
+            client.makeConnection(FakeTransport(server, self.reactor))
+            client.dataReceived(
+                (
+                    b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+                    b'Content-Type: application/json; charset="utf8"\r\n\r\n'
+                )
+                % (len(end_content),)
+                + end_content
+            )
+
+            self.pump()
+            self.assertEqual(channel.code, 200)
+            self.assertEqual(
+                channel.json_body,
+                {"og:title": None, "og:description": "Content Preview"},
+            )
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index 99908edba3..7f70353b0d 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -275,7 +275,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
             self.server_notices_manager.get_or_create_notice_room_for_user(self.user_id)
         )
 
-        token = self.get_success(self.event_source.get_current_token())
+        token = self.event_source.get_current_token()
         events, _ = self.get_success(
             self.store.get_recent_events_for_room(
                 room_id, limit=100, end_token=token.room_key
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 5a50e4fdd4..319e2c2325 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -323,7 +323,7 @@ class UpsertManyTests(unittest.HomeserverTestCase):
 
         self.table_name = "table_" + hs.get_secrets().token_hex(6)
         self.get_success(
-            self.storage.db.runInteraction(
+            self.storage.db_pool.runInteraction(
                 "create",
                 lambda x, *a: x.execute(*a),
                 "CREATE TABLE %s (id INTEGER, username TEXT, value TEXT)"
@@ -331,7 +331,7 @@ class UpsertManyTests(unittest.HomeserverTestCase):
             )
         )
         self.get_success(
-            self.storage.db.runInteraction(
+            self.storage.db_pool.runInteraction(
                 "index",
                 lambda x, *a: x.execute(*a),
                 "CREATE UNIQUE INDEX %sindex ON %s(id, username)"
@@ -354,9 +354,9 @@ class UpsertManyTests(unittest.HomeserverTestCase):
         value_values = [["hello"], ["there"]]
 
         self.get_success(
-            self.storage.db.runInteraction(
+            self.storage.db_pool.runInteraction(
                 "test",
-                self.storage.db.simple_upsert_many_txn,
+                self.storage.db_pool.simple_upsert_many_txn,
                 self.table_name,
                 key_names,
                 key_values,
@@ -367,7 +367,7 @@ class UpsertManyTests(unittest.HomeserverTestCase):
 
         # Check results are what we expect
         res = self.get_success(
-            self.storage.db.simple_select_list(
+            self.storage.db_pool.simple_select_list(
                 self.table_name, None, ["id, username, value"]
             )
         )
@@ -381,9 +381,9 @@ class UpsertManyTests(unittest.HomeserverTestCase):
         value_values = [["bleb"]]
 
         self.get_success(
-            self.storage.db.runInteraction(
+            self.storage.db_pool.runInteraction(
                 "test",
-                self.storage.db.simple_upsert_many_txn,
+                self.storage.db_pool.simple_upsert_many_txn,
                 self.table_name,
                 key_names,
                 key_values,
@@ -394,7 +394,7 @@ class UpsertManyTests(unittest.HomeserverTestCase):
 
         # Check results are what we expect
         res = self.get_success(
-            self.storage.db.simple_select_list(
+            self.storage.db_pool.simple_select_list(
                 self.table_name, None, ["id, username, value"]
             )
         )
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index ef296e7dab..1b516b7976 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -24,11 +24,11 @@ from twisted.internet import defer
 
 from synapse.appservice import ApplicationService, ApplicationServiceState
 from synapse.config._base import ConfigError
-from synapse.storage.data_stores.main.appservice import (
+from synapse.storage.database import DatabasePool, make_conn
+from synapse.storage.databases.main.appservice import (
     ApplicationServiceStore,
     ApplicationServiceTransactionStore,
 )
-from synapse.storage.database import Database, make_conn
 
 from tests import unittest
 from tests.utils import setup_test_homeserver
@@ -391,7 +391,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
 
 # required for ApplicationServiceTransactionStoreTestCase tests
 class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore):
-    def __init__(self, database: Database, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs):
         super(TestTransactionStore, self).__init__(database, db_conn, hs)
 
 
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index 940b166129..2efbc97c2e 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -9,7 +9,9 @@ from tests import unittest
 
 class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
     def prepare(self, reactor, clock, homeserver):
-        self.updates = self.hs.get_datastore().db.updates  # type: BackgroundUpdater
+        self.updates = (
+            self.hs.get_datastore().db_pool.updates
+        )  # type: BackgroundUpdater
         # the base test class should have run the real bg updates for us
         self.assertTrue(
             self.get_success(self.updates.has_completed_background_updates())
@@ -29,7 +31,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
 
         store = self.hs.get_datastore()
         self.get_success(
-            store.db.simple_insert(
+            store.db_pool.simple_insert(
                 "background_updates",
                 values={"update_name": "test_update", "progress_json": '{"my_key": 1}'},
             )
@@ -40,7 +42,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
         def update(progress, count):
             yield self.clock.sleep((count * duration_ms) / 1000)
             progress = {"my_key": progress["my_key"] + 1}
-            yield store.db.runInteraction(
+            yield store.db_pool.runInteraction(
                 "update_progress",
                 self.updates._background_update_progress_txn,
                 "test_update",
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index b589506c60..efcaeef1e7 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -21,7 +21,7 @@ from mock import Mock
 from twisted.internet import defer
 
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
 from synapse.storage.engines import create_engine
 
 from tests import unittest
@@ -57,7 +57,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
         fake_engine = Mock(wraps=engine)
         fake_engine.can_native_upsert = False
 
-        db = Database(Mock(), Mock(config=sqlite_config), fake_engine)
+        db = DatabasePool(Mock(), Mock(config=sqlite_config), fake_engine)
         db._db_pool = self.db_pool
 
         self.datastore = SQLBaseStore(db, None, hs)
@@ -66,7 +66,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
     def test_insert_1col(self):
         self.mock_txn.rowcount = 1
 
-        yield self.datastore.db.simple_insert(
+        yield self.datastore.db_pool.simple_insert(
             table="tablename", values={"columname": "Value"}
         )
 
@@ -78,7 +78,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
     def test_insert_3cols(self):
         self.mock_txn.rowcount = 1
 
-        yield self.datastore.db.simple_insert(
+        yield self.datastore.db_pool.simple_insert(
             table="tablename",
             # Use OrderedDict() so we can assert on the SQL generated
             values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]),
@@ -93,7 +93,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
         self.mock_txn.rowcount = 1
         self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)]))
 
-        value = yield self.datastore.db.simple_select_one_onecol(
+        value = yield self.datastore.db_pool.simple_select_one_onecol(
             table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol"
         )
 
@@ -107,7 +107,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
         self.mock_txn.rowcount = 1
         self.mock_txn.fetchone.return_value = (1, 2, 3)
 
-        ret = yield self.datastore.db.simple_select_one(
+        ret = yield self.datastore.db_pool.simple_select_one(
             table="tablename",
             keyvalues={"keycol": "TheKey"},
             retcols=["colA", "colB", "colC"],
@@ -123,7 +123,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
         self.mock_txn.rowcount = 0
         self.mock_txn.fetchone.return_value = None
 
-        ret = yield self.datastore.db.simple_select_one(
+        ret = yield self.datastore.db_pool.simple_select_one(
             table="tablename",
             keyvalues={"keycol": "Not here"},
             retcols=["colA"],
@@ -138,7 +138,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
         self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)]))
         self.mock_txn.description = (("colA", None, None, None, None, None, None),)
 
-        ret = yield self.datastore.db.simple_select_list(
+        ret = yield self.datastore.db_pool.simple_select_list(
             table="tablename", keyvalues={"keycol": "A set"}, retcols=["colA"]
         )
 
@@ -151,7 +151,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
     def test_update_one_1col(self):
         self.mock_txn.rowcount = 1
 
-        yield self.datastore.db.simple_update_one(
+        yield self.datastore.db_pool.simple_update_one(
             table="tablename",
             keyvalues={"keycol": "TheKey"},
             updatevalues={"columnname": "New Value"},
@@ -166,7 +166,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
     def test_update_one_4cols(self):
         self.mock_txn.rowcount = 1
 
-        yield self.datastore.db.simple_update_one(
+        yield self.datastore.db_pool.simple_update_one(
             table="tablename",
             keyvalues=OrderedDict([("colA", 1), ("colB", 2)]),
             updatevalues=OrderedDict([("colC", 3), ("colD", 4)]),
@@ -181,7 +181,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
     def test_delete_one(self):
         self.mock_txn.rowcount = 1
 
-        yield self.datastore.db.simple_delete_one(
+        yield self.datastore.db_pool.simple_delete_one(
             table="tablename", keyvalues={"keycol": "Go away"}
         )
 
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index 43425c969a..3fab5a5248 100644
--- a/tests/storage/test_cleanup_extrems.py
+++ b/tests/storage/test_cleanup_extrems.py
@@ -47,12 +47,12 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
         """
         # Make sure we don't clash with in progress updates.
         self.assertTrue(
-            self.store.db.updates._all_done, "Background updates are still ongoing"
+            self.store.db_pool.updates._all_done, "Background updates are still ongoing"
         )
 
         schema_path = os.path.join(
             prepare_database.dir_path,
-            "data_stores",
+            "databases",
             "main",
             "schema",
             "delta",
@@ -64,19 +64,19 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
             prepare_database.executescript(txn, schema_path)
 
         self.get_success(
-            self.store.db.runInteraction(
+            self.store.db_pool.runInteraction(
                 "test_delete_forward_extremities", run_delta_file
             )
         )
 
         # Ugh, have to reset this flag
-        self.store.db.updates._all_done = False
+        self.store.db_pool.updates._all_done = False
 
         while not self.get_success(
-            self.store.db.updates.has_completed_background_updates()
+            self.store.db_pool.updates.has_completed_background_updates()
         ):
             self.get_success(
-                self.store.db.updates.do_next_background_update(100), by=0.1
+                self.store.db_pool.updates.do_next_background_update(100), by=0.1
             )
 
     def test_soft_failed_extremities_handled_correctly(self):
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 3b483bc7f0..224ea6fd79 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -86,7 +86,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         self.pump(0)
 
         result = self.get_success(
-            self.store.db.simple_select_list(
+            self.store.db_pool.simple_select_list(
                 table="user_ips",
                 keyvalues={"user_id": user_id},
                 retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
@@ -117,7 +117,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         self.pump(0)
 
         result = self.get_success(
-            self.store.db.simple_select_list(
+            self.store.db_pool.simple_select_list(
                 table="user_ips",
                 keyvalues={"user_id": user_id},
                 retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
@@ -204,10 +204,10 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
     def test_devices_last_seen_bg_update(self):
         # First make sure we have completed all updates.
         while not self.get_success(
-            self.store.db.updates.has_completed_background_updates()
+            self.store.db_pool.updates.has_completed_background_updates()
         ):
             self.get_success(
-                self.store.db.updates.do_next_background_update(100), by=0.1
+                self.store.db_pool.updates.do_next_background_update(100), by=0.1
             )
 
         user_id = "@user:id"
@@ -225,7 +225,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
 
         # But clear the associated entry in devices table
         self.get_success(
-            self.store.db.simple_update(
+            self.store.db_pool.simple_update(
                 table="devices",
                 keyvalues={"user_id": user_id, "device_id": device_id},
                 updatevalues={"last_seen": None, "ip": None, "user_agent": None},
@@ -252,7 +252,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
 
         # Register the background update to run again.
         self.get_success(
-            self.store.db.simple_insert(
+            self.store.db_pool.simple_insert(
                 table="background_updates",
                 values={
                     "update_name": "devices_last_seen",
@@ -263,14 +263,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         )
 
         # ... and tell the DataStore that it hasn't finished all updates yet
-        self.store.db.updates._all_done = False
+        self.store.db_pool.updates._all_done = False
 
         # Now let's actually drive the updates to completion
         while not self.get_success(
-            self.store.db.updates.has_completed_background_updates()
+            self.store.db_pool.updates.has_completed_background_updates()
         ):
             self.get_success(
-                self.store.db.updates.do_next_background_update(100), by=0.1
+                self.store.db_pool.updates.do_next_background_update(100), by=0.1
             )
 
         # We should now get the correct result again
@@ -293,10 +293,10 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
     def test_old_user_ips_pruned(self):
         # First make sure we have completed all updates.
         while not self.get_success(
-            self.store.db.updates.has_completed_background_updates()
+            self.store.db_pool.updates.has_completed_background_updates()
         ):
             self.get_success(
-                self.store.db.updates.do_next_background_update(100), by=0.1
+                self.store.db_pool.updates.do_next_background_update(100), by=0.1
             )
 
         user_id = "@user:id"
@@ -315,7 +315,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
 
         # We should see that in the DB
         result = self.get_success(
-            self.store.db.simple_select_list(
+            self.store.db_pool.simple_select_list(
                 table="user_ips",
                 keyvalues={"user_id": user_id},
                 retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
@@ -341,7 +341,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
 
         # We should get no results.
         result = self.get_success(
-            self.store.db.simple_select_list(
+            self.store.db_pool.simple_select_list(
                 table="user_ips",
                 keyvalues={"user_id": user_id},
                 retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 3aeec0dc0f..d4c3b867e3 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -56,7 +56,9 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
             )
 
         for i in range(0, 20):
-            self.get_success(self.store.db.runInteraction("insert", insert_event, i))
+            self.get_success(
+                self.store.db_pool.runInteraction("insert", insert_event, i)
+            )
 
         # this should get the last ten
         r = self.get_success(self.store.get_prev_events_for_room(room_id))
@@ -81,13 +83,13 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
 
         for i in range(0, 20):
             self.get_success(
-                self.store.db.runInteraction("insert", insert_event, i, room1)
+                self.store.db_pool.runInteraction("insert", insert_event, i, room1)
             )
             self.get_success(
-                self.store.db.runInteraction("insert", insert_event, i, room2)
+                self.store.db_pool.runInteraction("insert", insert_event, i, room2)
             )
             self.get_success(
-                self.store.db.runInteraction("insert", insert_event, i, room3)
+                self.store.db_pool.runInteraction("insert", insert_event, i, room3)
             )
 
         # Test simple case
@@ -164,7 +166,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
 
             depth = depth_map[event_id]
 
-            self.store.db.simple_insert_txn(
+            self.store.db_pool.simple_insert_txn(
                 txn,
                 table="events",
                 values={
@@ -179,7 +181,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
                 },
             )
 
-            self.store.db.simple_insert_many_txn(
+            self.store.db_pool.simple_insert_many_txn(
                 txn,
                 table="event_auth",
                 values=[
@@ -192,7 +194,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
         for event_id in auth_graph:
             next_stream_ordering += 1
             self.get_success(
-                self.store.db.runInteraction(
+                self.store.db_pool.runInteraction(
                     "insert", insert_event, event_id, next_stream_ordering
                 )
             )
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index b45bc9c115..857db071d4 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -39,14 +39,18 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_get_unread_push_actions_for_user_in_range_for_http(self):
-        yield self.store.get_unread_push_actions_for_user_in_range_for_http(
-            USER_ID, 0, 1000, 20
+        yield defer.ensureDeferred(
+            self.store.get_unread_push_actions_for_user_in_range_for_http(
+                USER_ID, 0, 1000, 20
+            )
         )
 
     @defer.inlineCallbacks
     def test_get_unread_push_actions_for_user_in_range_for_email(self):
-        yield self.store.get_unread_push_actions_for_user_in_range_for_email(
-            USER_ID, 0, 1000, 20
+        yield defer.ensureDeferred(
+            self.store.get_unread_push_actions_for_user_in_range_for_email(
+                USER_ID, 0, 1000, 20
+            )
         )
 
     @defer.inlineCallbacks
@@ -56,7 +60,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
 
         @defer.inlineCallbacks
         def _assert_counts(noitf_count, highlight_count):
-            counts = yield self.store.db.runInteraction(
+            counts = yield self.store.db_pool.runInteraction(
                 "", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0
             )
             self.assertEquals(
@@ -72,10 +76,12 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
             event.internal_metadata.stream_ordering = stream
             event.depth = stream
 
-            yield self.store.add_push_actions_to_staging(
-                event.event_id, {user_id: action}
+            yield defer.ensureDeferred(
+                self.store.add_push_actions_to_staging(
+                    event.event_id, {user_id: action}
+                )
             )
-            yield self.store.db.runInteraction(
+            yield self.store.db_pool.runInteraction(
                 "",
                 self.persist_events_store._set_push_actions_for_event_and_users_txn,
                 [(event, None)],
@@ -83,12 +89,12 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
             )
 
         def _rotate(stream):
-            return self.store.db.runInteraction(
+            return self.store.db_pool.runInteraction(
                 "", self.store._rotate_notifs_before_txn, stream
             )
 
         def _mark_read(stream, depth):
-            return self.store.db.runInteraction(
+            return self.store.db_pool.runInteraction(
                 "",
                 self.store._remove_old_push_actions_before_txn,
                 room_id,
@@ -117,7 +123,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
         yield _inject_actions(6, PlAIN_NOTIF)
         yield _rotate(7)
 
-        yield self.store.db.simple_delete(
+        yield self.store.db_pool.simple_delete(
             table="event_push_actions", keyvalues={"1": 1}, desc=""
         )
 
@@ -136,7 +142,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
     @defer.inlineCallbacks
     def test_find_first_stream_ordering_after_ts(self):
         def add_event(so, ts):
-            return self.store.db.simple_insert(
+            return self.store.db_pool.simple_insert(
                 "events",
                 {
                     "stream_ordering": so,
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 55e9ecf264..e845410dae 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
 from synapse.storage.util.id_generators import MultiWriterIdGenerator
 
 from tests.unittest import HomeserverTestCase
@@ -27,9 +27,9 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
 
     def prepare(self, reactor, clock, hs):
         self.store = hs.get_datastore()
-        self.db = self.store.db  # type: Database
+        self.db_pool = self.store.db_pool  # type: DatabasePool
 
-        self.get_success(self.db.runInteraction("_setup_db", self._setup_db))
+        self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
 
     def _setup_db(self, txn):
         txn.execute("CREATE SEQUENCE foobar_seq")
@@ -47,7 +47,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         def _create(conn):
             return MultiWriterIdGenerator(
                 conn,
-                self.db,
+                self.db_pool,
                 instance_name=instance_name,
                 table="foobar",
                 instance_column="instance_name",
@@ -55,7 +55,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
                 sequence_name="foobar_seq",
             )
 
-        return self.get_success(self.db.runWithConnection(_create))
+        return self.get_success(self.db_pool.runWithConnection(_create))
 
     def _insert_rows(self, instance_name: str, number: int):
         def _insert(txn):
@@ -65,7 +65,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
                     (instance_name,),
                 )
 
-        self.get_success(self.db.runInteraction("test_single_instance", _insert))
+        self.get_success(self.db_pool.runInteraction("test_single_instance", _insert))
 
     def test_empty(self):
         """Test an ID generator against an empty database gives sensible
@@ -178,7 +178,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
             self.assertEqual(id_gen.get_positions(), {"master": 7})
             self.assertEqual(id_gen.get_current_token("master"), 7)
 
-        self.get_success(self.db.runInteraction("test", _get_next_txn))
+        self.get_success(self.db_pool.runInteraction("test", _get_next_txn))
 
         self.assertEqual(id_gen.get_positions(), {"master": 8})
         self.assertEqual(id_gen.get_current_token("master"), 8)
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index 9c04e92577..259f2215f1 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -78,7 +78,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
         # XXX why are we doing this here? this function is only run at startup
         # so it is odd to re-run it here.
         self.get_success(
-            self.store.db.runInteraction(
+            self.store.db_pool.runInteraction(
                 "initialise", self.store._initialise_reserved_users, threepids
             )
         )
@@ -204,7 +204,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
                     self.store.user_add_threepid(user, "email", email, now, now)
                 )
 
-        d = self.store.db.runInteraction(
+        d = self.store.db_pool.runInteraction(
             "initialise", self.store._initialise_reserved_users, threepids
         )
         self.get_success(d)
@@ -280,7 +280,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
         ]
 
         self.hs.config.mau_limits_reserved_threepids = threepids
-        d = self.store.db.runInteraction(
+        d = self.store.db_pool.runInteraction(
             "initialise", self.store._initialise_reserved_users, threepids
         )
         self.get_success(d)
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index b9fafaa1a6..a6012c973d 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -13,6 +13,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from twisted.internet import defer
+
 from synapse.rest.client.v1 import room
 
 from tests.unittest import HomeserverTestCase
@@ -49,7 +51,9 @@ class PurgeTests(HomeserverTestCase):
         event = self.successResultOf(event)
 
         # Purge everything before this topological token
-        purge = storage.purge_events.purge_history(self.room_id, event, True)
+        purge = defer.ensureDeferred(
+            storage.purge_events.purge_history(self.room_id, event, True)
+        )
         self.pump()
         self.assertEqual(self.successResultOf(purge), None)
 
@@ -88,7 +92,7 @@ class PurgeTests(HomeserverTestCase):
         )
 
         # Purge everything before this topological token
-        purge = storage.purge_history(self.room_id, event, True)
+        purge = defer.ensureDeferred(storage.purge_history(self.room_id, event, True))
         self.pump()
         f = self.failureResultOf(purge)
         self.assertIn("greater than forward", f.value.args[0])
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index db3667dc43..41511d479f 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -237,7 +237,9 @@ class RedactionTestCase(unittest.HomeserverTestCase):
 
             @defer.inlineCallbacks
             def build(self, prev_event_ids):
-                built_event = yield self._base_builder.build(prev_event_ids)
+                built_event = yield defer.ensureDeferred(
+                    self._base_builder.build(prev_event_ids)
+                )
 
                 built_event._event_id = self._event_id
                 built_event._dict["event_id"] = self._event_id
@@ -341,7 +343,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
         )
 
         event_json = self.get_success(
-            self.store.db.simple_select_one_onecol(
+            self.store.db_pool.simple_select_one_onecol(
                 table="event_json",
                 keyvalues={"event_id": msg_event.event_id},
                 retcol="json",
@@ -359,7 +361,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
         self.reactor.advance(60 * 60 * 2)
 
         event_json = self.get_success(
-            self.store.db.simple_select_one_onecol(
+            self.store.db_pool.simple_select_one_onecol(
                 table="event_json",
                 keyvalues={"event_id": msg_event.event_id},
                 retcol="json",
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index 1d77b4a2d6..d07b985a8e 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -37,11 +37,13 @@ class RoomStoreTestCase(unittest.TestCase):
         self.alias = RoomAlias.from_string("#a-room-name:test")
         self.u_creator = UserID.from_string("@creator:test")
 
-        yield self.store.store_room(
-            self.room.to_string(),
-            room_creator_user_id=self.u_creator.to_string(),
-            is_public=True,
-            room_version=RoomVersions.V1,
+        yield defer.ensureDeferred(
+            self.store.store_room(
+                self.room.to_string(),
+                room_creator_user_id=self.u_creator.to_string(),
+                is_public=True,
+                room_version=RoomVersions.V1,
+            )
         )
 
     @defer.inlineCallbacks
@@ -88,17 +90,21 @@ class RoomEventsStoreTestCase(unittest.TestCase):
 
         self.room = RoomID.from_string("!abcde:test")
 
-        yield self.store.store_room(
-            self.room.to_string(),
-            room_creator_user_id="@creator:text",
-            is_public=True,
-            room_version=RoomVersions.V1,
+        yield defer.ensureDeferred(
+            self.store.store_room(
+                self.room.to_string(),
+                room_creator_user_id="@creator:text",
+                is_public=True,
+                room_version=RoomVersions.V1,
+            )
         )
 
     @defer.inlineCallbacks
     def inject_room_event(self, **kwargs):
-        yield self.storage.persistence.persist_event(
-            self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
+        yield defer.ensureDeferred(
+            self.storage.persistence.persist_event(
+                self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
+            )
         )
 
     @defer.inlineCallbacks
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index f282921538..17c9da4838 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -179,10 +179,10 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
     def test_can_rerun_update(self):
         # First make sure we have completed all updates.
         while not self.get_success(
-            self.store.db.updates.has_completed_background_updates()
+            self.store.db_pool.updates.has_completed_background_updates()
         ):
             self.get_success(
-                self.store.db.updates.do_next_background_update(100), by=0.1
+                self.store.db_pool.updates.do_next_background_update(100), by=0.1
             )
 
         # Now let's create a room, which will insert a membership
@@ -192,7 +192,7 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
 
         # Register the background update to run again.
         self.get_success(
-            self.store.db.simple_insert(
+            self.store.db_pool.simple_insert(
                 table="background_updates",
                 values={
                     "update_name": "current_state_events_membership",
@@ -203,12 +203,12 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
         )
 
         # ... and tell the DataStore that it hasn't finished all updates yet
-        self.store.db.updates._all_done = False
+        self.store.db_pool.updates._all_done = False
 
         # Now let's actually drive the updates to completion
         while not self.get_success(
-            self.store.db.updates.has_completed_background_updates()
+            self.store.db_pool.updates.has_completed_background_updates()
         ):
             self.get_success(
-                self.store.db.updates.do_next_background_update(100), by=0.1
+                self.store.db_pool.updates.do_next_background_update(100), by=0.1
             )
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index a0e133cd4a..8bd12fa847 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -44,11 +44,13 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         self.room = RoomID.from_string("!abc123:test")
 
-        yield self.store.store_room(
-            self.room.to_string(),
-            room_creator_user_id="@creator:text",
-            is_public=True,
-            room_version=RoomVersions.V1,
+        yield defer.ensureDeferred(
+            self.store.store_room(
+                self.room.to_string(),
+                room_creator_user_id="@creator:text",
+                is_public=True,
+                room_version=RoomVersions.V1,
+            )
         )
 
     @defer.inlineCallbacks
@@ -68,7 +70,9 @@ class StateStoreTestCase(tests.unittest.TestCase):
             self.event_creation_handler.create_new_client_event(builder)
         )
 
-        yield self.storage.persistence.persist_event(event, context)
+        yield defer.ensureDeferred(
+            self.storage.persistence.persist_event(event, context)
+        )
 
         return event
 
@@ -87,8 +91,8 @@ class StateStoreTestCase(tests.unittest.TestCase):
             self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
         )
 
-        state_group_map = yield self.storage.state.get_state_groups_ids(
-            self.room, [e2.event_id]
+        state_group_map = yield defer.ensureDeferred(
+            self.storage.state.get_state_groups_ids(self.room, [e2.event_id])
         )
         self.assertEqual(len(state_group_map), 1)
         state_map = list(state_group_map.values())[0]
@@ -106,8 +110,8 @@ class StateStoreTestCase(tests.unittest.TestCase):
             self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
         )
 
-        state_group_map = yield self.storage.state.get_state_groups(
-            self.room, [e2.event_id]
+        state_group_map = yield defer.ensureDeferred(
+            self.storage.state.get_state_groups(self.room, [e2.event_id])
         )
         self.assertEqual(len(state_group_map), 1)
         state_list = list(state_group_map.values())[0]
@@ -148,7 +152,9 @@ class StateStoreTestCase(tests.unittest.TestCase):
         )
 
         # check we get the full state as of the final event
-        state = yield self.storage.state.get_state_for_event(e5.event_id)
+        state = yield defer.ensureDeferred(
+            self.storage.state.get_state_for_event(e5.event_id)
+        )
 
         self.assertIsNotNone(e4)
 
@@ -164,22 +170,28 @@ class StateStoreTestCase(tests.unittest.TestCase):
         )
 
         # check we can filter to the m.room.name event (with a '' state key)
-        state = yield self.storage.state.get_state_for_event(
-            e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
+        state = yield defer.ensureDeferred(
+            self.storage.state.get_state_for_event(
+                e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
+            )
         )
 
         self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
 
         # check we can filter to the m.room.name event (with a wildcard None state key)
-        state = yield self.storage.state.get_state_for_event(
-            e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
+        state = yield defer.ensureDeferred(
+            self.storage.state.get_state_for_event(
+                e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
+            )
         )
 
         self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
 
         # check we can grab the m.room.member events (with a wildcard None state key)
-        state = yield self.storage.state.get_state_for_event(
-            e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
+        state = yield defer.ensureDeferred(
+            self.storage.state.get_state_for_event(
+                e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
+            )
         )
 
         self.assertStateMapEqual(
@@ -188,12 +200,14 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         # check we can grab a specific room member without filtering out the
         # other event types
-        state = yield self.storage.state.get_state_for_event(
-            e5.event_id,
-            state_filter=StateFilter(
-                types={EventTypes.Member: {self.u_alice.to_string()}},
-                include_others=True,
-            ),
+        state = yield defer.ensureDeferred(
+            self.storage.state.get_state_for_event(
+                e5.event_id,
+                state_filter=StateFilter(
+                    types={EventTypes.Member: {self.u_alice.to_string()}},
+                    include_others=True,
+                ),
+            )
         )
 
         self.assertStateMapEqual(
@@ -206,11 +220,13 @@ class StateStoreTestCase(tests.unittest.TestCase):
         )
 
         # check that we can grab everything except members
-        state = yield self.storage.state.get_state_for_event(
-            e5.event_id,
-            state_filter=StateFilter(
-                types={EventTypes.Member: set()}, include_others=True
-            ),
+        state = yield defer.ensureDeferred(
+            self.storage.state.get_state_for_event(
+                e5.event_id,
+                state_filter=StateFilter(
+                    types={EventTypes.Member: set()}, include_others=True
+                ),
+            )
         )
 
         self.assertStateMapEqual(
@@ -222,8 +238,8 @@ class StateStoreTestCase(tests.unittest.TestCase):
         #######################################################
 
         room_id = self.room.to_string()
-        group_ids = yield self.storage.state.get_state_groups_ids(
-            room_id, [e5.event_id]
+        group_ids = yield defer.ensureDeferred(
+            self.storage.state.get_state_groups_ids(room_id, [e5.event_id])
         )
         group = list(group_ids.keys())[0]
 
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 87a16d7d7a..c2f12c2741 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -95,7 +95,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
         prev_events that said event references.
         """
 
-        def post_json(destination, path, data, headers=None, timeout=0):
+        async def post_json(destination, path, data, headers=None, timeout=0):
             # If it asks us for new missing events, give them NOTHING
             if path.startswith("/_matrix/federation/v1/get_missing_events/"):
                 return {"events": []}
diff --git a/tests/test_server.py b/tests/test_server.py
index 073b2362cc..d628070e48 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -157,6 +157,29 @@ class JsonResourceTests(unittest.TestCase):
         self.assertEqual(channel.json_body["error"], "Unrecognized request")
         self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
 
+    def test_head_request(self):
+        """
+        JsonResource.handler_for_request gives correctly decoded URL args to
+        the callback, while Twisted will give the raw bytes of URL query
+        arguments.
+        """
+
+        def _callback(request, **kwargs):
+            return 200, {"result": True}
+
+        res = JsonResource(self.homeserver)
+        res.register_paths(
+            "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet",
+        )
+
+        # The path was registered as GET, but this is a HEAD request.
+        request, channel = make_request(self.reactor, b"HEAD", b"/_matrix/foo")
+        render(request, res, self.reactor)
+
+        self.assertEqual(channel.result["code"], b"200")
+        self.assertNotIn("body", channel.result)
+        self.assertEqual(channel.headers.getRawHeaders(b"Content-Length"), [b"15"])
+
 
 class OptionsResourceTests(unittest.TestCase):
     def setUp(self):
@@ -255,7 +278,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
         self.reactor = ThreadedMemoryReactorClock()
 
     def test_good_response(self):
-        def callback(request):
+        async def callback(request):
             request.write(b"response")
             request.finish()
 
@@ -275,7 +298,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
         with the right location.
         """
 
-        def callback(request, **kwargs):
+        async def callback(request, **kwargs):
             raise RedirectException(b"/look/an/eagle", 301)
 
         res = WrapHtmlRequestHandlerTests.TestResource()
@@ -295,7 +318,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
         returned too
         """
 
-        def callback(request, **kwargs):
+        async def callback(request, **kwargs):
             e = RedirectException(b"/no/over/there", 304)
             e.cookies.append(b"session=yespls")
             raise e
@@ -312,3 +335,19 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
         self.assertEqual(location_headers, [b"/no/over/there"])
         cookies_headers = [v for k, v in headers if k == b"Set-Cookie"]
         self.assertEqual(cookies_headers, [b"session=yespls"])
+
+    def test_head_request(self):
+        """A head request should work by being turned into a GET request."""
+
+        async def callback(request):
+            request.write(b"response")
+            request.finish()
+
+        res = WrapHtmlRequestHandlerTests.TestResource()
+        res.callback = callback
+
+        request, channel = make_request(self.reactor, b"HEAD", b"/path")
+        render(request, res, self.reactor)
+
+        self.assertEqual(channel.result["code"], b"200")
+        self.assertNotIn("body", channel.result)
diff --git a/tests/test_state.py b/tests/test_state.py
index 4858e8fc59..b5c3667d2a 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -213,7 +213,7 @@ class StateTestCase(unittest.TestCase):
         ctx_c = context_store["C"]
         ctx_d = context_store["D"]
 
-        prev_state_ids = yield ctx_d.get_prev_state_ids()
+        prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
         self.assertEqual(2, len(prev_state_ids))
 
         self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
@@ -259,7 +259,7 @@ class StateTestCase(unittest.TestCase):
         ctx_c = context_store["C"]
         ctx_d = context_store["D"]
 
-        prev_state_ids = yield ctx_d.get_prev_state_ids()
+        prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
         self.assertSetEqual({"START", "A", "C"}, set(prev_state_ids.values()))
 
         self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
@@ -318,7 +318,7 @@ class StateTestCase(unittest.TestCase):
         ctx_c = context_store["C"]
         ctx_e = context_store["E"]
 
-        prev_state_ids = yield ctx_e.get_prev_state_ids()
+        prev_state_ids = yield defer.ensureDeferred(ctx_e.get_prev_state_ids())
         self.assertSetEqual({"START", "A", "B", "C"}, set(prev_state_ids.values()))
         self.assertEqual(ctx_c.state_group, ctx_e.state_group_before_event)
         self.assertEqual(ctx_e.state_group_before_event, ctx_e.state_group)
@@ -393,7 +393,7 @@ class StateTestCase(unittest.TestCase):
         ctx_b = context_store["B"]
         ctx_d = context_store["D"]
 
-        prev_state_ids = yield ctx_d.get_prev_state_ids()
+        prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
         self.assertSetEqual({"A1", "A2", "A3", "A5", "B"}, set(prev_state_ids.values()))
 
         self.assertEqual(ctx_b.state_group, ctx_d.state_group_before_event)
@@ -425,7 +425,7 @@ class StateTestCase(unittest.TestCase):
             self.state.compute_event_context(event, old_state=old_state)
         )
 
-        prev_state_ids = yield context.get_prev_state_ids()
+        prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
         self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
 
         current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
@@ -450,7 +450,7 @@ class StateTestCase(unittest.TestCase):
             self.state.compute_event_context(event, old_state=old_state)
         )
 
-        prev_state_ids = yield context.get_prev_state_ids()
+        prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
         self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
 
         current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
@@ -519,7 +519,7 @@ class StateTestCase(unittest.TestCase):
 
         context = yield defer.ensureDeferred(self.state.compute_event_context(event))
 
-        prev_state_ids = yield context.get_prev_state_ids()
+        prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
 
         self.assertEqual({e.event_id for e in old_state}, set(prev_state_ids.values()))
 
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index b371efc0df..531a9b9118 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -40,7 +40,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
         self.store = self.hs.get_datastore()
         self.storage = self.hs.get_storage()
 
-        yield create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM")
+        yield defer.ensureDeferred(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM"))
 
     @defer.inlineCallbacks
     def test_filtering(self):
@@ -64,8 +64,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
             evt = yield self.inject_room_member(user, extra_content={"a": "b"})
             events_to_filter.append(evt)
 
-        filtered = yield filter_events_for_server(
-            self.storage, "test_server", events_to_filter
+        filtered = yield defer.ensureDeferred(
+            filter_events_for_server(self.storage, "test_server", events_to_filter)
         )
 
         # the result should be 5 redacted events, and 5 unredacted events.
@@ -102,8 +102,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
         yield self.hs.get_datastore().mark_user_erased("@erased:local_hs")
 
         # ... and the filtering happens.
-        filtered = yield filter_events_for_server(
-            self.storage, "test_server", events_to_filter
+        filtered = yield defer.ensureDeferred(
+            filter_events_for_server(self.storage, "test_server", events_to_filter)
         )
 
         for i in range(0, len(events_to_filter)):
@@ -140,7 +140,9 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
         event, context = yield defer.ensureDeferred(
             self.event_creation_handler.create_new_client_event(builder)
         )
-        yield self.storage.persistence.persist_event(event, context)
+        yield defer.ensureDeferred(
+            self.storage.persistence.persist_event(event, context)
+        )
         return event
 
     @defer.inlineCallbacks
@@ -162,7 +164,9 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
             self.event_creation_handler.create_new_client_event(builder)
         )
 
-        yield self.storage.persistence.persist_event(event, context)
+        yield defer.ensureDeferred(
+            self.storage.persistence.persist_event(event, context)
+        )
         return event
 
     @defer.inlineCallbacks
@@ -183,7 +187,9 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
             self.event_creation_handler.create_new_client_event(builder)
         )
 
-        yield self.storage.persistence.persist_event(event, context)
+        yield defer.ensureDeferred(
+            self.storage.persistence.persist_event(event, context)
+        )
         return event
 
     @defer.inlineCallbacks
@@ -265,8 +271,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
         storage.main = test_store
         storage.state = test_store
 
-        filtered = yield filter_events_for_server(
-            test_store, "test_server", events_to_filter
+        filtered = yield defer.ensureDeferred(
+            filter_events_for_server(test_store, "test_server", events_to_filter)
         )
         logger.info("Filtering took %f seconds", time.time() - start)
 
diff --git a/tests/unittest.py b/tests/unittest.py
index 68d2586efd..2152c693f2 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -422,8 +422,8 @@ class HomeserverTestCase(TestCase):
 
         async def run_bg_updates():
             with LoggingContext("run_bg_updates", request="run_bg_updates-1"):
-                while not await stor.db.updates.has_completed_background_updates():
-                    await stor.db.updates.do_next_background_update(1)
+                while not await stor.db_pool.updates.has_completed_background_updates():
+                    await stor.db_pool.updates.do_next_background_update(1)
 
         hs = setup_test_homeserver(self.addCleanup, *args, **kwargs)
         stor = hs.get_datastore()
@@ -571,7 +571,7 @@ class HomeserverTestCase(TestCase):
         Add the given event as an extremity to the room.
         """
         self.get_success(
-            self.hs.get_datastore().db.simple_insert(
+            self.hs.get_datastore().db_pool.simple_insert(
                 table="event_forward_extremities",
                 values={"room_id": room_id, "event_id": event_id},
                 desc="test_add_extremity",
diff --git a/tests/utils.py b/tests/utils.py
index ac643679aa..a61cbdef44 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -154,6 +154,10 @@ def default_config(name, parse=False):
             "account": {"per_second": 10000, "burst_count": 10000},
             "failed_attempts": {"per_second": 10000, "burst_count": 10000},
         },
+        "rc_joins": {
+            "local": {"per_second": 10000, "burst_count": 10000},
+            "remote": {"per_second": 10000, "burst_count": 10000},
+        },
         "saml2_enabled": False,
         "public_baseurl": None,
         "default_identity_server": None,
@@ -638,14 +642,8 @@ class DeferredMockCallable(object):
             )
 
 
-@defer.inlineCallbacks
-def create_room(hs, room_id, creator_id):
+async def create_room(hs, room_id: str, creator_id: str):
     """Creates and persist a creation event for the given room
-
-    Args:
-        hs
-        room_id (str)
-        creator_id (str)
     """
 
     persistence_store = hs.get_storage().persistence
@@ -653,7 +651,7 @@ def create_room(hs, room_id, creator_id):
     event_builder_factory = hs.get_event_builder_factory()
     event_creation_handler = hs.get_event_creation_handler()
 
-    yield store.store_room(
+    await store.store_room(
         room_id=room_id,
         room_creator_user_id=creator_id,
         is_public=False,
@@ -671,8 +669,6 @@ def create_room(hs, room_id, creator_id):
         },
     )
 
-    event, context = yield defer.ensureDeferred(
-        event_creation_handler.create_new_client_event(builder)
-    )
+    event, context = await event_creation_handler.create_new_client_event(builder)
 
-    yield persistence_store.persist_event(event, context)
+    await persistence_store.persist_event(event, context)