summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/api/test_auth.py71
-rw-r--r--tests/api/test_filtering.py72
-rw-r--r--tests/api/test_ratelimiting.py73
-rw-r--r--tests/app/test_frontend_proxy.py26
-rw-r--r--tests/app/test_openid_listener.py8
-rw-r--r--tests/appservice/test_appservice.py89
-rw-r--r--tests/appservice/test_scheduler.py19
-rw-r--r--tests/config/test_base.py82
-rw-r--r--tests/crypto/test_keyring.py58
-rw-r--r--tests/events/test_snapshot.py36
-rw-r--r--tests/federation/test_complexity.py134
-rw-r--r--tests/federation/test_federation_sender.py35
-rw-r--r--tests/federation/test_federation_server.py33
-rw-r--r--tests/federation/transport/test_server.py2
-rw-r--r--tests/handlers/test_appservice.py69
-rw-r--r--tests/handlers/test_auth.py23
-rw-r--r--tests/handlers/test_device.py13
-rw-r--r--tests/handlers/test_directory.py5
-rw-r--r--tests/handlers/test_e2e_keys.py316
-rw-r--r--tests/handlers/test_e2e_room_keys.py374
-rw-r--r--tests/handlers/test_oidc.py59
-rw-r--r--tests/handlers/test_presence.py2
-rw-r--r--tests/handlers/test_profile.py82
-rw-r--r--tests/handlers/test_register.py287
-rw-r--r--tests/handlers/test_stats.py112
-rw-r--r--tests/handlers/test_typing.py60
-rw-r--r--tests/handlers/test_user_directory.py119
-rw-r--r--tests/http/__init__.py2
-rw-r--r--tests/http/federation/test_matrix_federation_agent.py87
-rw-r--r--tests/http/federation/test_srv_resolver.py26
-rw-r--r--tests/http/test_additional_resource.py62
-rw-r--r--tests/http/test_fedclient.py100
-rw-r--r--tests/http/test_servlet.py80
-rw-r--r--tests/logging/test_structured.py4
-rw-r--r--tests/module_api/test_api.py2
-rw-r--r--tests/push/test_email.py4
-rw-r--r--tests/push/test_http.py352
-rw-r--r--tests/push/test_push_rule_evaluator.py56
-rw-r--r--tests/replication/_base.py170
-rw-r--r--tests/replication/slave/storage/test_events.py14
-rw-r--r--tests/replication/tcp/streams/test_events.py150
-rw-r--r--tests/replication/tcp/streams/test_typing.py88
-rw-r--r--tests/replication/test_client_reader_shard.py96
-rw-r--r--tests/replication/test_federation_ack.py1
-rw-r--r--tests/replication/test_federation_sender_shard.py234
-rw-r--r--tests/replication/test_pusher_shard.py193
-rw-r--r--tests/rest/admin/test_admin.py141
-rw-r--r--tests/rest/admin/test_room.py2507
-rw-r--r--tests/rest/admin/test_user.py56
-rw-r--r--tests/rest/client/test_retention.py100
-rw-r--r--tests/rest/client/test_shadow_banned.py312
-rw-r--r--tests/rest/client/third_party_rules.py2
-rw-r--r--tests/rest/client/v1/test_login.py159
-rw-r--r--tests/rest/client/v1/test_presence.py2
-rw-r--r--tests/rest/client/v1/test_profile.py4
-rw-r--r--tests/rest/client/v1/test_rooms.py96
-rw-r--r--tests/rest/client/v1/test_typing.py6
-rw-r--r--tests/rest/client/v1/utils.py36
-rw-r--r--tests/rest/client/v2_alpha/test_account.py175
-rw-r--r--tests/rest/client/v2_alpha/test_filter.py8
-rw-r--r--tests/rest/client/v2_alpha/test_register.py6
-rw-r--r--tests/rest/client/v2_alpha/test_relations.py11
-rw-r--r--tests/rest/client/v2_alpha/test_shared_rooms.py138
-rw-r--r--tests/rest/client/v2_alpha/test_sync.py157
-rw-r--r--tests/rest/key/v2/test_remote_key_resource.py4
-rw-r--r--tests/rest/media/v1/test_media_storage.py13
-rw-r--r--tests/rest/media/v1/test_url_preview.py148
-rw-r--r--tests/rest/test_health.py34
-rw-r--r--tests/server.py36
-rw-r--r--tests/server_notices/test_resource_limits_server_notices.py23
-rw-r--r--tests/state/test_v2.py32
-rw-r--r--tests/storage/test__base.py34
-rw-r--r--tests/storage/test_appservice.py71
-rw-r--r--tests/storage/test_background_update.py20
-rw-r--r--tests/storage/test_base.py87
-rw-r--r--tests/storage/test_cleanup_extrems.py21
-rw-r--r--tests/storage/test_client_ips.py31
-rw-r--r--tests/storage/test_devices.py52
-rw-r--r--tests/storage/test_directory.py38
-rw-r--r--tests/storage/test_end_to_end_keys.py62
-rw-r--r--tests/storage/test_event_federation.py16
-rw-r--r--tests/storage/test_event_metrics.py2
-rw-r--r--tests/storage/test_event_push_actions.py138
-rw-r--r--tests/storage/test_id_generators.py217
-rw-r--r--tests/storage/test_main.py14
-rw-r--r--tests/storage/test_monthly_active_users.py31
-rw-r--r--tests/storage/test_profile.py27
-rw-r--r--tests/storage/test_purge.py51
-rw-r--r--tests/storage/test_redaction.py12
-rw-r--r--tests/storage/test_registration.py95
-rw-r--r--tests/storage/test_room.py62
-rw-r--r--tests/storage/test_roommember.py76
-rw-r--r--tests/storage/test_state.py80
-rw-r--r--tests/storage/test_user_directory.py20
-rw-r--r--tests/test_federation.py107
-rw-r--r--tests/test_mau.py2
-rw-r--r--tests/test_server.py127
-rw-r--r--tests/test_state.py134
-rw-r--r--tests/test_terms_auth.py9
-rw-r--r--tests/test_utils/__init__.py7
-rw-r--r--tests/test_utils/event_injection.py37
-rw-r--r--tests/test_visibility.py47
-rw-r--r--tests/unittest.py44
-rw-r--r--tests/util/caches/test_descriptors.py32
-rw-r--r--tests/util/test_file_consumer.py6
-rw-r--r--tests/util/test_linearizer.py2
-rw-r--r--tests/util/test_logcontext.py4
-rw-r--r--tests/util/test_retryutils.py46
-rw-r--r--tests/util/test_rwlock.py6
-rw-r--r--tests/util/test_stringutils.py3
-rw-r--r--tests/util/test_threepids.py49
-rw-r--r--tests/utils.py27
112 files changed, 7410 insertions, 2722 deletions
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 0bfb86bf1f..8ab56ec94c 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -36,7 +36,7 @@ from tests import unittest
 from tests.utils import mock_getRawHeaders, setup_test_homeserver
 
 
-class TestHandlers(object):
+class TestHandlers:
     def __init__(self, hs):
         self.auth_handler = synapse.handlers.auth.AuthHandler(hs)
 
@@ -62,12 +62,15 @@ class AuthTestCase(unittest.TestCase):
         # this is overridden for the appservice tests
         self.store.get_app_service_by_token = Mock(return_value=None)
 
+        self.store.insert_client_ip = Mock(return_value=defer.succeed(None))
         self.store.is_support_user = Mock(return_value=defer.succeed(False))
 
     @defer.inlineCallbacks
     def test_get_user_by_req_user_valid_token(self):
         user_info = {"name": self.test_user, "token_id": "ditto", "device_id": "device"}
-        self.store.get_user_by_access_token = Mock(return_value=user_info)
+        self.store.get_user_by_access_token = Mock(
+            return_value=defer.succeed(user_info)
+        )
 
         request = Mock(args={})
         request.args[b"access_token"] = [self.test_token]
@@ -76,23 +79,25 @@ class AuthTestCase(unittest.TestCase):
         self.assertEquals(requester.user.to_string(), self.test_user)
 
     def test_get_user_by_req_user_bad_token(self):
-        self.store.get_user_by_access_token = Mock(return_value=None)
+        self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
 
         request = Mock(args={})
         request.args[b"access_token"] = [self.test_token]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
-        d = self.auth.get_user_by_req(request)
+        d = defer.ensureDeferred(self.auth.get_user_by_req(request))
         f = self.failureResultOf(d, InvalidClientTokenError).value
         self.assertEqual(f.code, 401)
         self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
 
     def test_get_user_by_req_user_missing_token(self):
         user_info = {"name": self.test_user, "token_id": "ditto"}
-        self.store.get_user_by_access_token = Mock(return_value=user_info)
+        self.store.get_user_by_access_token = Mock(
+            return_value=defer.succeed(user_info)
+        )
 
         request = Mock(args={})
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
-        d = self.auth.get_user_by_req(request)
+        d = defer.ensureDeferred(self.auth.get_user_by_req(request))
         f = self.failureResultOf(d, MissingClientTokenError).value
         self.assertEqual(f.code, 401)
         self.assertEqual(f.errcode, "M_MISSING_TOKEN")
@@ -103,7 +108,7 @@ class AuthTestCase(unittest.TestCase):
             token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
         )
         self.store.get_app_service_by_token = Mock(return_value=app_service)
-        self.store.get_user_by_access_token = Mock(return_value=None)
+        self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
 
         request = Mock(args={})
         request.getClientIP.return_value = "127.0.0.1"
@@ -123,7 +128,7 @@ class AuthTestCase(unittest.TestCase):
             ip_range_whitelist=IPSet(["192.168/16"]),
         )
         self.store.get_app_service_by_token = Mock(return_value=app_service)
-        self.store.get_user_by_access_token = Mock(return_value=None)
+        self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
 
         request = Mock(args={})
         request.getClientIP.return_value = "192.168.10.10"
@@ -142,25 +147,25 @@ class AuthTestCase(unittest.TestCase):
             ip_range_whitelist=IPSet(["192.168/16"]),
         )
         self.store.get_app_service_by_token = Mock(return_value=app_service)
-        self.store.get_user_by_access_token = Mock(return_value=None)
+        self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
 
         request = Mock(args={})
         request.getClientIP.return_value = "131.111.8.42"
         request.args[b"access_token"] = [self.test_token]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
-        d = self.auth.get_user_by_req(request)
+        d = defer.ensureDeferred(self.auth.get_user_by_req(request))
         f = self.failureResultOf(d, InvalidClientTokenError).value
         self.assertEqual(f.code, 401)
         self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
 
     def test_get_user_by_req_appservice_bad_token(self):
         self.store.get_app_service_by_token = Mock(return_value=None)
-        self.store.get_user_by_access_token = Mock(return_value=None)
+        self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
 
         request = Mock(args={})
         request.args[b"access_token"] = [self.test_token]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
-        d = self.auth.get_user_by_req(request)
+        d = defer.ensureDeferred(self.auth.get_user_by_req(request))
         f = self.failureResultOf(d, InvalidClientTokenError).value
         self.assertEqual(f.code, 401)
         self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
@@ -168,11 +173,11 @@ class AuthTestCase(unittest.TestCase):
     def test_get_user_by_req_appservice_missing_token(self):
         app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
         self.store.get_app_service_by_token = Mock(return_value=app_service)
-        self.store.get_user_by_access_token = Mock(return_value=None)
+        self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
 
         request = Mock(args={})
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
-        d = self.auth.get_user_by_req(request)
+        d = defer.ensureDeferred(self.auth.get_user_by_req(request))
         f = self.failureResultOf(d, MissingClientTokenError).value
         self.assertEqual(f.code, 401)
         self.assertEqual(f.errcode, "M_MISSING_TOKEN")
@@ -185,7 +190,11 @@ class AuthTestCase(unittest.TestCase):
         )
         app_service.is_interested_in_user = Mock(return_value=True)
         self.store.get_app_service_by_token = Mock(return_value=app_service)
-        self.store.get_user_by_access_token = Mock(return_value=None)
+        # This just needs to return a truth-y value.
+        self.store.get_user_by_id = Mock(
+            return_value=defer.succeed({"is_guest": False})
+        )
+        self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
 
         request = Mock(args={})
         request.getClientIP.return_value = "127.0.0.1"
@@ -204,20 +213,22 @@ class AuthTestCase(unittest.TestCase):
         )
         app_service.is_interested_in_user = Mock(return_value=False)
         self.store.get_app_service_by_token = Mock(return_value=app_service)
-        self.store.get_user_by_access_token = Mock(return_value=None)
+        self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
 
         request = Mock(args={})
         request.getClientIP.return_value = "127.0.0.1"
         request.args[b"access_token"] = [self.test_token]
         request.args[b"user_id"] = [masquerading_user_id]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
-        d = self.auth.get_user_by_req(request)
+        d = defer.ensureDeferred(self.auth.get_user_by_req(request))
         self.failureResultOf(d, AuthError)
 
     @defer.inlineCallbacks
     def test_get_user_from_macaroon(self):
         self.store.get_user_by_access_token = Mock(
-            return_value={"name": "@baldrick:matrix.org", "device_id": "device"}
+            return_value=defer.succeed(
+                {"name": "@baldrick:matrix.org", "device_id": "device"}
+            )
         )
 
         user_id = "@baldrick:matrix.org"
@@ -241,8 +252,8 @@ class AuthTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_get_guest_user_from_macaroon(self):
-        self.store.get_user_by_id = Mock(return_value={"is_guest": True})
-        self.store.get_user_by_access_token = Mock(return_value=None)
+        self.store.get_user_by_id = Mock(return_value=defer.succeed({"is_guest": True}))
+        self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
 
         user_id = "@baldrick:matrix.org"
         macaroon = pymacaroons.Macaroon(
@@ -282,16 +293,20 @@ class AuthTestCase(unittest.TestCase):
 
         def get_user(tok):
             if token != tok:
-                return None
-            return {
-                "name": USER_ID,
-                "is_guest": False,
-                "token_id": 1234,
-                "device_id": "DEVICE",
-            }
+                return defer.succeed(None)
+            return defer.succeed(
+                {
+                    "name": USER_ID,
+                    "is_guest": False,
+                    "token_id": 1234,
+                    "device_id": "DEVICE",
+                }
+            )
 
         self.store.get_user_by_access_token = get_user
-        self.store.get_user_by_id = Mock(return_value={"is_guest": False})
+        self.store.get_user_by_id = Mock(
+            return_value=defer.succeed({"is_guest": False})
+        )
 
         # check the token works
         request = Mock(args={})
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index 4e67503cf0..d2d535d23c 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -369,14 +369,18 @@ class FilteringTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_filter_presence_match(self):
         user_filter_json = {"presence": {"types": ["m.*"]}}
-        filter_id = yield self.datastore.add_user_filter(
-            user_localpart=user_localpart, user_filter=user_filter_json
+        filter_id = yield defer.ensureDeferred(
+            self.datastore.add_user_filter(
+                user_localpart=user_localpart, user_filter=user_filter_json
+            )
         )
         event = MockEvent(sender="@foo:bar", type="m.profile")
         events = [event]
 
-        user_filter = yield self.filtering.get_user_filter(
-            user_localpart=user_localpart, filter_id=filter_id
+        user_filter = yield defer.ensureDeferred(
+            self.filtering.get_user_filter(
+                user_localpart=user_localpart, filter_id=filter_id
+            )
         )
 
         results = user_filter.filter_presence(events=events)
@@ -386,8 +390,10 @@ class FilteringTestCase(unittest.TestCase):
     def test_filter_presence_no_match(self):
         user_filter_json = {"presence": {"types": ["m.*"]}}
 
-        filter_id = yield self.datastore.add_user_filter(
-            user_localpart=user_localpart + "2", user_filter=user_filter_json
+        filter_id = yield defer.ensureDeferred(
+            self.datastore.add_user_filter(
+                user_localpart=user_localpart + "2", user_filter=user_filter_json
+            )
         )
         event = MockEvent(
             event_id="$asdasd:localhost",
@@ -396,8 +402,10 @@ class FilteringTestCase(unittest.TestCase):
         )
         events = [event]
 
-        user_filter = yield self.filtering.get_user_filter(
-            user_localpart=user_localpart + "2", filter_id=filter_id
+        user_filter = yield defer.ensureDeferred(
+            self.filtering.get_user_filter(
+                user_localpart=user_localpart + "2", filter_id=filter_id
+            )
         )
 
         results = user_filter.filter_presence(events=events)
@@ -406,14 +414,18 @@ class FilteringTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_filter_room_state_match(self):
         user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
-        filter_id = yield self.datastore.add_user_filter(
-            user_localpart=user_localpart, user_filter=user_filter_json
+        filter_id = yield defer.ensureDeferred(
+            self.datastore.add_user_filter(
+                user_localpart=user_localpart, user_filter=user_filter_json
+            )
         )
         event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar")
         events = [event]
 
-        user_filter = yield self.filtering.get_user_filter(
-            user_localpart=user_localpart, filter_id=filter_id
+        user_filter = yield defer.ensureDeferred(
+            self.filtering.get_user_filter(
+                user_localpart=user_localpart, filter_id=filter_id
+            )
         )
 
         results = user_filter.filter_room_state(events=events)
@@ -422,16 +434,20 @@ class FilteringTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_filter_room_state_no_match(self):
         user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
-        filter_id = yield self.datastore.add_user_filter(
-            user_localpart=user_localpart, user_filter=user_filter_json
+        filter_id = yield defer.ensureDeferred(
+            self.datastore.add_user_filter(
+                user_localpart=user_localpart, user_filter=user_filter_json
+            )
         )
         event = MockEvent(
             sender="@foo:bar", type="org.matrix.custom.event", room_id="!foo:bar"
         )
         events = [event]
 
-        user_filter = yield self.filtering.get_user_filter(
-            user_localpart=user_localpart, filter_id=filter_id
+        user_filter = yield defer.ensureDeferred(
+            self.filtering.get_user_filter(
+                user_localpart=user_localpart, filter_id=filter_id
+            )
         )
 
         results = user_filter.filter_room_state(events)
@@ -457,16 +473,20 @@ class FilteringTestCase(unittest.TestCase):
     def test_add_filter(self):
         user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
 
-        filter_id = yield self.filtering.add_user_filter(
-            user_localpart=user_localpart, user_filter=user_filter_json
+        filter_id = yield defer.ensureDeferred(
+            self.filtering.add_user_filter(
+                user_localpart=user_localpart, user_filter=user_filter_json
+            )
         )
 
         self.assertEquals(filter_id, 0)
         self.assertEquals(
             user_filter_json,
             (
-                yield self.datastore.get_user_filter(
-                    user_localpart=user_localpart, filter_id=0
+                yield defer.ensureDeferred(
+                    self.datastore.get_user_filter(
+                        user_localpart=user_localpart, filter_id=0
+                    )
                 )
             ),
         )
@@ -475,12 +495,16 @@ class FilteringTestCase(unittest.TestCase):
     def test_get_filter(self):
         user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
 
-        filter_id = yield self.datastore.add_user_filter(
-            user_localpart=user_localpart, user_filter=user_filter_json
+        filter_id = yield defer.ensureDeferred(
+            self.datastore.add_user_filter(
+                user_localpart=user_localpart, user_filter=user_filter_json
+            )
         )
 
-        filter = yield self.filtering.get_user_filter(
-            user_localpart=user_localpart, filter_id=filter_id
+        filter = yield defer.ensureDeferred(
+            self.filtering.get_user_filter(
+                user_localpart=user_localpart, filter_id=filter_id
+            )
         )
 
         self.assertEquals(filter.get_filter_json(), user_filter_json)
diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py
index d580e729c5..1e1f30d790 100644
--- a/tests/api/test_ratelimiting.py
+++ b/tests/api/test_ratelimiting.py
@@ -1,4 +1,6 @@
 from synapse.api.ratelimiting import LimitExceededError, Ratelimiter
+from synapse.appservice import ApplicationService
+from synapse.types import create_requester
 
 from tests import unittest
 
@@ -18,6 +20,77 @@ class TestRatelimiter(unittest.TestCase):
         self.assertTrue(allowed)
         self.assertEquals(20.0, time_allowed)
 
+    def test_allowed_user_via_can_requester_do_action(self):
+        user_requester = create_requester("@user:example.com")
+        limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
+        allowed, time_allowed = limiter.can_requester_do_action(
+            user_requester, _time_now_s=0
+        )
+        self.assertTrue(allowed)
+        self.assertEquals(10.0, time_allowed)
+
+        allowed, time_allowed = limiter.can_requester_do_action(
+            user_requester, _time_now_s=5
+        )
+        self.assertFalse(allowed)
+        self.assertEquals(10.0, time_allowed)
+
+        allowed, time_allowed = limiter.can_requester_do_action(
+            user_requester, _time_now_s=10
+        )
+        self.assertTrue(allowed)
+        self.assertEquals(20.0, time_allowed)
+
+    def test_allowed_appservice_ratelimited_via_can_requester_do_action(self):
+        appservice = ApplicationService(
+            None, "example.com", id="foo", rate_limited=True,
+        )
+        as_requester = create_requester("@user:example.com", app_service=appservice)
+
+        limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
+        allowed, time_allowed = limiter.can_requester_do_action(
+            as_requester, _time_now_s=0
+        )
+        self.assertTrue(allowed)
+        self.assertEquals(10.0, time_allowed)
+
+        allowed, time_allowed = limiter.can_requester_do_action(
+            as_requester, _time_now_s=5
+        )
+        self.assertFalse(allowed)
+        self.assertEquals(10.0, time_allowed)
+
+        allowed, time_allowed = limiter.can_requester_do_action(
+            as_requester, _time_now_s=10
+        )
+        self.assertTrue(allowed)
+        self.assertEquals(20.0, time_allowed)
+
+    def test_allowed_appservice_via_can_requester_do_action(self):
+        appservice = ApplicationService(
+            None, "example.com", id="foo", rate_limited=False,
+        )
+        as_requester = create_requester("@user:example.com", app_service=appservice)
+
+        limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
+        allowed, time_allowed = limiter.can_requester_do_action(
+            as_requester, _time_now_s=0
+        )
+        self.assertTrue(allowed)
+        self.assertEquals(-1, time_allowed)
+
+        allowed, time_allowed = limiter.can_requester_do_action(
+            as_requester, _time_now_s=5
+        )
+        self.assertTrue(allowed)
+        self.assertEquals(-1, time_allowed)
+
+        allowed, time_allowed = limiter.can_requester_do_action(
+            as_requester, _time_now_s=10
+        )
+        self.assertTrue(allowed)
+        self.assertEquals(-1, time_allowed)
+
     def test_allowed_via_ratelimit(self):
         limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
 
diff --git a/tests/app/test_frontend_proxy.py b/tests/app/test_frontend_proxy.py
index be20a89682..641093d349 100644
--- a/tests/app/test_frontend_proxy.py
+++ b/tests/app/test_frontend_proxy.py
@@ -30,6 +30,16 @@ class FrontendProxyTests(HomeserverTestCase):
     def default_config(self):
         c = super().default_config()
         c["worker_app"] = "synapse.app.frontend_proxy"
+
+        c["worker_listeners"] = [
+            {
+                "type": "http",
+                "port": 8080,
+                "bind_addresses": ["0.0.0.0"],
+                "resources": [{"names": ["client"]}],
+            }
+        ]
+
         return c
 
     def test_listen_http_with_presence_enabled(self):
@@ -39,14 +49,8 @@ class FrontendProxyTests(HomeserverTestCase):
         # Presence is on
         self.hs.config.use_presence = True
 
-        config = {
-            "port": 8080,
-            "bind_addresses": ["0.0.0.0"],
-            "resources": [{"names": ["client"]}],
-        }
-
         # Listen with the config
-        self.hs._listen_http(config)
+        self.hs._listen_http(self.hs.config.worker.worker_listeners[0])
 
         # Grab the resource from the site that was told to listen
         self.assertEqual(len(self.reactor.tcpServers), 1)
@@ -67,14 +71,8 @@ class FrontendProxyTests(HomeserverTestCase):
         # Presence is off
         self.hs.config.use_presence = False
 
-        config = {
-            "port": 8080,
-            "bind_addresses": ["0.0.0.0"],
-            "resources": [{"names": ["client"]}],
-        }
-
         # Listen with the config
-        self.hs._listen_http(config)
+        self.hs._listen_http(self.hs.config.worker.worker_listeners[0])
 
         # Grab the resource from the site that was told to listen
         self.assertEqual(len(self.reactor.tcpServers), 1)
diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py
index 7364f9f1ec..0f016c32eb 100644
--- a/tests/app/test_openid_listener.py
+++ b/tests/app/test_openid_listener.py
@@ -18,6 +18,7 @@ from parameterized import parameterized
 
 from synapse.app.generic_worker import GenericWorkerServer
 from synapse.app.homeserver import SynapseHomeServer
+from synapse.config.server import parse_listener_def
 
 from tests.unittest import HomeserverTestCase
 
@@ -35,6 +36,7 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase):
         # have to tell the FederationHandler not to try to access stuff that is only
         # in the primary store.
         conf["worker_app"] = "yes"
+
         return conf
 
     @parameterized.expand(
@@ -53,12 +55,13 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase):
         """
         config = {
             "port": 8080,
+            "type": "http",
             "bind_addresses": ["0.0.0.0"],
             "resources": [{"names": names}],
         }
 
         # Listen with the config
-        self.hs._listen_http(config)
+        self.hs._listen_http(parse_listener_def(config))
 
         # Grab the resource from the site that was told to listen
         site = self.reactor.tcpServers[0][1]
@@ -101,12 +104,13 @@ class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase):
         """
         config = {
             "port": 8080,
+            "type": "http",
             "bind_addresses": ["0.0.0.0"],
             "resources": [{"names": names}],
         }
 
         # Listen with the config
-        self.hs._listener_http(config, config)
+        self.hs._listener_http(self.hs.get_config(), parse_listener_def(config))
 
         # Grab the resource from the site that was told to listen
         site = self.reactor.tcpServers[0][1]
diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py
index 4003869ed6..236b608d58 100644
--- a/tests/appservice/test_appservice.py
+++ b/tests/appservice/test_appservice.py
@@ -50,13 +50,17 @@ class ApplicationServiceTestCase(unittest.TestCase):
     def test_regex_user_id_prefix_match(self):
         self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
         self.event.sender = "@irc_foobar:matrix.org"
-        self.assertTrue((yield self.service.is_interested(self.event)))
+        self.assertTrue(
+            (yield defer.ensureDeferred(self.service.is_interested(self.event)))
+        )
 
     @defer.inlineCallbacks
     def test_regex_user_id_prefix_no_match(self):
         self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
         self.event.sender = "@someone_else:matrix.org"
-        self.assertFalse((yield self.service.is_interested(self.event)))
+        self.assertFalse(
+            (yield defer.ensureDeferred(self.service.is_interested(self.event)))
+        )
 
     @defer.inlineCallbacks
     def test_regex_room_member_is_checked(self):
@@ -64,7 +68,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
         self.event.sender = "@someone_else:matrix.org"
         self.event.type = "m.room.member"
         self.event.state_key = "@irc_foobar:matrix.org"
-        self.assertTrue((yield self.service.is_interested(self.event)))
+        self.assertTrue(
+            (yield defer.ensureDeferred(self.service.is_interested(self.event)))
+        )
 
     @defer.inlineCallbacks
     def test_regex_room_id_match(self):
@@ -72,7 +78,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
             _regex("!some_prefix.*some_suffix:matrix.org")
         )
         self.event.room_id = "!some_prefixs0m3th1nGsome_suffix:matrix.org"
-        self.assertTrue((yield self.service.is_interested(self.event)))
+        self.assertTrue(
+            (yield defer.ensureDeferred(self.service.is_interested(self.event)))
+        )
 
     @defer.inlineCallbacks
     def test_regex_room_id_no_match(self):
@@ -80,19 +88,26 @@ class ApplicationServiceTestCase(unittest.TestCase):
             _regex("!some_prefix.*some_suffix:matrix.org")
         )
         self.event.room_id = "!XqBunHwQIXUiqCaoxq:matrix.org"
-        self.assertFalse((yield self.service.is_interested(self.event)))
+        self.assertFalse(
+            (yield defer.ensureDeferred(self.service.is_interested(self.event)))
+        )
 
     @defer.inlineCallbacks
     def test_regex_alias_match(self):
         self.service.namespaces[ApplicationService.NS_ALIASES].append(
             _regex("#irc_.*:matrix.org")
         )
-        self.store.get_aliases_for_room.return_value = [
-            "#irc_foobar:matrix.org",
-            "#athing:matrix.org",
-        ]
-        self.store.get_users_in_room.return_value = []
-        self.assertTrue((yield self.service.is_interested(self.event, self.store)))
+        self.store.get_aliases_for_room.return_value = defer.succeed(
+            ["#irc_foobar:matrix.org", "#athing:matrix.org"]
+        )
+        self.store.get_users_in_room.return_value = defer.succeed([])
+        self.assertTrue(
+            (
+                yield defer.ensureDeferred(
+                    self.service.is_interested(self.event, self.store)
+                )
+            )
+        )
 
     def test_non_exclusive_alias(self):
         self.service.namespaces[ApplicationService.NS_ALIASES].append(
@@ -135,12 +150,17 @@ class ApplicationServiceTestCase(unittest.TestCase):
         self.service.namespaces[ApplicationService.NS_ALIASES].append(
             _regex("#irc_.*:matrix.org")
         )
-        self.store.get_aliases_for_room.return_value = [
-            "#xmpp_foobar:matrix.org",
-            "#athing:matrix.org",
-        ]
-        self.store.get_users_in_room.return_value = []
-        self.assertFalse((yield self.service.is_interested(self.event, self.store)))
+        self.store.get_aliases_for_room.return_value = defer.succeed(
+            ["#xmpp_foobar:matrix.org", "#athing:matrix.org"]
+        )
+        self.store.get_users_in_room.return_value = defer.succeed([])
+        self.assertFalse(
+            (
+                yield defer.ensureDeferred(
+                    self.service.is_interested(self.event, self.store)
+                )
+            )
+        )
 
     @defer.inlineCallbacks
     def test_regex_multiple_matches(self):
@@ -149,9 +169,17 @@ class ApplicationServiceTestCase(unittest.TestCase):
         )
         self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
         self.event.sender = "@irc_foobar:matrix.org"
-        self.store.get_aliases_for_room.return_value = ["#irc_barfoo:matrix.org"]
-        self.store.get_users_in_room.return_value = []
-        self.assertTrue((yield self.service.is_interested(self.event, self.store)))
+        self.store.get_aliases_for_room.return_value = defer.succeed(
+            ["#irc_barfoo:matrix.org"]
+        )
+        self.store.get_users_in_room.return_value = defer.succeed([])
+        self.assertTrue(
+            (
+                yield defer.ensureDeferred(
+                    self.service.is_interested(self.event, self.store)
+                )
+            )
+        )
 
     @defer.inlineCallbacks
     def test_interested_in_self(self):
@@ -161,19 +189,24 @@ class ApplicationServiceTestCase(unittest.TestCase):
         self.event.type = "m.room.member"
         self.event.content = {"membership": "invite"}
         self.event.state_key = self.service.sender
-        self.assertTrue((yield self.service.is_interested(self.event)))
+        self.assertTrue(
+            (yield defer.ensureDeferred(self.service.is_interested(self.event)))
+        )
 
     @defer.inlineCallbacks
     def test_member_list_match(self):
         self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
-        self.store.get_users_in_room.return_value = [
-            "@alice:here",
-            "@irc_fo:here",  # AS user
-            "@bob:here",
-        ]
-        self.store.get_aliases_for_room.return_value = []
+        # Note that @irc_fo:here is the AS user.
+        self.store.get_users_in_room.return_value = defer.succeed(
+            ["@alice:here", "@irc_fo:here", "@bob:here"]
+        )
+        self.store.get_aliases_for_room.return_value = defer.succeed([])
 
         self.event.sender = "@xmpp_foobar:matrix.org"
         self.assertTrue(
-            (yield self.service.is_interested(event=self.event, store=self.store))
+            (
+                yield defer.ensureDeferred(
+                    self.service.is_interested(event=self.event, store=self.store)
+                )
+            )
         )
diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py
index 52f89d3f83..68a4caabbf 100644
--- a/tests/appservice/test_scheduler.py
+++ b/tests/appservice/test_scheduler.py
@@ -25,6 +25,7 @@ from synapse.appservice.scheduler import (
 from synapse.logging.context import make_deferred_yieldable
 
 from tests import unittest
+from tests.test_utils import make_awaitable
 
 from ..utils import MockClock
 
@@ -52,11 +53,11 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
         self.store.get_appservice_state = Mock(
             return_value=defer.succeed(ApplicationServiceState.UP)
         )
-        txn.send = Mock(return_value=defer.succeed(True))
+        txn.send = Mock(return_value=make_awaitable(True))
         self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn))
 
         # actual call
-        self.txnctrl.send(service, events)
+        self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
 
         self.store.create_appservice_txn.assert_called_once_with(
             service=service, events=events  # txn made and saved
@@ -77,7 +78,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
         self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn))
 
         # actual call
-        self.txnctrl.send(service, events)
+        self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
 
         self.store.create_appservice_txn.assert_called_once_with(
             service=service, events=events  # txn made and saved
@@ -98,11 +99,11 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
             return_value=defer.succeed(ApplicationServiceState.UP)
         )
         self.store.set_appservice_state = Mock(return_value=defer.succeed(True))
-        txn.send = Mock(return_value=defer.succeed(False))  # fails to send
+        txn.send = Mock(return_value=make_awaitable(False))  # fails to send
         self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn))
 
         # actual call
-        self.txnctrl.send(service, events)
+        self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
 
         self.store.create_appservice_txn.assert_called_once_with(
             service=service, events=events
@@ -144,7 +145,8 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
         self.recoverer.recover()
         # shouldn't have called anything prior to waiting for exp backoff
         self.assertEquals(0, self.store.get_oldest_unsent_txn.call_count)
-        txn.send = Mock(return_value=True)
+        txn.send = Mock(return_value=make_awaitable(True))
+        txn.complete.return_value = make_awaitable(None)
         # wait for exp backoff
         self.clock.advance_time(2)
         self.assertEquals(1, txn.send.call_count)
@@ -169,7 +171,8 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
 
         self.recoverer.recover()
         self.assertEquals(0, self.store.get_oldest_unsent_txn.call_count)
-        txn.send = Mock(return_value=False)
+        txn.send = Mock(return_value=make_awaitable(False))
+        txn.complete.return_value = make_awaitable(None)
         self.clock.advance_time(2)
         self.assertEquals(1, txn.send.call_count)
         self.assertEquals(0, txn.complete.call_count)
@@ -182,7 +185,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
         self.assertEquals(3, txn.send.call_count)
         self.assertEquals(0, txn.complete.call_count)
         self.assertEquals(0, self.callback.call_count)
-        txn.send = Mock(return_value=True)  # successfully send the txn
+        txn.send = Mock(return_value=make_awaitable(True))  # successfully send the txn
         pop_txn = True  # returns the txn the first time, then no more.
         self.clock.advance_time(16)
         self.assertEquals(1, txn.send.call_count)  # new mock reset call count
diff --git a/tests/config/test_base.py b/tests/config/test_base.py
new file mode 100644
index 0000000000..42ee5f56d9
--- /dev/null
+++ b/tests/config/test_base.py
@@ -0,0 +1,82 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os.path
+import tempfile
+
+from synapse.config import ConfigError
+from synapse.util.stringutils import random_string
+
+from tests import unittest
+
+
+class BaseConfigTestCase(unittest.HomeserverTestCase):
+    def prepare(self, reactor, clock, hs):
+        self.hs = hs
+
+    def test_loading_missing_templates(self):
+        # Use a temporary directory that exists on the system, but that isn't likely to
+        # contain template files
+        with tempfile.TemporaryDirectory() as tmp_dir:
+            # Attempt to load an HTML template from our custom template directory
+            template = self.hs.config.read_templates(["sso_error.html"], tmp_dir)[0]
+
+        # If no errors, we should've gotten the default template instead
+
+        # Render the template
+        a_random_string = random_string(5)
+        html_content = template.render({"error_description": a_random_string})
+
+        # Check that our string exists in the template
+        self.assertIn(
+            a_random_string,
+            html_content,
+            "Template file did not contain our test string",
+        )
+
+    def test_loading_custom_templates(self):
+        # Use a temporary directory that exists on the system
+        with tempfile.TemporaryDirectory() as tmp_dir:
+            # Create a temporary bogus template file
+            with tempfile.NamedTemporaryFile(dir=tmp_dir) as tmp_template:
+                # Get temporary file's filename
+                template_filename = os.path.basename(tmp_template.name)
+
+                # Write a custom HTML template
+                contents = b"{{ test_variable }}"
+                tmp_template.write(contents)
+                tmp_template.flush()
+
+                # Attempt to load the template from our custom template directory
+                template = (
+                    self.hs.config.read_templates([template_filename], tmp_dir)
+                )[0]
+
+        # Render the template
+        a_random_string = random_string(5)
+        html_content = template.render({"test_variable": a_random_string})
+
+        # Check that our string exists in the template
+        self.assertIn(
+            a_random_string,
+            html_content,
+            "Template file did not contain our test string",
+        )
+
+    def test_loading_template_from_nonexistent_custom_directory(self):
+        with self.assertRaises(ConfigError):
+            self.hs.config.read_templates(
+                ["some_filename.html"], "a_nonexistent_directory"
+            )
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 70c8e72303..2e6e7abf1f 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -40,9 +40,10 @@ from synapse.logging.context import (
 from synapse.storage.keys import FetchKeyResult
 
 from tests import unittest
+from tests.test_utils import make_awaitable
 
 
-class MockPerspectiveServer(object):
+class MockPerspectiveServer:
     def __init__(self):
         self.server_name = "mock_server"
         self.key = signedjson.key.generate_signing_key(0)
@@ -102,11 +103,10 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         }
         persp_deferred = defer.Deferred()
 
-        @defer.inlineCallbacks
-        def get_perspectives(**kwargs):
+        async def get_perspectives(**kwargs):
             self.assertEquals(current_context().request, "11")
             with PreserveLoggingContext():
-                yield persp_deferred
+                await persp_deferred
             return persp_resp
 
         self.http_client.post_json.side_effect = get_perspectives
@@ -190,9 +190,9 @@ class KeyringTestCase(unittest.HomeserverTestCase):
 
         # should fail immediately on an unsigned object
         d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned")
-        self.failureResultOf(d, SynapseError)
+        self.get_failure(d, SynapseError)
 
-        # should suceed on a signed object
+        # should succeed on a signed object
         d = _verify_json_for_server(kr, "server9", json1, 500, "test signed")
         # self.assertFalse(d.called)
         self.get_success(d)
@@ -202,7 +202,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         with a null `ts_valid_until_ms`
         """
         mock_fetcher = keyring.KeyFetcher()
-        mock_fetcher.get_keys = Mock(return_value=defer.succeed({}))
+        mock_fetcher.get_keys = Mock(return_value=make_awaitable({}))
 
         kr = keyring.Keyring(
             self.hs, key_fetchers=(StoreKeyFetcher(self.hs), mock_fetcher)
@@ -221,7 +221,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
 
         # should fail immediately on an unsigned object
         d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned")
-        self.failureResultOf(d, SynapseError)
+        self.get_failure(d, SynapseError)
 
         # should fail on a signed object with a non-zero minimum_valid_until_ms,
         # as it tries to refetch the keys and fails.
@@ -245,17 +245,15 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         """Two requests for the same key should be deduped."""
         key1 = signedjson.key.generate_signing_key(1)
 
-        def get_keys(keys_to_fetch):
+        async def get_keys(keys_to_fetch):
             # there should only be one request object (with the max validity)
             self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
 
-            return defer.succeed(
-                {
-                    "server1": {
-                        get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
-                    }
+            return {
+                "server1": {
+                    get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
                 }
-            )
+            }
 
         mock_fetcher = keyring.KeyFetcher()
         mock_fetcher.get_keys = Mock(side_effect=get_keys)
@@ -282,25 +280,19 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         """If the first fetcher cannot provide a recent enough key, we fall back"""
         key1 = signedjson.key.generate_signing_key(1)
 
-        def get_keys1(keys_to_fetch):
+        async def get_keys1(keys_to_fetch):
             self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
-            return defer.succeed(
-                {
-                    "server1": {
-                        get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800)
-                    }
-                }
-            )
+            return {
+                "server1": {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800)}
+            }
 
-        def get_keys2(keys_to_fetch):
+        async def get_keys2(keys_to_fetch):
             self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
-            return defer.succeed(
-                {
-                    "server1": {
-                        get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
-                    }
+            return {
+                "server1": {
+                    get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
                 }
-            )
+            }
 
         mock_fetcher1 = keyring.KeyFetcher()
         mock_fetcher1.get_keys = Mock(side_effect=get_keys1)
@@ -355,7 +347,7 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
         }
         signedjson.sign.sign_json(response, SERVER_NAME, testkey)
 
-        def get_json(destination, path, **kwargs):
+        async def get_json(destination, path, **kwargs):
             self.assertEqual(destination, SERVER_NAME)
             self.assertEqual(path, "/_matrix/key/v2/server/key1")
             return response
@@ -444,7 +436,7 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
         Tell the mock http client to expect a perspectives-server key query
         """
 
-        def post_json(destination, path, data, **kwargs):
+        async def post_json(destination, path, data, **kwargs):
             self.assertEqual(destination, self.mock_perspective_server.server_name)
             self.assertEqual(path, "/_matrix/key/v2/query")
 
@@ -580,14 +572,12 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
         # remove the perspectives server's signature
         response = build_response()
         del response["signatures"][self.mock_perspective_server.server_name]
-        self.http_client.post_json.return_value = {"server_keys": [response]}
         keys = get_key_from_perspectives(response)
         self.assertEqual(keys, {}, "Expected empty dict with missing persp server sig")
 
         # remove the origin server's signature
         response = build_response()
         del response["signatures"][SERVER_NAME]
-        self.http_client.post_json.return_value = {"server_keys": [response]}
         keys = get_key_from_perspectives(response)
         self.assertEqual(keys, {}, "Expected empty dict with missing origin server sig")
 
diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py
index 640f5f3bce..3a80626224 100644
--- a/tests/events/test_snapshot.py
+++ b/tests/events/test_snapshot.py
@@ -41,8 +41,10 @@ class TestEventContext(unittest.HomeserverTestCase):
         serialize/deserialize.
         """
 
-        event, context = create_event(
-            self.hs, room_id=self.room_id, type="m.test", sender=self.user_id,
+        event, context = self.get_success(
+            create_event(
+                self.hs, room_id=self.room_id, type="m.test", sender=self.user_id,
+            )
         )
 
         self._check_serialize_deserialize(event, context)
@@ -51,12 +53,14 @@ class TestEventContext(unittest.HomeserverTestCase):
         """Test that an EventContext for a state event (with not previous entry)
         is the same after serialize/deserialize.
         """
-        event, context = create_event(
-            self.hs,
-            room_id=self.room_id,
-            type="m.test",
-            sender=self.user_id,
-            state_key="",
+        event, context = self.get_success(
+            create_event(
+                self.hs,
+                room_id=self.room_id,
+                type="m.test",
+                sender=self.user_id,
+                state_key="",
+            )
         )
 
         self._check_serialize_deserialize(event, context)
@@ -65,13 +69,15 @@ class TestEventContext(unittest.HomeserverTestCase):
         """Test that an EventContext for a state event (which replaces a
         previous entry) is the same after serialize/deserialize.
         """
-        event, context = create_event(
-            self.hs,
-            room_id=self.room_id,
-            type="m.room.member",
-            sender=self.user_id,
-            state_key=self.user_id,
-            content={"membership": "leave"},
+        event, context = self.get_success(
+            create_event(
+                self.hs,
+                room_id=self.room_id,
+                type="m.room.member",
+                sender=self.user_id,
+                state_key=self.user_id,
+                content={"membership": "leave"},
+            )
         )
 
         self._check_serialize_deserialize(event, context)
diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py
index 0c9987be54..3d880c499d 100644
--- a/tests/federation/test_complexity.py
+++ b/tests/federation/test_complexity.py
@@ -15,14 +15,13 @@
 
 from mock import Mock
 
-from twisted.internet import defer
-
 from synapse.api.errors import Codes, SynapseError
 from synapse.rest import admin
 from synapse.rest.client.v1 import login, room
 from synapse.types import UserID
 
 from tests import unittest
+from tests.test_utils import make_awaitable
 
 
 class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
@@ -59,7 +58,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
 
         # Artificially raise the complexity
         store = self.hs.get_datastore()
-        store.get_current_state_event_counts = lambda x: defer.succeed(500 * 1.23)
+        store.get_current_state_event_counts = lambda x: make_awaitable(500 * 1.23)
 
         # Get the room complexity again -- make sure it's our artificial value
         request, channel = self.make_request(
@@ -78,9 +77,44 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
         fed_transport = self.hs.get_federation_transport_client()
 
         # Mock out some things, because we don't want to test the whole join
-        fed_transport.client.get_json = Mock(return_value=defer.succeed({"v1": 9999}))
+        fed_transport.client.get_json = Mock(
+            side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999})
+        )
         handler.federation_handler.do_invite_join = Mock(
-            return_value=defer.succeed(("", 1))
+            side_effect=lambda *args, **kwargs: make_awaitable(("", 1))
+        )
+
+        d = handler._remote_join(
+            None,
+            ["other.example.com"],
+            "roomid",
+            UserID.from_string(u1),
+            {"membership": "join"},
+        )
+
+        self.pump()
+
+        # The request failed with a SynapseError saying the resource limit was
+        # exceeded.
+        f = self.get_failure(d, SynapseError)
+        self.assertEqual(f.value.code, 400, f.value)
+        self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+
+    def test_join_too_large_admin(self):
+        # Check whether an admin can join if option "admins_can_join" is undefined,
+        # this option defaults to false, so the join should fail.
+
+        u1 = self.register_user("u1", "pass", admin=True)
+
+        handler = self.hs.get_room_member_handler()
+        fed_transport = self.hs.get_federation_transport_client()
+
+        # Mock out some things, because we don't want to test the whole join
+        fed_transport.client.get_json = Mock(
+            side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999})
+        )
+        handler.federation_handler.do_invite_join = Mock(
+            side_effect=lambda *args, **kwargs: make_awaitable(("", 1))
         )
 
         d = handler._remote_join(
@@ -116,13 +150,15 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
         fed_transport = self.hs.get_federation_transport_client()
 
         # Mock out some things, because we don't want to test the whole join
-        fed_transport.client.get_json = Mock(return_value=defer.succeed(None))
+        fed_transport.client.get_json = Mock(
+            side_effect=lambda *args, **kwargs: make_awaitable(None)
+        )
         handler.federation_handler.do_invite_join = Mock(
-            return_value=defer.succeed(("", 1))
+            side_effect=lambda *args, **kwargs: make_awaitable(("", 1))
         )
 
         # Artificially raise the complexity
-        self.hs.get_datastore().get_current_state_event_counts = lambda x: defer.succeed(
+        self.hs.get_datastore().get_current_state_event_counts = lambda x: make_awaitable(
             600
         )
 
@@ -141,3 +177,85 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
         f = self.get_failure(d, SynapseError)
         self.assertEqual(f.value.code, 400)
         self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+
+
+class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
+    # Test the behavior of joining rooms which exceed the complexity if option
+    # limit_remote_rooms.admins_can_join is True.
+
+    servlets = [
+        admin.register_servlets,
+        room.register_servlets,
+        login.register_servlets,
+    ]
+
+    def default_config(self):
+        config = super().default_config()
+        config["limit_remote_rooms"] = {
+            "enabled": True,
+            "complexity": 0.05,
+            "admins_can_join": True,
+        }
+        return config
+
+    def test_join_too_large_no_admin(self):
+        # A user which is not an admin should not be able to join a remote room
+        # which is too complex.
+
+        u1 = self.register_user("u1", "pass")
+
+        handler = self.hs.get_room_member_handler()
+        fed_transport = self.hs.get_federation_transport_client()
+
+        # Mock out some things, because we don't want to test the whole join
+        fed_transport.client.get_json = Mock(
+            side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999})
+        )
+        handler.federation_handler.do_invite_join = Mock(
+            side_effect=lambda *args, **kwargs: make_awaitable(("", 1))
+        )
+
+        d = handler._remote_join(
+            None,
+            ["other.example.com"],
+            "roomid",
+            UserID.from_string(u1),
+            {"membership": "join"},
+        )
+
+        self.pump()
+
+        # The request failed with a SynapseError saying the resource limit was
+        # exceeded.
+        f = self.get_failure(d, SynapseError)
+        self.assertEqual(f.value.code, 400, f.value)
+        self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+
+    def test_join_too_large_admin(self):
+        # An admin should be able to join rooms where a complexity check fails.
+
+        u1 = self.register_user("u1", "pass", admin=True)
+
+        handler = self.hs.get_room_member_handler()
+        fed_transport = self.hs.get_federation_transport_client()
+
+        # Mock out some things, because we don't want to test the whole join
+        fed_transport.client.get_json = Mock(
+            side_effect=lambda *args, **kwargs: make_awaitable({"v1": 9999})
+        )
+        handler.federation_handler.do_invite_join = Mock(
+            side_effect=lambda *args, **kwargs: make_awaitable(("", 1))
+        )
+
+        d = handler._remote_join(
+            None,
+            ["other.example.com"],
+            "roomid",
+            UserID.from_string(u1),
+            {"membership": "join"},
+        )
+
+        self.pump()
+
+        # The request success since the user is an admin
+        self.get_success(d)
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index ff12539041..5f512ff8bf 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -21,35 +21,39 @@ from signedjson.types import BaseKey, SigningKey
 
 from twisted.internet import defer
 
+from synapse.api.constants import RoomEncryptionAlgorithms
 from synapse.rest import admin
 from synapse.rest.client.v1 import login
 from synapse.types import JsonDict, ReadReceipt
 
+from tests.test_utils import make_awaitable
 from tests.unittest import HomeserverTestCase, override_config
 
 
 class FederationSenderReceiptsTestCases(HomeserverTestCase):
     def make_homeserver(self, reactor, clock):
+        mock_state_handler = Mock(spec=["get_current_hosts_in_room"])
+        # Ensure a new Awaitable is created for each call.
+        mock_state_handler.get_current_hosts_in_room.side_effect = lambda room_Id: make_awaitable(
+            ["test", "host2"]
+        )
         return self.setup_test_homeserver(
-            state_handler=Mock(spec=["get_current_hosts_in_room"]),
+            state_handler=mock_state_handler,
             federation_transport_client=Mock(spec=["send_transaction"]),
         )
 
     @override_config({"send_federation": True})
     def test_send_receipts(self):
-        mock_state_handler = self.hs.get_state_handler()
-        mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"]
-
         mock_send_transaction = (
             self.hs.get_federation_transport_client().send_transaction
         )
-        mock_send_transaction.return_value = defer.succeed({})
+        mock_send_transaction.return_value = make_awaitable({})
 
         sender = self.hs.get_federation_sender()
         receipt = ReadReceipt(
             "room_id", "m.read", "user_id", ["event_id"], {"ts": 1234}
         )
-        self.successResultOf(sender.send_read_receipt(receipt))
+        self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt)))
 
         self.pump()
 
@@ -80,19 +84,16 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
     def test_send_receipts_with_backoff(self):
         """Send two receipts in quick succession; the second should be flushed, but
         only after 20ms"""
-        mock_state_handler = self.hs.get_state_handler()
-        mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"]
-
         mock_send_transaction = (
             self.hs.get_federation_transport_client().send_transaction
         )
-        mock_send_transaction.return_value = defer.succeed({})
+        mock_send_transaction.return_value = make_awaitable({})
 
         sender = self.hs.get_federation_sender()
         receipt = ReadReceipt(
             "room_id", "m.read", "user_id", ["event_id"], {"ts": 1234}
         )
-        self.successResultOf(sender.send_read_receipt(receipt))
+        self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt)))
 
         self.pump()
 
@@ -124,7 +125,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
         receipt = ReadReceipt(
             "room_id", "m.read", "user_id", ["other_id"], {"ts": 1234}
         )
-        self.successResultOf(sender.send_read_receipt(receipt))
+        self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt)))
         self.pump()
         mock_send_transaction.assert_not_called()
 
@@ -163,7 +164,6 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
 
     def make_homeserver(self, reactor, clock):
         return self.setup_test_homeserver(
-            state_handler=Mock(spec=["get_current_hosts_in_room"]),
             federation_transport_client=Mock(spec=["send_transaction"]),
         )
 
@@ -173,10 +173,6 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
         return c
 
     def prepare(self, reactor, clock, hs):
-        # stub out get_current_hosts_in_room
-        mock_state_handler = hs.get_state_handler()
-        mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"]
-
         # stub out get_users_who_share_room_with_user so that it claims that
         # `@user2:host2` is in the room
         def get_users_who_share_room_with_user(user_id):
@@ -536,7 +532,10 @@ def build_device_dict(user_id: str, device_id: str, sk: SigningKey):
     return {
         "user_id": user_id,
         "device_id": device_id,
-        "algorithms": ["m.olm.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"],
+        "algorithms": [
+            "m.olm.curve25519-aes-sha2",
+            RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
+        ],
         "keys": {
             "curve25519:" + device_id: "curve25519+key",
             key_id(sk): encode_pubkey(sk),
diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py
index 296dc887be..da933ecd75 100644
--- a/tests/federation/test_federation_server.py
+++ b/tests/federation/test_federation_server.py
@@ -15,6 +15,8 @@
 # limitations under the License.
 import logging
 
+from parameterized import parameterized
+
 from synapse.events import make_event_from_dict
 from synapse.federation.federation_server import server_matches_acl_event
 from synapse.rest import admin
@@ -23,6 +25,37 @@ from synapse.rest.client.v1 import login, room
 from tests import unittest
 
 
+class FederationServerTests(unittest.FederatingHomeserverTestCase):
+
+    servlets = [
+        admin.register_servlets,
+        room.register_servlets,
+        login.register_servlets,
+    ]
+
+    @parameterized.expand([(b"",), (b"foo",), (b'{"limit": Infinity}',)])
+    def test_bad_request(self, query_content):
+        """
+        Querying with bad data returns a reasonable error code.
+        """
+        u1 = self.register_user("u1", "pass")
+        u1_token = self.login("u1", "pass")
+
+        room_1 = self.helper.create_room_as(u1, tok=u1_token)
+        self.inject_room_member(room_1, "@user:other.example.com", "join")
+
+        "/get_missing_events/(?P<room_id>[^/]*)/?"
+
+        request, channel = self.make_request(
+            "POST",
+            "/_matrix/federation/v1/get_missing_events/%s" % (room_1,),
+            query_content,
+        )
+        self.render(request)
+        self.assertEquals(400, channel.code, channel.result)
+        self.assertEqual(channel.json_body["errcode"], "M_NOT_JSON")
+
+
 class ServerACLsTestCase(unittest.TestCase):
     def test_blacklisted_server(self):
         e = _create_acl_event({"allow": ["*"], "deny": ["evil.com"]})
diff --git a/tests/federation/transport/test_server.py b/tests/federation/transport/test_server.py
index 27d83bb7d9..72e22d655f 100644
--- a/tests/federation/transport/test_server.py
+++ b/tests/federation/transport/test_server.py
@@ -26,7 +26,7 @@ from tests.unittest import override_config
 
 class RoomDirectoryFederationTests(unittest.HomeserverTestCase):
     def prepare(self, reactor, clock, homeserver):
-        class Authenticator(object):
+        class Authenticator:
             def authenticate_request(self, request, content):
                 return defer.succeed("otherserver.nottld")
 
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index ba7148ec01..2a0b7c1b56 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -19,6 +19,7 @@ from twisted.internet import defer
 
 from synapse.handlers.appservice import ApplicationServicesHandler
 
+from tests.test_utils import make_awaitable
 from tests.utils import MockClock
 
 from .. import unittest
@@ -32,10 +33,11 @@ class AppServiceHandlerTestCase(unittest.TestCase):
         self.mock_as_api = Mock()
         self.mock_scheduler = Mock()
         hs = Mock()
-        hs.get_datastore = Mock(return_value=self.mock_store)
-        self.mock_store.get_received_ts.return_value = 0
-        hs.get_application_service_api = Mock(return_value=self.mock_as_api)
-        hs.get_application_service_scheduler = Mock(return_value=self.mock_scheduler)
+        hs.get_datastore.return_value = self.mock_store
+        self.mock_store.get_received_ts.return_value = defer.succeed(0)
+        self.mock_store.set_appservice_last_pos.return_value = defer.succeed(None)
+        hs.get_application_service_api.return_value = self.mock_as_api
+        hs.get_application_service_scheduler.return_value = self.mock_scheduler
         hs.get_clock.return_value = MockClock()
         self.handler = ApplicationServicesHandler(hs)
 
@@ -48,18 +50,18 @@ class AppServiceHandlerTestCase(unittest.TestCase):
             self._mkservice(is_interested=False),
         ]
 
-        self.mock_store.get_app_services = Mock(return_value=services)
-        self.mock_store.get_user_by_id = Mock(return_value=[])
+        self.mock_as_api.query_user.return_value = defer.succeed(True)
+        self.mock_store.get_app_services.return_value = services
+        self.mock_store.get_user_by_id.return_value = defer.succeed([])
 
         event = Mock(
             sender="@someone:anywhere", type="m.room.message", room_id="!foo:bar"
         )
         self.mock_store.get_new_events_for_appservice.side_effect = [
-            (0, [event]),
-            (0, []),
+            defer.succeed((0, [event])),
+            defer.succeed((0, [])),
         ]
-        self.mock_as_api.push = Mock()
-        yield self.handler.notify_interested_services(0)
+        yield defer.ensureDeferred(self.handler.notify_interested_services(0))
         self.mock_scheduler.submit_event_for_as.assert_called_once_with(
             interested_service, event
         )
@@ -68,36 +70,34 @@ class AppServiceHandlerTestCase(unittest.TestCase):
     def test_query_user_exists_unknown_user(self):
         user_id = "@someone:anywhere"
         services = [self._mkservice(is_interested=True)]
-        services[0].is_interested_in_user = Mock(return_value=True)
-        self.mock_store.get_app_services = Mock(return_value=services)
-        self.mock_store.get_user_by_id = Mock(return_value=None)
+        services[0].is_interested_in_user.return_value = True
+        self.mock_store.get_app_services.return_value = services
+        self.mock_store.get_user_by_id.return_value = defer.succeed(None)
 
         event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
-        self.mock_as_api.push = Mock()
-        self.mock_as_api.query_user = Mock()
+        self.mock_as_api.query_user.return_value = defer.succeed(True)
         self.mock_store.get_new_events_for_appservice.side_effect = [
-            (0, [event]),
-            (0, []),
+            defer.succeed((0, [event])),
+            defer.succeed((0, [])),
         ]
-        yield self.handler.notify_interested_services(0)
+        yield defer.ensureDeferred(self.handler.notify_interested_services(0))
         self.mock_as_api.query_user.assert_called_once_with(services[0], user_id)
 
     @defer.inlineCallbacks
     def test_query_user_exists_known_user(self):
         user_id = "@someone:anywhere"
         services = [self._mkservice(is_interested=True)]
-        services[0].is_interested_in_user = Mock(return_value=True)
-        self.mock_store.get_app_services = Mock(return_value=services)
-        self.mock_store.get_user_by_id = Mock(return_value={"name": user_id})
+        services[0].is_interested_in_user.return_value = True
+        self.mock_store.get_app_services.return_value = services
+        self.mock_store.get_user_by_id.return_value = defer.succeed({"name": user_id})
 
         event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
-        self.mock_as_api.push = Mock()
-        self.mock_as_api.query_user = Mock()
+        self.mock_as_api.query_user.return_value = defer.succeed(True)
         self.mock_store.get_new_events_for_appservice.side_effect = [
-            (0, [event]),
-            (0, []),
+            defer.succeed((0, [event])),
+            defer.succeed((0, [])),
         ]
-        yield self.handler.notify_interested_services(0)
+        yield defer.ensureDeferred(self.handler.notify_interested_services(0))
         self.assertFalse(
             self.mock_as_api.query_user.called,
             "query_user called when it shouldn't have been.",
@@ -107,7 +107,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
     def test_query_room_alias_exists(self):
         room_alias_str = "#foo:bar"
         room_alias = Mock()
-        room_alias.to_string = Mock(return_value=room_alias_str)
+        room_alias.to_string.return_value = room_alias_str
 
         room_id = "!alpha:bet"
         servers = ["aperture"]
@@ -118,12 +118,15 @@ class AppServiceHandlerTestCase(unittest.TestCase):
             self._mkservice_alias(is_interested_in_alias=False),
         ]
 
-        self.mock_store.get_app_services = Mock(return_value=services)
-        self.mock_store.get_association_from_room_alias = Mock(
-            return_value=Mock(room_id=room_id, servers=servers)
+        self.mock_as_api.query_alias.return_value = make_awaitable(True)
+        self.mock_store.get_app_services.return_value = services
+        self.mock_store.get_association_from_room_alias.return_value = make_awaitable(
+            Mock(room_id=room_id, servers=servers)
         )
 
-        result = yield self.handler.query_room_alias_exists(room_alias)
+        result = yield defer.ensureDeferred(
+            self.handler.query_room_alias_exists(room_alias)
+        )
 
         self.mock_as_api.query_alias.assert_called_once_with(
             interested_service, room_alias_str
@@ -133,14 +136,14 @@ class AppServiceHandlerTestCase(unittest.TestCase):
 
     def _mkservice(self, is_interested):
         service = Mock()
-        service.is_interested = Mock(return_value=is_interested)
+        service.is_interested.return_value = make_awaitable(is_interested)
         service.token = "mock_service_token"
         service.url = "mock_service_url"
         return service
 
     def _mkservice_alias(self, is_interested_in_alias):
         service = Mock()
-        service.is_interested_in_alias = Mock(return_value=is_interested_in_alias)
+        service.is_interested_in_alias.return_value = is_interested_in_alias
         service.token = "mock_service_token"
         service.url = "mock_service_url"
         return service
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index c01b04e1dc..c7efd3822d 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -24,10 +24,11 @@ from synapse.api.errors import ResourceLimitError
 from synapse.handlers.auth import AuthHandler
 
 from tests import unittest
+from tests.test_utils import make_awaitable
 from tests.utils import setup_test_homeserver
 
 
-class AuthHandlers(object):
+class AuthHandlers:
     def __init__(self, hs):
         self.auth_handler = AuthHandler(hs)
 
@@ -142,7 +143,7 @@ class AuthTestCase(unittest.TestCase):
     def test_mau_limits_exceeded_large(self):
         self.auth_blocking._limit_usage_by_mau = True
         self.hs.get_datastore().get_monthly_active_count = Mock(
-            return_value=defer.succeed(self.large_number_of_users)
+            side_effect=lambda: make_awaitable(self.large_number_of_users)
         )
 
         with self.assertRaises(ResourceLimitError):
@@ -153,7 +154,7 @@ class AuthTestCase(unittest.TestCase):
             )
 
         self.hs.get_datastore().get_monthly_active_count = Mock(
-            return_value=defer.succeed(self.large_number_of_users)
+            side_effect=lambda: make_awaitable(self.large_number_of_users)
         )
         with self.assertRaises(ResourceLimitError):
             yield defer.ensureDeferred(
@@ -168,7 +169,7 @@ class AuthTestCase(unittest.TestCase):
 
         # If not in monthly active cohort
         self.hs.get_datastore().get_monthly_active_count = Mock(
-            return_value=defer.succeed(self.auth_blocking._max_mau_value)
+            side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value)
         )
         with self.assertRaises(ResourceLimitError):
             yield defer.ensureDeferred(
@@ -178,7 +179,7 @@ class AuthTestCase(unittest.TestCase):
             )
 
         self.hs.get_datastore().get_monthly_active_count = Mock(
-            return_value=defer.succeed(self.auth_blocking._max_mau_value)
+            side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value)
         )
         with self.assertRaises(ResourceLimitError):
             yield defer.ensureDeferred(
@@ -188,10 +189,10 @@ class AuthTestCase(unittest.TestCase):
             )
         # If in monthly active cohort
         self.hs.get_datastore().user_last_seen_monthly_active = Mock(
-            return_value=defer.succeed(self.hs.get_clock().time_msec())
+            side_effect=lambda user_id: make_awaitable(self.hs.get_clock().time_msec())
         )
         self.hs.get_datastore().get_monthly_active_count = Mock(
-            return_value=defer.succeed(self.auth_blocking._max_mau_value)
+            side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value)
         )
         yield defer.ensureDeferred(
             self.auth_handler.get_access_token_for_user_id(
@@ -199,10 +200,10 @@ class AuthTestCase(unittest.TestCase):
             )
         )
         self.hs.get_datastore().user_last_seen_monthly_active = Mock(
-            return_value=defer.succeed(self.hs.get_clock().time_msec())
+            side_effect=lambda user_id: make_awaitable(self.hs.get_clock().time_msec())
         )
         self.hs.get_datastore().get_monthly_active_count = Mock(
-            return_value=defer.succeed(self.auth_blocking._max_mau_value)
+            side_effect=lambda: make_awaitable(self.auth_blocking._max_mau_value)
         )
         yield defer.ensureDeferred(
             self.auth_handler.validate_short_term_login_token_and_get_user_id(
@@ -215,7 +216,7 @@ class AuthTestCase(unittest.TestCase):
         self.auth_blocking._limit_usage_by_mau = True
 
         self.hs.get_datastore().get_monthly_active_count = Mock(
-            return_value=defer.succeed(self.small_number_of_users)
+            side_effect=lambda: make_awaitable(self.small_number_of_users)
         )
         # Ensure does not raise exception
         yield defer.ensureDeferred(
@@ -225,7 +226,7 @@ class AuthTestCase(unittest.TestCase):
         )
 
         self.hs.get_datastore().get_monthly_active_count = Mock(
-            return_value=defer.succeed(self.small_number_of_users)
+            side_effect=lambda: make_awaitable(self.small_number_of_users)
         )
         yield defer.ensureDeferred(
             self.auth_handler.validate_short_term_login_token_and_get_user_id(
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 62b47f6574..6aa322bf3a 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -142,10 +142,8 @@ class DeviceTestCase(unittest.HomeserverTestCase):
         self.get_success(self.handler.delete_device(user1, "abc"))
 
         # check the device was deleted
-        res = self.handler.get_device(user1, "abc")
-        self.pump()
-        self.assertIsInstance(
-            self.failureResultOf(res).value, synapse.api.errors.NotFoundError
+        self.get_failure(
+            self.handler.get_device(user1, "abc"), synapse.api.errors.NotFoundError
         )
 
         # we'd like to check the access token was invalidated, but that's a
@@ -180,10 +178,9 @@ class DeviceTestCase(unittest.HomeserverTestCase):
 
     def test_update_unknown_device(self):
         update = {"display_name": "new_display"}
-        res = self.handler.update_device("user_id", "unknown_device_id", update)
-        self.pump()
-        self.assertIsInstance(
-            self.failureResultOf(res).value, synapse.api.errors.NotFoundError
+        self.get_failure(
+            self.handler.update_device("user_id", "unknown_device_id", update),
+            synapse.api.errors.NotFoundError,
         )
 
     def _record_users(self):
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index 00bb776271..bc0c5aefdc 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -16,8 +16,6 @@
 
 from mock import Mock
 
-from twisted.internet import defer
-
 import synapse
 import synapse.api.errors
 from synapse.api.constants import EventTypes
@@ -26,6 +24,7 @@ from synapse.rest.client.v1 import directory, login, room
 from synapse.types import RoomAlias, create_requester
 
 from tests import unittest
+from tests.test_utils import make_awaitable
 
 
 class DirectoryTestCase(unittest.HomeserverTestCase):
@@ -71,7 +70,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
         self.assertEquals({"room_id": "!8765qwer:test", "servers": ["test"]}, result)
 
     def test_get_remote_association(self):
-        self.mock_federation.make_query.return_value = defer.succeed(
+        self.mock_federation.make_query.return_value = make_awaitable(
             {"room_id": "!8765qwer:test", "servers": ["test", "remote"]}
         )
 
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index e1e144b2e7..210ddcbb88 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -14,17 +14,16 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import mock
 
-import signedjson.key as key
-import signedjson.sign as sign
+from signedjson import key as key, sign as sign
 
 from twisted.internet import defer
 
 import synapse.handlers.e2e_keys
 import synapse.storage
 from synapse.api import errors
+from synapse.api.constants import RoomEncryptionAlgorithms
 
 from tests import unittest, utils
 
@@ -47,7 +46,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         """If the user has no devices, we expect an empty list.
         """
         local_user = "@boris:" + self.hs.hostname
-        res = yield self.handler.query_local_devices({local_user: None})
+        res = yield defer.ensureDeferred(
+            self.handler.query_local_devices({local_user: None})
+        )
         self.assertDictEqual(res, {local_user: {}})
 
     @defer.inlineCallbacks
@@ -61,15 +62,19 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
             "alg2:k3": {"key": "key3"},
         }
 
-        res = yield self.handler.upload_keys_for_user(
-            local_user, device_id, {"one_time_keys": keys}
+        res = yield defer.ensureDeferred(
+            self.handler.upload_keys_for_user(
+                local_user, device_id, {"one_time_keys": keys}
+            )
         )
         self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
 
         # we should be able to change the signature without a problem
         keys["alg2:k2"]["signatures"]["k1"] = "sig2"
-        res = yield self.handler.upload_keys_for_user(
-            local_user, device_id, {"one_time_keys": keys}
+        res = yield defer.ensureDeferred(
+            self.handler.upload_keys_for_user(
+                local_user, device_id, {"one_time_keys": keys}
+            )
         )
         self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
 
@@ -85,44 +90,56 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
             "alg2:k3": {"key": "key3"},
         }
 
-        res = yield self.handler.upload_keys_for_user(
-            local_user, device_id, {"one_time_keys": keys}
+        res = yield defer.ensureDeferred(
+            self.handler.upload_keys_for_user(
+                local_user, device_id, {"one_time_keys": keys}
+            )
         )
         self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
 
         try:
-            yield self.handler.upload_keys_for_user(
-                local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}}
+            yield defer.ensureDeferred(
+                self.handler.upload_keys_for_user(
+                    local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}}
+                )
             )
             self.fail("No error when changing string key")
         except errors.SynapseError:
             pass
 
         try:
-            yield self.handler.upload_keys_for_user(
-                local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}}
+            yield defer.ensureDeferred(
+                self.handler.upload_keys_for_user(
+                    local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}}
+                )
             )
             self.fail("No error when replacing dict key with string")
         except errors.SynapseError:
             pass
 
         try:
-            yield self.handler.upload_keys_for_user(
-                local_user, device_id, {"one_time_keys": {"alg1:k1": {"key": "key"}}}
+            yield defer.ensureDeferred(
+                self.handler.upload_keys_for_user(
+                    local_user,
+                    device_id,
+                    {"one_time_keys": {"alg1:k1": {"key": "key"}}},
+                )
             )
             self.fail("No error when replacing string key with dict")
         except errors.SynapseError:
             pass
 
         try:
-            yield self.handler.upload_keys_for_user(
-                local_user,
-                device_id,
-                {
-                    "one_time_keys": {
-                        "alg2:k2": {"key": "key3", "signatures": {"k1": "sig1"}}
-                    }
-                },
+            yield defer.ensureDeferred(
+                self.handler.upload_keys_for_user(
+                    local_user,
+                    device_id,
+                    {
+                        "one_time_keys": {
+                            "alg2:k2": {"key": "key3", "signatures": {"k1": "sig1"}}
+                        }
+                    },
+                )
             )
             self.fail("No error when replacing dict key")
         except errors.SynapseError:
@@ -134,13 +151,17 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         device_id = "xyz"
         keys = {"alg1:k1": "key1"}
 
-        res = yield self.handler.upload_keys_for_user(
-            local_user, device_id, {"one_time_keys": keys}
+        res = yield defer.ensureDeferred(
+            self.handler.upload_keys_for_user(
+                local_user, device_id, {"one_time_keys": keys}
+            )
         )
         self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1}})
 
-        res2 = yield self.handler.claim_one_time_keys(
-            {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+        res2 = yield defer.ensureDeferred(
+            self.handler.claim_one_time_keys(
+                {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+            )
         )
         self.assertEqual(
             res2,
@@ -164,7 +185,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
                 },
             }
         }
-        yield self.handler.upload_signing_keys_for_user(local_user, keys1)
+        yield defer.ensureDeferred(
+            self.handler.upload_signing_keys_for_user(local_user, keys1)
+        )
 
         keys2 = {
             "master_key": {
@@ -176,10 +199,12 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
                 },
             }
         }
-        yield self.handler.upload_signing_keys_for_user(local_user, keys2)
+        yield defer.ensureDeferred(
+            self.handler.upload_signing_keys_for_user(local_user, keys2)
+        )
 
-        devices = yield self.handler.query_devices(
-            {"device_keys": {local_user: []}}, 0, local_user
+        devices = yield defer.ensureDeferred(
+            self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user)
         )
         self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]})
 
@@ -216,13 +241,18 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
             "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
             "2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0",
         )
-        yield self.handler.upload_signing_keys_for_user(local_user, keys1)
+        yield defer.ensureDeferred(
+            self.handler.upload_signing_keys_for_user(local_user, keys1)
+        )
 
         # upload two device keys, which will be signed later by the self-signing key
         device_key_1 = {
             "user_id": local_user,
             "device_id": "abc",
-            "algorithms": ["m.olm.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"],
+            "algorithms": [
+                "m.olm.curve25519-aes-sha2",
+                RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
+            ],
             "keys": {
                 "ed25519:abc": "base64+ed25519+key",
                 "curve25519:abc": "base64+curve25519+key",
@@ -232,7 +262,10 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         device_key_2 = {
             "user_id": local_user,
             "device_id": "def",
-            "algorithms": ["m.olm.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"],
+            "algorithms": [
+                "m.olm.curve25519-aes-sha2",
+                RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
+            ],
             "keys": {
                 "ed25519:def": "base64+ed25519+key",
                 "curve25519:def": "base64+curve25519+key",
@@ -240,18 +273,24 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
             "signatures": {local_user: {"ed25519:def": "base64+signature"}},
         }
 
-        yield self.handler.upload_keys_for_user(
-            local_user, "abc", {"device_keys": device_key_1}
+        yield defer.ensureDeferred(
+            self.handler.upload_keys_for_user(
+                local_user, "abc", {"device_keys": device_key_1}
+            )
         )
-        yield self.handler.upload_keys_for_user(
-            local_user, "def", {"device_keys": device_key_2}
+        yield defer.ensureDeferred(
+            self.handler.upload_keys_for_user(
+                local_user, "def", {"device_keys": device_key_2}
+            )
         )
 
         # sign the first device key and upload it
         del device_key_1["signatures"]
         sign.sign_json(device_key_1, local_user, signing_key)
-        yield self.handler.upload_signatures_for_device_keys(
-            local_user, {local_user: {"abc": device_key_1}}
+        yield defer.ensureDeferred(
+            self.handler.upload_signatures_for_device_keys(
+                local_user, {local_user: {"abc": device_key_1}}
+            )
         )
 
         # sign the second device key and upload both device keys.  The server
@@ -259,14 +298,16 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         # signature for it
         del device_key_2["signatures"]
         sign.sign_json(device_key_2, local_user, signing_key)
-        yield self.handler.upload_signatures_for_device_keys(
-            local_user, {local_user: {"abc": device_key_1, "def": device_key_2}}
+        yield defer.ensureDeferred(
+            self.handler.upload_signatures_for_device_keys(
+                local_user, {local_user: {"abc": device_key_1, "def": device_key_2}}
+            )
         )
 
         device_key_1["signatures"][local_user]["ed25519:abc"] = "base64+signature"
         device_key_2["signatures"][local_user]["ed25519:def"] = "base64+signature"
-        devices = yield self.handler.query_devices(
-            {"device_keys": {local_user: []}}, 0, local_user
+        devices = yield defer.ensureDeferred(
+            self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user)
         )
         del devices["device_keys"][local_user]["abc"]["unsigned"]
         del devices["device_keys"][local_user]["def"]["unsigned"]
@@ -287,20 +328,26 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
                 },
             }
         }
-        yield self.handler.upload_signing_keys_for_user(local_user, keys1)
+        yield defer.ensureDeferred(
+            self.handler.upload_signing_keys_for_user(local_user, keys1)
+        )
 
         res = None
         try:
-            yield self.hs.get_device_handler().check_device_registered(
-                user_id=local_user,
-                device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
-                initial_device_display_name="new display name",
+            yield defer.ensureDeferred(
+                self.hs.get_device_handler().check_device_registered(
+                    user_id=local_user,
+                    device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
+                    initial_device_display_name="new display name",
+                )
             )
         except errors.SynapseError as e:
             res = e.code
         self.assertEqual(res, 400)
 
-        res = yield self.handler.query_local_devices({local_user: None})
+        res = yield defer.ensureDeferred(
+            self.handler.query_local_devices({local_user: None})
+        )
         self.assertDictEqual(res, {local_user: {}})
 
     @defer.inlineCallbacks
@@ -315,7 +362,10 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         device_key = {
             "user_id": local_user,
             "device_id": device_id,
-            "algorithms": ["m.olm.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"],
+            "algorithms": [
+                "m.olm.curve25519-aes-sha2",
+                RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
+            ],
             "keys": {"curve25519:xyz": "curve25519+key", "ed25519:xyz": device_pubkey},
             "signatures": {local_user: {"ed25519:xyz": "something"}},
         }
@@ -323,8 +373,10 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
             "ed25519", "xyz", "OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA"
         )
 
-        yield self.handler.upload_keys_for_user(
-            local_user, device_id, {"device_keys": device_key}
+        yield defer.ensureDeferred(
+            self.handler.upload_keys_for_user(
+                local_user, device_id, {"device_keys": device_key}
+            )
         )
 
         # private key: 2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0
@@ -364,7 +416,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
             "user_signing_key": usersigning_key,
             "self_signing_key": selfsigning_key,
         }
-        yield self.handler.upload_signing_keys_for_user(local_user, cross_signing_keys)
+        yield defer.ensureDeferred(
+            self.handler.upload_signing_keys_for_user(local_user, cross_signing_keys)
+        )
 
         # set up another user with a master key.  This user will be signed by
         # the first user
@@ -376,76 +430,90 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
             "usage": ["master"],
             "keys": {"ed25519:" + other_master_pubkey: other_master_pubkey},
         }
-        yield self.handler.upload_signing_keys_for_user(
-            other_user, {"master_key": other_master_key}
+        yield defer.ensureDeferred(
+            self.handler.upload_signing_keys_for_user(
+                other_user, {"master_key": other_master_key}
+            )
         )
 
         # test various signature failures (see below)
-        ret = yield self.handler.upload_signatures_for_device_keys(
-            local_user,
-            {
-                local_user: {
-                    # fails because the signature is invalid
-                    # should fail with INVALID_SIGNATURE
-                    device_id: {
-                        "user_id": local_user,
-                        "device_id": device_id,
-                        "algorithms": [
-                            "m.olm.curve25519-aes-sha2",
-                            "m.megolm.v1.aes-sha2",
-                        ],
-                        "keys": {
-                            "curve25519:xyz": "curve25519+key",
-                            # private key: OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA
-                            "ed25519:xyz": device_pubkey,
+        ret = yield defer.ensureDeferred(
+            self.handler.upload_signatures_for_device_keys(
+                local_user,
+                {
+                    local_user: {
+                        # fails because the signature is invalid
+                        # should fail with INVALID_SIGNATURE
+                        device_id: {
+                            "user_id": local_user,
+                            "device_id": device_id,
+                            "algorithms": [
+                                "m.olm.curve25519-aes-sha2",
+                                RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
+                            ],
+                            "keys": {
+                                "curve25519:xyz": "curve25519+key",
+                                # private key: OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA
+                                "ed25519:xyz": device_pubkey,
+                            },
+                            "signatures": {
+                                local_user: {
+                                    "ed25519:" + selfsigning_pubkey: "something"
+                                }
+                            },
                         },
-                        "signatures": {
-                            local_user: {"ed25519:" + selfsigning_pubkey: "something"}
+                        # fails because device is unknown
+                        # should fail with NOT_FOUND
+                        "unknown": {
+                            "user_id": local_user,
+                            "device_id": "unknown",
+                            "signatures": {
+                                local_user: {
+                                    "ed25519:" + selfsigning_pubkey: "something"
+                                }
+                            },
                         },
-                    },
-                    # fails because device is unknown
-                    # should fail with NOT_FOUND
-                    "unknown": {
-                        "user_id": local_user,
-                        "device_id": "unknown",
-                        "signatures": {
-                            local_user: {"ed25519:" + selfsigning_pubkey: "something"}
+                        # fails because the signature is invalid
+                        # should fail with INVALID_SIGNATURE
+                        master_pubkey: {
+                            "user_id": local_user,
+                            "usage": ["master"],
+                            "keys": {"ed25519:" + master_pubkey: master_pubkey},
+                            "signatures": {
+                                local_user: {"ed25519:" + device_pubkey: "something"}
+                            },
                         },
                     },
-                    # fails because the signature is invalid
-                    # should fail with INVALID_SIGNATURE
-                    master_pubkey: {
-                        "user_id": local_user,
-                        "usage": ["master"],
-                        "keys": {"ed25519:" + master_pubkey: master_pubkey},
-                        "signatures": {
-                            local_user: {"ed25519:" + device_pubkey: "something"}
+                    other_user: {
+                        # fails because the device is not the user's master-signing key
+                        # should fail with NOT_FOUND
+                        "unknown": {
+                            "user_id": other_user,
+                            "device_id": "unknown",
+                            "signatures": {
+                                local_user: {
+                                    "ed25519:" + usersigning_pubkey: "something"
+                                }
+                            },
                         },
-                    },
-                },
-                other_user: {
-                    # fails because the device is not the user's master-signing key
-                    # should fail with NOT_FOUND
-                    "unknown": {
-                        "user_id": other_user,
-                        "device_id": "unknown",
-                        "signatures": {
-                            local_user: {"ed25519:" + usersigning_pubkey: "something"}
-                        },
-                    },
-                    other_master_pubkey: {
-                        # fails because the key doesn't match what the server has
-                        # should fail with UNKNOWN
-                        "user_id": other_user,
-                        "usage": ["master"],
-                        "keys": {"ed25519:" + other_master_pubkey: other_master_pubkey},
-                        "something": "random",
-                        "signatures": {
-                            local_user: {"ed25519:" + usersigning_pubkey: "something"}
+                        other_master_pubkey: {
+                            # fails because the key doesn't match what the server has
+                            # should fail with UNKNOWN
+                            "user_id": other_user,
+                            "usage": ["master"],
+                            "keys": {
+                                "ed25519:" + other_master_pubkey: other_master_pubkey
+                            },
+                            "something": "random",
+                            "signatures": {
+                                local_user: {
+                                    "ed25519:" + usersigning_pubkey: "something"
+                                }
+                            },
                         },
                     },
                 },
-            },
+            )
         )
 
         user_failures = ret["failures"][local_user]
@@ -470,19 +538,23 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         sign.sign_json(device_key, local_user, selfsigning_signing_key)
         sign.sign_json(master_key, local_user, device_signing_key)
         sign.sign_json(other_master_key, local_user, usersigning_signing_key)
-        ret = yield self.handler.upload_signatures_for_device_keys(
-            local_user,
-            {
-                local_user: {device_id: device_key, master_pubkey: master_key},
-                other_user: {other_master_pubkey: other_master_key},
-            },
+        ret = yield defer.ensureDeferred(
+            self.handler.upload_signatures_for_device_keys(
+                local_user,
+                {
+                    local_user: {device_id: device_key, master_pubkey: master_key},
+                    other_user: {other_master_pubkey: other_master_key},
+                },
+            )
         )
 
         self.assertEqual(ret["failures"], {})
 
         # fetch the signed keys/devices and make sure that the signatures are there
-        ret = yield self.handler.query_devices(
-            {"device_keys": {local_user: [], other_user: []}}, 0, local_user
+        ret = yield defer.ensureDeferred(
+            self.handler.query_devices(
+                {"device_keys": {local_user: [], other_user: []}}, 0, local_user
+            )
         )
 
         self.assertEqual(
diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py
index 70f172eb02..3362050ce0 100644
--- a/tests/handlers/test_e2e_room_keys.py
+++ b/tests/handlers/test_e2e_room_keys.py
@@ -66,7 +66,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         """
         res = None
         try:
-            yield self.handler.get_version_info(self.local_user)
+            yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
         except errors.SynapseError as e:
             res = e.code
         self.assertEqual(res, 404)
@@ -78,7 +78,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         """
         res = None
         try:
-            yield self.handler.get_version_info(self.local_user, "bogus_version")
+            yield defer.ensureDeferred(
+                self.handler.get_version_info(self.local_user, "bogus_version")
+            )
         except errors.SynapseError as e:
             res = e.code
         self.assertEqual(res, 404)
@@ -87,15 +89,21 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
     def test_create_version(self):
         """Check that we can create and then retrieve versions.
         """
-        res = yield self.handler.create_version(
-            self.local_user,
-            {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+        res = yield defer.ensureDeferred(
+            self.handler.create_version(
+                self.local_user,
+                {
+                    "algorithm": "m.megolm_backup.v1",
+                    "auth_data": "first_version_auth_data",
+                },
+            )
         )
         self.assertEqual(res, "1")
 
         # check we can retrieve it as the current version
-        res = yield self.handler.get_version_info(self.local_user)
+        res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
         version_etag = res["etag"]
+        self.assertIsInstance(version_etag, str)
         del res["etag"]
         self.assertDictEqual(
             res,
@@ -108,7 +116,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         )
 
         # check we can retrieve it as a specific version
-        res = yield self.handler.get_version_info(self.local_user, "1")
+        res = yield defer.ensureDeferred(
+            self.handler.get_version_info(self.local_user, "1")
+        )
         self.assertEqual(res["etag"], version_etag)
         del res["etag"]
         self.assertDictEqual(
@@ -122,17 +132,19 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         )
 
         # upload a new one...
-        res = yield self.handler.create_version(
-            self.local_user,
-            {
-                "algorithm": "m.megolm_backup.v1",
-                "auth_data": "second_version_auth_data",
-            },
+        res = yield defer.ensureDeferred(
+            self.handler.create_version(
+                self.local_user,
+                {
+                    "algorithm": "m.megolm_backup.v1",
+                    "auth_data": "second_version_auth_data",
+                },
+            )
         )
         self.assertEqual(res, "2")
 
         # check we can retrieve it as the current version
-        res = yield self.handler.get_version_info(self.local_user)
+        res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
         del res["etag"]
         self.assertDictEqual(
             res,
@@ -148,25 +160,32 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
     def test_update_version(self):
         """Check that we can update versions.
         """
-        version = yield self.handler.create_version(
-            self.local_user,
-            {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+        version = yield defer.ensureDeferred(
+            self.handler.create_version(
+                self.local_user,
+                {
+                    "algorithm": "m.megolm_backup.v1",
+                    "auth_data": "first_version_auth_data",
+                },
+            )
         )
         self.assertEqual(version, "1")
 
-        res = yield self.handler.update_version(
-            self.local_user,
-            version,
-            {
-                "algorithm": "m.megolm_backup.v1",
-                "auth_data": "revised_first_version_auth_data",
-                "version": version,
-            },
+        res = yield defer.ensureDeferred(
+            self.handler.update_version(
+                self.local_user,
+                version,
+                {
+                    "algorithm": "m.megolm_backup.v1",
+                    "auth_data": "revised_first_version_auth_data",
+                    "version": version,
+                },
+            )
         )
         self.assertDictEqual(res, {})
 
         # check we can retrieve it as the current version
-        res = yield self.handler.get_version_info(self.local_user)
+        res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
         del res["etag"]
         self.assertDictEqual(
             res,
@@ -184,14 +203,16 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         """
         res = None
         try:
-            yield self.handler.update_version(
-                self.local_user,
-                "1",
-                {
-                    "algorithm": "m.megolm_backup.v1",
-                    "auth_data": "revised_first_version_auth_data",
-                    "version": "1",
-                },
+            yield defer.ensureDeferred(
+                self.handler.update_version(
+                    self.local_user,
+                    "1",
+                    {
+                        "algorithm": "m.megolm_backup.v1",
+                        "auth_data": "revised_first_version_auth_data",
+                        "version": "1",
+                    },
+                )
             )
         except errors.SynapseError as e:
             res = e.code
@@ -201,23 +222,30 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
     def test_update_omitted_version(self):
         """Check that the update succeeds if the version is missing from the body
         """
-        version = yield self.handler.create_version(
-            self.local_user,
-            {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+        version = yield defer.ensureDeferred(
+            self.handler.create_version(
+                self.local_user,
+                {
+                    "algorithm": "m.megolm_backup.v1",
+                    "auth_data": "first_version_auth_data",
+                },
+            )
         )
         self.assertEqual(version, "1")
 
-        yield self.handler.update_version(
-            self.local_user,
-            version,
-            {
-                "algorithm": "m.megolm_backup.v1",
-                "auth_data": "revised_first_version_auth_data",
-            },
+        yield defer.ensureDeferred(
+            self.handler.update_version(
+                self.local_user,
+                version,
+                {
+                    "algorithm": "m.megolm_backup.v1",
+                    "auth_data": "revised_first_version_auth_data",
+                },
+            )
         )
 
         # check we can retrieve it as the current version
-        res = yield self.handler.get_version_info(self.local_user)
+        res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
         del res["etag"]  # etag is opaque, so don't test its contents
         self.assertDictEqual(
             res,
@@ -233,22 +261,29 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
     def test_update_bad_version(self):
         """Check that we get a 400 if the version in the body doesn't match
         """
-        version = yield self.handler.create_version(
-            self.local_user,
-            {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+        version = yield defer.ensureDeferred(
+            self.handler.create_version(
+                self.local_user,
+                {
+                    "algorithm": "m.megolm_backup.v1",
+                    "auth_data": "first_version_auth_data",
+                },
+            )
         )
         self.assertEqual(version, "1")
 
         res = None
         try:
-            yield self.handler.update_version(
-                self.local_user,
-                version,
-                {
-                    "algorithm": "m.megolm_backup.v1",
-                    "auth_data": "revised_first_version_auth_data",
-                    "version": "incorrect",
-                },
+            yield defer.ensureDeferred(
+                self.handler.update_version(
+                    self.local_user,
+                    version,
+                    {
+                        "algorithm": "m.megolm_backup.v1",
+                        "auth_data": "revised_first_version_auth_data",
+                        "version": "incorrect",
+                    },
+                )
             )
         except errors.SynapseError as e:
             res = e.code
@@ -260,7 +295,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         """
         res = None
         try:
-            yield self.handler.delete_version(self.local_user, "1")
+            yield defer.ensureDeferred(
+                self.handler.delete_version(self.local_user, "1")
+            )
         except errors.SynapseError as e:
             res = e.code
         self.assertEqual(res, 404)
@@ -271,7 +308,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         """
         res = None
         try:
-            yield self.handler.delete_version(self.local_user)
+            yield defer.ensureDeferred(self.handler.delete_version(self.local_user))
         except errors.SynapseError as e:
             res = e.code
         self.assertEqual(res, 404)
@@ -280,19 +317,26 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
     def test_delete_version(self):
         """Check that we can create and then delete versions.
         """
-        res = yield self.handler.create_version(
-            self.local_user,
-            {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+        res = yield defer.ensureDeferred(
+            self.handler.create_version(
+                self.local_user,
+                {
+                    "algorithm": "m.megolm_backup.v1",
+                    "auth_data": "first_version_auth_data",
+                },
+            )
         )
         self.assertEqual(res, "1")
 
         # check we can delete it
-        yield self.handler.delete_version(self.local_user, "1")
+        yield defer.ensureDeferred(self.handler.delete_version(self.local_user, "1"))
 
         # check that it's gone
         res = None
         try:
-            yield self.handler.get_version_info(self.local_user, "1")
+            yield defer.ensureDeferred(
+                self.handler.get_version_info(self.local_user, "1")
+            )
         except errors.SynapseError as e:
             res = e.code
         self.assertEqual(res, 404)
@@ -303,7 +347,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         """
         res = None
         try:
-            yield self.handler.get_room_keys(self.local_user, "bogus_version")
+            yield defer.ensureDeferred(
+                self.handler.get_room_keys(self.local_user, "bogus_version")
+            )
         except errors.SynapseError as e:
             res = e.code
         self.assertEqual(res, 404)
@@ -312,13 +358,20 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
     def test_get_missing_room_keys(self):
         """Check we get an empty response from an empty backup
         """
-        version = yield self.handler.create_version(
-            self.local_user,
-            {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+        version = yield defer.ensureDeferred(
+            self.handler.create_version(
+                self.local_user,
+                {
+                    "algorithm": "m.megolm_backup.v1",
+                    "auth_data": "first_version_auth_data",
+                },
+            )
         )
         self.assertEqual(version, "1")
 
-        res = yield self.handler.get_room_keys(self.local_user, version)
+        res = yield defer.ensureDeferred(
+            self.handler.get_room_keys(self.local_user, version)
+        )
         self.assertDictEqual(res, {"rooms": {}})
 
     # TODO: test the locking semantics when uploading room_keys,
@@ -330,8 +383,8 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         """
         res = None
         try:
-            yield self.handler.upload_room_keys(
-                self.local_user, "no_version", room_keys
+            yield defer.ensureDeferred(
+                self.handler.upload_room_keys(self.local_user, "no_version", room_keys)
             )
         except errors.SynapseError as e:
             res = e.code
@@ -342,16 +395,23 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         """Check that we get a 404 on uploading keys when an nonexistent version
         is specified
         """
-        version = yield self.handler.create_version(
-            self.local_user,
-            {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+        version = yield defer.ensureDeferred(
+            self.handler.create_version(
+                self.local_user,
+                {
+                    "algorithm": "m.megolm_backup.v1",
+                    "auth_data": "first_version_auth_data",
+                },
+            )
         )
         self.assertEqual(version, "1")
 
         res = None
         try:
-            yield self.handler.upload_room_keys(
-                self.local_user, "bogus_version", room_keys
+            yield defer.ensureDeferred(
+                self.handler.upload_room_keys(
+                    self.local_user, "bogus_version", room_keys
+                )
             )
         except errors.SynapseError as e:
             res = e.code
@@ -361,24 +421,33 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
     def test_upload_room_keys_wrong_version(self):
         """Check that we get a 403 on uploading keys for an old version
         """
-        version = yield self.handler.create_version(
-            self.local_user,
-            {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+        version = yield defer.ensureDeferred(
+            self.handler.create_version(
+                self.local_user,
+                {
+                    "algorithm": "m.megolm_backup.v1",
+                    "auth_data": "first_version_auth_data",
+                },
+            )
         )
         self.assertEqual(version, "1")
 
-        version = yield self.handler.create_version(
-            self.local_user,
-            {
-                "algorithm": "m.megolm_backup.v1",
-                "auth_data": "second_version_auth_data",
-            },
+        version = yield defer.ensureDeferred(
+            self.handler.create_version(
+                self.local_user,
+                {
+                    "algorithm": "m.megolm_backup.v1",
+                    "auth_data": "second_version_auth_data",
+                },
+            )
         )
         self.assertEqual(version, "2")
 
         res = None
         try:
-            yield self.handler.upload_room_keys(self.local_user, "1", room_keys)
+            yield defer.ensureDeferred(
+                self.handler.upload_room_keys(self.local_user, "1", room_keys)
+            )
         except errors.SynapseError as e:
             res = e.code
         self.assertEqual(res, 403)
@@ -387,26 +456,39 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
     def test_upload_room_keys_insert(self):
         """Check that we can insert and retrieve keys for a session
         """
-        version = yield self.handler.create_version(
-            self.local_user,
-            {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+        version = yield defer.ensureDeferred(
+            self.handler.create_version(
+                self.local_user,
+                {
+                    "algorithm": "m.megolm_backup.v1",
+                    "auth_data": "first_version_auth_data",
+                },
+            )
         )
         self.assertEqual(version, "1")
 
-        yield self.handler.upload_room_keys(self.local_user, version, room_keys)
+        yield defer.ensureDeferred(
+            self.handler.upload_room_keys(self.local_user, version, room_keys)
+        )
 
-        res = yield self.handler.get_room_keys(self.local_user, version)
+        res = yield defer.ensureDeferred(
+            self.handler.get_room_keys(self.local_user, version)
+        )
         self.assertDictEqual(res, room_keys)
 
         # check getting room_keys for a given room
-        res = yield self.handler.get_room_keys(
-            self.local_user, version, room_id="!abc:matrix.org"
+        res = yield defer.ensureDeferred(
+            self.handler.get_room_keys(
+                self.local_user, version, room_id="!abc:matrix.org"
+            )
         )
         self.assertDictEqual(res, room_keys)
 
         # check getting room_keys for a given session_id
-        res = yield self.handler.get_room_keys(
-            self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+        res = yield defer.ensureDeferred(
+            self.handler.get_room_keys(
+                self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+            )
         )
         self.assertDictEqual(res, room_keys)
 
@@ -414,16 +496,23 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
     def test_upload_room_keys_merge(self):
         """Check that we can upload a new room_key for an existing session and
         have it correctly merged"""
-        version = yield self.handler.create_version(
-            self.local_user,
-            {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+        version = yield defer.ensureDeferred(
+            self.handler.create_version(
+                self.local_user,
+                {
+                    "algorithm": "m.megolm_backup.v1",
+                    "auth_data": "first_version_auth_data",
+                },
+            )
         )
         self.assertEqual(version, "1")
 
-        yield self.handler.upload_room_keys(self.local_user, version, room_keys)
+        yield defer.ensureDeferred(
+            self.handler.upload_room_keys(self.local_user, version, room_keys)
+        )
 
         # get the etag to compare to future versions
-        res = yield self.handler.get_version_info(self.local_user)
+        res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
         backup_etag = res["etag"]
         self.assertEqual(res["count"], 1)
 
@@ -433,29 +522,37 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         # test that increasing the message_index doesn't replace the existing session
         new_room_key["first_message_index"] = 2
         new_room_key["session_data"] = "new"
-        yield self.handler.upload_room_keys(self.local_user, version, new_room_keys)
+        yield defer.ensureDeferred(
+            self.handler.upload_room_keys(self.local_user, version, new_room_keys)
+        )
 
-        res = yield self.handler.get_room_keys(self.local_user, version)
+        res = yield defer.ensureDeferred(
+            self.handler.get_room_keys(self.local_user, version)
+        )
         self.assertEqual(
             res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"],
             "SSBBTSBBIEZJU0gK",
         )
 
         # the etag should be the same since the session did not change
-        res = yield self.handler.get_version_info(self.local_user)
+        res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
         self.assertEqual(res["etag"], backup_etag)
 
         # test that marking the session as verified however /does/ replace it
         new_room_key["is_verified"] = True
-        yield self.handler.upload_room_keys(self.local_user, version, new_room_keys)
+        yield defer.ensureDeferred(
+            self.handler.upload_room_keys(self.local_user, version, new_room_keys)
+        )
 
-        res = yield self.handler.get_room_keys(self.local_user, version)
+        res = yield defer.ensureDeferred(
+            self.handler.get_room_keys(self.local_user, version)
+        )
         self.assertEqual(
             res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new"
         )
 
         # the etag should NOT be equal now, since the key changed
-        res = yield self.handler.get_version_info(self.local_user)
+        res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
         self.assertNotEqual(res["etag"], backup_etag)
         backup_etag = res["etag"]
 
@@ -463,15 +560,19 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         # with a lower forwarding count
         new_room_key["forwarded_count"] = 2
         new_room_key["session_data"] = "other"
-        yield self.handler.upload_room_keys(self.local_user, version, new_room_keys)
+        yield defer.ensureDeferred(
+            self.handler.upload_room_keys(self.local_user, version, new_room_keys)
+        )
 
-        res = yield self.handler.get_room_keys(self.local_user, version)
+        res = yield defer.ensureDeferred(
+            self.handler.get_room_keys(self.local_user, version)
+        )
         self.assertEqual(
             res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new"
         )
 
         # the etag should be the same since the session did not change
-        res = yield self.handler.get_version_info(self.local_user)
+        res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
         self.assertEqual(res["etag"], backup_etag)
 
         # TODO: check edge cases as well as the common variations here
@@ -480,36 +581,59 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
     def test_delete_room_keys(self):
         """Check that we can insert and delete keys for a session
         """
-        version = yield self.handler.create_version(
-            self.local_user,
-            {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
+        version = yield defer.ensureDeferred(
+            self.handler.create_version(
+                self.local_user,
+                {
+                    "algorithm": "m.megolm_backup.v1",
+                    "auth_data": "first_version_auth_data",
+                },
+            )
         )
         self.assertEqual(version, "1")
 
         # check for bulk-delete
-        yield self.handler.upload_room_keys(self.local_user, version, room_keys)
-        yield self.handler.delete_room_keys(self.local_user, version)
-        res = yield self.handler.get_room_keys(
-            self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+        yield defer.ensureDeferred(
+            self.handler.upload_room_keys(self.local_user, version, room_keys)
+        )
+        yield defer.ensureDeferred(
+            self.handler.delete_room_keys(self.local_user, version)
+        )
+        res = yield defer.ensureDeferred(
+            self.handler.get_room_keys(
+                self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+            )
         )
         self.assertDictEqual(res, {"rooms": {}})
 
         # check for bulk-delete per room
-        yield self.handler.upload_room_keys(self.local_user, version, room_keys)
-        yield self.handler.delete_room_keys(
-            self.local_user, version, room_id="!abc:matrix.org"
+        yield defer.ensureDeferred(
+            self.handler.upload_room_keys(self.local_user, version, room_keys)
+        )
+        yield defer.ensureDeferred(
+            self.handler.delete_room_keys(
+                self.local_user, version, room_id="!abc:matrix.org"
+            )
         )
-        res = yield self.handler.get_room_keys(
-            self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+        res = yield defer.ensureDeferred(
+            self.handler.get_room_keys(
+                self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+            )
         )
         self.assertDictEqual(res, {"rooms": {}})
 
         # check for bulk-delete per session
-        yield self.handler.upload_room_keys(self.local_user, version, room_keys)
-        yield self.handler.delete_room_keys(
-            self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+        yield defer.ensureDeferred(
+            self.handler.upload_room_keys(self.local_user, version, room_keys)
+        )
+        yield defer.ensureDeferred(
+            self.handler.delete_room_keys(
+                self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+            )
         )
-        res = yield self.handler.get_room_keys(
-            self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+        res = yield defer.ensureDeferred(
+            self.handler.get_room_keys(
+                self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
+            )
         )
         self.assertDictEqual(res, {"rooms": {}})
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 1bb25ab684..89ec5fcb31 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -75,7 +75,17 @@ COMMON_CONFIG = {
 COOKIE_NAME = b"oidc_session"
 COOKIE_PATH = "/_synapse/oidc"
 
-MockedMappingProvider = Mock(OidcMappingProvider)
+
+class TestMappingProvider(OidcMappingProvider):
+    @staticmethod
+    def parse_config(config):
+        return
+
+    def get_remote_user_id(self, userinfo):
+        return userinfo["sub"]
+
+    async def map_user_attributes(self, userinfo, token):
+        return {"localpart": userinfo["username"], "display_name": None}
 
 
 def simple_async_mock(return_value=None, raises=None):
@@ -123,7 +133,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         oidc_config["issuer"] = ISSUER
         oidc_config["scopes"] = SCOPES
         oidc_config["user_mapping_provider"] = {
-            "module": __name__ + ".MockedMappingProvider"
+            "module": __name__ + ".TestMappingProvider",
         }
         config["oidc_config"] = oidc_config
 
@@ -374,12 +384,16 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
         self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
         self.handler._auth_handler.complete_sso_login = simple_async_mock()
-        request = Mock(spec=["args", "getCookie", "addCookie"])
+        request = Mock(
+            spec=["args", "getCookie", "addCookie", "requestHeaders", "getClientIP"]
+        )
 
         code = "code"
         state = "state"
         nonce = "nonce"
         client_redirect_url = "http://client/redirect"
+        user_agent = "Browser"
+        ip_address = "10.0.0.1"
         session = self.handler._generate_oidc_session_token(
             state=state,
             nonce=nonce,
@@ -392,6 +406,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
         request.args[b"code"] = [code.encode("utf-8")]
         request.args[b"state"] = [state.encode("utf-8")]
 
+        request.requestHeaders = Mock(spec=["getRawHeaders"])
+        request.requestHeaders.getRawHeaders.return_value = [user_agent.encode("ascii")]
+        request.getClientIP.return_value = ip_address
+
         yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
 
         self.handler._auth_handler.complete_sso_login.assert_called_once_with(
@@ -399,7 +417,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
         )
         self.handler._exchange_code.assert_called_once_with(code)
         self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce)
-        self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token)
+        self.handler._map_userinfo_to_user.assert_called_once_with(
+            userinfo, token, user_agent, ip_address
+        )
         self.handler._fetch_userinfo.assert_not_called()
         self.handler._render_error.assert_not_called()
 
@@ -431,7 +451,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
         )
         self.handler._exchange_code.assert_called_once_with(code)
         self.handler._parse_id_token.assert_not_called()
-        self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token)
+        self.handler._map_userinfo_to_user.assert_called_once_with(
+            userinfo, token, user_agent, ip_address
+        )
         self.handler._fetch_userinfo.assert_called_once_with(token)
         self.handler._render_error.assert_not_called()
 
@@ -568,3 +590,30 @@ class OidcHandlerTestCase(HomeserverTestCase):
         with self.assertRaises(OidcError) as exc:
             yield defer.ensureDeferred(self.handler._exchange_code(code))
         self.assertEqual(exc.exception.error, "some_error")
+
+    def test_map_userinfo_to_user(self):
+        """Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
+        userinfo = {
+            "sub": "test_user",
+            "username": "test_user",
+        }
+        # The token doesn't matter with the default user mapping provider.
+        token = {}
+        mxid = self.get_success(
+            self.handler._map_userinfo_to_user(
+                userinfo, token, "user-agent", "10.10.10.10"
+            )
+        )
+        self.assertEqual(mxid, "@test_user:test")
+
+        # Some providers return an integer ID.
+        userinfo = {
+            "sub": 1234,
+            "username": "test_user_2",
+        }
+        mxid = self.get_success(
+            self.handler._map_userinfo_to_user(
+                userinfo, token, "user-agent", "10.10.10.10"
+            )
+        )
+        self.assertEqual(mxid, "@test_user_2:test")
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 05ea40a7de..306dcfe944 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -19,6 +19,7 @@ from mock import Mock, call
 from signedjson.key import generate_signing_key
 
 from synapse.api.constants import EventTypes, Membership, PresenceState
+from synapse.api.presence import UserPresenceState
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.events.builder import EventBuilder
 from synapse.handlers.presence import (
@@ -32,7 +33,6 @@ from synapse.handlers.presence import (
     handle_update,
 )
 from synapse.rest.client.v1 import room
-from synapse.storage.presence import UserPresenceState
 from synapse.types import UserID, get_domain_from_id
 
 from tests import unittest
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 29dd7d9c6e..8e95e53d9e 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -24,10 +24,11 @@ from synapse.handlers.profile import MasterProfileHandler
 from synapse.types import UserID
 
 from tests import unittest
+from tests.test_utils import make_awaitable
 from tests.utils import setup_test_homeserver
 
 
-class ProfileHandlers(object):
+class ProfileHandlers:
     def __init__(self, hs):
         self.profile_handler = MasterProfileHandler(hs)
 
@@ -63,16 +64,20 @@ class ProfileTestCase(unittest.TestCase):
         self.bob = UserID.from_string("@4567:test")
         self.alice = UserID.from_string("@alice:remote")
 
-        yield self.store.create_profile(self.frank.localpart)
+        yield defer.ensureDeferred(self.store.create_profile(self.frank.localpart))
 
         self.handler = hs.get_profile_handler()
         self.hs = hs
 
     @defer.inlineCallbacks
     def test_get_my_name(self):
-        yield self.store.set_profile_displayname(self.frank.localpart, "Frank")
+        yield defer.ensureDeferred(
+            self.store.set_profile_displayname(self.frank.localpart, "Frank")
+        )
 
-        displayname = yield self.handler.get_displayname(self.frank)
+        displayname = yield defer.ensureDeferred(
+            self.handler.get_displayname(self.frank)
+        )
 
         self.assertEquals("Frank", displayname)
 
@@ -101,7 +106,12 @@ class ProfileTestCase(unittest.TestCase):
         )
 
         self.assertEquals(
-            (yield self.store.get_profile_displayname(self.frank.localpart)), "Frank",
+            (
+                yield defer.ensureDeferred(
+                    self.store.get_profile_displayname(self.frank.localpart)
+                )
+            ),
+            "Frank",
         )
 
     @defer.inlineCallbacks
@@ -109,10 +119,17 @@ class ProfileTestCase(unittest.TestCase):
         self.hs.config.enable_set_displayname = False
 
         # Setting displayname for the first time is allowed
-        yield self.store.set_profile_displayname(self.frank.localpart, "Frank")
+        yield defer.ensureDeferred(
+            self.store.set_profile_displayname(self.frank.localpart, "Frank")
+        )
 
         self.assertEquals(
-            (yield self.store.get_profile_displayname(self.frank.localpart)), "Frank",
+            (
+                yield defer.ensureDeferred(
+                    self.store.get_profile_displayname(self.frank.localpart)
+                )
+            ),
+            "Frank",
         )
 
         # Setting displayname a second time is forbidden
@@ -136,11 +153,13 @@ class ProfileTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_get_other_name(self):
-        self.mock_federation.make_query.return_value = defer.succeed(
+        self.mock_federation.make_query.return_value = make_awaitable(
             {"displayname": "Alice"}
         )
 
-        displayname = yield self.handler.get_displayname(self.alice)
+        displayname = yield defer.ensureDeferred(
+            self.handler.get_displayname(self.alice)
+        )
 
         self.assertEquals(displayname, "Alice")
         self.mock_federation.make_query.assert_called_with(
@@ -152,22 +171,27 @@ class ProfileTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_incoming_fed_query(self):
-        yield self.store.create_profile("caroline")
-        yield self.store.set_profile_displayname("caroline", "Caroline")
+        yield defer.ensureDeferred(self.store.create_profile("caroline"))
+        yield defer.ensureDeferred(
+            self.store.set_profile_displayname("caroline", "Caroline")
+        )
 
-        response = yield self.query_handlers["profile"](
-            {"user_id": "@caroline:test", "field": "displayname"}
+        response = yield defer.ensureDeferred(
+            self.query_handlers["profile"](
+                {"user_id": "@caroline:test", "field": "displayname"}
+            )
         )
 
         self.assertEquals({"displayname": "Caroline"}, response)
 
     @defer.inlineCallbacks
     def test_get_my_avatar(self):
-        yield self.store.set_profile_avatar_url(
-            self.frank.localpart, "http://my.server/me.png"
+        yield defer.ensureDeferred(
+            self.store.set_profile_avatar_url(
+                self.frank.localpart, "http://my.server/me.png"
+            )
         )
-
-        avatar_url = yield self.handler.get_avatar_url(self.frank)
+        avatar_url = yield defer.ensureDeferred(self.handler.get_avatar_url(self.frank))
 
         self.assertEquals("http://my.server/me.png", avatar_url)
 
@@ -182,7 +206,11 @@ class ProfileTestCase(unittest.TestCase):
         )
 
         self.assertEquals(
-            (yield self.store.get_profile_avatar_url(self.frank.localpart)),
+            (
+                yield defer.ensureDeferred(
+                    self.store.get_profile_avatar_url(self.frank.localpart)
+                )
+            ),
             "http://my.server/pic.gif",
         )
 
@@ -196,7 +224,11 @@ class ProfileTestCase(unittest.TestCase):
         )
 
         self.assertEquals(
-            (yield self.store.get_profile_avatar_url(self.frank.localpart)),
+            (
+                yield defer.ensureDeferred(
+                    self.store.get_profile_avatar_url(self.frank.localpart)
+                )
+            ),
             "http://my.server/me.png",
         )
 
@@ -205,12 +237,18 @@ class ProfileTestCase(unittest.TestCase):
         self.hs.config.enable_set_avatar_url = False
 
         # Setting displayname for the first time is allowed
-        yield self.store.set_profile_avatar_url(
-            self.frank.localpart, "http://my.server/me.png"
+        yield defer.ensureDeferred(
+            self.store.set_profile_avatar_url(
+                self.frank.localpart, "http://my.server/me.png"
+            )
         )
 
         self.assertEquals(
-            (yield self.store.get_profile_avatar_url(self.frank.localpart)),
+            (
+                yield defer.ensureDeferred(
+                    self.store.get_profile_avatar_url(self.frank.localpart)
+                )
+            ),
             "http://my.server/me.png",
         )
 
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index ca32f993a3..eddf5e2498 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -15,17 +15,21 @@
 
 from mock import Mock
 
-from twisted.internet import defer
-
+from synapse.api.auth import Auth
 from synapse.api.constants import UserTypes
 from synapse.api.errors import Codes, ResourceLimitError, SynapseError
 from synapse.handlers.register import RegistrationHandler
+from synapse.spam_checker_api import RegistrationBehaviour
 from synapse.types import RoomAlias, UserID, create_requester
 
+from tests.test_utils import make_awaitable
+from tests.unittest import override_config
+from tests.utils import mock_getRawHeaders
+
 from .. import unittest
 
 
-class RegistrationHandlers(object):
+class RegistrationHandlers:
     def __init__(self, hs):
         self.registration_handler = RegistrationHandler(hs)
 
@@ -96,7 +100,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
     def test_get_or_create_user_mau_not_blocked(self):
         self.hs.config.limit_usage_by_mau = True
         self.store.count_monthly_users = Mock(
-            return_value=defer.succeed(self.hs.config.max_mau_value - 1)
+            side_effect=lambda: make_awaitable(self.hs.config.max_mau_value - 1)
         )
         # Ensure does not throw exception
         self.get_success(self.get_or_create_user(self.requester, "c", "User"))
@@ -104,7 +108,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
     def test_get_or_create_user_mau_blocked(self):
         self.hs.config.limit_usage_by_mau = True
         self.store.get_monthly_active_count = Mock(
-            return_value=defer.succeed(self.lots_of_users)
+            side_effect=lambda: make_awaitable(self.lots_of_users)
         )
         self.get_failure(
             self.get_or_create_user(self.requester, "b", "display_name"),
@@ -112,7 +116,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
         )
 
         self.store.get_monthly_active_count = Mock(
-            return_value=defer.succeed(self.hs.config.max_mau_value)
+            side_effect=lambda: make_awaitable(self.hs.config.max_mau_value)
         )
         self.get_failure(
             self.get_or_create_user(self.requester, "b", "display_name"),
@@ -122,14 +126,14 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
     def test_register_mau_blocked(self):
         self.hs.config.limit_usage_by_mau = True
         self.store.get_monthly_active_count = Mock(
-            return_value=defer.succeed(self.lots_of_users)
+            side_effect=lambda: make_awaitable(self.lots_of_users)
         )
         self.get_failure(
             self.handler.register_user(localpart="local_part"), ResourceLimitError
         )
 
         self.store.get_monthly_active_count = Mock(
-            return_value=defer.succeed(self.hs.config.max_mau_value)
+            side_effect=lambda: make_awaitable(self.hs.config.max_mau_value)
         )
         self.get_failure(
             self.handler.register_user(localpart="local_part"), ResourceLimitError
@@ -145,9 +149,9 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
         rooms = self.get_success(self.store.get_rooms_for_user(user_id))
         self.assertEqual(len(rooms), 0)
 
+    @override_config({"auto_join_rooms": ["#room:test"]})
     def test_auto_create_auto_join_rooms(self):
         room_alias_str = "#room:test"
-        self.hs.config.auto_join_rooms = [room_alias_str]
         user_id = self.get_success(self.handler.register_user(localpart="jeff"))
         rooms = self.get_success(self.store.get_rooms_for_user(user_id))
         directory_handler = self.hs.get_handlers().directory_handler
@@ -185,7 +189,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
         room_alias_str = "#room:test"
         self.hs.config.auto_join_rooms = [room_alias_str]
 
-        self.store.is_real_user = Mock(return_value=defer.succeed(False))
+        self.store.is_real_user = Mock(return_value=make_awaitable(False))
         user_id = self.get_success(self.handler.register_user(localpart="support"))
         rooms = self.get_success(self.store.get_rooms_for_user(user_id))
         self.assertEqual(len(rooms), 0)
@@ -193,12 +197,12 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
         room_alias = RoomAlias.from_string(room_alias_str)
         self.get_failure(directory_handler.get_association(room_alias), SynapseError)
 
+    @override_config({"auto_join_rooms": ["#room:test"]})
     def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self):
         room_alias_str = "#room:test"
-        self.hs.config.auto_join_rooms = [room_alias_str]
 
-        self.store.count_real_users = Mock(return_value=defer.succeed(1))
-        self.store.is_real_user = Mock(return_value=defer.succeed(True))
+        self.store.count_real_users = Mock(return_value=make_awaitable(1))
+        self.store.is_real_user = Mock(return_value=make_awaitable(True))
         user_id = self.get_success(self.handler.register_user(localpart="real"))
         rooms = self.get_success(self.store.get_rooms_for_user(user_id))
         directory_handler = self.hs.get_handlers().directory_handler
@@ -212,12 +216,218 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
         room_alias_str = "#room:test"
         self.hs.config.auto_join_rooms = [room_alias_str]
 
-        self.store.count_real_users = Mock(return_value=defer.succeed(2))
-        self.store.is_real_user = Mock(return_value=defer.succeed(True))
+        self.store.count_real_users = Mock(return_value=make_awaitable(2))
+        self.store.is_real_user = Mock(return_value=make_awaitable(True))
         user_id = self.get_success(self.handler.register_user(localpart="real"))
         rooms = self.get_success(self.store.get_rooms_for_user(user_id))
         self.assertEqual(len(rooms), 0)
 
+    @override_config(
+        {
+            "auto_join_rooms": ["#room:test"],
+            "autocreate_auto_join_rooms_federated": False,
+        }
+    )
+    def test_auto_create_auto_join_rooms_federated(self):
+        """
+        Auto-created rooms that are private require an invite to go to the user
+        (instead of directly joining it).
+        """
+        room_alias_str = "#room:test"
+        user_id = self.get_success(self.handler.register_user(localpart="jeff"))
+
+        # Ensure the room was created.
+        directory_handler = self.hs.get_handlers().directory_handler
+        room_alias = RoomAlias.from_string(room_alias_str)
+        room_id = self.get_success(directory_handler.get_association(room_alias))
+
+        # Ensure the room is properly not federated.
+        room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
+        self.assertFalse(room["federatable"])
+        self.assertFalse(room["public"])
+        self.assertEqual(room["join_rules"], "public")
+        self.assertIsNone(room["guest_access"])
+
+        # The user should be in the room.
+        rooms = self.get_success(self.store.get_rooms_for_user(user_id))
+        self.assertIn(room_id["room_id"], rooms)
+
+    @override_config(
+        {"auto_join_rooms": ["#room:test"], "auto_join_mxid_localpart": "support"}
+    )
+    def test_auto_join_mxid_localpart(self):
+        """
+        Ensure the user still needs up in the room created by a different user.
+        """
+        # Ensure the support user exists.
+        inviter = "@support:test"
+
+        room_alias_str = "#room:test"
+        user_id = self.get_success(self.handler.register_user(localpart="jeff"))
+
+        # Ensure the room was created.
+        directory_handler = self.hs.get_handlers().directory_handler
+        room_alias = RoomAlias.from_string(room_alias_str)
+        room_id = self.get_success(directory_handler.get_association(room_alias))
+
+        # Ensure the room is properly a public room.
+        room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
+        self.assertEqual(room["join_rules"], "public")
+
+        # Both users should be in the room.
+        rooms = self.get_success(self.store.get_rooms_for_user(inviter))
+        self.assertIn(room_id["room_id"], rooms)
+        rooms = self.get_success(self.store.get_rooms_for_user(user_id))
+        self.assertIn(room_id["room_id"], rooms)
+
+        # Register a second user, which should also end up in the room.
+        user_id = self.get_success(self.handler.register_user(localpart="bob"))
+        rooms = self.get_success(self.store.get_rooms_for_user(user_id))
+        self.assertIn(room_id["room_id"], rooms)
+
+    @override_config(
+        {
+            "auto_join_rooms": ["#room:test"],
+            "autocreate_auto_join_room_preset": "private_chat",
+            "auto_join_mxid_localpart": "support",
+        }
+    )
+    def test_auto_create_auto_join_room_preset(self):
+        """
+        Auto-created rooms that are private require an invite to go to the user
+        (instead of directly joining it).
+        """
+        # Ensure the support user exists.
+        inviter = "@support:test"
+
+        room_alias_str = "#room:test"
+        user_id = self.get_success(self.handler.register_user(localpart="jeff"))
+
+        # Ensure the room was created.
+        directory_handler = self.hs.get_handlers().directory_handler
+        room_alias = RoomAlias.from_string(room_alias_str)
+        room_id = self.get_success(directory_handler.get_association(room_alias))
+
+        # Ensure the room is properly a private room.
+        room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
+        self.assertFalse(room["public"])
+        self.assertEqual(room["join_rules"], "invite")
+        self.assertEqual(room["guest_access"], "can_join")
+
+        # Both users should be in the room.
+        rooms = self.get_success(self.store.get_rooms_for_user(inviter))
+        self.assertIn(room_id["room_id"], rooms)
+        rooms = self.get_success(self.store.get_rooms_for_user(user_id))
+        self.assertIn(room_id["room_id"], rooms)
+
+        # Register a second user, which should also end up in the room.
+        user_id = self.get_success(self.handler.register_user(localpart="bob"))
+        rooms = self.get_success(self.store.get_rooms_for_user(user_id))
+        self.assertIn(room_id["room_id"], rooms)
+
+    @override_config(
+        {
+            "auto_join_rooms": ["#room:test"],
+            "autocreate_auto_join_room_preset": "private_chat",
+            "auto_join_mxid_localpart": "support",
+        }
+    )
+    def test_auto_create_auto_join_room_preset_guest(self):
+        """
+        Auto-created rooms that are private require an invite to go to the user
+        (instead of directly joining it).
+
+        This should also work for guests.
+        """
+        inviter = "@support:test"
+
+        room_alias_str = "#room:test"
+        user_id = self.get_success(
+            self.handler.register_user(localpart="jeff", make_guest=True)
+        )
+
+        # Ensure the room was created.
+        directory_handler = self.hs.get_handlers().directory_handler
+        room_alias = RoomAlias.from_string(room_alias_str)
+        room_id = self.get_success(directory_handler.get_association(room_alias))
+
+        # Ensure the room is properly a private room.
+        room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
+        self.assertFalse(room["public"])
+        self.assertEqual(room["join_rules"], "invite")
+        self.assertEqual(room["guest_access"], "can_join")
+
+        # Both users should be in the room.
+        rooms = self.get_success(self.store.get_rooms_for_user(inviter))
+        self.assertIn(room_id["room_id"], rooms)
+        rooms = self.get_success(self.store.get_rooms_for_user(user_id))
+        self.assertIn(room_id["room_id"], rooms)
+
+    @override_config(
+        {
+            "auto_join_rooms": ["#room:test"],
+            "autocreate_auto_join_room_preset": "private_chat",
+            "auto_join_mxid_localpart": "support",
+        }
+    )
+    def test_auto_create_auto_join_room_preset_invalid_permissions(self):
+        """
+        Auto-created rooms that are private require an invite, check that
+        registration doesn't completely break if the inviter doesn't have proper
+        permissions.
+        """
+        inviter = "@support:test"
+
+        # Register an initial user to create the room and such (essentially this
+        # is a subset of test_auto_create_auto_join_room_preset).
+        room_alias_str = "#room:test"
+        user_id = self.get_success(self.handler.register_user(localpart="jeff"))
+
+        # Ensure the room was created.
+        directory_handler = self.hs.get_handlers().directory_handler
+        room_alias = RoomAlias.from_string(room_alias_str)
+        room_id = self.get_success(directory_handler.get_association(room_alias))
+
+        # Ensure the room exists.
+        self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
+
+        # Both users should be in the room.
+        rooms = self.get_success(self.store.get_rooms_for_user(inviter))
+        self.assertIn(room_id["room_id"], rooms)
+        rooms = self.get_success(self.store.get_rooms_for_user(user_id))
+        self.assertIn(room_id["room_id"], rooms)
+
+        # Lower the permissions of the inviter.
+        event_creation_handler = self.hs.get_event_creation_handler()
+        requester = create_requester(inviter)
+        event, context = self.get_success(
+            event_creation_handler.create_event(
+                requester,
+                {
+                    "type": "m.room.power_levels",
+                    "state_key": "",
+                    "room_id": room_id["room_id"],
+                    "content": {"invite": 100, "users": {inviter: 0}},
+                    "sender": inviter,
+                },
+            )
+        )
+        self.get_success(
+            event_creation_handler.send_nonmember_event(requester, event, context)
+        )
+
+        # Register a second user, which won't be be in the room (or even have an invite)
+        # since the inviter no longer has the proper permissions.
+        user_id = self.get_success(self.handler.register_user(localpart="bob"))
+
+        # This user should not be in any rooms.
+        rooms = self.get_success(self.store.get_rooms_for_user(user_id))
+        invited_rooms = self.get_success(
+            self.store.get_invited_rooms_for_local_user(user_id)
+        )
+        self.assertEqual(rooms, set())
+        self.assertEqual(invited_rooms, [])
+
     def test_auto_create_auto_join_where_no_consent(self):
         """Test to ensure that the first user is not auto-joined to a room if
         they have not given general consent.
@@ -266,6 +476,53 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
             self.handler.register_user(localpart=invalid_user_id), SynapseError
         )
 
+    def test_spam_checker_deny(self):
+        """A spam checker can deny registration, which results in an error."""
+
+        class DenyAll:
+            def check_registration_for_spam(
+                self, email_threepid, username, request_info
+            ):
+                return RegistrationBehaviour.DENY
+
+        # Configure a spam checker that denies all users.
+        spam_checker = self.hs.get_spam_checker()
+        spam_checker.spam_checkers = [DenyAll()]
+
+        self.get_failure(self.handler.register_user(localpart="user"), SynapseError)
+
+    def test_spam_checker_shadow_ban(self):
+        """A spam checker can choose to shadow-ban a user, which allows registration to succeed."""
+
+        class BanAll:
+            def check_registration_for_spam(
+                self, email_threepid, username, request_info
+            ):
+                return RegistrationBehaviour.SHADOW_BAN
+
+        # Configure a spam checker that denies all users.
+        spam_checker = self.hs.get_spam_checker()
+        spam_checker.spam_checkers = [BanAll()]
+
+        user_id = self.get_success(self.handler.register_user(localpart="user"))
+
+        # Get an access token.
+        token = self.macaroon_generator.generate_access_token(user_id)
+        self.get_success(
+            self.store.add_access_token_to_user(
+                user_id=user_id, token=token, device_id=None, valid_until_ms=None
+            )
+        )
+
+        # Ensure the user was marked as shadow-banned.
+        request = Mock(args={})
+        request.args[b"access_token"] = [token.encode("ascii")]
+        request.requestHeaders.getRawHeaders = mock_getRawHeaders()
+        auth = Auth(self.hs)
+        requester = self.get_success(auth.get_user_by_req(request))
+
+        self.assertTrue(requester.shadow_banned)
+
     async def get_or_create_user(
         self, requester, localpart, displayname, password_hash=None
     ):
diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py
index d9d312f0fb..a609f148c0 100644
--- a/tests/handlers/test_stats.py
+++ b/tests/handlers/test_stats.py
@@ -15,7 +15,7 @@
 
 from synapse.rest import admin
 from synapse.rest.client.v1 import login, room
-from synapse.storage.data_stores.main import stats
+from synapse.storage.databases.main import stats
 
 from tests import unittest
 
@@ -42,36 +42,36 @@ class StatsRoomTests(unittest.HomeserverTestCase):
         Add the background updates we need to run.
         """
         # Ugh, have to reset this flag
-        self.store.db.updates._all_done = False
+        self.store.db_pool.updates._all_done = False
 
         self.get_success(
-            self.store.db.simple_insert(
+            self.store.db_pool.simple_insert(
                 "background_updates",
                 {"update_name": "populate_stats_prepare", "progress_json": "{}"},
             )
         )
         self.get_success(
-            self.store.db.simple_insert(
+            self.store.db_pool.simple_insert(
                 "background_updates",
                 {
-                    "update_name": "populate_stats_process_rooms",
+                    "update_name": "populate_stats_process_rooms_2",
                     "progress_json": "{}",
                     "depends_on": "populate_stats_prepare",
                 },
             )
         )
         self.get_success(
-            self.store.db.simple_insert(
+            self.store.db_pool.simple_insert(
                 "background_updates",
                 {
                     "update_name": "populate_stats_process_users",
                     "progress_json": "{}",
-                    "depends_on": "populate_stats_process_rooms",
+                    "depends_on": "populate_stats_process_rooms_2",
                 },
             )
         )
         self.get_success(
-            self.store.db.simple_insert(
+            self.store.db_pool.simple_insert(
                 "background_updates",
                 {
                     "update_name": "populate_stats_cleanup",
@@ -81,8 +81,8 @@ class StatsRoomTests(unittest.HomeserverTestCase):
             )
         )
 
-    def get_all_room_state(self):
-        return self.store.db.simple_select_list(
+    async def get_all_room_state(self):
+        return await self.store.db_pool.simple_select_list(
             "room_stats_state", None, retcols=("name", "topic", "canonical_alias")
         )
 
@@ -96,7 +96,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
         end_ts = self.store.quantise_stats_time(self.reactor.seconds() * 1000)
 
         return self.get_success(
-            self.store.db.simple_select_one(
+            self.store.db_pool.simple_select_one(
                 table + "_historical",
                 {id_col: stat_id, end_ts: end_ts},
                 cols,
@@ -109,10 +109,10 @@ class StatsRoomTests(unittest.HomeserverTestCase):
         self._add_background_updates()
 
         while not self.get_success(
-            self.store.db.updates.has_completed_background_updates()
+            self.store.db_pool.updates.has_completed_background_updates()
         ):
             self.get_success(
-                self.store.db.updates.do_next_background_update(100), by=0.1
+                self.store.db_pool.updates.do_next_background_update(100), by=0.1
             )
 
     def test_initial_room(self):
@@ -146,10 +146,10 @@ class StatsRoomTests(unittest.HomeserverTestCase):
         self._add_background_updates()
 
         while not self.get_success(
-            self.store.db.updates.has_completed_background_updates()
+            self.store.db_pool.updates.has_completed_background_updates()
         ):
             self.get_success(
-                self.store.db.updates.do_next_background_update(100), by=0.1
+                self.store.db_pool.updates.do_next_background_update(100), by=0.1
             )
 
         r = self.get_success(self.get_all_room_state())
@@ -186,9 +186,9 @@ class StatsRoomTests(unittest.HomeserverTestCase):
         # the position that the deltas should begin at, once they take over.
         self.hs.config.stats_enabled = True
         self.handler.stats_enabled = True
-        self.store.db.updates._all_done = False
+        self.store.db_pool.updates._all_done = False
         self.get_success(
-            self.store.db.simple_update_one(
+            self.store.db_pool.simple_update_one(
                 table="stats_incremental_position",
                 keyvalues={},
                 updatevalues={"stream_id": 0},
@@ -196,17 +196,17 @@ class StatsRoomTests(unittest.HomeserverTestCase):
         )
 
         self.get_success(
-            self.store.db.simple_insert(
+            self.store.db_pool.simple_insert(
                 "background_updates",
                 {"update_name": "populate_stats_prepare", "progress_json": "{}"},
             )
         )
 
         while not self.get_success(
-            self.store.db.updates.has_completed_background_updates()
+            self.store.db_pool.updates.has_completed_background_updates()
         ):
             self.get_success(
-                self.store.db.updates.do_next_background_update(100), by=0.1
+                self.store.db_pool.updates.do_next_background_update(100), by=0.1
             )
 
         # Now, before the table is actually ingested, add some more events.
@@ -217,28 +217,31 @@ class StatsRoomTests(unittest.HomeserverTestCase):
 
         # Now do the initial ingestion.
         self.get_success(
-            self.store.db.simple_insert(
+            self.store.db_pool.simple_insert(
                 "background_updates",
-                {"update_name": "populate_stats_process_rooms", "progress_json": "{}"},
+                {
+                    "update_name": "populate_stats_process_rooms_2",
+                    "progress_json": "{}",
+                },
             )
         )
         self.get_success(
-            self.store.db.simple_insert(
+            self.store.db_pool.simple_insert(
                 "background_updates",
                 {
                     "update_name": "populate_stats_cleanup",
                     "progress_json": "{}",
-                    "depends_on": "populate_stats_process_rooms",
+                    "depends_on": "populate_stats_process_rooms_2",
                 },
             )
         )
 
-        self.store.db.updates._all_done = False
+        self.store.db_pool.updates._all_done = False
         while not self.get_success(
-            self.store.db.updates.has_completed_background_updates()
+            self.store.db_pool.updates.has_completed_background_updates()
         ):
             self.get_success(
-                self.store.db.updates.do_next_background_update(100), by=0.1
+                self.store.db_pool.updates.do_next_background_update(100), by=0.1
             )
 
         self.reactor.advance(86401)
@@ -253,7 +256,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
         # self.handler.notify_new_event()
 
         # We need to let the delta processor advance…
-        self.pump(10 * 60)
+        self.reactor.advance(10 * 60)
 
         # Get the slices! There should be two -- day 1, and day 2.
         r = self.get_success(self.store.get_statistics_for_subject("room", room_1, 0))
@@ -346,6 +349,37 @@ class StatsRoomTests(unittest.HomeserverTestCase):
 
         self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1)
 
+    def test_updating_profile_information_does_not_increase_joined_members_count(self):
+        """
+        Check that the joined_members count does not increase when a user changes their
+        profile information (which is done by sending another join membership event into
+        the room.
+        """
+        self._perform_background_initial_update()
+
+        # Create a user and room
+        u1 = self.register_user("u1", "pass")
+        u1token = self.login("u1", "pass")
+        r1 = self.helper.create_room_as(u1, tok=u1token)
+
+        # Get the current room stats
+        r1stats_ante = self._get_current_stats("room", r1)
+
+        # Send a profile update into the room
+        new_profile = {"displayname": "bob"}
+        self.helper.change_membership(
+            r1, u1, u1, "join", extra_data=new_profile, tok=u1token
+        )
+
+        # Get the new room stats
+        r1stats_post = self._get_current_stats("room", r1)
+
+        # Ensure that the user count did not changed
+        self.assertEqual(r1stats_post["joined_members"], r1stats_ante["joined_members"])
+        self.assertEqual(
+            r1stats_post["local_users_in_room"], r1stats_ante["local_users_in_room"]
+        )
+
     def test_send_state_event_nonoverwriting(self):
         """
         When we send a non-overwriting state event, it increments total_events AND current_state_events
@@ -669,15 +703,15 @@ class StatsRoomTests(unittest.HomeserverTestCase):
 
         # preparation stage of the initial background update
         # Ugh, have to reset this flag
-        self.store.db.updates._all_done = False
+        self.store.db_pool.updates._all_done = False
 
         self.get_success(
-            self.store.db.simple_delete(
+            self.store.db_pool.simple_delete(
                 "room_stats_current", {"1": 1}, "test_delete_stats"
             )
         )
         self.get_success(
-            self.store.db.simple_delete(
+            self.store.db_pool.simple_delete(
                 "user_stats_current", {"1": 1}, "test_delete_stats"
             )
         )
@@ -689,29 +723,29 @@ class StatsRoomTests(unittest.HomeserverTestCase):
 
         # now do the background updates
 
-        self.store.db.updates._all_done = False
+        self.store.db_pool.updates._all_done = False
         self.get_success(
-            self.store.db.simple_insert(
+            self.store.db_pool.simple_insert(
                 "background_updates",
                 {
-                    "update_name": "populate_stats_process_rooms",
+                    "update_name": "populate_stats_process_rooms_2",
                     "progress_json": "{}",
                     "depends_on": "populate_stats_prepare",
                 },
             )
         )
         self.get_success(
-            self.store.db.simple_insert(
+            self.store.db_pool.simple_insert(
                 "background_updates",
                 {
                     "update_name": "populate_stats_process_users",
                     "progress_json": "{}",
-                    "depends_on": "populate_stats_process_rooms",
+                    "depends_on": "populate_stats_process_rooms_2",
                 },
             )
         )
         self.get_success(
-            self.store.db.simple_insert(
+            self.store.db_pool.simple_insert(
                 "background_updates",
                 {
                     "update_name": "populate_stats_cleanup",
@@ -722,10 +756,10 @@ class StatsRoomTests(unittest.HomeserverTestCase):
         )
 
         while not self.get_success(
-            self.store.db.updates.has_completed_background_updates()
+            self.store.db_pool.updates.has_completed_background_updates()
         ):
             self.get_success(
-                self.store.db.updates.do_next_background_update(100), by=0.1
+                self.store.db_pool.updates.do_next_background_update(100), by=0.1
             )
 
         r1stats_complete = self._get_current_stats("room", r1)
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 2fa8d4739b..7bf15c4ba9 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -21,9 +21,10 @@ from mock import ANY, Mock, call
 from twisted.internet import defer
 
 from synapse.api.errors import AuthError
-from synapse.types import UserID
+from synapse.types import UserID, create_requester
 
 from tests import unittest
+from tests.test_utils import make_awaitable
 from tests.unittest import override_config
 from tests.utils import register_federation_servlets
 
@@ -115,7 +116,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
             retry_timings_res
         )
 
-        self.datastore.get_device_updates_by_remote.return_value = defer.succeed(
+        self.datastore.get_device_updates_by_remote.side_effect = lambda destination, from_stream_id, limit: make_awaitable(
             (0, [])
         )
 
@@ -126,9 +127,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
 
         self.room_members = []
 
-        def check_user_in_room(room_id, user_id):
+        async def check_user_in_room(room_id, user_id):
             if user_id not in [u.to_string() for u in self.room_members]:
                 raise AuthError(401, "User is not in the room")
+            return None
 
         hs.get_auth().check_user_in_room = check_user_in_room
 
@@ -137,24 +139,26 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
 
         self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room
 
-        def get_current_users_in_room(room_id):
-            return {str(u) for u in self.room_members}
+        def get_users_in_room(room_id):
+            return defer.succeed({str(u) for u in self.room_members})
 
-        hs.get_state_handler().get_current_users_in_room = get_current_users_in_room
+        self.datastore.get_users_in_room = get_users_in_room
 
-        self.datastore.get_user_directory_stream_pos.return_value = (
+        self.datastore.get_user_directory_stream_pos.side_effect = (
             # we deliberately return a non-None stream pos to avoid doing an initial_spam
-            defer.succeed(1)
+            lambda: make_awaitable(1)
         )
 
         self.datastore.get_current_state_deltas.return_value = (0, None)
 
         self.datastore.get_to_device_stream_token = lambda: 0
-        self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: defer.succeed(
+        self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: make_awaitable(
             ([], 0)
         )
-        self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None
-        self.datastore.set_received_txn_response = lambda *args, **kwargs: defer.succeed(
+        self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: make_awaitable(
+            None
+        )
+        self.datastore.set_received_txn_response = lambda *args, **kwargs: make_awaitable(
             None
         )
 
@@ -163,9 +167,12 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
 
         self.assertEquals(self.event_source.get_current_key(), 0)
 
-        self.successResultOf(
+        self.get_success(
             self.handler.started_typing(
-                target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=20000
+                target_user=U_APPLE,
+                requester=create_requester(U_APPLE),
+                room_id=ROOM_ID,
+                timeout=20000,
             )
         )
 
@@ -190,9 +197,12 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
     def test_started_typing_remote_send(self):
         self.room_members = [U_APPLE, U_ONION]
 
-        self.successResultOf(
+        self.get_success(
             self.handler.started_typing(
-                target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=20000
+                target_user=U_APPLE,
+                requester=create_requester(U_APPLE),
+                room_id=ROOM_ID,
+                timeout=20000,
             )
         )
 
@@ -265,9 +275,11 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
 
         self.assertEquals(self.event_source.get_current_key(), 0)
 
-        self.successResultOf(
+        self.get_success(
             self.handler.stopped_typing(
-                target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID
+                target_user=U_APPLE,
+                requester=create_requester(U_APPLE),
+                room_id=ROOM_ID,
             )
         )
 
@@ -305,9 +317,12 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
 
         self.assertEquals(self.event_source.get_current_key(), 0)
 
-        self.successResultOf(
+        self.get_success(
             self.handler.started_typing(
-                target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=10000
+                target_user=U_APPLE,
+                requester=create_requester(U_APPLE),
+                room_id=ROOM_ID,
+                timeout=10000,
             )
         )
 
@@ -344,9 +359,12 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
 
         # SYN-230 - see if we can still set after timeout
 
-        self.successResultOf(
+        self.get_success(
             self.handler.started_typing(
-                target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=10000
+                target_user=U_APPLE,
+                requester=create_requester(U_APPLE),
+                room_id=ROOM_ID,
+                timeout=10000,
             )
         )
 
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index c15bce5bef..87be94111f 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -17,12 +17,13 @@ from mock import Mock
 from twisted.internet import defer
 
 import synapse.rest.admin
-from synapse.api.constants import UserTypes
+from synapse.api.constants import EventTypes, RoomEncryptionAlgorithms, UserTypes
 from synapse.rest.client.v1 import login, room
 from synapse.rest.client.v2_alpha import user_directory
 from synapse.storage.roommember import ProfileInfo
 
 from tests import unittest
+from tests.unittest import override_config
 
 
 class UserDirectoryTestCase(unittest.HomeserverTestCase):
@@ -147,9 +148,97 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         s = self.get_success(self.handler.search_users(u1, "user3", 10))
         self.assertEqual(len(s["results"]), 0)
 
+    @override_config({"encryption_enabled_by_default_for_room_type": "all"})
+    def test_encrypted_by_default_config_option_all(self):
+        """Tests that invite-only and non-invite-only rooms have encryption enabled by
+        default when the config option encryption_enabled_by_default_for_room_type is "all".
+        """
+        # Create a user
+        user = self.register_user("user", "pass")
+        user_token = self.login(user, "pass")
+
+        # Create an invite-only room as that user
+        room_id = self.helper.create_room_as(user, is_public=False, tok=user_token)
+
+        # Check that the room has an encryption state event
+        event_content = self.helper.get_state(
+            room_id=room_id, event_type=EventTypes.RoomEncryption, tok=user_token,
+        )
+        self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT})
+
+        # Create a non invite-only room as that user
+        room_id = self.helper.create_room_as(user, is_public=True, tok=user_token)
+
+        # Check that the room has an encryption state event
+        event_content = self.helper.get_state(
+            room_id=room_id, event_type=EventTypes.RoomEncryption, tok=user_token,
+        )
+        self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT})
+
+    @override_config({"encryption_enabled_by_default_for_room_type": "invite"})
+    def test_encrypted_by_default_config_option_invite(self):
+        """Tests that only new, invite-only rooms have encryption enabled by default when
+        the config option encryption_enabled_by_default_for_room_type is "invite".
+        """
+        # Create a user
+        user = self.register_user("user", "pass")
+        user_token = self.login(user, "pass")
+
+        # Create an invite-only room as that user
+        room_id = self.helper.create_room_as(user, is_public=False, tok=user_token)
+
+        # Check that the room has an encryption state event
+        event_content = self.helper.get_state(
+            room_id=room_id, event_type=EventTypes.RoomEncryption, tok=user_token,
+        )
+        self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT})
+
+        # Create a non invite-only room as that user
+        room_id = self.helper.create_room_as(user, is_public=True, tok=user_token)
+
+        # Check that the room does not have an encryption state event
+        self.helper.get_state(
+            room_id=room_id,
+            event_type=EventTypes.RoomEncryption,
+            tok=user_token,
+            expect_code=404,
+        )
+
+    @override_config({"encryption_enabled_by_default_for_room_type": "off"})
+    def test_encrypted_by_default_config_option_off(self):
+        """Tests that neither new invite-only nor non-invite-only rooms have encryption
+        enabled by default when the config option
+        encryption_enabled_by_default_for_room_type is "off".
+        """
+        # Create a user
+        user = self.register_user("user", "pass")
+        user_token = self.login(user, "pass")
+
+        # Create an invite-only room as that user
+        room_id = self.helper.create_room_as(user, is_public=False, tok=user_token)
+
+        # Check that the room does not have an encryption state event
+        self.helper.get_state(
+            room_id=room_id,
+            event_type=EventTypes.RoomEncryption,
+            tok=user_token,
+            expect_code=404,
+        )
+
+        # Create a non invite-only room as that user
+        room_id = self.helper.create_room_as(user, is_public=True, tok=user_token)
+
+        # Check that the room does not have an encryption state event
+        self.helper.get_state(
+            room_id=room_id,
+            event_type=EventTypes.RoomEncryption,
+            tok=user_token,
+            expect_code=404,
+        )
+
     def test_spam_checker(self):
         """
-        A user which fails to the spam checks will not appear in search results.
+        A user which fails the spam checks will not appear in search results.
         """
         u1 = self.register_user("user1", "pass")
         u1_token = self.login(u1, "pass")
@@ -180,7 +269,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         # Configure a spam checker that does not filter any users.
         spam_checker = self.hs.get_spam_checker()
 
-        class AllowAll(object):
+        class AllowAll:
             def check_username_for_spam(self, user_profile):
                 # Allow all users.
                 return False
@@ -193,7 +282,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         self.assertEqual(len(s["results"]), 1)
 
         # Configure a spam checker that filters all users.
-        class BlockAll(object):
+        class BlockAll:
             def check_username_for_spam(self, user_profile):
                 # All users are spammy.
                 return True
@@ -250,7 +339,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
 
     def get_users_in_public_rooms(self):
         r = self.get_success(
-            self.store.db.simple_select_list(
+            self.store.db_pool.simple_select_list(
                 "users_in_public_rooms", None, ("user_id", "room_id")
             )
         )
@@ -261,7 +350,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
 
     def get_users_who_share_private_rooms(self):
         return self.get_success(
-            self.store.db.simple_select_list(
+            self.store.db_pool.simple_select_list(
                 "users_who_share_private_rooms",
                 None,
                 ["user_id", "other_user_id", "room_id"],
@@ -273,10 +362,10 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         Add the background updates we need to run.
         """
         # Ugh, have to reset this flag
-        self.store.db.updates._all_done = False
+        self.store.db_pool.updates._all_done = False
 
         self.get_success(
-            self.store.db.simple_insert(
+            self.store.db_pool.simple_insert(
                 "background_updates",
                 {
                     "update_name": "populate_user_directory_createtables",
@@ -285,7 +374,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
             )
         )
         self.get_success(
-            self.store.db.simple_insert(
+            self.store.db_pool.simple_insert(
                 "background_updates",
                 {
                     "update_name": "populate_user_directory_process_rooms",
@@ -295,7 +384,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
             )
         )
         self.get_success(
-            self.store.db.simple_insert(
+            self.store.db_pool.simple_insert(
                 "background_updates",
                 {
                     "update_name": "populate_user_directory_process_users",
@@ -305,7 +394,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
             )
         )
         self.get_success(
-            self.store.db.simple_insert(
+            self.store.db_pool.simple_insert(
                 "background_updates",
                 {
                     "update_name": "populate_user_directory_cleanup",
@@ -348,10 +437,10 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         self._add_background_updates()
 
         while not self.get_success(
-            self.store.db.updates.has_completed_background_updates()
+            self.store.db_pool.updates.has_completed_background_updates()
         ):
             self.get_success(
-                self.store.db.updates.do_next_background_update(100), by=0.1
+                self.store.db_pool.updates.do_next_background_update(100), by=0.1
             )
 
         shares_private = self.get_users_who_share_private_rooms()
@@ -387,10 +476,10 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         self._add_background_updates()
 
         while not self.get_success(
-            self.store.db.updates.has_completed_background_updates()
+            self.store.db_pool.updates.has_completed_background_updates()
         ):
             self.get_success(
-                self.store.db.updates.do_next_background_update(100), by=0.1
+                self.store.db_pool.updates.do_next_background_update(100), by=0.1
             )
 
         shares_private = self.get_users_who_share_private_rooms()
diff --git a/tests/http/__init__.py b/tests/http/__init__.py
index 2096ba3c91..5d41443293 100644
--- a/tests/http/__init__.py
+++ b/tests/http/__init__.py
@@ -133,7 +133,7 @@ def create_test_cert_file(sanlist):
 
 
 @implementer(IOpenSSLServerConnectionCreator)
-class TestServerTLSConnectionFactory(object):
+class TestServerTLSConnectionFactory:
     """An SSL connection creator which returns connections which present a certificate
     signed by our test CA."""
 
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index 562397cdda..8b5ad4574f 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -67,6 +67,14 @@ def get_connection_factory():
     return test_server_connection_factory
 
 
+# Once Async Mocks or lambdas are supported this can go away.
+def generate_resolve_service(result):
+    async def resolve_service(_):
+        return result
+
+    return resolve_service
+
+
 class MatrixFederationAgentTests(unittest.TestCase):
     def setUp(self):
         self.reactor = ThreadedMemoryReactorClock()
@@ -86,6 +94,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         self.well_known_resolver = WellKnownResolver(
             self.reactor,
             Agent(self.reactor, contextFactory=self.tls_factory),
+            b"test-agent",
             well_known_cache=self.well_known_cache,
             had_well_known_cache=self.had_well_known_cache,
         )
@@ -93,6 +102,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         self.agent = MatrixFederationAgent(
             reactor=self.reactor,
             tls_client_options_factory=self.tls_factory,
+            user_agent="test-agent",  # Note that this is unused since _well_known_resolver is provided.
             _srv_resolver=self.mock_resolver,
             _well_known_resolver=self.well_known_resolver,
         )
@@ -186,6 +196,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
         # check the .well-known request and send a response
         self.assertEqual(len(well_known_server.requests), 1)
         request = well_known_server.requests[0]
+        self.assertEqual(
+            request.requestHeaders.getRawHeaders(b"user-agent"), [b"test-agent"]
+        )
         self._send_well_known_response(request, content, headers=response_headers)
         return well_known_server
 
@@ -231,6 +244,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
         self.assertEqual(
             request.requestHeaders.getRawHeaders(b"host"), [b"testserv:8448"]
         )
+        self.assertEqual(
+            request.requestHeaders.getRawHeaders(b"user-agent"), [b"test-agent"]
+        )
         content = request.content.read()
         self.assertEqual(content, b"")
 
@@ -365,7 +381,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         """
         Test the behaviour when the certificate on the server doesn't match the hostname
         """
-        self.mock_resolver.resolve_service.side_effect = lambda _: []
+        self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
         self.reactor.lookups["testserv1"] = "1.2.3.4"
 
         test_d = self._make_get_request(b"matrix://testserv1/foo/bar")
@@ -448,7 +464,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         Test the behaviour when the server name has no port, no SRV, and no well-known
         """
 
-        self.mock_resolver.resolve_service.side_effect = lambda _: []
+        self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
         self.reactor.lookups["testserv"] = "1.2.3.4"
 
         test_d = self._make_get_request(b"matrix://testserv/foo/bar")
@@ -502,7 +518,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         """Test the behaviour when the .well-known delegates elsewhere
         """
 
-        self.mock_resolver.resolve_service.side_effect = lambda _: []
+        self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
         self.reactor.lookups["testserv"] = "1.2.3.4"
         self.reactor.lookups["target-server"] = "1::f"
 
@@ -564,7 +580,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         """Test the behaviour when the server name has no port and no SRV record, but
         the .well-known has a 300 redirect
         """
-        self.mock_resolver.resolve_service.side_effect = lambda _: []
+        self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
         self.reactor.lookups["testserv"] = "1.2.3.4"
         self.reactor.lookups["target-server"] = "1::f"
 
@@ -653,7 +669,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         Test the behaviour when the server name has an *invalid* well-known (and no SRV)
         """
 
-        self.mock_resolver.resolve_service.side_effect = lambda _: []
+        self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
         self.reactor.lookups["testserv"] = "1.2.3.4"
 
         test_d = self._make_get_request(b"matrix://testserv/foo/bar")
@@ -709,7 +725,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
         # the config left to the default, which will not trust it (since the
         # presented cert is signed by a test CA)
 
-        self.mock_resolver.resolve_service.side_effect = lambda _: []
+        self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
         self.reactor.lookups["testserv"] = "1.2.3.4"
 
         config = default_config("test", parse=True)
@@ -719,10 +735,12 @@ class MatrixFederationAgentTests(unittest.TestCase):
         agent = MatrixFederationAgent(
             reactor=self.reactor,
             tls_client_options_factory=tls_factory,
+            user_agent=b"test-agent",  # This is unused since _well_known_resolver is passed below.
             _srv_resolver=self.mock_resolver,
             _well_known_resolver=WellKnownResolver(
                 self.reactor,
                 Agent(self.reactor, contextFactory=tls_factory),
+                b"test-agent",
                 well_known_cache=self.well_known_cache,
                 had_well_known_cache=self.had_well_known_cache,
             ),
@@ -754,9 +772,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
         """
         Test the behaviour when there is a single SRV record
         """
-        self.mock_resolver.resolve_service.side_effect = lambda _: [
-            Server(host=b"srvtarget", port=8443)
-        ]
+        self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
+            [Server(host=b"srvtarget", port=8443)]
+        )
         self.reactor.lookups["srvtarget"] = "1.2.3.4"
 
         test_d = self._make_get_request(b"matrix://testserv/foo/bar")
@@ -809,9 +827,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
         self.assertEqual(host, "1.2.3.4")
         self.assertEqual(port, 443)
 
-        self.mock_resolver.resolve_service.side_effect = lambda _: [
-            Server(host=b"srvtarget", port=8443)
-        ]
+        self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
+            [Server(host=b"srvtarget", port=8443)]
+        )
 
         self._handle_well_known_connection(
             client_factory,
@@ -851,7 +869,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
     def test_idna_servername(self):
         """test the behaviour when the server name has idna chars in"""
 
-        self.mock_resolver.resolve_service.side_effect = lambda _: []
+        self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
 
         # the resolver is always called with the IDNA hostname as a native string.
         self.reactor.lookups["xn--bcher-kva.com"] = "1.2.3.4"
@@ -912,9 +930,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
     def test_idna_srv_target(self):
         """test the behaviour when the target of a SRV record has idna chars"""
 
-        self.mock_resolver.resolve_service.side_effect = lambda _: [
-            Server(host=b"xn--trget-3qa.com", port=8443)  # târget.com
-        ]
+        self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
+            [Server(host=b"xn--trget-3qa.com", port=8443)]  # târget.com
+        )
         self.reactor.lookups["xn--trget-3qa.com"] = "1.2.3.4"
 
         test_d = self._make_get_request(b"matrix://xn--bcher-kva.com/foo/bar")
@@ -954,7 +972,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
     def test_well_known_cache(self):
         self.reactor.lookups["testserv"] = "1.2.3.4"
 
-        fetch_d = self.well_known_resolver.get_well_known(b"testserv")
+        fetch_d = defer.ensureDeferred(
+            self.well_known_resolver.get_well_known(b"testserv")
+        )
 
         # there should be an attempt to connect on port 443 for the .well-known
         clients = self.reactor.tcpClients
@@ -977,7 +997,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
         well_known_server.loseConnection()
 
         # repeat the request: it should hit the cache
-        fetch_d = self.well_known_resolver.get_well_known(b"testserv")
+        fetch_d = defer.ensureDeferred(
+            self.well_known_resolver.get_well_known(b"testserv")
+        )
         r = self.successResultOf(fetch_d)
         self.assertEqual(r.delegated_server, b"target-server")
 
@@ -985,7 +1007,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
         self.reactor.pump((1000.0,))
 
         # now it should connect again
-        fetch_d = self.well_known_resolver.get_well_known(b"testserv")
+        fetch_d = defer.ensureDeferred(
+            self.well_known_resolver.get_well_known(b"testserv")
+        )
 
         self.assertEqual(len(clients), 1)
         (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
@@ -1008,7 +1032,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
 
         self.reactor.lookups["testserv"] = "1.2.3.4"
 
-        fetch_d = self.well_known_resolver.get_well_known(b"testserv")
+        fetch_d = defer.ensureDeferred(
+            self.well_known_resolver.get_well_known(b"testserv")
+        )
 
         # there should be an attempt to connect on port 443 for the .well-known
         clients = self.reactor.tcpClients
@@ -1034,7 +1060,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
         # another lookup.
         self.reactor.pump((900.0,))
 
-        fetch_d = self.well_known_resolver.get_well_known(b"testserv")
+        fetch_d = defer.ensureDeferred(
+            self.well_known_resolver.get_well_known(b"testserv")
+        )
 
         # The resolver may retry a few times, so fonx all requests that come along
         attempts = 0
@@ -1064,7 +1092,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
         self.reactor.pump((10000.0,))
 
         # Repated the request, this time it should fail if the lookup fails.
-        fetch_d = self.well_known_resolver.get_well_known(b"testserv")
+        fetch_d = defer.ensureDeferred(
+            self.well_known_resolver.get_well_known(b"testserv")
+        )
 
         clients = self.reactor.tcpClients
         (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
@@ -1077,11 +1107,12 @@ class MatrixFederationAgentTests(unittest.TestCase):
     def test_srv_fallbacks(self):
         """Test that other SRV results are tried if the first one fails.
         """
-
-        self.mock_resolver.resolve_service.side_effect = lambda _: [
-            Server(host=b"target.com", port=8443),
-            Server(host=b"target.com", port=8444),
-        ]
+        self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
+            [
+                Server(host=b"target.com", port=8443),
+                Server(host=b"target.com", port=8444),
+            ]
+        )
         self.reactor.lookups["target.com"] = "1.2.3.4"
 
         test_d = self._make_get_request(b"matrix://testserv/foo/bar")
@@ -1233,7 +1264,7 @@ def _log_request(request):
 
 
 @implementer(IPolicyForHTTPS)
-class TrustingTLSPolicyForHTTPS(object):
+class TrustingTLSPolicyForHTTPS:
     """An IPolicyForHTTPS which checks that the certificate belongs to the
     right server, but doesn't check the certificate chain."""
 
diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py
index babc201643..fee2985d35 100644
--- a/tests/http/federation/test_srv_resolver.py
+++ b/tests/http/federation/test_srv_resolver.py
@@ -22,7 +22,7 @@ from twisted.internet.error import ConnectError
 from twisted.names import dns, error
 
 from synapse.http.federation.srv_resolver import SrvResolver
-from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context
+from synapse.logging.context import LoggingContext, current_context
 
 from tests import unittest
 from tests.utils import MockClock
@@ -50,13 +50,7 @@ class SrvResolverTestCase(unittest.TestCase):
 
             with LoggingContext("one") as ctx:
                 resolve_d = resolver.resolve_service(service_name)
-
-                self.assertNoResult(resolve_d)
-
-                # should have reset to the sentinel context
-                self.assertIs(current_context(), SENTINEL_CONTEXT)
-
-                result = yield resolve_d
+                result = yield defer.ensureDeferred(resolve_d)
 
                 # should have restored our context
                 self.assertIs(current_context(), ctx)
@@ -91,7 +85,7 @@ class SrvResolverTestCase(unittest.TestCase):
         cache = {service_name: [entry]}
         resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
 
-        servers = yield resolver.resolve_service(service_name)
+        servers = yield defer.ensureDeferred(resolver.resolve_service(service_name))
 
         dns_client_mock.lookupService.assert_called_once_with(service_name)
 
@@ -117,7 +111,7 @@ class SrvResolverTestCase(unittest.TestCase):
             dns_client=dns_client_mock, cache=cache, get_time=clock.time
         )
 
-        servers = yield resolver.resolve_service(service_name)
+        servers = yield defer.ensureDeferred(resolver.resolve_service(service_name))
 
         self.assertFalse(dns_client_mock.lookupService.called)
 
@@ -136,7 +130,7 @@ class SrvResolverTestCase(unittest.TestCase):
         resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
 
         with self.assertRaises(error.DNSServerError):
-            yield resolver.resolve_service(service_name)
+            yield defer.ensureDeferred(resolver.resolve_service(service_name))
 
     @defer.inlineCallbacks
     def test_name_error(self):
@@ -149,7 +143,7 @@ class SrvResolverTestCase(unittest.TestCase):
         cache = {}
         resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
 
-        servers = yield resolver.resolve_service(service_name)
+        servers = yield defer.ensureDeferred(resolver.resolve_service(service_name))
 
         self.assertEquals(len(servers), 0)
         self.assertEquals(len(cache), 0)
@@ -166,8 +160,8 @@ class SrvResolverTestCase(unittest.TestCase):
         cache = {}
         resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
 
-        resolve_d = resolver.resolve_service(service_name)
-        self.assertNoResult(resolve_d)
+        # Old versions of Twisted don't have an ensureDeferred in failureResultOf.
+        resolve_d = defer.ensureDeferred(resolver.resolve_service(service_name))
 
         # returning a single "." should make the lookup fail with a ConenctError
         lookup_deferred.callback(
@@ -192,8 +186,8 @@ class SrvResolverTestCase(unittest.TestCase):
         cache = {}
         resolver = SrvResolver(dns_client=dns_client_mock, cache=cache)
 
-        resolve_d = resolver.resolve_service(service_name)
-        self.assertNoResult(resolve_d)
+        # Old versions of Twisted don't have an ensureDeferred in successResultOf.
+        resolve_d = defer.ensureDeferred(resolver.resolve_service(service_name))
 
         lookup_deferred.callback(
             (
diff --git a/tests/http/test_additional_resource.py b/tests/http/test_additional_resource.py
new file mode 100644
index 0000000000..62d36c2906
--- /dev/null
+++ b/tests/http/test_additional_resource.py
@@ -0,0 +1,62 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from synapse.http.additional_resource import AdditionalResource
+from synapse.http.server import respond_with_json
+
+from tests.unittest import HomeserverTestCase
+
+
+class _AsyncTestCustomEndpoint:
+    def __init__(self, config, module_api):
+        pass
+
+    async def handle_request(self, request):
+        respond_with_json(request, 200, {"some_key": "some_value_async"})
+
+
+class _SyncTestCustomEndpoint:
+    def __init__(self, config, module_api):
+        pass
+
+    async def handle_request(self, request):
+        respond_with_json(request, 200, {"some_key": "some_value_sync"})
+
+
+class AdditionalResourceTests(HomeserverTestCase):
+    """Very basic tests that `AdditionalResource` works correctly with sync
+    and async handlers.
+    """
+
+    def test_async(self):
+        handler = _AsyncTestCustomEndpoint({}, None).handle_request
+        self.resource = AdditionalResource(self.hs, handler)
+
+        request, channel = self.make_request("GET", "/")
+        self.render(request)
+
+        self.assertEqual(request.code, 200)
+        self.assertEqual(channel.json_body, {"some_key": "some_value_async"})
+
+    def test_sync(self):
+        handler = _SyncTestCustomEndpoint({}, None).handle_request
+        self.resource = AdditionalResource(self.hs, handler)
+
+        request, channel = self.make_request("GET", "/")
+        self.render(request)
+
+        self.assertEqual(request.code, 200)
+        self.assertEqual(channel.json_body, {"some_key": "some_value_sync"})
diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py
index fff4f0cbf4..5604af3795 100644
--- a/tests/http/test_fedclient.py
+++ b/tests/http/test_fedclient.py
@@ -16,6 +16,7 @@
 from mock import Mock
 
 from netaddr import IPSet
+from parameterized import parameterized
 
 from twisted.internet import defer
 from twisted.internet.defer import TimeoutError
@@ -58,7 +59,9 @@ class FederationClientTests(HomeserverTestCase):
         @defer.inlineCallbacks
         def do_request():
             with LoggingContext("one") as context:
-                fetch_d = self.cl.get_json("testserv:8008", "foo/bar")
+                fetch_d = defer.ensureDeferred(
+                    self.cl.get_json("testserv:8008", "foo/bar")
+                )
 
                 # Nothing happened yet
                 self.assertNoResult(fetch_d)
@@ -120,7 +123,9 @@ class FederationClientTests(HomeserverTestCase):
         """
         If the DNS lookup returns an error, it will bubble up.
         """
-        d = self.cl.get_json("testserv2:8008", "foo/bar", timeout=10000)
+        d = defer.ensureDeferred(
+            self.cl.get_json("testserv2:8008", "foo/bar", timeout=10000)
+        )
         self.pump()
 
         f = self.failureResultOf(d)
@@ -128,7 +133,9 @@ class FederationClientTests(HomeserverTestCase):
         self.assertIsInstance(f.value.inner_exception, DNSLookupError)
 
     def test_client_connection_refused(self):
-        d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
+        d = defer.ensureDeferred(
+            self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
+        )
 
         self.pump()
 
@@ -154,7 +161,9 @@ class FederationClientTests(HomeserverTestCase):
         If the HTTP request is not connected and is timed out, it'll give a
         ConnectingCancelledError or TimeoutError.
         """
-        d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
+        d = defer.ensureDeferred(
+            self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
+        )
 
         self.pump()
 
@@ -184,7 +193,9 @@ class FederationClientTests(HomeserverTestCase):
         If the HTTP request is connected, but gets no response before being
         timed out, it'll give a ResponseNeverReceived.
         """
-        d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
+        d = defer.ensureDeferred(
+            self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
+        )
 
         self.pump()
 
@@ -226,7 +237,7 @@ class FederationClientTests(HomeserverTestCase):
         # Try making a GET request to a blacklisted IPv4 address
         # ------------------------------------------------------
         # Make the request
-        d = cl.get_json("internal:8008", "foo/bar", timeout=10000)
+        d = defer.ensureDeferred(cl.get_json("internal:8008", "foo/bar", timeout=10000))
 
         # Nothing happened yet
         self.assertNoResult(d)
@@ -244,7 +255,9 @@ class FederationClientTests(HomeserverTestCase):
         # Try making a POST request to a blacklisted IPv6 address
         # -------------------------------------------------------
         # Make the request
-        d = cl.post_json("internalv6:8008", "foo/bar", timeout=10000)
+        d = defer.ensureDeferred(
+            cl.post_json("internalv6:8008", "foo/bar", timeout=10000)
+        )
 
         # Nothing has happened yet
         self.assertNoResult(d)
@@ -263,7 +276,7 @@ class FederationClientTests(HomeserverTestCase):
         # Try making a GET request to a non-blacklisted IPv4 address
         # ----------------------------------------------------------
         # Make the request
-        d = cl.post_json("fine:8008", "foo/bar", timeout=10000)
+        d = defer.ensureDeferred(cl.post_json("fine:8008", "foo/bar", timeout=10000))
 
         # Nothing has happened yet
         self.assertNoResult(d)
@@ -286,7 +299,7 @@ class FederationClientTests(HomeserverTestCase):
         request = MatrixFederationRequest(
             method="GET", destination="testserv:8008", path="foo/bar"
         )
-        d = self.cl._send_request(request, timeout=10000)
+        d = defer.ensureDeferred(self.cl._send_request(request, timeout=10000))
 
         self.pump()
 
@@ -310,7 +323,9 @@ class FederationClientTests(HomeserverTestCase):
         If the HTTP request is connected, but gets no response before being
         timed out, it'll give a ResponseNeverReceived.
         """
-        d = self.cl.post_json("testserv:8008", "foo/bar", timeout=10000)
+        d = defer.ensureDeferred(
+            self.cl.post_json("testserv:8008", "foo/bar", timeout=10000)
+        )
 
         self.pump()
 
@@ -342,7 +357,9 @@ class FederationClientTests(HomeserverTestCase):
         requiring a trailing slash. We need to retry the request with a
         trailing slash. Workaround for Synapse <= v0.99.3, explained in #3622.
         """
-        d = self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True)
+        d = defer.ensureDeferred(
+            self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True)
+        )
 
         # Send the request
         self.pump()
@@ -395,7 +412,9 @@ class FederationClientTests(HomeserverTestCase):
 
         See test_client_requires_trailing_slashes() for context.
         """
-        d = self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True)
+        d = defer.ensureDeferred(
+            self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True)
+        )
 
         # Send the request
         self.pump()
@@ -432,7 +451,11 @@ class FederationClientTests(HomeserverTestCase):
         self.failureResultOf(d)
 
     def test_client_sends_body(self):
-        self.cl.post_json("testserv:8008", "foo/bar", timeout=10000, data={"a": "b"})
+        defer.ensureDeferred(
+            self.cl.post_json(
+                "testserv:8008", "foo/bar", timeout=10000, data={"a": "b"}
+            )
+        )
 
         self.pump()
 
@@ -453,7 +476,7 @@ class FederationClientTests(HomeserverTestCase):
 
     def test_closes_connection(self):
         """Check that the client closes unused HTTP connections"""
-        d = self.cl.get_json("testserv:8008", "foo/bar")
+        d = defer.ensureDeferred(self.cl.get_json("testserv:8008", "foo/bar"))
 
         self.pump()
 
@@ -486,6 +509,53 @@ class FederationClientTests(HomeserverTestCase):
         self.assertFalse(conn.disconnecting)
 
         # wait for a while
-        self.pump(120)
+        self.reactor.advance(120)
 
         self.assertTrue(conn.disconnecting)
+
+    @parameterized.expand([(b"",), (b"foo",), (b'{"a": Infinity}',)])
+    def test_json_error(self, return_value):
+        """
+        Test what happens if invalid JSON is returned from the remote endpoint.
+        """
+
+        test_d = defer.ensureDeferred(self.cl.get_json("testserv:8008", "foo/bar"))
+
+        self.pump()
+
+        # Nothing happened yet
+        self.assertNoResult(test_d)
+
+        # Make sure treq is trying to connect
+        clients = self.reactor.tcpClients
+        self.assertEqual(len(clients), 1)
+        (host, port, factory, _timeout, _bindAddress) = clients[0]
+        self.assertEqual(host, "1.2.3.4")
+        self.assertEqual(port, 8008)
+
+        # complete the connection and wire it up to a fake transport
+        protocol = factory.buildProtocol(None)
+        transport = StringTransport()
+        protocol.makeConnection(transport)
+
+        # that should have made it send the request to the transport
+        self.assertRegex(transport.value(), b"^GET /foo/bar")
+        self.assertRegex(transport.value(), b"Host: testserv:8008")
+
+        # Deferred is still without a result
+        self.assertNoResult(test_d)
+
+        # Send it the HTTP response
+        protocol.dataReceived(
+            b"HTTP/1.1 200 OK\r\n"
+            b"Server: Fake\r\n"
+            b"Content-Type: application/json\r\n"
+            b"Content-Length: %i\r\n"
+            b"\r\n"
+            b"%s" % (len(return_value), return_value)
+        )
+
+        self.pump()
+
+        f = self.failureResultOf(test_d)
+        self.assertIsInstance(f.value, ValueError)
diff --git a/tests/http/test_servlet.py b/tests/http/test_servlet.py
new file mode 100644
index 0000000000..45089158ce
--- /dev/null
+++ b/tests/http/test_servlet.py
@@ -0,0 +1,80 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import json
+from io import BytesIO
+
+from mock import Mock
+
+from synapse.api.errors import SynapseError
+from synapse.http.servlet import (
+    parse_json_object_from_request,
+    parse_json_value_from_request,
+)
+
+from tests import unittest
+
+
+def make_request(content):
+    """Make an object that acts enough like a request."""
+    request = Mock(spec=["content"])
+
+    if isinstance(content, dict):
+        content = json.dumps(content).encode("utf8")
+
+    request.content = BytesIO(content)
+    return request
+
+
+class TestServletUtils(unittest.TestCase):
+    def test_parse_json_value(self):
+        """Basic tests for parse_json_value_from_request."""
+        # Test round-tripping.
+        obj = {"foo": 1}
+        result = parse_json_value_from_request(make_request(obj))
+        self.assertEqual(result, obj)
+
+        # Results don't have to be objects.
+        result = parse_json_value_from_request(make_request(b'["foo"]'))
+        self.assertEqual(result, ["foo"])
+
+        # Test empty.
+        with self.assertRaises(SynapseError):
+            parse_json_value_from_request(make_request(b""))
+
+        result = parse_json_value_from_request(make_request(b""), allow_empty_body=True)
+        self.assertIsNone(result)
+
+        # Invalid UTF-8.
+        with self.assertRaises(SynapseError):
+            parse_json_value_from_request(make_request(b"\xFF\x00"))
+
+        # Invalid JSON.
+        with self.assertRaises(SynapseError):
+            parse_json_value_from_request(make_request(b"foo"))
+
+        with self.assertRaises(SynapseError):
+            parse_json_value_from_request(make_request(b'{"foo": Infinity}'))
+
+    def test_parse_json_object(self):
+        """Basic tests for parse_json_object_from_request."""
+        # Test empty.
+        result = parse_json_object_from_request(
+            make_request(b""), allow_empty_body=True
+        )
+        self.assertEqual(result, {})
+
+        # Test not an object
+        with self.assertRaises(SynapseError):
+            parse_json_object_from_request(make_request(b'["foo"]'))
diff --git a/tests/logging/test_structured.py b/tests/logging/test_structured.py
index 451d05c0f0..d36f5f426c 100644
--- a/tests/logging/test_structured.py
+++ b/tests/logging/test_structured.py
@@ -29,12 +29,12 @@ from synapse.logging.context import LoggingContext
 from tests.unittest import DEBUG, HomeserverTestCase
 
 
-class FakeBeginner(object):
+class FakeBeginner:
     def beginLoggingTo(self, observers, **kwargs):
         self.observers = observers
 
 
-class StructuredLoggingTestBase(object):
+class StructuredLoggingTestBase:
     """
     Test base that registers a cleanup handler to reset the stdlib log handler
     to 'unset'.
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 807cd65dd6..04de0b9dbe 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -35,7 +35,7 @@ class ModuleApiTestCase(HomeserverTestCase):
         # Check that the new user exists with all provided attributes
         self.assertEqual(user_id, "@bob:test")
         self.assertTrue(access_token)
-        self.assertTrue(self.store.get_user_by_id(user_id))
+        self.assertTrue(self.get_success(self.store.get_user_by_id(user_id)))
 
         # Check that the email was assigned
         emails = self.get_success(self.store.user_get_threepids(user_id))
diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index 83032cc9ea..3224568640 100644
--- a/tests/push/test_email.py
+++ b/tests/push/test_email.py
@@ -27,7 +27,7 @@ from tests.unittest import HomeserverTestCase
 
 
 @attr.s
-class _User(object):
+class _User:
     "Helper wrapper for user ID and access token"
     id = attr.ib()
     token = attr.ib()
@@ -170,7 +170,7 @@ class EmailPusherTests(HomeserverTestCase):
         last_stream_ordering = pushers[0]["last_stream_ordering"]
 
         # Advance time a bit, so the pusher will register something has happened
-        self.pump(100)
+        self.pump(10)
 
         # It hasn't succeeded yet, so the stream ordering shouldn't have moved
         pushers = self.get_success(
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index baf9c785f4..b567868b02 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -25,7 +25,6 @@ from tests.unittest import HomeserverTestCase
 
 
 class HTTPPusherTests(HomeserverTestCase):
-
     servlets = [
         synapse.rest.admin.register_servlets_for_client_rest_resource,
         room.register_servlets,
@@ -35,7 +34,6 @@ class HTTPPusherTests(HomeserverTestCase):
     hijack_auth = False
 
     def make_homeserver(self, reactor, clock):
-
         self.push_attempts = []
 
         m = Mock()
@@ -90,9 +88,6 @@ class HTTPPusherTests(HomeserverTestCase):
         # Create a room
         room = self.helper.create_room_as(user_id, tok=access_token)
 
-        # Invite the other person
-        self.helper.invite(room=room, src=user_id, tok=access_token, targ=other_user_id)
-
         # The other user joins
         self.helper.join(room=room, user=other_user_id, tok=other_access_token)
 
@@ -157,3 +152,350 @@ class HTTPPusherTests(HomeserverTestCase):
         pushers = list(pushers)
         self.assertEqual(len(pushers), 1)
         self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering)
+
+    def test_sends_high_priority_for_encrypted(self):
+        """
+        The HTTP pusher will send pushes at high priority if they correspond
+        to an encrypted message.
+        This will happen both in 1:1 rooms and larger rooms.
+        """
+        # Register the user who gets notified
+        user_id = self.register_user("user", "pass")
+        access_token = self.login("user", "pass")
+
+        # Register the user who sends the message
+        other_user_id = self.register_user("otheruser", "pass")
+        other_access_token = self.login("otheruser", "pass")
+
+        # Register a third user
+        yet_another_user_id = self.register_user("yetanotheruser", "pass")
+        yet_another_access_token = self.login("yetanotheruser", "pass")
+
+        # Create a room
+        room = self.helper.create_room_as(user_id, tok=access_token)
+
+        # The other user joins
+        self.helper.join(room=room, user=other_user_id, tok=other_access_token)
+
+        # Register the pusher
+        user_tuple = self.get_success(
+            self.hs.get_datastore().get_user_by_access_token(access_token)
+        )
+        token_id = user_tuple["token_id"]
+
+        self.get_success(
+            self.hs.get_pusherpool().add_pusher(
+                user_id=user_id,
+                access_token=token_id,
+                kind="http",
+                app_id="m.http",
+                app_display_name="HTTP Push Notifications",
+                device_display_name="pushy push",
+                pushkey="a@example.com",
+                lang=None,
+                data={"url": "example.com"},
+            )
+        )
+
+        # Send an encrypted event
+        # I know there'd normally be set-up of an encrypted room first
+        # but this will do for our purposes
+        self.helper.send_event(
+            room,
+            "m.room.encrypted",
+            content={
+                "algorithm": "m.megolm.v1.aes-sha2",
+                "sender_key": "6lImKbzK51MzWLwHh8tUM3UBBSBrLlgup/OOCGTvumM",
+                "ciphertext": "AwgAErABoRxwpMipdgiwXgu46rHiWQ0DmRj0qUlPrMraBUDk"
+                "leTnJRljpuc7IOhsYbLY3uo2WI0ab/ob41sV+3JEIhODJPqH"
+                "TK7cEZaIL+/up9e+dT9VGF5kRTWinzjkeqO8FU5kfdRjm+3w"
+                "0sy3o1OCpXXCfO+faPhbV/0HuK4ndx1G+myNfK1Nk/CxfMcT"
+                "BT+zDS/Df/QePAHVbrr9uuGB7fW8ogW/ulnydgZPRluusFGv"
+                "J3+cg9LoPpZPAmv5Me3ec7NtdlfN0oDZ0gk3TiNkkhsxDG9Y"
+                "YcNzl78USI0q8+kOV26Bu5dOBpU4WOuojXZHJlP5lMgdzLLl"
+                "EQ0",
+                "session_id": "IigqfNWLL+ez/Is+Duwp2s4HuCZhFG9b9CZKTYHtQ4A",
+                "device_id": "AHQDUSTAAA",
+            },
+            tok=other_access_token,
+        )
+
+        # Advance time a bit, so the pusher will register something has happened
+        self.pump()
+
+        # Make the push succeed
+        self.push_attempts[0][0].callback({})
+        self.pump()
+
+        # Check our push made it with high priority
+        self.assertEqual(len(self.push_attempts), 1)
+        self.assertEqual(self.push_attempts[0][1], "example.com")
+        self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high")
+
+        # Add yet another person — we want to make this room not a 1:1
+        # (as encrypted messages in a 1:1 currently have tweaks applied
+        #  so it doesn't properly exercise the condition of all encrypted
+        #  messages need to be high).
+        self.helper.join(
+            room=room, user=yet_another_user_id, tok=yet_another_access_token
+        )
+
+        # Check no push notifications are sent regarding the membership changes
+        # (that would confuse the test)
+        self.pump()
+        self.assertEqual(len(self.push_attempts), 1)
+
+        # Send another encrypted event
+        self.helper.send_event(
+            room,
+            "m.room.encrypted",
+            content={
+                "ciphertext": "AwgAEoABtEuic/2DF6oIpNH+q/PonzlhXOVho8dTv0tzFr5m"
+                "9vTo50yabx3nxsRlP2WxSqa8I07YftP+EKWCWJvTkg6o7zXq"
+                "6CK+GVvLQOVgK50SfvjHqJXN+z1VEqj+5mkZVN/cAgJzoxcH"
+                "zFHkwDPJC8kQs47IHd8EO9KBUK4v6+NQ1uE/BIak4qAf9aS/"
+                "kI+f0gjn9IY9K6LXlah82A/iRyrIrxkCkE/n0VfvLhaWFecC"
+                "sAWTcMLoF6fh1Jpke95mljbmFSpsSd/eEQw",
+                "device_id": "SRCFTWTHXO",
+                "session_id": "eMA+bhGczuTz1C5cJR1YbmrnnC6Goni4lbvS5vJ1nG4",
+                "algorithm": "m.megolm.v1.aes-sha2",
+                "sender_key": "rC/XSIAiYrVGSuaHMop8/pTZbku4sQKBZwRwukgnN1c",
+            },
+            tok=other_access_token,
+        )
+
+        # Advance time a bit, so the pusher will register something has happened
+        self.pump()
+        self.assertEqual(len(self.push_attempts), 2)
+        self.assertEqual(self.push_attempts[1][1], "example.com")
+        self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "high")
+
+    def test_sends_high_priority_for_one_to_one_only(self):
+        """
+        The HTTP pusher will send pushes at high priority if they correspond
+        to a message in a one-to-one room.
+        """
+        # Register the user who gets notified
+        user_id = self.register_user("user", "pass")
+        access_token = self.login("user", "pass")
+
+        # Register the user who sends the message
+        other_user_id = self.register_user("otheruser", "pass")
+        other_access_token = self.login("otheruser", "pass")
+
+        # Register a third user
+        yet_another_user_id = self.register_user("yetanotheruser", "pass")
+        yet_another_access_token = self.login("yetanotheruser", "pass")
+
+        # Create a room
+        room = self.helper.create_room_as(user_id, tok=access_token)
+
+        # The other user joins
+        self.helper.join(room=room, user=other_user_id, tok=other_access_token)
+
+        # Register the pusher
+        user_tuple = self.get_success(
+            self.hs.get_datastore().get_user_by_access_token(access_token)
+        )
+        token_id = user_tuple["token_id"]
+
+        self.get_success(
+            self.hs.get_pusherpool().add_pusher(
+                user_id=user_id,
+                access_token=token_id,
+                kind="http",
+                app_id="m.http",
+                app_display_name="HTTP Push Notifications",
+                device_display_name="pushy push",
+                pushkey="a@example.com",
+                lang=None,
+                data={"url": "example.com"},
+            )
+        )
+
+        # Send a message
+        self.helper.send(room, body="Hi!", tok=other_access_token)
+
+        # Advance time a bit, so the pusher will register something has happened
+        self.pump()
+
+        # Make the push succeed
+        self.push_attempts[0][0].callback({})
+        self.pump()
+
+        # Check our push made it with high priority — this is a one-to-one room
+        self.assertEqual(len(self.push_attempts), 1)
+        self.assertEqual(self.push_attempts[0][1], "example.com")
+        self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high")
+
+        # Yet another user joins
+        self.helper.join(
+            room=room, user=yet_another_user_id, tok=yet_another_access_token
+        )
+
+        # Check no push notifications are sent regarding the membership changes
+        # (that would confuse the test)
+        self.pump()
+        self.assertEqual(len(self.push_attempts), 1)
+
+        # Send another event
+        self.helper.send(room, body="Welcome!", tok=other_access_token)
+
+        # Advance time a bit, so the pusher will register something has happened
+        self.pump()
+        self.assertEqual(len(self.push_attempts), 2)
+        self.assertEqual(self.push_attempts[1][1], "example.com")
+
+        # check that this is low-priority
+        self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
+
+    def test_sends_high_priority_for_mention(self):
+        """
+        The HTTP pusher will send pushes at high priority if they correspond
+        to a message containing the user's display name.
+        """
+        # Register the user who gets notified
+        user_id = self.register_user("user", "pass")
+        access_token = self.login("user", "pass")
+
+        # Register the user who sends the message
+        other_user_id = self.register_user("otheruser", "pass")
+        other_access_token = self.login("otheruser", "pass")
+
+        # Register a third user
+        yet_another_user_id = self.register_user("yetanotheruser", "pass")
+        yet_another_access_token = self.login("yetanotheruser", "pass")
+
+        # Create a room
+        room = self.helper.create_room_as(user_id, tok=access_token)
+
+        # The other users join
+        self.helper.join(room=room, user=other_user_id, tok=other_access_token)
+        self.helper.join(
+            room=room, user=yet_another_user_id, tok=yet_another_access_token
+        )
+
+        # Register the pusher
+        user_tuple = self.get_success(
+            self.hs.get_datastore().get_user_by_access_token(access_token)
+        )
+        token_id = user_tuple["token_id"]
+
+        self.get_success(
+            self.hs.get_pusherpool().add_pusher(
+                user_id=user_id,
+                access_token=token_id,
+                kind="http",
+                app_id="m.http",
+                app_display_name="HTTP Push Notifications",
+                device_display_name="pushy push",
+                pushkey="a@example.com",
+                lang=None,
+                data={"url": "example.com"},
+            )
+        )
+
+        # Send a message
+        self.helper.send(room, body="Oh, user, hello!", tok=other_access_token)
+
+        # Advance time a bit, so the pusher will register something has happened
+        self.pump()
+
+        # Make the push succeed
+        self.push_attempts[0][0].callback({})
+        self.pump()
+
+        # Check our push made it with high priority
+        self.assertEqual(len(self.push_attempts), 1)
+        self.assertEqual(self.push_attempts[0][1], "example.com")
+        self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high")
+
+        # Send another event, this time with no mention
+        self.helper.send(room, body="Are you there?", tok=other_access_token)
+
+        # Advance time a bit, so the pusher will register something has happened
+        self.pump()
+        self.assertEqual(len(self.push_attempts), 2)
+        self.assertEqual(self.push_attempts[1][1], "example.com")
+
+        # check that this is low-priority
+        self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
+
+    def test_sends_high_priority_for_atroom(self):
+        """
+        The HTTP pusher will send pushes at high priority if they correspond
+        to a message that contains @room.
+        """
+        # Register the user who gets notified
+        user_id = self.register_user("user", "pass")
+        access_token = self.login("user", "pass")
+
+        # Register the user who sends the message
+        other_user_id = self.register_user("otheruser", "pass")
+        other_access_token = self.login("otheruser", "pass")
+
+        # Register a third user
+        yet_another_user_id = self.register_user("yetanotheruser", "pass")
+        yet_another_access_token = self.login("yetanotheruser", "pass")
+
+        # Create a room (as other_user so the power levels are compatible with
+        # other_user sending @room).
+        room = self.helper.create_room_as(other_user_id, tok=other_access_token)
+
+        # The other users join
+        self.helper.join(room=room, user=user_id, tok=access_token)
+        self.helper.join(
+            room=room, user=yet_another_user_id, tok=yet_another_access_token
+        )
+
+        # Register the pusher
+        user_tuple = self.get_success(
+            self.hs.get_datastore().get_user_by_access_token(access_token)
+        )
+        token_id = user_tuple["token_id"]
+
+        self.get_success(
+            self.hs.get_pusherpool().add_pusher(
+                user_id=user_id,
+                access_token=token_id,
+                kind="http",
+                app_id="m.http",
+                app_display_name="HTTP Push Notifications",
+                device_display_name="pushy push",
+                pushkey="a@example.com",
+                lang=None,
+                data={"url": "example.com"},
+            )
+        )
+
+        # Send a message
+        self.helper.send(
+            room,
+            body="@room eeek! There's a spider on the table!",
+            tok=other_access_token,
+        )
+
+        # Advance time a bit, so the pusher will register something has happened
+        self.pump()
+
+        # Make the push succeed
+        self.push_attempts[0][0].callback({})
+        self.pump()
+
+        # Check our push made it with high priority
+        self.assertEqual(len(self.push_attempts), 1)
+        self.assertEqual(self.push_attempts[0][1], "example.com")
+        self.assertEqual(self.push_attempts[0][2]["notification"]["prio"], "high")
+
+        # Send another event, this time as someone without the power of @room
+        self.helper.send(
+            room, body="@room the spider is gone", tok=yet_another_access_token
+        )
+
+        # Advance time a bit, so the pusher will register something has happened
+        self.pump()
+        self.assertEqual(len(self.push_attempts), 2)
+        self.assertEqual(self.push_attempts[1][1], "example.com")
+
+        # check that this is low-priority
+        self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py
index 9ae6a87d7b..1f4b5ca2ac 100644
--- a/tests/push/test_push_rule_evaluator.py
+++ b/tests/push/test_push_rule_evaluator.py
@@ -15,13 +15,14 @@
 
 from synapse.api.room_versions import RoomVersions
 from synapse.events import FrozenEvent
+from synapse.push import push_rule_evaluator
 from synapse.push.push_rule_evaluator import PushRuleEvaluatorForEvent
 
 from tests import unittest
 
 
 class PushRuleEvaluatorTestCase(unittest.TestCase):
-    def setUp(self):
+    def _get_evaluator(self, content):
         event = FrozenEvent(
             {
                 "event_id": "$event_id",
@@ -29,37 +30,74 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
                 "sender": "@user:test",
                 "state_key": "",
                 "room_id": "@room:test",
-                "content": {"body": "foo bar baz"},
+                "content": content,
             },
             RoomVersions.V1,
         )
         room_member_count = 0
         sender_power_level = 0
         power_levels = {}
-        self.evaluator = PushRuleEvaluatorForEvent(
+        return PushRuleEvaluatorForEvent(
             event, room_member_count, sender_power_level, power_levels
         )
 
     def test_display_name(self):
         """Check for a matching display name in the body of the event."""
+        evaluator = self._get_evaluator({"body": "foo bar baz"})
+
         condition = {
             "kind": "contains_display_name",
         }
 
         # Blank names are skipped.
-        self.assertFalse(self.evaluator.matches(condition, "@user:test", ""))
+        self.assertFalse(evaluator.matches(condition, "@user:test", ""))
 
         # Check a display name that doesn't match.
-        self.assertFalse(self.evaluator.matches(condition, "@user:test", "not found"))
+        self.assertFalse(evaluator.matches(condition, "@user:test", "not found"))
 
         # Check a display name which matches.
-        self.assertTrue(self.evaluator.matches(condition, "@user:test", "foo"))
+        self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))
 
         # A display name that matches, but not a full word does not result in a match.
-        self.assertFalse(self.evaluator.matches(condition, "@user:test", "ba"))
+        self.assertFalse(evaluator.matches(condition, "@user:test", "ba"))
 
         # A display name should not be interpreted as a regular expression.
-        self.assertFalse(self.evaluator.matches(condition, "@user:test", "ba[rz]"))
+        self.assertFalse(evaluator.matches(condition, "@user:test", "ba[rz]"))
 
         # A display name with spaces should work fine.
-        self.assertTrue(self.evaluator.matches(condition, "@user:test", "foo bar"))
+        self.assertTrue(evaluator.matches(condition, "@user:test", "foo bar"))
+
+    def test_no_body(self):
+        """Not having a body shouldn't break the evaluator."""
+        evaluator = self._get_evaluator({})
+
+        condition = {
+            "kind": "contains_display_name",
+        }
+        self.assertFalse(evaluator.matches(condition, "@user:test", "foo"))
+
+    def test_invalid_body(self):
+        """A non-string body should not break the evaluator."""
+        condition = {
+            "kind": "contains_display_name",
+        }
+
+        for body in (1, True, {"foo": "bar"}):
+            evaluator = self._get_evaluator({"body": body})
+            self.assertFalse(evaluator.matches(condition, "@user:test", "foo"))
+
+    def test_tweaks_for_actions(self):
+        """
+        This tests the behaviour of tweaks_for_actions.
+        """
+
+        actions = [
+            {"set_tweak": "sound", "value": "default"},
+            {"set_tweak": "highlight"},
+            "notify",
+        ]
+
+        self.assertEqual(
+            push_rule_evaluator.tweaks_for_actions(actions),
+            {"sound": "default", "highlight": True},
+        )
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 9d4f0bbe44..ae60874ec3 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 import logging
-from typing import Any, List, Optional, Tuple
+from typing import Any, Callable, List, Optional, Tuple
 
 import attr
 
@@ -26,8 +26,9 @@ from synapse.app.generic_worker import (
     GenericWorkerReplicationHandler,
     GenericWorkerServer,
 )
+from synapse.http.server import JsonResource
 from synapse.http.site import SynapseRequest
-from synapse.replication.http import streams
+from synapse.replication.http import ReplicationRestResource, streams
 from synapse.replication.tcp.handler import ReplicationCommandHandler
 from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
 from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
@@ -35,7 +36,7 @@ from synapse.server import HomeServer
 from synapse.util import Clock
 
 from tests import unittest
-from tests.server import FakeTransport
+from tests.server import FakeTransport, render
 
 logger = logging.getLogger(__name__)
 
@@ -64,7 +65,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
 
         # Since we use sqlite in memory databases we need to make sure the
         # databases objects are the same.
-        self.worker_hs.get_datastore().db = hs.get_datastore().db
+        self.worker_hs.get_datastore().db_pool = hs.get_datastore().db_pool
 
         self.test_handler = self._build_replication_data_handler()
         self.worker_hs.replication_data_handler = self.test_handler
@@ -180,6 +181,159 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         self.assertEqual(request.method, b"GET")
 
 
+class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
+    """Base class for tests running multiple workers.
+
+    Automatically handle HTTP replication requests from workers to master,
+    unlike `BaseStreamTestCase`.
+    """
+
+    servlets = []  # type: List[Callable[[HomeServer, JsonResource], None]]
+
+    def setUp(self):
+        super().setUp()
+
+        # build a replication server
+        self.server_factory = ReplicationStreamProtocolFactory(self.hs)
+        self.streamer = self.hs.get_replication_streamer()
+
+        store = self.hs.get_datastore()
+        self.database_pool = store.db_pool
+
+        self.reactor.lookups["testserv"] = "1.2.3.4"
+
+        self._worker_hs_to_resource = {}
+
+        # When we see a connection attempt to the master replication listener we
+        # automatically set up the connection. This is so that tests don't
+        # manually have to go and explicitly set it up each time (plus sometimes
+        # it is impossible to write the handling explicitly in the tests).
+        self.reactor.add_tcp_client_callback(
+            "1.2.3.4", 8765, self._handle_http_replication_attempt
+        )
+
+    def create_test_json_resource(self):
+        """Overrides `HomeserverTestCase.create_test_json_resource`.
+        """
+        # We override this so that it automatically registers all the HTTP
+        # replication servlets, without having to explicitly do that in all
+        # subclassses.
+
+        resource = ReplicationRestResource(self.hs)
+
+        for servlet in self.servlets:
+            servlet(self.hs, resource)
+
+        return resource
+
+    def make_worker_hs(
+        self, worker_app: str, extra_config: dict = {}, **kwargs
+    ) -> HomeServer:
+        """Make a new worker HS instance, correctly connecting replcation
+        stream to the master HS.
+
+        Args:
+            worker_app: Type of worker, e.g. `synapse.app.federation_sender`.
+            extra_config: Any extra config to use for this instances.
+            **kwargs: Options that get passed to `self.setup_test_homeserver`,
+                useful to e.g. pass some mocks for things like `http_client`
+
+        Returns:
+            The new worker HomeServer instance.
+        """
+
+        config = self._get_worker_hs_config()
+        config["worker_app"] = worker_app
+        config.update(extra_config)
+
+        worker_hs = self.setup_test_homeserver(
+            homeserverToUse=GenericWorkerServer,
+            config=config,
+            reactor=self.reactor,
+            **kwargs
+        )
+
+        store = worker_hs.get_datastore()
+        store.db_pool._db_pool = self.database_pool._db_pool
+
+        repl_handler = ReplicationCommandHandler(worker_hs)
+        client = ClientReplicationStreamProtocol(
+            worker_hs, "client", "test", self.clock, repl_handler,
+        )
+        server = self.server_factory.buildProtocol(None)
+
+        client_transport = FakeTransport(server, self.reactor)
+        client.makeConnection(client_transport)
+
+        server_transport = FakeTransport(client, self.reactor)
+        server.makeConnection(server_transport)
+
+        # Set up a resource for the worker
+        resource = ReplicationRestResource(self.hs)
+
+        for servlet in self.servlets:
+            servlet(worker_hs, resource)
+
+        self._worker_hs_to_resource[worker_hs] = resource
+
+        return worker_hs
+
+    def _get_worker_hs_config(self) -> dict:
+        config = self.default_config()
+        config["worker_replication_host"] = "testserv"
+        config["worker_replication_http_port"] = "8765"
+        return config
+
+    def render_on_worker(self, worker_hs: HomeServer, request: SynapseRequest):
+        render(request, self._worker_hs_to_resource[worker_hs], self.reactor)
+
+    def replicate(self):
+        """Tell the master side of replication that something has happened, and then
+        wait for the replication to occur.
+        """
+        self.streamer.on_notifier_poke()
+        self.pump()
+
+    def _handle_http_replication_attempt(self):
+        """Handles a connection attempt to the master replication HTTP
+        listener.
+        """
+
+        # We should have at least one outbound connection attempt, where the
+        # last is one to the HTTP repication IP/port.
+        clients = self.reactor.tcpClients
+        self.assertGreaterEqual(len(clients), 1)
+        (host, port, client_factory, _timeout, _bindAddress) = clients.pop()
+        self.assertEqual(host, "1.2.3.4")
+        self.assertEqual(port, 8765)
+
+        # Set up client side protocol
+        client_protocol = client_factory.buildProtocol(None)
+
+        request_factory = OneShotRequestFactory()
+
+        # Set up the server side protocol
+        channel = _PushHTTPChannel(self.reactor)
+        channel.requestFactory = request_factory
+        channel.site = self.site
+
+        # Connect client to server and vice versa.
+        client_to_server_transport = FakeTransport(
+            channel, self.reactor, client_protocol
+        )
+        client_protocol.makeConnection(client_to_server_transport)
+
+        server_to_client_transport = FakeTransport(
+            client_protocol, self.reactor, channel
+        )
+        channel.makeConnection(server_to_client_transport)
+
+        # Note: at this point we've wired everything up, but we need to return
+        # before the data starts flowing over the connections as this is called
+        # inside `connecTCP` before the connection has been passed back to the
+        # code that requested the TCP connection.
+
+
 class TestReplicationDataHandler(GenericWorkerReplicationHandler):
     """Drop-in for ReplicationDataHandler which just collects RDATA rows"""
 
@@ -241,6 +395,14 @@ class _PushHTTPChannel(HTTPChannel):
             # We need to manually stop the _PullToPushProducer.
             self._pull_to_push_producer.stop()
 
+    def checkPersistence(self, request, version):
+        """Check whether the connection can be re-used
+        """
+        # We hijack this to always say no for ease of wiring stuff up in
+        # `handle_http_replication_attempt`.
+        request.responseHeaders.setRawHeaders(b"connection", [b"close"])
+        return False
+
 
 class _PullToPushProducer:
     """A push producer that wraps a pull producer.
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 1a88c7fb80..561258a356 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -160,7 +160,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
         self.check(
             "get_unread_event_push_actions_by_room_for_user",
             [ROOM_ID, USER_ID_2, event1.event_id],
-            {"highlight_count": 0, "notify_count": 0},
+            {"highlight_count": 0, "unread_count": 0, "notify_count": 0},
         )
 
         self.persist(
@@ -173,7 +173,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
         self.check(
             "get_unread_event_push_actions_by_room_for_user",
             [ROOM_ID, USER_ID_2, event1.event_id],
-            {"highlight_count": 0, "notify_count": 1},
+            {"highlight_count": 0, "unread_count": 0, "notify_count": 1},
         )
 
         self.persist(
@@ -188,7 +188,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
         self.check(
             "get_unread_event_push_actions_by_room_for_user",
             [ROOM_ID, USER_ID_2, event1.event_id],
-            {"highlight_count": 1, "notify_count": 2},
+            {"highlight_count": 1, "unread_count": 0, "notify_count": 2},
         )
 
     def test_get_rooms_for_user_with_stream_ordering(self):
@@ -366,7 +366,11 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
         state_handler = self.hs.get_state_handler()
         context = self.get_success(state_handler.compute_event_context(event))
 
-        self.master_store.add_push_actions_to_staging(
-            event.event_id, {user_id: actions for user_id, actions in push_actions}
+        self.get_success(
+            self.master_store.add_push_actions_to_staging(
+                event.event_id,
+                {user_id: actions for user_id, actions in push_actions},
+                False,
+            )
         )
         return event, context
diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
index 51bf0ef4e9..c9998e88e6 100644
--- a/tests/replication/tcp/streams/test_events.py
+++ b/tests/replication/tcp/streams/test_events.py
@@ -17,6 +17,7 @@ from typing import List, Optional
 
 from synapse.api.constants import EventTypes, Membership
 from synapse.events import EventBase
+from synapse.replication.tcp.commands import RdataCommand
 from synapse.replication.tcp.streams._base import _STREAM_UPDATE_TARGET_ROW_COUNT
 from synapse.replication.tcp.streams.events import (
     EventsStreamCurrentStateRow,
@@ -66,11 +67,6 @@ class EventsStreamTestCase(BaseStreamTestCase):
         # also one state event
         state_event = self._inject_state_event()
 
-        # tell the notifier to catch up to avoid duplicate rows.
-        # workaround for https://github.com/matrix-org/synapse/issues/7360
-        # FIXME remove this when the above is fixed
-        self.replicate()
-
         # check we're testing what we think we are: no rows should yet have been
         # received
         self.assertEqual([], self.test_handler.received_rdata_rows)
@@ -123,7 +119,9 @@ class EventsStreamTestCase(BaseStreamTestCase):
         OTHER_USER = "@other_user:localhost"
 
         # have the user join
-        inject_member_event(self.hs, self.room_id, OTHER_USER, Membership.JOIN)
+        self.get_success(
+            inject_member_event(self.hs, self.room_id, OTHER_USER, Membership.JOIN)
+        )
 
         # Update existing power levels with mod at PL50
         pls = self.helper.get_state(
@@ -161,24 +159,21 @@ class EventsStreamTestCase(BaseStreamTestCase):
         # roll back all the state by de-modding the user
         prev_events = fork_point
         pls["users"][OTHER_USER] = 0
-        pl_event = inject_event(
-            self.hs,
-            prev_event_ids=prev_events,
-            type=EventTypes.PowerLevels,
-            state_key="",
-            sender=self.user_id,
-            room_id=self.room_id,
-            content=pls,
+        pl_event = self.get_success(
+            inject_event(
+                self.hs,
+                prev_event_ids=prev_events,
+                type=EventTypes.PowerLevels,
+                state_key="",
+                sender=self.user_id,
+                room_id=self.room_id,
+                content=pls,
+            )
         )
 
         # one more bit of state that doesn't get rolled back
         state2 = self._inject_state_event()
 
-        # tell the notifier to catch up to avoid duplicate rows.
-        # workaround for https://github.com/matrix-org/synapse/issues/7360
-        # FIXME remove this when the above is fixed
-        self.replicate()
-
         # check we're testing what we think we are: no rows should yet have been
         # received
         self.assertEqual([], self.test_handler.received_rdata_rows)
@@ -277,7 +272,9 @@ class EventsStreamTestCase(BaseStreamTestCase):
 
         # have the users join
         for u in user_ids:
-            inject_member_event(self.hs, self.room_id, u, Membership.JOIN)
+            self.get_success(
+                inject_member_event(self.hs, self.room_id, u, Membership.JOIN)
+            )
 
         # Update existing power levels with mod at PL50
         pls = self.helper.get_state(
@@ -315,23 +312,20 @@ class EventsStreamTestCase(BaseStreamTestCase):
         pl_events = []
         for u in user_ids:
             pls["users"][u] = 0
-            e = inject_event(
-                self.hs,
-                prev_event_ids=prev_events,
-                type=EventTypes.PowerLevels,
-                state_key="",
-                sender=self.user_id,
-                room_id=self.room_id,
-                content=pls,
+            e = self.get_success(
+                inject_event(
+                    self.hs,
+                    prev_event_ids=prev_events,
+                    type=EventTypes.PowerLevels,
+                    state_key="",
+                    sender=self.user_id,
+                    room_id=self.room_id,
+                    content=pls,
+                )
             )
             prev_events = [e.event_id]
             pl_events.append(e)
 
-        # tell the notifier to catch up to avoid duplicate rows.
-        # workaround for https://github.com/matrix-org/synapse/issues/7360
-        # FIXME remove this when the above is fixed
-        self.replicate()
-
         # check we're testing what we think we are: no rows should yet have been
         # received
         self.assertEqual([], self.test_handler.received_rdata_rows)
@@ -378,6 +372,64 @@ class EventsStreamTestCase(BaseStreamTestCase):
 
         self.assertEqual([], received_rows)
 
+    def test_backwards_stream_id(self):
+        """
+        Test that RDATA that comes after the current position should be discarded.
+        """
+        # disconnect, so that we can stack up some changes
+        self.disconnect()
+
+        # Generate an events. We inject them using inject_event so that they are
+        # not send out over replication until we call self.replicate().
+        event = self._inject_test_event()
+
+        # check we're testing what we think we are: no rows should yet have been
+        # received
+        self.assertEqual([], self.test_handler.received_rdata_rows)
+
+        # now reconnect to pull the updates
+        self.reconnect()
+        self.replicate()
+
+        # We should have received the expected single row (as well as various
+        # cache invalidation updates which we ignore).
+        received_rows = [
+            row for row in self.test_handler.received_rdata_rows if row[0] == "events"
+        ]
+
+        # There should be a single received row.
+        self.assertEqual(len(received_rows), 1)
+
+        stream_name, token, row = received_rows[0]
+        self.assertEqual("events", stream_name)
+        self.assertIsInstance(row, EventsStreamRow)
+        self.assertEqual(row.type, "ev")
+        self.assertIsInstance(row.data, EventsStreamEventRow)
+        self.assertEqual(row.data.event_id, event.event_id)
+
+        # Reset the data.
+        self.test_handler.received_rdata_rows = []
+
+        # Save the current token for later.
+        worker_events_stream = self.worker_hs.get_replication_streams()["events"]
+        prev_token = worker_events_stream.current_token("master")
+
+        # Manually send an old RDATA command, which should get dropped. This
+        # re-uses the row from above, but with an earlier stream token.
+        self.hs.get_tcp_replication().send_command(
+            RdataCommand("events", "master", 1, row)
+        )
+
+        # No updates have been received (because it was discard as old).
+        received_rows = [
+            row for row in self.test_handler.received_rdata_rows if row[0] == "events"
+        ]
+        self.assertEqual(len(received_rows), 0)
+
+        # Ensure the stream has not gone backwards.
+        current_token = worker_events_stream.current_token("master")
+        self.assertGreaterEqual(current_token, prev_token)
+
     event_count = 0
 
     def _inject_test_event(
@@ -390,13 +442,15 @@ class EventsStreamTestCase(BaseStreamTestCase):
             body = "event %i" % (self.event_count,)
             self.event_count += 1
 
-        return inject_event(
-            self.hs,
-            room_id=self.room_id,
-            sender=sender,
-            type="test_event",
-            content={"body": body},
-            **kwargs
+        return self.get_success(
+            inject_event(
+                self.hs,
+                room_id=self.room_id,
+                sender=sender,
+                type="test_event",
+                content={"body": body},
+                **kwargs
+            )
         )
 
     def _inject_state_event(
@@ -415,11 +469,13 @@ class EventsStreamTestCase(BaseStreamTestCase):
         if body is None:
             body = "state event %s" % (state_key,)
 
-        return inject_event(
-            self.hs,
-            room_id=self.room_id,
-            sender=sender,
-            type="test_state_event",
-            state_key=state_key,
-            content={"body": body},
+        return self.get_success(
+            inject_event(
+                self.hs,
+                room_id=self.room_id,
+                sender=sender,
+                type="test_state_event",
+                state_key=state_key,
+                content={"body": body},
+            )
         )
diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py
index fd62b26356..5acfb3e53e 100644
--- a/tests/replication/tcp/streams/test_typing.py
+++ b/tests/replication/tcp/streams/test_typing.py
@@ -16,10 +16,15 @@ from mock import Mock
 
 from synapse.handlers.typing import RoomMember
 from synapse.replication.tcp.streams import TypingStream
+from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 from tests.replication._base import BaseStreamTestCase
 
 USER_ID = "@feeling:blue"
+USER_ID_2 = "@da-ba-dee:blue"
+
+ROOM_ID = "!bar:blue"
+ROOM_ID_2 = "!foo:blue"
 
 
 class TypingStreamTestCase(BaseStreamTestCase):
@@ -29,11 +34,9 @@ class TypingStreamTestCase(BaseStreamTestCase):
     def test_typing(self):
         typing = self.hs.get_typing_handler()
 
-        room_id = "!bar:blue"
-
         self.reconnect()
 
-        typing._push_update(member=RoomMember(room_id, USER_ID), typing=True)
+        typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=True)
 
         self.reactor.advance(0)
 
@@ -46,7 +49,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
         self.assertEqual(stream_name, "typing")
         self.assertEqual(1, len(rdata_rows))
         row = rdata_rows[0]  # type: TypingStream.TypingStreamRow
-        self.assertEqual(room_id, row.room_id)
+        self.assertEqual(ROOM_ID, row.room_id)
         self.assertEqual([USER_ID], row.user_ids)
 
         # Now let's disconnect and insert some data.
@@ -54,7 +57,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
 
         self.test_handler.on_rdata.reset_mock()
 
-        typing._push_update(member=RoomMember(room_id, USER_ID), typing=False)
+        typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=False)
 
         self.test_handler.on_rdata.assert_not_called()
 
@@ -73,5 +76,78 @@ class TypingStreamTestCase(BaseStreamTestCase):
         self.assertEqual(stream_name, "typing")
         self.assertEqual(1, len(rdata_rows))
         row = rdata_rows[0]
-        self.assertEqual(room_id, row.room_id)
+        self.assertEqual(ROOM_ID, row.room_id)
+        self.assertEqual([], row.user_ids)
+
+    def test_reset(self):
+        """
+        Test what happens when a typing stream resets.
+
+        This is emulated by jumping the stream ahead, then reconnecting (which
+        sends the proper position and RDATA).
+        """
+        typing = self.hs.get_typing_handler()
+
+        self.reconnect()
+
+        typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=True)
+
+        self.reactor.advance(0)
+
+        # We should now see an attempt to connect to the master
+        request = self.handle_http_replication_attempt()
+        self.assert_request_is_get_repl_stream_updates(request, "typing")
+
+        self.test_handler.on_rdata.assert_called_once()
+        stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+        self.assertEqual(stream_name, "typing")
+        self.assertEqual(1, len(rdata_rows))
+        row = rdata_rows[0]  # type: TypingStream.TypingStreamRow
+        self.assertEqual(ROOM_ID, row.room_id)
+        self.assertEqual([USER_ID], row.user_ids)
+
+        # Push the stream forward a bunch so it can be reset.
+        for i in range(100):
+            typing._push_update(
+                member=RoomMember(ROOM_ID, "@test%s:blue" % i), typing=True
+            )
+        self.reactor.advance(0)
+
+        # Disconnect.
+        self.disconnect()
+
+        # Reset the typing handler
+        self.hs.get_replication_streams()["typing"].last_token = 0
+        self.hs.get_tcp_replication()._streams["typing"].last_token = 0
+        typing._latest_room_serial = 0
+        typing._typing_stream_change_cache = StreamChangeCache(
+            "TypingStreamChangeCache", typing._latest_room_serial
+        )
+        typing._reset()
+
+        # Reconnect.
+        self.reconnect()
+        self.pump(0.1)
+
+        # We should now see an attempt to connect to the master
+        request = self.handle_http_replication_attempt()
+        self.assert_request_is_get_repl_stream_updates(request, "typing")
+
+        # Reset the test code.
+        self.test_handler.on_rdata.reset_mock()
+        self.test_handler.on_rdata.assert_not_called()
+
+        # Push additional data.
+        typing._push_update(member=RoomMember(ROOM_ID_2, USER_ID_2), typing=False)
+        self.reactor.advance(0)
+
+        self.test_handler.on_rdata.assert_called_once()
+        stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+        self.assertEqual(stream_name, "typing")
+        self.assertEqual(1, len(rdata_rows))
+        row = rdata_rows[0]
+        self.assertEqual(ROOM_ID_2, row.room_id)
         self.assertEqual([], row.user_ids)
+
+        # The token should have been reset.
+        self.assertEqual(token, 1)
diff --git a/tests/replication/test_client_reader_shard.py b/tests/replication/test_client_reader_shard.py
new file mode 100644
index 0000000000..86c03fd89c
--- /dev/null
+++ b/tests/replication/test_client_reader_shard.py
@@ -0,0 +1,96 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from synapse.api.constants import LoginType
+from synapse.http.site import SynapseRequest
+from synapse.rest.client.v2_alpha import register
+
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+from tests.rest.client.v2_alpha.test_auth import DummyRecaptchaChecker
+from tests.server import FakeChannel
+
+logger = logging.getLogger(__name__)
+
+
+class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
+    """Base class for tests of the replication streams"""
+
+    servlets = [register.register_servlets]
+
+    def prepare(self, reactor, clock, hs):
+        self.recaptcha_checker = DummyRecaptchaChecker(hs)
+        auth_handler = hs.get_auth_handler()
+        auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker
+
+    def _get_worker_hs_config(self) -> dict:
+        config = self.default_config()
+        config["worker_app"] = "synapse.app.client_reader"
+        config["worker_replication_host"] = "testserv"
+        config["worker_replication_http_port"] = "8765"
+        return config
+
+    def test_register_single_worker(self):
+        """Test that registration works when using a single client reader worker.
+        """
+        worker_hs = self.make_worker_hs("synapse.app.client_reader")
+
+        request_1, channel_1 = self.make_request(
+            "POST",
+            "register",
+            {"username": "user", "type": "m.login.password", "password": "bar"},
+        )  # type: SynapseRequest, FakeChannel
+        self.render_on_worker(worker_hs, request_1)
+        self.assertEqual(request_1.code, 401)
+
+        # Grab the session
+        session = channel_1.json_body["session"]
+
+        # also complete the dummy auth
+        request_2, channel_2 = self.make_request(
+            "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}}
+        )  # type: SynapseRequest, FakeChannel
+        self.render_on_worker(worker_hs, request_2)
+        self.assertEqual(request_2.code, 200)
+
+        # We're given a registered user.
+        self.assertEqual(channel_2.json_body["user_id"], "@user:test")
+
+    def test_register_multi_worker(self):
+        """Test that registration works when using multiple client reader workers.
+        """
+        worker_hs_1 = self.make_worker_hs("synapse.app.client_reader")
+        worker_hs_2 = self.make_worker_hs("synapse.app.client_reader")
+
+        request_1, channel_1 = self.make_request(
+            "POST",
+            "register",
+            {"username": "user", "type": "m.login.password", "password": "bar"},
+        )  # type: SynapseRequest, FakeChannel
+        self.render_on_worker(worker_hs_1, request_1)
+        self.assertEqual(request_1.code, 401)
+
+        # Grab the session
+        session = channel_1.json_body["session"]
+
+        # also complete the dummy auth
+        request_2, channel_2 = self.make_request(
+            "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}}
+        )  # type: SynapseRequest, FakeChannel
+        self.render_on_worker(worker_hs_2, request_2)
+        self.assertEqual(request_2.code, 200)
+
+        # We're given a registered user.
+        self.assertEqual(channel_2.json_body["user_id"], "@user:test")
diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py
index 5448d9f0dc..23be1167a3 100644
--- a/tests/replication/test_federation_ack.py
+++ b/tests/replication/test_federation_ack.py
@@ -32,6 +32,7 @@ class FederationAckTestCase(HomeserverTestCase):
 
     def make_homeserver(self, reactor, clock):
         hs = self.setup_test_homeserver(homeserverToUse=GenericWorkerServer)
+
         return hs
 
     def test_federation_ack_sent(self):
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
new file mode 100644
index 0000000000..8b4982ecb1
--- /dev/null
+++ b/tests/replication/test_federation_sender_shard.py
@@ -0,0 +1,234 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from mock import Mock
+
+from synapse.api.constants import EventTypes, Membership
+from synapse.events.builder import EventBuilderFactory
+from synapse.rest.admin import register_servlets_for_client_rest_resource
+from synapse.rest.client.v1 import login, room
+from synapse.types import UserID, create_requester
+
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+from tests.test_utils import make_awaitable
+
+logger = logging.getLogger(__name__)
+
+
+class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
+    servlets = [
+        login.register_servlets,
+        register_servlets_for_client_rest_resource,
+        room.register_servlets,
+    ]
+
+    def default_config(self):
+        conf = super().default_config()
+        conf["send_federation"] = False
+        return conf
+
+    def test_send_event_single_sender(self):
+        """Test that using a single federation sender worker correctly sends a
+        new event.
+        """
+        mock_client = Mock(spec=["put_json"])
+        mock_client.put_json.side_effect = lambda *_, **__: make_awaitable({})
+
+        self.make_worker_hs(
+            "synapse.app.federation_sender",
+            {"send_federation": True},
+            http_client=mock_client,
+        )
+
+        user = self.register_user("user", "pass")
+        token = self.login("user", "pass")
+
+        room = self.create_room_with_remote_server(user, token)
+
+        mock_client.put_json.reset_mock()
+
+        self.create_and_send_event(room, UserID.from_string(user))
+        self.replicate()
+
+        # Assert that the event was sent out over federation.
+        mock_client.put_json.assert_called()
+        self.assertEqual(mock_client.put_json.call_args[0][0], "other_server")
+        self.assertTrue(mock_client.put_json.call_args[1]["data"].get("pdus"))
+
+    def test_send_event_sharded(self):
+        """Test that using two federation sender workers correctly sends
+        new events.
+        """
+        mock_client1 = Mock(spec=["put_json"])
+        mock_client1.put_json.side_effect = lambda *_, **__: make_awaitable({})
+        self.make_worker_hs(
+            "synapse.app.federation_sender",
+            {
+                "send_federation": True,
+                "worker_name": "sender1",
+                "federation_sender_instances": ["sender1", "sender2"],
+            },
+            http_client=mock_client1,
+        )
+
+        mock_client2 = Mock(spec=["put_json"])
+        mock_client2.put_json.side_effect = lambda *_, **__: make_awaitable({})
+        self.make_worker_hs(
+            "synapse.app.federation_sender",
+            {
+                "send_federation": True,
+                "worker_name": "sender2",
+                "federation_sender_instances": ["sender1", "sender2"],
+            },
+            http_client=mock_client2,
+        )
+
+        user = self.register_user("user2", "pass")
+        token = self.login("user2", "pass")
+
+        sent_on_1 = False
+        sent_on_2 = False
+        for i in range(20):
+            server_name = "other_server_%d" % (i,)
+            room = self.create_room_with_remote_server(user, token, server_name)
+            mock_client1.reset_mock()  # type: ignore[attr-defined]
+            mock_client2.reset_mock()  # type: ignore[attr-defined]
+
+            self.create_and_send_event(room, UserID.from_string(user))
+            self.replicate()
+
+            if mock_client1.put_json.called:
+                sent_on_1 = True
+                mock_client2.put_json.assert_not_called()
+                self.assertEqual(mock_client1.put_json.call_args[0][0], server_name)
+                self.assertTrue(mock_client1.put_json.call_args[1]["data"].get("pdus"))
+            elif mock_client2.put_json.called:
+                sent_on_2 = True
+                mock_client1.put_json.assert_not_called()
+                self.assertEqual(mock_client2.put_json.call_args[0][0], server_name)
+                self.assertTrue(mock_client2.put_json.call_args[1]["data"].get("pdus"))
+            else:
+                raise AssertionError(
+                    "Expected send transaction from one or the other sender"
+                )
+
+            if sent_on_1 and sent_on_2:
+                break
+
+        self.assertTrue(sent_on_1)
+        self.assertTrue(sent_on_2)
+
+    def test_send_typing_sharded(self):
+        """Test that using two federation sender workers correctly sends
+        new typing EDUs.
+        """
+        mock_client1 = Mock(spec=["put_json"])
+        mock_client1.put_json.side_effect = lambda *_, **__: make_awaitable({})
+        self.make_worker_hs(
+            "synapse.app.federation_sender",
+            {
+                "send_federation": True,
+                "worker_name": "sender1",
+                "federation_sender_instances": ["sender1", "sender2"],
+            },
+            http_client=mock_client1,
+        )
+
+        mock_client2 = Mock(spec=["put_json"])
+        mock_client2.put_json.side_effect = lambda *_, **__: make_awaitable({})
+        self.make_worker_hs(
+            "synapse.app.federation_sender",
+            {
+                "send_federation": True,
+                "worker_name": "sender2",
+                "federation_sender_instances": ["sender1", "sender2"],
+            },
+            http_client=mock_client2,
+        )
+
+        user = self.register_user("user3", "pass")
+        token = self.login("user3", "pass")
+
+        typing_handler = self.hs.get_typing_handler()
+
+        sent_on_1 = False
+        sent_on_2 = False
+        for i in range(20):
+            server_name = "other_server_%d" % (i,)
+            room = self.create_room_with_remote_server(user, token, server_name)
+            mock_client1.reset_mock()  # type: ignore[attr-defined]
+            mock_client2.reset_mock()  # type: ignore[attr-defined]
+
+            self.get_success(
+                typing_handler.started_typing(
+                    target_user=UserID.from_string(user),
+                    requester=create_requester(user),
+                    room_id=room,
+                    timeout=20000,
+                )
+            )
+
+            self.replicate()
+
+            if mock_client1.put_json.called:
+                sent_on_1 = True
+                mock_client2.put_json.assert_not_called()
+                self.assertEqual(mock_client1.put_json.call_args[0][0], server_name)
+                self.assertTrue(mock_client1.put_json.call_args[1]["data"].get("edus"))
+            elif mock_client2.put_json.called:
+                sent_on_2 = True
+                mock_client1.put_json.assert_not_called()
+                self.assertEqual(mock_client2.put_json.call_args[0][0], server_name)
+                self.assertTrue(mock_client2.put_json.call_args[1]["data"].get("edus"))
+            else:
+                raise AssertionError(
+                    "Expected send transaction from one or the other sender"
+                )
+
+            if sent_on_1 and sent_on_2:
+                break
+
+        self.assertTrue(sent_on_1)
+        self.assertTrue(sent_on_2)
+
+    def create_room_with_remote_server(self, user, token, remote_server="other_server"):
+        room = self.helper.create_room_as(user, tok=token)
+        store = self.hs.get_datastore()
+        federation = self.hs.get_handlers().federation_handler
+
+        prev_event_ids = self.get_success(store.get_latest_event_ids_in_room(room))
+        room_version = self.get_success(store.get_room_version(room))
+
+        factory = EventBuilderFactory(self.hs)
+        factory.hostname = remote_server
+
+        user_id = UserID("user", remote_server).to_string()
+
+        event_dict = {
+            "type": EventTypes.Member,
+            "state_key": user_id,
+            "content": {"membership": Membership.JOIN},
+            "sender": user_id,
+            "room_id": room,
+        }
+
+        builder = factory.for_room_version(room_version, event_dict)
+        join_event = self.get_success(builder.build(prev_event_ids))
+
+        self.get_success(federation.on_send_join_request(remote_server, join_event))
+        self.replicate()
+
+        return room
diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py
new file mode 100644
index 0000000000..2bdc6edbb1
--- /dev/null
+++ b/tests/replication/test_pusher_shard.py
@@ -0,0 +1,193 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from mock import Mock
+
+from twisted.internet import defer
+
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+
+logger = logging.getLogger(__name__)
+
+
+class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
+    """Checks pusher sharding works
+    """
+
+    servlets = [
+        admin.register_servlets_for_client_rest_resource,
+        room.register_servlets,
+        login.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, hs):
+        # Register a user who sends a message that we'll get notified about
+        self.other_user_id = self.register_user("otheruser", "pass")
+        self.other_access_token = self.login("otheruser", "pass")
+
+    def default_config(self):
+        conf = super().default_config()
+        conf["start_pushers"] = False
+        return conf
+
+    def _create_pusher_and_send_msg(self, localpart):
+        # Create a user that will get push notifications
+        user_id = self.register_user(localpart, "pass")
+        access_token = self.login(localpart, "pass")
+
+        # Register a pusher
+        user_dict = self.get_success(
+            self.hs.get_datastore().get_user_by_access_token(access_token)
+        )
+        token_id = user_dict["token_id"]
+
+        self.get_success(
+            self.hs.get_pusherpool().add_pusher(
+                user_id=user_id,
+                access_token=token_id,
+                kind="http",
+                app_id="m.http",
+                app_display_name="HTTP Push Notifications",
+                device_display_name="pushy push",
+                pushkey="a@example.com",
+                lang=None,
+                data={"url": "https://push.example.com/push"},
+            )
+        )
+
+        self.pump()
+
+        # Create a room
+        room = self.helper.create_room_as(user_id, tok=access_token)
+
+        # The other user joins
+        self.helper.join(
+            room=room, user=self.other_user_id, tok=self.other_access_token
+        )
+
+        # The other user sends some messages
+        response = self.helper.send(room, body="Hi!", tok=self.other_access_token)
+        event_id = response["event_id"]
+
+        return event_id
+
+    def test_send_push_single_worker(self):
+        """Test that registration works when using a pusher worker.
+        """
+        http_client_mock = Mock(spec_set=["post_json_get_json"])
+        http_client_mock.post_json_get_json.side_effect = lambda *_, **__: defer.succeed(
+            {}
+        )
+
+        self.make_worker_hs(
+            "synapse.app.pusher",
+            {"start_pushers": True},
+            proxied_http_client=http_client_mock,
+        )
+
+        event_id = self._create_pusher_and_send_msg("user")
+
+        # Advance time a bit, so the pusher will register something has happened
+        self.pump()
+
+        http_client_mock.post_json_get_json.assert_called_once()
+        self.assertEqual(
+            http_client_mock.post_json_get_json.call_args[0][0],
+            "https://push.example.com/push",
+        )
+        self.assertEqual(
+            event_id,
+            http_client_mock.post_json_get_json.call_args[0][1]["notification"][
+                "event_id"
+            ],
+        )
+
+    def test_send_push_multiple_workers(self):
+        """Test that registration works when using sharded pusher workers.
+        """
+        http_client_mock1 = Mock(spec_set=["post_json_get_json"])
+        http_client_mock1.post_json_get_json.side_effect = lambda *_, **__: defer.succeed(
+            {}
+        )
+
+        self.make_worker_hs(
+            "synapse.app.pusher",
+            {
+                "start_pushers": True,
+                "worker_name": "pusher1",
+                "pusher_instances": ["pusher1", "pusher2"],
+            },
+            proxied_http_client=http_client_mock1,
+        )
+
+        http_client_mock2 = Mock(spec_set=["post_json_get_json"])
+        http_client_mock2.post_json_get_json.side_effect = lambda *_, **__: defer.succeed(
+            {}
+        )
+
+        self.make_worker_hs(
+            "synapse.app.pusher",
+            {
+                "start_pushers": True,
+                "worker_name": "pusher2",
+                "pusher_instances": ["pusher1", "pusher2"],
+            },
+            proxied_http_client=http_client_mock2,
+        )
+
+        # We choose a user name that we know should go to pusher1.
+        event_id = self._create_pusher_and_send_msg("user2")
+
+        # Advance time a bit, so the pusher will register something has happened
+        self.pump()
+
+        http_client_mock1.post_json_get_json.assert_called_once()
+        http_client_mock2.post_json_get_json.assert_not_called()
+        self.assertEqual(
+            http_client_mock1.post_json_get_json.call_args[0][0],
+            "https://push.example.com/push",
+        )
+        self.assertEqual(
+            event_id,
+            http_client_mock1.post_json_get_json.call_args[0][1]["notification"][
+                "event_id"
+            ],
+        )
+
+        http_client_mock1.post_json_get_json.reset_mock()
+        http_client_mock2.post_json_get_json.reset_mock()
+
+        # Now we choose a user name that we know should go to pusher2.
+        event_id = self._create_pusher_and_send_msg("user4")
+
+        # Advance time a bit, so the pusher will register something has happened
+        self.pump()
+
+        http_client_mock1.post_json_get_json.assert_not_called()
+        http_client_mock2.post_json_get_json.assert_called_once()
+        self.assertEqual(
+            http_client_mock2.post_json_get_json.call_args[0][0],
+            "https://push.example.com/push",
+        )
+        self.assertEqual(
+            event_id,
+            http_client_mock2.post_json_get_json.call_args[0][1]["notification"][
+                "event_id"
+            ],
+        )
diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 977615ebef..0f1144fe1e 100644
--- a/tests/rest/admin/test_admin.py
+++ b/tests/rest/admin/test_admin.py
@@ -178,7 +178,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
 
         self.fetches = []
 
-        def get_file(destination, path, output_stream, args=None, max_size=None):
+        async def get_file(destination, path, output_stream, args=None, max_size=None):
             """
             Returns tuple[int,dict,str,int] of file length, response headers,
             absolute URI, and response code.
@@ -192,7 +192,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
             d = Deferred()
             d.addCallback(write_to)
             self.fetches.append((d, destination, path, args))
-            return make_deferred_yieldable(d)
+            return await make_deferred_yieldable(d)
 
         client = Mock()
         client.get_file = get_file
@@ -220,6 +220,24 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
 
         return hs
 
+    def _ensure_quarantined(self, admin_user_tok, server_and_media_id):
+        """Ensure a piece of media is quarantined when trying to access it."""
+        request, channel = self.make_request(
+            "GET", server_and_media_id, shorthand=False, access_token=admin_user_tok,
+        )
+        request.render(self.download_resource)
+        self.pump(1.0)
+
+        # Should be quarantined
+        self.assertEqual(
+            404,
+            int(channel.code),
+            msg=(
+                "Expected to receive a 404 on accessing quarantined media: %s"
+                % server_and_media_id
+            ),
+        )
+
     def test_quarantine_media_requires_admin(self):
         self.register_user("nonadmin", "pass", admin=False)
         non_admin_user_tok = self.login("nonadmin", "pass")
@@ -292,24 +310,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
         self.assertEqual(200, int(channel.code), msg=channel.result["body"])
 
         # Attempt to access the media
-        request, channel = self.make_request(
-            "GET",
-            server_name_and_media_id,
-            shorthand=False,
-            access_token=admin_user_tok,
-        )
-        request.render(self.download_resource)
-        self.pump(1.0)
-
-        # Should be quarantined
-        self.assertEqual(
-            404,
-            int(channel.code),
-            msg=(
-                "Expected to receive a 404 on accessing quarantined media: %s"
-                % server_name_and_media_id
-            ),
-        )
+        self._ensure_quarantined(admin_user_tok, server_name_and_media_id)
 
     def test_quarantine_all_media_in_room(self, override_url_template=None):
         self.register_user("room_admin", "pass", admin=True)
@@ -371,45 +372,10 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
         server_and_media_id_2 = mxc_2[6:]
 
         # Test that we cannot download any of the media anymore
-        request, channel = self.make_request(
-            "GET",
-            server_and_media_id_1,
-            shorthand=False,
-            access_token=non_admin_user_tok,
-        )
-        request.render(self.download_resource)
-        self.pump(1.0)
-
-        # Should be quarantined
-        self.assertEqual(
-            404,
-            int(channel.code),
-            msg=(
-                "Expected to receive a 404 on accessing quarantined media: %s"
-                % server_and_media_id_1
-            ),
-        )
-
-        request, channel = self.make_request(
-            "GET",
-            server_and_media_id_2,
-            shorthand=False,
-            access_token=non_admin_user_tok,
-        )
-        request.render(self.download_resource)
-        self.pump(1.0)
-
-        # Should be quarantined
-        self.assertEqual(
-            404,
-            int(channel.code),
-            msg=(
-                "Expected to receive a 404 on accessing quarantined media: %s"
-                % server_and_media_id_2
-            ),
-        )
+        self._ensure_quarantined(admin_user_tok, server_and_media_id_1)
+        self._ensure_quarantined(admin_user_tok, server_and_media_id_2)
 
-    def test_quaraantine_all_media_in_room_deprecated_api_path(self):
+    def test_quarantine_all_media_in_room_deprecated_api_path(self):
         # Perform the above test with the deprecated API path
         self.test_quarantine_all_media_in_room("/_synapse/admin/v1/quarantine_media/%s")
 
@@ -449,25 +415,52 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
         )
 
         # Attempt to access each piece of media
+        self._ensure_quarantined(admin_user_tok, server_and_media_id_1)
+        self._ensure_quarantined(admin_user_tok, server_and_media_id_2)
+
+    def test_cannot_quarantine_safe_media(self):
+        self.register_user("user_admin", "pass", admin=True)
+        admin_user_tok = self.login("user_admin", "pass")
+
+        non_admin_user = self.register_user("user_nonadmin", "pass", admin=False)
+        non_admin_user_tok = self.login("user_nonadmin", "pass")
+
+        # Upload some media
+        response_1 = self.helper.upload_media(
+            self.upload_resource, self.image_data, tok=non_admin_user_tok
+        )
+        response_2 = self.helper.upload_media(
+            self.upload_resource, self.image_data, tok=non_admin_user_tok
+        )
+
+        # Extract media IDs
+        server_and_media_id_1 = response_1["content_uri"][6:]
+        server_and_media_id_2 = response_2["content_uri"][6:]
+
+        # Mark the second item as safe from quarantine.
+        _, media_id_2 = server_and_media_id_2.split("/")
+        self.get_success(self.store.mark_local_media_as_safe(media_id_2))
+
+        # Quarantine all media by this user
+        url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote(
+            non_admin_user
+        )
         request, channel = self.make_request(
-            "GET",
-            server_and_media_id_1,
-            shorthand=False,
-            access_token=non_admin_user_tok,
+            "POST", url.encode("ascii"), access_token=admin_user_tok,
         )
-        request.render(self.download_resource)
+        self.render(request)
         self.pump(1.0)
-
-        # Should be quarantined
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(
-            404,
-            int(channel.code),
-            msg=(
-                "Expected to receive a 404 on accessing quarantined media: %s"
-                % server_and_media_id_1,
-            ),
+            json.loads(channel.result["body"].decode("utf-8")),
+            {"num_quarantined": 1},
+            "Expected 1 quarantined item",
         )
 
+        # Attempt to access each piece of media, the first should fail, the
+        # second should succeed.
+        self._ensure_quarantined(admin_user_tok, server_and_media_id_1)
+
         # Attempt to access each piece of media
         request, channel = self.make_request(
             "GET",
@@ -478,12 +471,12 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
         request.render(self.download_resource)
         self.pump(1.0)
 
-        # Should be quarantined
+        # Shouldn't be quarantined
         self.assertEqual(
-            404,
+            200,
             int(channel.code),
             msg=(
-                "Expected to receive a 404 on accessing quarantined media: %s"
+                "Expected to receive a 200 on accessing not-quarantined media: %s"
                 % server_and_media_id_2
             ),
         )
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 54cd24bf64..408c568a27 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -1,1007 +1,1500 @@
-# -*- coding: utf-8 -*-

-# Copyright 2020 Dirk Klimpel

-#

-# Licensed under the Apache License, Version 2.0 (the "License");

-# you may not use this file except in compliance with the License.

-# You may obtain a copy of the License at

-#

-#     http://www.apache.org/licenses/LICENSE-2.0

-#

-# Unless required by applicable law or agreed to in writing, software

-# distributed under the License is distributed on an "AS IS" BASIS,

-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

-# See the License for the specific language governing permissions and

-# limitations under the License.

-

-import json

-import urllib.parse

-from typing import List, Optional

-

-from mock import Mock

-

-import synapse.rest.admin

-from synapse.api.errors import Codes

-from synapse.rest.client.v1 import directory, events, login, room

-

-from tests import unittest

-

-"""Tests admin REST events for /rooms paths."""

-

-

-class ShutdownRoomTestCase(unittest.HomeserverTestCase):

-    servlets = [

-        synapse.rest.admin.register_servlets_for_client_rest_resource,

-        login.register_servlets,

-        events.register_servlets,

-        room.register_servlets,

-        room.register_deprecated_servlets,

-    ]

-

-    def prepare(self, reactor, clock, hs):

-        self.event_creation_handler = hs.get_event_creation_handler()

-        hs.config.user_consent_version = "1"

-

-        consent_uri_builder = Mock()

-        consent_uri_builder.build_user_consent_uri.return_value = "http://example.com"

-        self.event_creation_handler._consent_uri_builder = consent_uri_builder

-

-        self.store = hs.get_datastore()

-

-        self.admin_user = self.register_user("admin", "pass", admin=True)

-        self.admin_user_tok = self.login("admin", "pass")

-

-        self.other_user = self.register_user("user", "pass")

-        self.other_user_token = self.login("user", "pass")

-

-        # Mark the admin user as having consented

-        self.get_success(self.store.user_set_consent_version(self.admin_user, "1"))

-

-    def test_shutdown_room_consent(self):

-        """Test that we can shutdown rooms with local users who have not

-        yet accepted the privacy policy. This used to fail when we tried to

-        force part the user from the old room.

-        """

-        self.event_creation_handler._block_events_without_consent_error = None

-

-        room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token)

-

-        # Assert one user in room

-        users_in_room = self.get_success(self.store.get_users_in_room(room_id))

-        self.assertEqual([self.other_user], users_in_room)

-

-        # Enable require consent to send events

-        self.event_creation_handler._block_events_without_consent_error = "Error"

-

-        # Assert that the user is getting consent error

-        self.helper.send(

-            room_id, body="foo", tok=self.other_user_token, expect_code=403

-        )

-

-        # Test that the admin can still send shutdown

-        url = "admin/shutdown_room/" + room_id

-        request, channel = self.make_request(

-            "POST",

-            url.encode("ascii"),

-            json.dumps({"new_room_user_id": self.admin_user}),

-            access_token=self.admin_user_tok,

-        )

-        self.render(request)

-

-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])

-

-        # Assert there is now no longer anyone in the room

-        users_in_room = self.get_success(self.store.get_users_in_room(room_id))

-        self.assertEqual([], users_in_room)

-

-    def test_shutdown_room_block_peek(self):

-        """Test that a world_readable room can no longer be peeked into after

-        it has been shut down.

-        """

-

-        self.event_creation_handler._block_events_without_consent_error = None

-

-        room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token)

-

-        # Enable world readable

-        url = "rooms/%s/state/m.room.history_visibility" % (room_id,)

-        request, channel = self.make_request(

-            "PUT",

-            url.encode("ascii"),

-            json.dumps({"history_visibility": "world_readable"}),

-            access_token=self.other_user_token,

-        )

-        self.render(request)

-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])

-

-        # Test that the admin can still send shutdown

-        url = "admin/shutdown_room/" + room_id

-        request, channel = self.make_request(

-            "POST",

-            url.encode("ascii"),

-            json.dumps({"new_room_user_id": self.admin_user}),

-            access_token=self.admin_user_tok,

-        )

-        self.render(request)

-

-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])

-

-        # Assert we can no longer peek into the room

-        self._assert_peek(room_id, expect_code=403)

-

-    def _assert_peek(self, room_id, expect_code):

-        """Assert that the admin user can (or cannot) peek into the room.

-        """

-

-        url = "rooms/%s/initialSync" % (room_id,)

-        request, channel = self.make_request(

-            "GET", url.encode("ascii"), access_token=self.admin_user_tok

-        )

-        self.render(request)

-        self.assertEqual(

-            expect_code, int(channel.result["code"]), msg=channel.result["body"]

-        )

-

-        url = "events?timeout=0&room_id=" + room_id

-        request, channel = self.make_request(

-            "GET", url.encode("ascii"), access_token=self.admin_user_tok

-        )

-        self.render(request)

-        self.assertEqual(

-            expect_code, int(channel.result["code"]), msg=channel.result["body"]

-        )

-

-

-class PurgeRoomTestCase(unittest.HomeserverTestCase):

-    """Test /purge_room admin API.

-    """

-

-    servlets = [

-        synapse.rest.admin.register_servlets,

-        login.register_servlets,

-        room.register_servlets,

-    ]

-

-    def prepare(self, reactor, clock, hs):

-        self.store = hs.get_datastore()

-

-        self.admin_user = self.register_user("admin", "pass", admin=True)

-        self.admin_user_tok = self.login("admin", "pass")

-

-    def test_purge_room(self):

-        room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)

-

-        # All users have to have left the room.

-        self.helper.leave(room_id, user=self.admin_user, tok=self.admin_user_tok)

-

-        url = "/_synapse/admin/v1/purge_room"

-        request, channel = self.make_request(

-            "POST",

-            url.encode("ascii"),

-            {"room_id": room_id},

-            access_token=self.admin_user_tok,

-        )

-        self.render(request)

-

-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])

-

-        # Test that the following tables have been purged of all rows related to the room.

-        for table in (

-            "current_state_events",

-            "event_backward_extremities",

-            "event_forward_extremities",

-            "event_json",

-            "event_push_actions",

-            "event_search",

-            "events",

-            "group_rooms",

-            "public_room_list_stream",

-            "receipts_graph",

-            "receipts_linearized",

-            "room_aliases",

-            "room_depth",

-            "room_memberships",

-            "room_stats_state",

-            "room_stats_current",

-            "room_stats_historical",

-            "room_stats_earliest_token",

-            "rooms",

-            "stream_ordering_to_exterm",

-            "users_in_public_rooms",

-            "users_who_share_private_rooms",

-            "appservice_room_list",

-            "e2e_room_keys",

-            "event_push_summary",

-            "pusher_throttle",

-            "group_summary_rooms",

-            "local_invites",

-            "room_account_data",

-            "room_tags",

-            # "state_groups",  # Current impl leaves orphaned state groups around.

-            "state_groups_state",

-        ):

-            count = self.get_success(

-                self.store.db.simple_select_one_onecol(

-                    table=table,

-                    keyvalues={"room_id": room_id},

-                    retcol="COUNT(*)",

-                    desc="test_purge_room",

-                )

-            )

-

-            self.assertEqual(count, 0, msg="Rows not purged in {}".format(table))

-

-

-class RoomTestCase(unittest.HomeserverTestCase):

-    """Test /room admin API.

-    """

-

-    servlets = [

-        synapse.rest.admin.register_servlets,

-        login.register_servlets,

-        room.register_servlets,

-        directory.register_servlets,

-    ]

-

-    def prepare(self, reactor, clock, hs):

-        self.store = hs.get_datastore()

-

-        # Create user

-        self.admin_user = self.register_user("admin", "pass", admin=True)

-        self.admin_user_tok = self.login("admin", "pass")

-

-    def test_list_rooms(self):

-        """Test that we can list rooms"""

-        # Create 3 test rooms

-        total_rooms = 3

-        room_ids = []

-        for x in range(total_rooms):

-            room_id = self.helper.create_room_as(

-                self.admin_user, tok=self.admin_user_tok

-            )

-            room_ids.append(room_id)

-

-        # Request the list of rooms

-        url = "/_synapse/admin/v1/rooms"

-        request, channel = self.make_request(

-            "GET", url.encode("ascii"), access_token=self.admin_user_tok,

-        )

-        self.render(request)

-

-        # Check request completed successfully

-        self.assertEqual(200, int(channel.code), msg=channel.json_body)

-

-        # Check that response json body contains a "rooms" key

-        self.assertTrue(

-            "rooms" in channel.json_body,

-            msg="Response body does not " "contain a 'rooms' key",

-        )

-

-        # Check that 3 rooms were returned

-        self.assertEqual(3, len(channel.json_body["rooms"]), msg=channel.json_body)

-

-        # Check their room_ids match

-        returned_room_ids = [room["room_id"] for room in channel.json_body["rooms"]]

-        self.assertEqual(room_ids, returned_room_ids)

-

-        # Check that all fields are available

-        for r in channel.json_body["rooms"]:

-            self.assertIn("name", r)

-            self.assertIn("canonical_alias", r)

-            self.assertIn("joined_members", r)

-            self.assertIn("joined_local_members", r)

-            self.assertIn("version", r)

-            self.assertIn("creator", r)

-            self.assertIn("encryption", r)

-            self.assertIn("federatable", r)

-            self.assertIn("public", r)

-            self.assertIn("join_rules", r)

-            self.assertIn("guest_access", r)

-            self.assertIn("history_visibility", r)

-            self.assertIn("state_events", r)

-

-        # Check that the correct number of total rooms was returned

-        self.assertEqual(channel.json_body["total_rooms"], total_rooms)

-

-        # Check that the offset is correct

-        # Should be 0 as we aren't paginating

-        self.assertEqual(channel.json_body["offset"], 0)

-

-        # Check that the prev_batch parameter is not present

-        self.assertNotIn("prev_batch", channel.json_body)

-

-        # We shouldn't receive a next token here as there's no further rooms to show

-        self.assertNotIn("next_batch", channel.json_body)

-

-    def test_list_rooms_pagination(self):

-        """Test that we can get a full list of rooms through pagination"""

-        # Create 5 test rooms

-        total_rooms = 5

-        room_ids = []

-        for x in range(total_rooms):

-            room_id = self.helper.create_room_as(

-                self.admin_user, tok=self.admin_user_tok

-            )

-            room_ids.append(room_id)

-

-        # Set the name of the rooms so we get a consistent returned ordering

-        for idx, room_id in enumerate(room_ids):

-            self.helper.send_state(

-                room_id, "m.room.name", {"name": str(idx)}, tok=self.admin_user_tok,

-            )

-

-        # Request the list of rooms

-        returned_room_ids = []

-        start = 0

-        limit = 2

-

-        run_count = 0

-        should_repeat = True

-        while should_repeat:

-            run_count += 1

-

-            url = "/_synapse/admin/v1/rooms?from=%d&limit=%d&order_by=%s" % (

-                start,

-                limit,

-                "name",

-            )

-            request, channel = self.make_request(

-                "GET", url.encode("ascii"), access_token=self.admin_user_tok,

-            )

-            self.render(request)

-            self.assertEqual(

-                200, int(channel.result["code"]), msg=channel.result["body"]

-            )

-

-            self.assertTrue("rooms" in channel.json_body)

-            for r in channel.json_body["rooms"]:

-                returned_room_ids.append(r["room_id"])

-

-            # Check that the correct number of total rooms was returned

-            self.assertEqual(channel.json_body["total_rooms"], total_rooms)

-

-            # Check that the offset is correct

-            # We're only getting 2 rooms each page, so should be 2 * last run_count

-            self.assertEqual(channel.json_body["offset"], 2 * (run_count - 1))

-

-            if run_count > 1:

-                # Check the value of prev_batch is correct

-                self.assertEqual(channel.json_body["prev_batch"], 2 * (run_count - 2))

-

-            if "next_batch" not in channel.json_body:

-                # We have reached the end of the list

-                should_repeat = False

-            else:

-                # Make another query with an updated start value

-                start = channel.json_body["next_batch"]

-

-        # We should've queried the endpoint 3 times

-        self.assertEqual(

-            run_count,

-            3,

-            msg="Should've queried 3 times for 5 rooms with limit 2 per query",

-        )

-

-        # Check that we received all of the room ids

-        self.assertEqual(room_ids, returned_room_ids)

-

-        url = "/_synapse/admin/v1/rooms?from=%d&limit=%d" % (start, limit)

-        request, channel = self.make_request(

-            "GET", url.encode("ascii"), access_token=self.admin_user_tok,

-        )

-        self.render(request)

-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])

-

-    def test_correct_room_attributes(self):

-        """Test the correct attributes for a room are returned"""

-        # Create a test room

-        room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)

-

-        test_alias = "#test:test"

-        test_room_name = "something"

-

-        # Have another user join the room

-        user_2 = self.register_user("user4", "pass")

-        user_tok_2 = self.login("user4", "pass")

-        self.helper.join(room_id, user_2, tok=user_tok_2)

-

-        # Create a new alias to this room

-        url = "/_matrix/client/r0/directory/room/%s" % (urllib.parse.quote(test_alias),)

-        request, channel = self.make_request(

-            "PUT",

-            url.encode("ascii"),

-            {"room_id": room_id},

-            access_token=self.admin_user_tok,

-        )

-        self.render(request)

-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])

-

-        # Set this new alias as the canonical alias for this room

-        self.helper.send_state(

-            room_id,

-            "m.room.aliases",

-            {"aliases": [test_alias]},

-            tok=self.admin_user_tok,

-            state_key="test",

-        )

-        self.helper.send_state(

-            room_id,

-            "m.room.canonical_alias",

-            {"alias": test_alias},

-            tok=self.admin_user_tok,

-        )

-

-        # Set a name for the room

-        self.helper.send_state(

-            room_id, "m.room.name", {"name": test_room_name}, tok=self.admin_user_tok,

-        )

-

-        # Request the list of rooms

-        url = "/_synapse/admin/v1/rooms"

-        request, channel = self.make_request(

-            "GET", url.encode("ascii"), access_token=self.admin_user_tok,

-        )

-        self.render(request)

-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])

-

-        # Check that rooms were returned

-        self.assertTrue("rooms" in channel.json_body)

-        rooms = channel.json_body["rooms"]

-

-        # Check that only one room was returned

-        self.assertEqual(len(rooms), 1)

-

-        # And that the value of the total_rooms key was correct

-        self.assertEqual(channel.json_body["total_rooms"], 1)

-

-        # Check that the offset is correct

-        # We're not paginating, so should be 0

-        self.assertEqual(channel.json_body["offset"], 0)

-

-        # Check that there is no `prev_batch`

-        self.assertNotIn("prev_batch", channel.json_body)

-

-        # Check that there is no `next_batch`

-        self.assertNotIn("next_batch", channel.json_body)

-

-        # Check that all provided attributes are set

-        r = rooms[0]

-        self.assertEqual(room_id, r["room_id"])

-        self.assertEqual(test_room_name, r["name"])

-        self.assertEqual(test_alias, r["canonical_alias"])

-

-    def test_room_list_sort_order(self):

-        """Test room list sort ordering. alphabetical name versus number of members,

-        reversing the order, etc.

-        """

-

-        def _set_canonical_alias(room_id: str, test_alias: str, admin_user_tok: str):

-            # Create a new alias to this room

-            url = "/_matrix/client/r0/directory/room/%s" % (

-                urllib.parse.quote(test_alias),

-            )

-            request, channel = self.make_request(

-                "PUT",

-                url.encode("ascii"),

-                {"room_id": room_id},

-                access_token=admin_user_tok,

-            )

-            self.render(request)

-            self.assertEqual(

-                200, int(channel.result["code"]), msg=channel.result["body"]

-            )

-

-            # Set this new alias as the canonical alias for this room

-            self.helper.send_state(

-                room_id,

-                "m.room.aliases",

-                {"aliases": [test_alias]},

-                tok=admin_user_tok,

-                state_key="test",

-            )

-            self.helper.send_state(

-                room_id,

-                "m.room.canonical_alias",

-                {"alias": test_alias},

-                tok=admin_user_tok,

-            )

-

-        def _order_test(

-            order_type: str, expected_room_list: List[str], reverse: bool = False,

-        ):

-            """Request the list of rooms in a certain order. Assert that order is what

-            we expect

-

-            Args:

-                order_type: The type of ordering to give the server

-                expected_room_list: The list of room_ids in the order we expect to get

-                    back from the server

-            """

-            # Request the list of rooms in the given order

-            url = "/_synapse/admin/v1/rooms?order_by=%s" % (order_type,)

-            if reverse:

-                url += "&dir=b"

-            request, channel = self.make_request(

-                "GET", url.encode("ascii"), access_token=self.admin_user_tok,

-            )

-            self.render(request)

-            self.assertEqual(200, channel.code, msg=channel.json_body)

-

-            # Check that rooms were returned

-            self.assertTrue("rooms" in channel.json_body)

-            rooms = channel.json_body["rooms"]

-

-            # Check for the correct total_rooms value

-            self.assertEqual(channel.json_body["total_rooms"], 3)

-

-            # Check that the offset is correct

-            # We're not paginating, so should be 0

-            self.assertEqual(channel.json_body["offset"], 0)

-

-            # Check that there is no `prev_batch`

-            self.assertNotIn("prev_batch", channel.json_body)

-

-            # Check that there is no `next_batch`

-            self.assertNotIn("next_batch", channel.json_body)

-

-            # Check that rooms were returned in alphabetical order

-            returned_order = [r["room_id"] for r in rooms]

-            self.assertListEqual(expected_room_list, returned_order)  # order is checked

-

-        # Create 3 test rooms

-        room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)

-        room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)

-        room_id_3 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)

-

-        # Set room names in alphabetical order. room 1 -> A, 2 -> B, 3 -> C

-        self.helper.send_state(

-            room_id_1, "m.room.name", {"name": "A"}, tok=self.admin_user_tok,

-        )

-        self.helper.send_state(

-            room_id_2, "m.room.name", {"name": "B"}, tok=self.admin_user_tok,

-        )

-        self.helper.send_state(

-            room_id_3, "m.room.name", {"name": "C"}, tok=self.admin_user_tok,

-        )

-

-        # Set room canonical room aliases

-        _set_canonical_alias(room_id_1, "#A_alias:test", self.admin_user_tok)

-        _set_canonical_alias(room_id_2, "#B_alias:test", self.admin_user_tok)

-        _set_canonical_alias(room_id_3, "#C_alias:test", self.admin_user_tok)

-

-        # Set room member size in the reverse order. room 1 -> 1 member, 2 -> 2, 3 -> 3

-        user_1 = self.register_user("bob1", "pass")

-        user_1_tok = self.login("bob1", "pass")

-        self.helper.join(room_id_2, user_1, tok=user_1_tok)

-

-        user_2 = self.register_user("bob2", "pass")

-        user_2_tok = self.login("bob2", "pass")

-        self.helper.join(room_id_3, user_2, tok=user_2_tok)

-

-        user_3 = self.register_user("bob3", "pass")

-        user_3_tok = self.login("bob3", "pass")

-        self.helper.join(room_id_3, user_3, tok=user_3_tok)

-

-        # Test different sort orders, with forward and reverse directions

-        _order_test("name", [room_id_1, room_id_2, room_id_3])

-        _order_test("name", [room_id_3, room_id_2, room_id_1], reverse=True)

-

-        _order_test("canonical_alias", [room_id_1, room_id_2, room_id_3])

-        _order_test("canonical_alias", [room_id_3, room_id_2, room_id_1], reverse=True)

-

-        _order_test("joined_members", [room_id_3, room_id_2, room_id_1])

-        _order_test("joined_members", [room_id_1, room_id_2, room_id_3], reverse=True)

-

-        _order_test("joined_local_members", [room_id_3, room_id_2, room_id_1])

-        _order_test(

-            "joined_local_members", [room_id_1, room_id_2, room_id_3], reverse=True

-        )

-

-        _order_test("version", [room_id_1, room_id_2, room_id_3])

-        _order_test("version", [room_id_1, room_id_2, room_id_3], reverse=True)

-

-        _order_test("creator", [room_id_1, room_id_2, room_id_3])

-        _order_test("creator", [room_id_1, room_id_2, room_id_3], reverse=True)

-

-        _order_test("encryption", [room_id_1, room_id_2, room_id_3])

-        _order_test("encryption", [room_id_1, room_id_2, room_id_3], reverse=True)

-

-        _order_test("federatable", [room_id_1, room_id_2, room_id_3])

-        _order_test("federatable", [room_id_1, room_id_2, room_id_3], reverse=True)

-

-        _order_test("public", [room_id_1, room_id_2, room_id_3])

-        # Different sort order of SQlite and PostreSQL

-        # _order_test("public", [room_id_3, room_id_2, room_id_1], reverse=True)

-

-        _order_test("join_rules", [room_id_1, room_id_2, room_id_3])

-        _order_test("join_rules", [room_id_1, room_id_2, room_id_3], reverse=True)

-

-        _order_test("guest_access", [room_id_1, room_id_2, room_id_3])

-        _order_test("guest_access", [room_id_1, room_id_2, room_id_3], reverse=True)

-

-        _order_test("history_visibility", [room_id_1, room_id_2, room_id_3])

-        _order_test(

-            "history_visibility", [room_id_1, room_id_2, room_id_3], reverse=True

-        )

-

-        _order_test("state_events", [room_id_3, room_id_2, room_id_1])

-        _order_test("state_events", [room_id_1, room_id_2, room_id_3], reverse=True)

-

-    def test_search_term(self):

-        """Test that searching for a room works correctly"""

-        # Create two test rooms

-        room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)

-        room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)

-

-        room_name_1 = "something"

-        room_name_2 = "else"

-

-        # Set the name for each room

-        self.helper.send_state(

-            room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok,

-        )

-        self.helper.send_state(

-            room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok,

-        )

-

-        def _search_test(

-            expected_room_id: Optional[str],

-            search_term: str,

-            expected_http_code: int = 200,

-        ):

-            """Search for a room and check that the returned room's id is a match

-

-            Args:

-                expected_room_id: The room_id expected to be returned by the API. Set

-                    to None to expect zero results for the search

-                search_term: The term to search for room names with

-                expected_http_code: The expected http code for the request

-            """

-            url = "/_synapse/admin/v1/rooms?search_term=%s" % (search_term,)

-            request, channel = self.make_request(

-                "GET", url.encode("ascii"), access_token=self.admin_user_tok,

-            )

-            self.render(request)

-            self.assertEqual(expected_http_code, channel.code, msg=channel.json_body)

-

-            if expected_http_code != 200:

-                return

-

-            # Check that rooms were returned

-            self.assertTrue("rooms" in channel.json_body)

-            rooms = channel.json_body["rooms"]

-

-            # Check that the expected number of rooms were returned

-            expected_room_count = 1 if expected_room_id else 0

-            self.assertEqual(len(rooms), expected_room_count)

-            self.assertEqual(channel.json_body["total_rooms"], expected_room_count)

-

-            # Check that the offset is correct

-            # We're not paginating, so should be 0

-            self.assertEqual(channel.json_body["offset"], 0)

-

-            # Check that there is no `prev_batch`

-            self.assertNotIn("prev_batch", channel.json_body)

-

-            # Check that there is no `next_batch`

-            self.assertNotIn("next_batch", channel.json_body)

-

-            if expected_room_id:

-                # Check that the first returned room id is correct

-                r = rooms[0]

-                self.assertEqual(expected_room_id, r["room_id"])

-

-        # Perform search tests

-        _search_test(room_id_1, "something")

-        _search_test(room_id_1, "thing")

-

-        _search_test(room_id_2, "else")

-        _search_test(room_id_2, "se")

-

-        _search_test(None, "foo")

-        _search_test(None, "bar")

-        _search_test(None, "", expected_http_code=400)

-

-    def test_single_room(self):

-        """Test that a single room can be requested correctly"""

-        # Create two test rooms

-        room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)

-        room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)

-

-        room_name_1 = "something"

-        room_name_2 = "else"

-

-        # Set the name for each room

-        self.helper.send_state(

-            room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok,

-        )

-        self.helper.send_state(

-            room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok,

-        )

-

-        url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,)

-        request, channel = self.make_request(

-            "GET", url.encode("ascii"), access_token=self.admin_user_tok,

-        )

-        self.render(request)

-        self.assertEqual(200, channel.code, msg=channel.json_body)

-

-        self.assertIn("room_id", channel.json_body)

-        self.assertIn("name", channel.json_body)

-        self.assertIn("canonical_alias", channel.json_body)

-        self.assertIn("joined_members", channel.json_body)

-        self.assertIn("joined_local_members", channel.json_body)

-        self.assertIn("version", channel.json_body)

-        self.assertIn("creator", channel.json_body)

-        self.assertIn("encryption", channel.json_body)

-        self.assertIn("federatable", channel.json_body)

-        self.assertIn("public", channel.json_body)

-        self.assertIn("join_rules", channel.json_body)

-        self.assertIn("guest_access", channel.json_body)

-        self.assertIn("history_visibility", channel.json_body)

-        self.assertIn("state_events", channel.json_body)

-

-        self.assertEqual(room_id_1, channel.json_body["room_id"])

-

-

-class JoinAliasRoomTestCase(unittest.HomeserverTestCase):

-

-    servlets = [

-        synapse.rest.admin.register_servlets,

-        room.register_servlets,

-        login.register_servlets,

-    ]

-

-    def prepare(self, reactor, clock, homeserver):

-        self.admin_user = self.register_user("admin", "pass", admin=True)

-        self.admin_user_tok = self.login("admin", "pass")

-

-        self.creator = self.register_user("creator", "test")

-        self.creator_tok = self.login("creator", "test")

-

-        self.second_user_id = self.register_user("second", "test")

-        self.second_tok = self.login("second", "test")

-

-        self.public_room_id = self.helper.create_room_as(

-            self.creator, tok=self.creator_tok, is_public=True

-        )

-        self.url = "/_synapse/admin/v1/join/{}".format(self.public_room_id)

-

-    def test_requester_is_no_admin(self):

-        """

-        If the user is not a server admin, an error 403 is returned.

-        """

-        body = json.dumps({"user_id": self.second_user_id})

-

-        request, channel = self.make_request(

-            "POST",

-            self.url,

-            content=body.encode(encoding="utf_8"),

-            access_token=self.second_tok,

-        )

-        self.render(request)

-

-        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])

-        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])

-

-    def test_invalid_parameter(self):

-        """

-        If a parameter is missing, return an error

-        """

-        body = json.dumps({"unknown_parameter": "@unknown:test"})

-

-        request, channel = self.make_request(

-            "POST",

-            self.url,

-            content=body.encode(encoding="utf_8"),

-            access_token=self.admin_user_tok,

-        )

-        self.render(request)

-

-        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])

-        self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])

-

-    def test_local_user_does_not_exist(self):

-        """

-        Tests that a lookup for a user that does not exist returns a 404

-        """

-        body = json.dumps({"user_id": "@unknown:test"})

-

-        request, channel = self.make_request(

-            "POST",

-            self.url,

-            content=body.encode(encoding="utf_8"),

-            access_token=self.admin_user_tok,

-        )

-        self.render(request)

-

-        self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])

-        self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])

-

-    def test_remote_user(self):

-        """

-        Check that only local user can join rooms.

-        """

-        body = json.dumps({"user_id": "@not:exist.bla"})

-

-        request, channel = self.make_request(

-            "POST",

-            self.url,

-            content=body.encode(encoding="utf_8"),

-            access_token=self.admin_user_tok,

-        )

-        self.render(request)

-

-        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])

-        self.assertEqual(

-            "This endpoint can only be used with local users",

-            channel.json_body["error"],

-        )

-

-    def test_room_does_not_exist(self):

-        """

-        Check that unknown rooms/server return error 404.

-        """

-        body = json.dumps({"user_id": self.second_user_id})

-        url = "/_synapse/admin/v1/join/!unknown:test"

-

-        request, channel = self.make_request(

-            "POST",

-            url,

-            content=body.encode(encoding="utf_8"),

-            access_token=self.admin_user_tok,

-        )

-        self.render(request)

-

-        self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])

-        self.assertEqual("No known servers", channel.json_body["error"])

-

-    def test_room_is_not_valid(self):

-        """

-        Check that invalid room names, return an error 400.

-        """

-        body = json.dumps({"user_id": self.second_user_id})

-        url = "/_synapse/admin/v1/join/invalidroom"

-

-        request, channel = self.make_request(

-            "POST",

-            url,

-            content=body.encode(encoding="utf_8"),

-            access_token=self.admin_user_tok,

-        )

-        self.render(request)

-

-        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])

-        self.assertEqual(

-            "invalidroom was not legal room ID or room alias",

-            channel.json_body["error"],

-        )

-

-    def test_join_public_room(self):

-        """

-        Test joining a local user to a public room with "JoinRules.PUBLIC"

-        """

-        body = json.dumps({"user_id": self.second_user_id})

-

-        request, channel = self.make_request(

-            "POST",

-            self.url,

-            content=body.encode(encoding="utf_8"),

-            access_token=self.admin_user_tok,

-        )

-        self.render(request)

-

-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])

-        self.assertEqual(self.public_room_id, channel.json_body["room_id"])

-

-        # Validate if user is a member of the room

-

-        request, channel = self.make_request(

-            "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,

-        )

-        self.render(request)

-        self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])

-        self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0])

-

-    def test_join_private_room_if_not_member(self):

-        """

-        Test joining a local user to a private room with "JoinRules.INVITE"

-        when server admin is not member of this room.

-        """

-        private_room_id = self.helper.create_room_as(

-            self.creator, tok=self.creator_tok, is_public=False

-        )

-        url = "/_synapse/admin/v1/join/{}".format(private_room_id)

-        body = json.dumps({"user_id": self.second_user_id})

-

-        request, channel = self.make_request(

-            "POST",

-            url,

-            content=body.encode(encoding="utf_8"),

-            access_token=self.admin_user_tok,

-        )

-        self.render(request)

-

-        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])

-        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])

-

-    def test_join_private_room_if_member(self):

-        """

-        Test joining a local user to a private room with "JoinRules.INVITE",

-        when server admin is member of this room.

-        """

-        private_room_id = self.helper.create_room_as(

-            self.creator, tok=self.creator_tok, is_public=False

-        )

-        self.helper.invite(

-            room=private_room_id,

-            src=self.creator,

-            targ=self.admin_user,

-            tok=self.creator_tok,

-        )

-        self.helper.join(

-            room=private_room_id, user=self.admin_user, tok=self.admin_user_tok

-        )

-

-        # Validate if server admin is a member of the room

-

-        request, channel = self.make_request(

-            "GET", "/_matrix/client/r0/joined_rooms", access_token=self.admin_user_tok,

-        )

-        self.render(request)

-        self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])

-        self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])

-

-        # Join user to room.

-

-        url = "/_synapse/admin/v1/join/{}".format(private_room_id)

-        body = json.dumps({"user_id": self.second_user_id})

-

-        request, channel = self.make_request(

-            "POST",

-            url,

-            content=body.encode(encoding="utf_8"),

-            access_token=self.admin_user_tok,

-        )

-        self.render(request)

-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])

-        self.assertEqual(private_room_id, channel.json_body["room_id"])

-

-        # Validate if user is a member of the room

-

-        request, channel = self.make_request(

-            "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,

-        )

-        self.render(request)

-        self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])

-        self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])

-

-    def test_join_private_room_if_owner(self):

-        """

-        Test joining a local user to a private room with "JoinRules.INVITE",

-        when server admin is owner of this room.

-        """

-        private_room_id = self.helper.create_room_as(

-            self.admin_user, tok=self.admin_user_tok, is_public=False

-        )

-        url = "/_synapse/admin/v1/join/{}".format(private_room_id)

-        body = json.dumps({"user_id": self.second_user_id})

-

-        request, channel = self.make_request(

-            "POST",

-            url,

-            content=body.encode(encoding="utf_8"),

-            access_token=self.admin_user_tok,

-        )

-        self.render(request)

-

-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])

-        self.assertEqual(private_room_id, channel.json_body["room_id"])

-

-        # Validate if user is a member of the room

-

-        request, channel = self.make_request(

-            "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,

-        )

-        self.render(request)

-        self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])

-        self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])

+# -*- coding: utf-8 -*-
+# Copyright 2020 Dirk Klimpel
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import urllib.parse
+from typing import List, Optional
+
+from mock import Mock
+
+import synapse.rest.admin
+from synapse.api.errors import Codes
+from synapse.rest.client.v1 import directory, events, login, room
+
+from tests import unittest
+
+"""Tests admin REST events for /rooms paths."""
+
+
+class ShutdownRoomTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        synapse.rest.admin.register_servlets_for_client_rest_resource,
+        login.register_servlets,
+        events.register_servlets,
+        room.register_servlets,
+        room.register_deprecated_servlets,
+    ]
+
+    def prepare(self, reactor, clock, hs):
+        self.event_creation_handler = hs.get_event_creation_handler()
+        hs.config.user_consent_version = "1"
+
+        consent_uri_builder = Mock()
+        consent_uri_builder.build_user_consent_uri.return_value = "http://example.com"
+        self.event_creation_handler._consent_uri_builder = consent_uri_builder
+
+        self.store = hs.get_datastore()
+
+        self.admin_user = self.register_user("admin", "pass", admin=True)
+        self.admin_user_tok = self.login("admin", "pass")
+
+        self.other_user = self.register_user("user", "pass")
+        self.other_user_token = self.login("user", "pass")
+
+        # Mark the admin user as having consented
+        self.get_success(self.store.user_set_consent_version(self.admin_user, "1"))
+
+    def test_shutdown_room_consent(self):
+        """Test that we can shutdown rooms with local users who have not
+        yet accepted the privacy policy. This used to fail when we tried to
+        force part the user from the old room.
+        """
+        self.event_creation_handler._block_events_without_consent_error = None
+
+        room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token)
+
+        # Assert one user in room
+        users_in_room = self.get_success(self.store.get_users_in_room(room_id))
+        self.assertEqual([self.other_user], users_in_room)
+
+        # Enable require consent to send events
+        self.event_creation_handler._block_events_without_consent_error = "Error"
+
+        # Assert that the user is getting consent error
+        self.helper.send(
+            room_id, body="foo", tok=self.other_user_token, expect_code=403
+        )
+
+        # Test that the admin can still send shutdown
+        url = "admin/shutdown_room/" + room_id
+        request, channel = self.make_request(
+            "POST",
+            url.encode("ascii"),
+            json.dumps({"new_room_user_id": self.admin_user}),
+            access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+        # Assert there is now no longer anyone in the room
+        users_in_room = self.get_success(self.store.get_users_in_room(room_id))
+        self.assertEqual([], users_in_room)
+
+    def test_shutdown_room_block_peek(self):
+        """Test that a world_readable room can no longer be peeked into after
+        it has been shut down.
+        """
+
+        self.event_creation_handler._block_events_without_consent_error = None
+
+        room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token)
+
+        # Enable world readable
+        url = "rooms/%s/state/m.room.history_visibility" % (room_id,)
+        request, channel = self.make_request(
+            "PUT",
+            url.encode("ascii"),
+            json.dumps({"history_visibility": "world_readable"}),
+            access_token=self.other_user_token,
+        )
+        self.render(request)
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+        # Test that the admin can still send shutdown
+        url = "admin/shutdown_room/" + room_id
+        request, channel = self.make_request(
+            "POST",
+            url.encode("ascii"),
+            json.dumps({"new_room_user_id": self.admin_user}),
+            access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+        # Assert we can no longer peek into the room
+        self._assert_peek(room_id, expect_code=403)
+
+    def _assert_peek(self, room_id, expect_code):
+        """Assert that the admin user can (or cannot) peek into the room.
+        """
+
+        url = "rooms/%s/initialSync" % (room_id,)
+        request, channel = self.make_request(
+            "GET", url.encode("ascii"), access_token=self.admin_user_tok
+        )
+        self.render(request)
+        self.assertEqual(
+            expect_code, int(channel.result["code"]), msg=channel.result["body"]
+        )
+
+        url = "events?timeout=0&room_id=" + room_id
+        request, channel = self.make_request(
+            "GET", url.encode("ascii"), access_token=self.admin_user_tok
+        )
+        self.render(request)
+        self.assertEqual(
+            expect_code, int(channel.result["code"]), msg=channel.result["body"]
+        )
+
+
+class DeleteRoomTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        synapse.rest.admin.register_servlets,
+        login.register_servlets,
+        events.register_servlets,
+        room.register_servlets,
+        room.register_deprecated_servlets,
+    ]
+
+    def prepare(self, reactor, clock, hs):
+        self.event_creation_handler = hs.get_event_creation_handler()
+        hs.config.user_consent_version = "1"
+
+        consent_uri_builder = Mock()
+        consent_uri_builder.build_user_consent_uri.return_value = "http://example.com"
+        self.event_creation_handler._consent_uri_builder = consent_uri_builder
+
+        self.store = hs.get_datastore()
+
+        self.admin_user = self.register_user("admin", "pass", admin=True)
+        self.admin_user_tok = self.login("admin", "pass")
+
+        self.other_user = self.register_user("user", "pass")
+        self.other_user_tok = self.login("user", "pass")
+
+        # Mark the admin user as having consented
+        self.get_success(self.store.user_set_consent_version(self.admin_user, "1"))
+
+        self.room_id = self.helper.create_room_as(
+            self.other_user, tok=self.other_user_tok
+        )
+        self.url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id
+
+    def test_requester_is_no_admin(self):
+        """
+        If the user is not a server admin, an error 403 is returned.
+        """
+
+        request, channel = self.make_request(
+            "POST", self.url, json.dumps({}), access_token=self.other_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+    def test_room_does_not_exist(self):
+        """
+        Check that unknown rooms/server return error 404.
+        """
+        url = "/_synapse/admin/v1/rooms/!unknown:test/delete"
+
+        request, channel = self.make_request(
+            "POST", url, json.dumps({}), access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+    def test_room_is_not_valid(self):
+        """
+        Check that invalid room names, return an error 400.
+        """
+        url = "/_synapse/admin/v1/rooms/invalidroom/delete"
+
+        request, channel = self.make_request(
+            "POST", url, json.dumps({}), access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(
+            "invalidroom is not a legal room ID", channel.json_body["error"],
+        )
+
+    def test_new_room_user_does_not_exist(self):
+        """
+        Tests that the user ID must be from local server but it does not have to exist.
+        """
+        body = json.dumps({"new_room_user_id": "@unknown:test"})
+
+        request, channel = self.make_request(
+            "POST",
+            self.url,
+            content=body.encode(encoding="utf_8"),
+            access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertIn("new_room_id", channel.json_body)
+        self.assertIn("kicked_users", channel.json_body)
+        self.assertIn("failed_to_kick_users", channel.json_body)
+        self.assertIn("local_aliases", channel.json_body)
+
+    def test_new_room_user_is_not_local(self):
+        """
+        Check that only local users can create new room to move members.
+        """
+        body = json.dumps({"new_room_user_id": "@not:exist.bla"})
+
+        request, channel = self.make_request(
+            "POST",
+            self.url,
+            content=body.encode(encoding="utf_8"),
+            access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(
+            "User must be our own: @not:exist.bla", channel.json_body["error"],
+        )
+
+    def test_block_is_not_bool(self):
+        """
+        If parameter `block` is not boolean, return an error
+        """
+        body = json.dumps({"block": "NotBool"})
+
+        request, channel = self.make_request(
+            "POST",
+            self.url,
+            content=body.encode(encoding="utf_8"),
+            access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
+
+    def test_purge_is_not_bool(self):
+        """
+        If parameter `purge` is not boolean, return an error
+        """
+        body = json.dumps({"purge": "NotBool"})
+
+        request, channel = self.make_request(
+            "POST",
+            self.url,
+            content=body.encode(encoding="utf_8"),
+            access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
+
+    def test_purge_room_and_block(self):
+        """Test to purge a room and block it.
+        Members will not be moved to a new room and will not receive a message.
+        """
+        # Test that room is not purged
+        with self.assertRaises(AssertionError):
+            self._is_purged(self.room_id)
+
+        # Test that room is not blocked
+        self._is_blocked(self.room_id, expect=False)
+
+        # Assert one user in room
+        self._is_member(room_id=self.room_id, user_id=self.other_user)
+
+        body = json.dumps({"block": True, "purge": True})
+
+        request, channel = self.make_request(
+            "POST",
+            self.url.encode("ascii"),
+            content=body.encode(encoding="utf_8"),
+            access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(None, channel.json_body["new_room_id"])
+        self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
+        self.assertIn("failed_to_kick_users", channel.json_body)
+        self.assertIn("local_aliases", channel.json_body)
+
+        self._is_purged(self.room_id)
+        self._is_blocked(self.room_id, expect=True)
+        self._has_no_members(self.room_id)
+
+    def test_purge_room_and_not_block(self):
+        """Test to purge a room and do not block it.
+        Members will not be moved to a new room and will not receive a message.
+        """
+        # Test that room is not purged
+        with self.assertRaises(AssertionError):
+            self._is_purged(self.room_id)
+
+        # Test that room is not blocked
+        self._is_blocked(self.room_id, expect=False)
+
+        # Assert one user in room
+        self._is_member(room_id=self.room_id, user_id=self.other_user)
+
+        body = json.dumps({"block": False, "purge": True})
+
+        request, channel = self.make_request(
+            "POST",
+            self.url.encode("ascii"),
+            content=body.encode(encoding="utf_8"),
+            access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(None, channel.json_body["new_room_id"])
+        self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
+        self.assertIn("failed_to_kick_users", channel.json_body)
+        self.assertIn("local_aliases", channel.json_body)
+
+        self._is_purged(self.room_id)
+        self._is_blocked(self.room_id, expect=False)
+        self._has_no_members(self.room_id)
+
+    def test_block_room_and_not_purge(self):
+        """Test to block a room without purging it.
+        Members will not be moved to a new room and will not receive a message.
+        The room will not be purged.
+        """
+        # Test that room is not purged
+        with self.assertRaises(AssertionError):
+            self._is_purged(self.room_id)
+
+        # Test that room is not blocked
+        self._is_blocked(self.room_id, expect=False)
+
+        # Assert one user in room
+        self._is_member(room_id=self.room_id, user_id=self.other_user)
+
+        body = json.dumps({"block": False, "purge": False})
+
+        request, channel = self.make_request(
+            "POST",
+            self.url.encode("ascii"),
+            content=body.encode(encoding="utf_8"),
+            access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(None, channel.json_body["new_room_id"])
+        self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
+        self.assertIn("failed_to_kick_users", channel.json_body)
+        self.assertIn("local_aliases", channel.json_body)
+
+        with self.assertRaises(AssertionError):
+            self._is_purged(self.room_id)
+        self._is_blocked(self.room_id, expect=False)
+        self._has_no_members(self.room_id)
+
+    def test_shutdown_room_consent(self):
+        """Test that we can shutdown rooms with local users who have not
+        yet accepted the privacy policy. This used to fail when we tried to
+        force part the user from the old room.
+        Members will be moved to a new room and will receive a message.
+        """
+        self.event_creation_handler._block_events_without_consent_error = None
+
+        # Assert one user in room
+        users_in_room = self.get_success(self.store.get_users_in_room(self.room_id))
+        self.assertEqual([self.other_user], users_in_room)
+
+        # Enable require consent to send events
+        self.event_creation_handler._block_events_without_consent_error = "Error"
+
+        # Assert that the user is getting consent error
+        self.helper.send(
+            self.room_id, body="foo", tok=self.other_user_tok, expect_code=403
+        )
+
+        # Test that room is not purged
+        with self.assertRaises(AssertionError):
+            self._is_purged(self.room_id)
+
+        # Assert one user in room
+        self._is_member(room_id=self.room_id, user_id=self.other_user)
+
+        # Test that the admin can still send shutdown
+        url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id
+        request, channel = self.make_request(
+            "POST",
+            url.encode("ascii"),
+            json.dumps({"new_room_user_id": self.admin_user}),
+            access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
+        self.assertIn("new_room_id", channel.json_body)
+        self.assertIn("failed_to_kick_users", channel.json_body)
+        self.assertIn("local_aliases", channel.json_body)
+
+        # Test that member has moved to new room
+        self._is_member(
+            room_id=channel.json_body["new_room_id"], user_id=self.other_user
+        )
+
+        self._is_purged(self.room_id)
+        self._has_no_members(self.room_id)
+
+    def test_shutdown_room_block_peek(self):
+        """Test that a world_readable room can no longer be peeked into after
+        it has been shut down.
+        Members will be moved to a new room and will receive a message.
+        """
+        self.event_creation_handler._block_events_without_consent_error = None
+
+        # Enable world readable
+        url = "rooms/%s/state/m.room.history_visibility" % (self.room_id,)
+        request, channel = self.make_request(
+            "PUT",
+            url.encode("ascii"),
+            json.dumps({"history_visibility": "world_readable"}),
+            access_token=self.other_user_tok,
+        )
+        self.render(request)
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+        # Test that room is not purged
+        with self.assertRaises(AssertionError):
+            self._is_purged(self.room_id)
+
+        # Assert one user in room
+        self._is_member(room_id=self.room_id, user_id=self.other_user)
+
+        # Test that the admin can still send shutdown
+        url = "/_synapse/admin/v1/rooms/%s/delete" % self.room_id
+        request, channel = self.make_request(
+            "POST",
+            url.encode("ascii"),
+            json.dumps({"new_room_user_id": self.admin_user}),
+            access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
+        self.assertIn("new_room_id", channel.json_body)
+        self.assertIn("failed_to_kick_users", channel.json_body)
+        self.assertIn("local_aliases", channel.json_body)
+
+        # Test that member has moved to new room
+        self._is_member(
+            room_id=channel.json_body["new_room_id"], user_id=self.other_user
+        )
+
+        self._is_purged(self.room_id)
+        self._has_no_members(self.room_id)
+
+        # Assert we can no longer peek into the room
+        self._assert_peek(self.room_id, expect_code=403)
+
+    def _is_blocked(self, room_id, expect=True):
+        """Assert that the room is blocked or not
+        """
+        d = self.store.is_room_blocked(room_id)
+        if expect:
+            self.assertTrue(self.get_success(d))
+        else:
+            self.assertIsNone(self.get_success(d))
+
+    def _has_no_members(self, room_id):
+        """Assert there is now no longer anyone in the room
+        """
+        users_in_room = self.get_success(self.store.get_users_in_room(room_id))
+        self.assertEqual([], users_in_room)
+
+    def _is_member(self, room_id, user_id):
+        """Test that user is member of the room
+        """
+        users_in_room = self.get_success(self.store.get_users_in_room(room_id))
+        self.assertIn(user_id, users_in_room)
+
+    def _is_purged(self, room_id):
+        """Test that the following tables have been purged of all rows related to the room.
+        """
+        for table in (
+            "current_state_events",
+            "event_backward_extremities",
+            "event_forward_extremities",
+            "event_json",
+            "event_push_actions",
+            "event_search",
+            "events",
+            "group_rooms",
+            "public_room_list_stream",
+            "receipts_graph",
+            "receipts_linearized",
+            "room_aliases",
+            "room_depth",
+            "room_memberships",
+            "room_stats_state",
+            "room_stats_current",
+            "room_stats_historical",
+            "room_stats_earliest_token",
+            "rooms",
+            "stream_ordering_to_exterm",
+            "users_in_public_rooms",
+            "users_who_share_private_rooms",
+            "appservice_room_list",
+            "e2e_room_keys",
+            "event_push_summary",
+            "pusher_throttle",
+            "group_summary_rooms",
+            "local_invites",
+            "room_account_data",
+            "room_tags",
+            # "state_groups",  # Current impl leaves orphaned state groups around.
+            "state_groups_state",
+        ):
+            count = self.get_success(
+                self.store.db_pool.simple_select_one_onecol(
+                    table=table,
+                    keyvalues={"room_id": room_id},
+                    retcol="COUNT(*)",
+                    desc="test_purge_room",
+                )
+            )
+
+            self.assertEqual(count, 0, msg="Rows not purged in {}".format(table))
+
+    def _assert_peek(self, room_id, expect_code):
+        """Assert that the admin user can (or cannot) peek into the room.
+        """
+
+        url = "rooms/%s/initialSync" % (room_id,)
+        request, channel = self.make_request(
+            "GET", url.encode("ascii"), access_token=self.admin_user_tok
+        )
+        self.render(request)
+        self.assertEqual(
+            expect_code, int(channel.result["code"]), msg=channel.result["body"]
+        )
+
+        url = "events?timeout=0&room_id=" + room_id
+        request, channel = self.make_request(
+            "GET", url.encode("ascii"), access_token=self.admin_user_tok
+        )
+        self.render(request)
+        self.assertEqual(
+            expect_code, int(channel.result["code"]), msg=channel.result["body"]
+        )
+
+
+class PurgeRoomTestCase(unittest.HomeserverTestCase):
+    """Test /purge_room admin API.
+    """
+
+    servlets = [
+        synapse.rest.admin.register_servlets,
+        login.register_servlets,
+        room.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, hs):
+        self.store = hs.get_datastore()
+
+        self.admin_user = self.register_user("admin", "pass", admin=True)
+        self.admin_user_tok = self.login("admin", "pass")
+
+    def test_purge_room(self):
+        room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+        # All users have to have left the room.
+        self.helper.leave(room_id, user=self.admin_user, tok=self.admin_user_tok)
+
+        url = "/_synapse/admin/v1/purge_room"
+        request, channel = self.make_request(
+            "POST",
+            url.encode("ascii"),
+            {"room_id": room_id},
+            access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+        # Test that the following tables have been purged of all rows related to the room.
+        for table in (
+            "current_state_events",
+            "event_backward_extremities",
+            "event_forward_extremities",
+            "event_json",
+            "event_push_actions",
+            "event_search",
+            "events",
+            "group_rooms",
+            "public_room_list_stream",
+            "receipts_graph",
+            "receipts_linearized",
+            "room_aliases",
+            "room_depth",
+            "room_memberships",
+            "room_stats_state",
+            "room_stats_current",
+            "room_stats_historical",
+            "room_stats_earliest_token",
+            "rooms",
+            "stream_ordering_to_exterm",
+            "users_in_public_rooms",
+            "users_who_share_private_rooms",
+            "appservice_room_list",
+            "e2e_room_keys",
+            "event_push_summary",
+            "pusher_throttle",
+            "group_summary_rooms",
+            "room_account_data",
+            "room_tags",
+            # "state_groups",  # Current impl leaves orphaned state groups around.
+            "state_groups_state",
+        ):
+            count = self.get_success(
+                self.store.db_pool.simple_select_one_onecol(
+                    table=table,
+                    keyvalues={"room_id": room_id},
+                    retcol="COUNT(*)",
+                    desc="test_purge_room",
+                )
+            )
+
+            self.assertEqual(count, 0, msg="Rows not purged in {}".format(table))
+
+
+class RoomTestCase(unittest.HomeserverTestCase):
+    """Test /room admin API.
+    """
+
+    servlets = [
+        synapse.rest.admin.register_servlets,
+        login.register_servlets,
+        room.register_servlets,
+        directory.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, hs):
+        self.store = hs.get_datastore()
+
+        # Create user
+        self.admin_user = self.register_user("admin", "pass", admin=True)
+        self.admin_user_tok = self.login("admin", "pass")
+
+    def test_list_rooms(self):
+        """Test that we can list rooms"""
+        # Create 3 test rooms
+        total_rooms = 3
+        room_ids = []
+        for x in range(total_rooms):
+            room_id = self.helper.create_room_as(
+                self.admin_user, tok=self.admin_user_tok
+            )
+            room_ids.append(room_id)
+
+        # Request the list of rooms
+        url = "/_synapse/admin/v1/rooms"
+        request, channel = self.make_request(
+            "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        # Check request completed successfully
+        self.assertEqual(200, int(channel.code), msg=channel.json_body)
+
+        # Check that response json body contains a "rooms" key
+        self.assertTrue(
+            "rooms" in channel.json_body,
+            msg="Response body does not " "contain a 'rooms' key",
+        )
+
+        # Check that 3 rooms were returned
+        self.assertEqual(3, len(channel.json_body["rooms"]), msg=channel.json_body)
+
+        # Check their room_ids match
+        returned_room_ids = [room["room_id"] for room in channel.json_body["rooms"]]
+        self.assertEqual(room_ids, returned_room_ids)
+
+        # Check that all fields are available
+        for r in channel.json_body["rooms"]:
+            self.assertIn("name", r)
+            self.assertIn("canonical_alias", r)
+            self.assertIn("joined_members", r)
+            self.assertIn("joined_local_members", r)
+            self.assertIn("version", r)
+            self.assertIn("creator", r)
+            self.assertIn("encryption", r)
+            self.assertIn("federatable", r)
+            self.assertIn("public", r)
+            self.assertIn("join_rules", r)
+            self.assertIn("guest_access", r)
+            self.assertIn("history_visibility", r)
+            self.assertIn("state_events", r)
+
+        # Check that the correct number of total rooms was returned
+        self.assertEqual(channel.json_body["total_rooms"], total_rooms)
+
+        # Check that the offset is correct
+        # Should be 0 as we aren't paginating
+        self.assertEqual(channel.json_body["offset"], 0)
+
+        # Check that the prev_batch parameter is not present
+        self.assertNotIn("prev_batch", channel.json_body)
+
+        # We shouldn't receive a next token here as there's no further rooms to show
+        self.assertNotIn("next_batch", channel.json_body)
+
+    def test_list_rooms_pagination(self):
+        """Test that we can get a full list of rooms through pagination"""
+        # Create 5 test rooms
+        total_rooms = 5
+        room_ids = []
+        for x in range(total_rooms):
+            room_id = self.helper.create_room_as(
+                self.admin_user, tok=self.admin_user_tok
+            )
+            room_ids.append(room_id)
+
+        # Set the name of the rooms so we get a consistent returned ordering
+        for idx, room_id in enumerate(room_ids):
+            self.helper.send_state(
+                room_id, "m.room.name", {"name": str(idx)}, tok=self.admin_user_tok,
+            )
+
+        # Request the list of rooms
+        returned_room_ids = []
+        start = 0
+        limit = 2
+
+        run_count = 0
+        should_repeat = True
+        while should_repeat:
+            run_count += 1
+
+            url = "/_synapse/admin/v1/rooms?from=%d&limit=%d&order_by=%s" % (
+                start,
+                limit,
+                "name",
+            )
+            request, channel = self.make_request(
+                "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+            )
+            self.render(request)
+            self.assertEqual(
+                200, int(channel.result["code"]), msg=channel.result["body"]
+            )
+
+            self.assertTrue("rooms" in channel.json_body)
+            for r in channel.json_body["rooms"]:
+                returned_room_ids.append(r["room_id"])
+
+            # Check that the correct number of total rooms was returned
+            self.assertEqual(channel.json_body["total_rooms"], total_rooms)
+
+            # Check that the offset is correct
+            # We're only getting 2 rooms each page, so should be 2 * last run_count
+            self.assertEqual(channel.json_body["offset"], 2 * (run_count - 1))
+
+            if run_count > 1:
+                # Check the value of prev_batch is correct
+                self.assertEqual(channel.json_body["prev_batch"], 2 * (run_count - 2))
+
+            if "next_batch" not in channel.json_body:
+                # We have reached the end of the list
+                should_repeat = False
+            else:
+                # Make another query with an updated start value
+                start = channel.json_body["next_batch"]
+
+        # We should've queried the endpoint 3 times
+        self.assertEqual(
+            run_count,
+            3,
+            msg="Should've queried 3 times for 5 rooms with limit 2 per query",
+        )
+
+        # Check that we received all of the room ids
+        self.assertEqual(room_ids, returned_room_ids)
+
+        url = "/_synapse/admin/v1/rooms?from=%d&limit=%d" % (start, limit)
+        request, channel = self.make_request(
+            "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+        )
+        self.render(request)
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+    def test_correct_room_attributes(self):
+        """Test the correct attributes for a room are returned"""
+        # Create a test room
+        room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+        test_alias = "#test:test"
+        test_room_name = "something"
+
+        # Have another user join the room
+        user_2 = self.register_user("user4", "pass")
+        user_tok_2 = self.login("user4", "pass")
+        self.helper.join(room_id, user_2, tok=user_tok_2)
+
+        # Create a new alias to this room
+        url = "/_matrix/client/r0/directory/room/%s" % (urllib.parse.quote(test_alias),)
+        request, channel = self.make_request(
+            "PUT",
+            url.encode("ascii"),
+            {"room_id": room_id},
+            access_token=self.admin_user_tok,
+        )
+        self.render(request)
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+        # Set this new alias as the canonical alias for this room
+        self.helper.send_state(
+            room_id,
+            "m.room.aliases",
+            {"aliases": [test_alias]},
+            tok=self.admin_user_tok,
+            state_key="test",
+        )
+        self.helper.send_state(
+            room_id,
+            "m.room.canonical_alias",
+            {"alias": test_alias},
+            tok=self.admin_user_tok,
+        )
+
+        # Set a name for the room
+        self.helper.send_state(
+            room_id, "m.room.name", {"name": test_room_name}, tok=self.admin_user_tok,
+        )
+
+        # Request the list of rooms
+        url = "/_synapse/admin/v1/rooms"
+        request, channel = self.make_request(
+            "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+        )
+        self.render(request)
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+        # Check that rooms were returned
+        self.assertTrue("rooms" in channel.json_body)
+        rooms = channel.json_body["rooms"]
+
+        # Check that only one room was returned
+        self.assertEqual(len(rooms), 1)
+
+        # And that the value of the total_rooms key was correct
+        self.assertEqual(channel.json_body["total_rooms"], 1)
+
+        # Check that the offset is correct
+        # We're not paginating, so should be 0
+        self.assertEqual(channel.json_body["offset"], 0)
+
+        # Check that there is no `prev_batch`
+        self.assertNotIn("prev_batch", channel.json_body)
+
+        # Check that there is no `next_batch`
+        self.assertNotIn("next_batch", channel.json_body)
+
+        # Check that all provided attributes are set
+        r = rooms[0]
+        self.assertEqual(room_id, r["room_id"])
+        self.assertEqual(test_room_name, r["name"])
+        self.assertEqual(test_alias, r["canonical_alias"])
+
+    def test_room_list_sort_order(self):
+        """Test room list sort ordering. alphabetical name versus number of members,
+        reversing the order, etc.
+        """
+
+        def _set_canonical_alias(room_id: str, test_alias: str, admin_user_tok: str):
+            # Create a new alias to this room
+            url = "/_matrix/client/r0/directory/room/%s" % (
+                urllib.parse.quote(test_alias),
+            )
+            request, channel = self.make_request(
+                "PUT",
+                url.encode("ascii"),
+                {"room_id": room_id},
+                access_token=admin_user_tok,
+            )
+            self.render(request)
+            self.assertEqual(
+                200, int(channel.result["code"]), msg=channel.result["body"]
+            )
+
+            # Set this new alias as the canonical alias for this room
+            self.helper.send_state(
+                room_id,
+                "m.room.aliases",
+                {"aliases": [test_alias]},
+                tok=admin_user_tok,
+                state_key="test",
+            )
+            self.helper.send_state(
+                room_id,
+                "m.room.canonical_alias",
+                {"alias": test_alias},
+                tok=admin_user_tok,
+            )
+
+        def _order_test(
+            order_type: str, expected_room_list: List[str], reverse: bool = False,
+        ):
+            """Request the list of rooms in a certain order. Assert that order is what
+            we expect
+
+            Args:
+                order_type: The type of ordering to give the server
+                expected_room_list: The list of room_ids in the order we expect to get
+                    back from the server
+            """
+            # Request the list of rooms in the given order
+            url = "/_synapse/admin/v1/rooms?order_by=%s" % (order_type,)
+            if reverse:
+                url += "&dir=b"
+            request, channel = self.make_request(
+                "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+            )
+            self.render(request)
+            self.assertEqual(200, channel.code, msg=channel.json_body)
+
+            # Check that rooms were returned
+            self.assertTrue("rooms" in channel.json_body)
+            rooms = channel.json_body["rooms"]
+
+            # Check for the correct total_rooms value
+            self.assertEqual(channel.json_body["total_rooms"], 3)
+
+            # Check that the offset is correct
+            # We're not paginating, so should be 0
+            self.assertEqual(channel.json_body["offset"], 0)
+
+            # Check that there is no `prev_batch`
+            self.assertNotIn("prev_batch", channel.json_body)
+
+            # Check that there is no `next_batch`
+            self.assertNotIn("next_batch", channel.json_body)
+
+            # Check that rooms were returned in alphabetical order
+            returned_order = [r["room_id"] for r in rooms]
+            self.assertListEqual(expected_room_list, returned_order)  # order is checked
+
+        # Create 3 test rooms
+        room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+        room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+        room_id_3 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+        # Set room names in alphabetical order. room 1 -> A, 2 -> B, 3 -> C
+        self.helper.send_state(
+            room_id_1, "m.room.name", {"name": "A"}, tok=self.admin_user_tok,
+        )
+        self.helper.send_state(
+            room_id_2, "m.room.name", {"name": "B"}, tok=self.admin_user_tok,
+        )
+        self.helper.send_state(
+            room_id_3, "m.room.name", {"name": "C"}, tok=self.admin_user_tok,
+        )
+
+        # Set room canonical room aliases
+        _set_canonical_alias(room_id_1, "#A_alias:test", self.admin_user_tok)
+        _set_canonical_alias(room_id_2, "#B_alias:test", self.admin_user_tok)
+        _set_canonical_alias(room_id_3, "#C_alias:test", self.admin_user_tok)
+
+        # Set room member size in the reverse order. room 1 -> 1 member, 2 -> 2, 3 -> 3
+        user_1 = self.register_user("bob1", "pass")
+        user_1_tok = self.login("bob1", "pass")
+        self.helper.join(room_id_2, user_1, tok=user_1_tok)
+
+        user_2 = self.register_user("bob2", "pass")
+        user_2_tok = self.login("bob2", "pass")
+        self.helper.join(room_id_3, user_2, tok=user_2_tok)
+
+        user_3 = self.register_user("bob3", "pass")
+        user_3_tok = self.login("bob3", "pass")
+        self.helper.join(room_id_3, user_3, tok=user_3_tok)
+
+        # Test different sort orders, with forward and reverse directions
+        _order_test("name", [room_id_1, room_id_2, room_id_3])
+        _order_test("name", [room_id_3, room_id_2, room_id_1], reverse=True)
+
+        _order_test("canonical_alias", [room_id_1, room_id_2, room_id_3])
+        _order_test("canonical_alias", [room_id_3, room_id_2, room_id_1], reverse=True)
+
+        _order_test("joined_members", [room_id_3, room_id_2, room_id_1])
+        _order_test("joined_members", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+        _order_test("joined_local_members", [room_id_3, room_id_2, room_id_1])
+        _order_test(
+            "joined_local_members", [room_id_1, room_id_2, room_id_3], reverse=True
+        )
+
+        _order_test("version", [room_id_1, room_id_2, room_id_3])
+        _order_test("version", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+        _order_test("creator", [room_id_1, room_id_2, room_id_3])
+        _order_test("creator", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+        _order_test("encryption", [room_id_1, room_id_2, room_id_3])
+        _order_test("encryption", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+        _order_test("federatable", [room_id_1, room_id_2, room_id_3])
+        _order_test("federatable", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+        _order_test("public", [room_id_1, room_id_2, room_id_3])
+        # Different sort order of SQlite and PostreSQL
+        # _order_test("public", [room_id_3, room_id_2, room_id_1], reverse=True)
+
+        _order_test("join_rules", [room_id_1, room_id_2, room_id_3])
+        _order_test("join_rules", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+        _order_test("guest_access", [room_id_1, room_id_2, room_id_3])
+        _order_test("guest_access", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+        _order_test("history_visibility", [room_id_1, room_id_2, room_id_3])
+        _order_test(
+            "history_visibility", [room_id_1, room_id_2, room_id_3], reverse=True
+        )
+
+        _order_test("state_events", [room_id_3, room_id_2, room_id_1])
+        _order_test("state_events", [room_id_1, room_id_2, room_id_3], reverse=True)
+
+    def test_search_term(self):
+        """Test that searching for a room works correctly"""
+        # Create two test rooms
+        room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+        room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+        room_name_1 = "something"
+        room_name_2 = "else"
+
+        # Set the name for each room
+        self.helper.send_state(
+            room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok,
+        )
+        self.helper.send_state(
+            room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok,
+        )
+
+        def _search_test(
+            expected_room_id: Optional[str],
+            search_term: str,
+            expected_http_code: int = 200,
+        ):
+            """Search for a room and check that the returned room's id is a match
+
+            Args:
+                expected_room_id: The room_id expected to be returned by the API. Set
+                    to None to expect zero results for the search
+                search_term: The term to search for room names with
+                expected_http_code: The expected http code for the request
+            """
+            url = "/_synapse/admin/v1/rooms?search_term=%s" % (search_term,)
+            request, channel = self.make_request(
+                "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+            )
+            self.render(request)
+            self.assertEqual(expected_http_code, channel.code, msg=channel.json_body)
+
+            if expected_http_code != 200:
+                return
+
+            # Check that rooms were returned
+            self.assertTrue("rooms" in channel.json_body)
+            rooms = channel.json_body["rooms"]
+
+            # Check that the expected number of rooms were returned
+            expected_room_count = 1 if expected_room_id else 0
+            self.assertEqual(len(rooms), expected_room_count)
+            self.assertEqual(channel.json_body["total_rooms"], expected_room_count)
+
+            # Check that the offset is correct
+            # We're not paginating, so should be 0
+            self.assertEqual(channel.json_body["offset"], 0)
+
+            # Check that there is no `prev_batch`
+            self.assertNotIn("prev_batch", channel.json_body)
+
+            # Check that there is no `next_batch`
+            self.assertNotIn("next_batch", channel.json_body)
+
+            if expected_room_id:
+                # Check that the first returned room id is correct
+                r = rooms[0]
+                self.assertEqual(expected_room_id, r["room_id"])
+
+        # Perform search tests
+        _search_test(room_id_1, "something")
+        _search_test(room_id_1, "thing")
+
+        _search_test(room_id_2, "else")
+        _search_test(room_id_2, "se")
+
+        _search_test(None, "foo")
+        _search_test(None, "bar")
+        _search_test(None, "", expected_http_code=400)
+
+    def test_single_room(self):
+        """Test that a single room can be requested correctly"""
+        # Create two test rooms
+        room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+        room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+        room_name_1 = "something"
+        room_name_2 = "else"
+
+        # Set the name for each room
+        self.helper.send_state(
+            room_id_1, "m.room.name", {"name": room_name_1}, tok=self.admin_user_tok,
+        )
+        self.helper.send_state(
+            room_id_2, "m.room.name", {"name": room_name_2}, tok=self.admin_user_tok,
+        )
+
+        url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,)
+        request, channel = self.make_request(
+            "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+        )
+        self.render(request)
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+
+        self.assertIn("room_id", channel.json_body)
+        self.assertIn("name", channel.json_body)
+        self.assertIn("canonical_alias", channel.json_body)
+        self.assertIn("joined_members", channel.json_body)
+        self.assertIn("joined_local_members", channel.json_body)
+        self.assertIn("version", channel.json_body)
+        self.assertIn("creator", channel.json_body)
+        self.assertIn("encryption", channel.json_body)
+        self.assertIn("federatable", channel.json_body)
+        self.assertIn("public", channel.json_body)
+        self.assertIn("join_rules", channel.json_body)
+        self.assertIn("guest_access", channel.json_body)
+        self.assertIn("history_visibility", channel.json_body)
+        self.assertIn("state_events", channel.json_body)
+
+        self.assertEqual(room_id_1, channel.json_body["room_id"])
+
+    def test_room_members(self):
+        """Test that room members can be requested correctly"""
+        # Create two test rooms
+        room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+        room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+        # Have another user join the room
+        user_1 = self.register_user("foo", "pass")
+        user_tok_1 = self.login("foo", "pass")
+        self.helper.join(room_id_1, user_1, tok=user_tok_1)
+
+        # Have another user join the room
+        user_2 = self.register_user("bar", "pass")
+        user_tok_2 = self.login("bar", "pass")
+        self.helper.join(room_id_1, user_2, tok=user_tok_2)
+        self.helper.join(room_id_2, user_2, tok=user_tok_2)
+
+        # Have another user join the room
+        user_3 = self.register_user("foobar", "pass")
+        user_tok_3 = self.login("foobar", "pass")
+        self.helper.join(room_id_2, user_3, tok=user_tok_3)
+
+        url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_1,)
+        request, channel = self.make_request(
+            "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+        )
+        self.render(request)
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+
+        self.assertCountEqual(
+            ["@admin:test", "@foo:test", "@bar:test"], channel.json_body["members"]
+        )
+        self.assertEqual(channel.json_body["total"], 3)
+
+        url = "/_synapse/admin/v1/rooms/%s/members" % (room_id_2,)
+        request, channel = self.make_request(
+            "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+        )
+        self.render(request)
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+
+        self.assertCountEqual(
+            ["@admin:test", "@bar:test", "@foobar:test"], channel.json_body["members"]
+        )
+        self.assertEqual(channel.json_body["total"], 3)
+
+
+class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
+
+    servlets = [
+        synapse.rest.admin.register_servlets,
+        room.register_servlets,
+        login.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, homeserver):
+        self.admin_user = self.register_user("admin", "pass", admin=True)
+        self.admin_user_tok = self.login("admin", "pass")
+
+        self.creator = self.register_user("creator", "test")
+        self.creator_tok = self.login("creator", "test")
+
+        self.second_user_id = self.register_user("second", "test")
+        self.second_tok = self.login("second", "test")
+
+        self.public_room_id = self.helper.create_room_as(
+            self.creator, tok=self.creator_tok, is_public=True
+        )
+        self.url = "/_synapse/admin/v1/join/{}".format(self.public_room_id)
+
+    def test_requester_is_no_admin(self):
+        """
+        If the user is not a server admin, an error 403 is returned.
+        """
+        body = json.dumps({"user_id": self.second_user_id})
+
+        request, channel = self.make_request(
+            "POST",
+            self.url,
+            content=body.encode(encoding="utf_8"),
+            access_token=self.second_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+    def test_invalid_parameter(self):
+        """
+        If a parameter is missing, return an error
+        """
+        body = json.dumps({"unknown_parameter": "@unknown:test"})
+
+        request, channel = self.make_request(
+            "POST",
+            self.url,
+            content=body.encode(encoding="utf_8"),
+            access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
+
+    def test_local_user_does_not_exist(self):
+        """
+        Tests that a lookup for a user that does not exist returns a 404
+        """
+        body = json.dumps({"user_id": "@unknown:test"})
+
+        request, channel = self.make_request(
+            "POST",
+            self.url,
+            content=body.encode(encoding="utf_8"),
+            access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+    def test_remote_user(self):
+        """
+        Check that only local user can join rooms.
+        """
+        body = json.dumps({"user_id": "@not:exist.bla"})
+
+        request, channel = self.make_request(
+            "POST",
+            self.url,
+            content=body.encode(encoding="utf_8"),
+            access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(
+            "This endpoint can only be used with local users",
+            channel.json_body["error"],
+        )
+
+    def test_room_does_not_exist(self):
+        """
+        Check that unknown rooms/server return error 404.
+        """
+        body = json.dumps({"user_id": self.second_user_id})
+        url = "/_synapse/admin/v1/join/!unknown:test"
+
+        request, channel = self.make_request(
+            "POST",
+            url,
+            content=body.encode(encoding="utf_8"),
+            access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual("No known servers", channel.json_body["error"])
+
+    def test_room_is_not_valid(self):
+        """
+        Check that invalid room names, return an error 400.
+        """
+        body = json.dumps({"user_id": self.second_user_id})
+        url = "/_synapse/admin/v1/join/invalidroom"
+
+        request, channel = self.make_request(
+            "POST",
+            url,
+            content=body.encode(encoding="utf_8"),
+            access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(
+            "invalidroom was not legal room ID or room alias",
+            channel.json_body["error"],
+        )
+
+    def test_join_public_room(self):
+        """
+        Test joining a local user to a public room with "JoinRules.PUBLIC"
+        """
+        body = json.dumps({"user_id": self.second_user_id})
+
+        request, channel = self.make_request(
+            "POST",
+            self.url,
+            content=body.encode(encoding="utf_8"),
+            access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(self.public_room_id, channel.json_body["room_id"])
+
+        # Validate if user is a member of the room
+
+        request, channel = self.make_request(
+            "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
+        )
+        self.render(request)
+        self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0])
+
+    def test_join_private_room_if_not_member(self):
+        """
+        Test joining a local user to a private room with "JoinRules.INVITE"
+        when server admin is not member of this room.
+        """
+        private_room_id = self.helper.create_room_as(
+            self.creator, tok=self.creator_tok, is_public=False
+        )
+        url = "/_synapse/admin/v1/join/{}".format(private_room_id)
+        body = json.dumps({"user_id": self.second_user_id})
+
+        request, channel = self.make_request(
+            "POST",
+            url,
+            content=body.encode(encoding="utf_8"),
+            access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+    def test_join_private_room_if_member(self):
+        """
+        Test joining a local user to a private room with "JoinRules.INVITE",
+        when server admin is member of this room.
+        """
+        private_room_id = self.helper.create_room_as(
+            self.creator, tok=self.creator_tok, is_public=False
+        )
+        self.helper.invite(
+            room=private_room_id,
+            src=self.creator,
+            targ=self.admin_user,
+            tok=self.creator_tok,
+        )
+        self.helper.join(
+            room=private_room_id, user=self.admin_user, tok=self.admin_user_tok
+        )
+
+        # Validate if server admin is a member of the room
+
+        request, channel = self.make_request(
+            "GET", "/_matrix/client/r0/joined_rooms", access_token=self.admin_user_tok,
+        )
+        self.render(request)
+        self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
+
+        # Join user to room.
+
+        url = "/_synapse/admin/v1/join/{}".format(private_room_id)
+        body = json.dumps({"user_id": self.second_user_id})
+
+        request, channel = self.make_request(
+            "POST",
+            url,
+            content=body.encode(encoding="utf_8"),
+            access_token=self.admin_user_tok,
+        )
+        self.render(request)
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(private_room_id, channel.json_body["room_id"])
+
+        # Validate if user is a member of the room
+
+        request, channel = self.make_request(
+            "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
+        )
+        self.render(request)
+        self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
+
+    def test_join_private_room_if_owner(self):
+        """
+        Test joining a local user to a private room with "JoinRules.INVITE",
+        when server admin is owner of this room.
+        """
+        private_room_id = self.helper.create_room_as(
+            self.admin_user, tok=self.admin_user_tok, is_public=False
+        )
+        url = "/_synapse/admin/v1/join/{}".format(private_room_id)
+        body = json.dumps({"user_id": self.second_user_id})
+
+        request, channel = self.make_request(
+            "POST",
+            url,
+            content=body.encode(encoding="utf_8"),
+            access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(private_room_id, channel.json_body["room_id"])
+
+        # Validate if user is a member of the room
+
+        request, channel = self.make_request(
+            "GET", "/_matrix/client/r0/joined_rooms", access_token=self.second_tok,
+        )
+        self.render(request)
+        self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index cca5f548e6..160c630235 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -27,6 +27,7 @@ from synapse.rest.client.v1 import login
 from synapse.rest.client.v2_alpha import sync
 
 from tests import unittest
+from tests.test_utils import make_awaitable
 from tests.unittest import override_config
 
 
@@ -335,7 +336,9 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
         store = self.hs.get_datastore()
 
         # Set monthly active users to the limit
-        store.get_monthly_active_count = Mock(return_value=self.hs.config.max_mau_value)
+        store.get_monthly_active_count = Mock(
+            side_effect=lambda: make_awaitable(self.hs.config.max_mau_value)
+        )
         # Check that the blocking of monthly active users is working as expected
         # The registration of a new user fails due to the limit
         self.get_failure(
@@ -588,7 +591,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
 
         # Set monthly active users to the limit
         self.store.get_monthly_active_count = Mock(
-            return_value=self.hs.config.max_mau_value
+            side_effect=lambda: make_awaitable(self.hs.config.max_mau_value)
         )
         # Check that the blocking of monthly active users is working as expected
         # The registration of a new user fails due to the limit
@@ -628,7 +631,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
 
         # Set monthly active users to the limit
         self.store.get_monthly_active_count = Mock(
-            return_value=self.hs.config.max_mau_value
+            side_effect=lambda: make_awaitable(self.hs.config.max_mau_value)
         )
         # Check that the blocking of monthly active users is working as expected
         # The registration of a new user fails due to the limit
@@ -857,6 +860,53 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual("@user:test", channel.json_body["name"])
         self.assertEqual(True, channel.json_body["deactivated"])
 
+    def test_reactivate_user(self):
+        """
+        Test reactivating another user.
+        """
+
+        # Deactivate the user.
+        request, channel = self.make_request(
+            "PUT",
+            self.url_other_user,
+            access_token=self.admin_user_tok,
+            content=json.dumps({"deactivated": True}).encode(encoding="utf_8"),
+        )
+        self.render(request)
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+        # Attempt to reactivate the user (without a password).
+        request, channel = self.make_request(
+            "PUT",
+            self.url_other_user,
+            access_token=self.admin_user_tok,
+            content=json.dumps({"deactivated": False}).encode(encoding="utf_8"),
+        )
+        self.render(request)
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+
+        # Reactivate the user.
+        request, channel = self.make_request(
+            "PUT",
+            self.url_other_user,
+            access_token=self.admin_user_tok,
+            content=json.dumps({"deactivated": False, "password": "foo"}).encode(
+                encoding="utf_8"
+            ),
+        )
+        self.render(request)
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+        # Get user
+        request, channel = self.make_request(
+            "GET", self.url_other_user, access_token=self.admin_user_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual("@user:test", channel.json_body["name"])
+        self.assertEqual(False, channel.json_body["deactivated"])
+
     def test_set_user_as_admin(self):
         """
         Test setting the admin flag on a user.
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index 95475bb651..7d3773ff78 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -45,50 +45,63 @@ class RetentionTestCase(unittest.HomeserverTestCase):
         }
 
         self.hs = self.setup_test_homeserver(config=config)
+
         return self.hs
 
     def prepare(self, reactor, clock, homeserver):
         self.user_id = self.register_user("user", "password")
         self.token = self.login("user", "password")
 
-    def test_retention_state_event(self):
-        """Tests that the server configuration can limit the values a user can set to the
-        room's retention policy.
+        self.store = self.hs.get_datastore()
+        self.serializer = self.hs.get_event_client_serializer()
+        self.clock = self.hs.get_clock()
+
+    def test_retention_event_purged_with_state_event(self):
+        """Tests that expired events are correctly purged when the room's retention policy
+        is defined by a state event.
         """
         room_id = self.helper.create_room_as(self.user_id, tok=self.token)
 
+        # Set the room's retention period to 2 days.
+        lifetime = one_day_ms * 2
         self.helper.send_state(
             room_id=room_id,
             event_type=EventTypes.Retention,
-            body={"max_lifetime": one_day_ms * 4},
+            body={"max_lifetime": lifetime},
             tok=self.token,
-            expect_code=400,
         )
 
+        self._test_retention_event_purged(room_id, one_day_ms * 1.5)
+
+    def test_retention_event_purged_with_state_event_outside_allowed(self):
+        """Tests that the server configuration can override the policy for a room when
+        running the purge jobs.
+        """
+        room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+
+        # Set a max_lifetime higher than the maximum allowed value.
         self.helper.send_state(
             room_id=room_id,
             event_type=EventTypes.Retention,
-            body={"max_lifetime": one_hour_ms},
+            body={"max_lifetime": one_day_ms * 4},
             tok=self.token,
-            expect_code=400,
         )
 
-    def test_retention_event_purged_with_state_event(self):
-        """Tests that expired events are correctly purged when the room's retention policy
-        is defined by a state event.
-        """
-        room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+        # Check that the event is purged after waiting for the maximum allowed duration
+        # instead of the one specified in the room's policy.
+        self._test_retention_event_purged(room_id, one_day_ms * 1.5)
 
-        # Set the room's retention period to 2 days.
-        lifetime = one_day_ms * 2
+        # Set a max_lifetime lower than the minimum allowed value.
         self.helper.send_state(
             room_id=room_id,
             event_type=EventTypes.Retention,
-            body={"max_lifetime": lifetime},
+            body={"max_lifetime": one_hour_ms},
             tok=self.token,
         )
 
-        self._test_retention_event_purged(room_id, one_day_ms * 1.5)
+        # Check that the event is purged after waiting for the minimum allowed duration
+        # instead of the one specified in the room's policy.
+        self._test_retention_event_purged(room_id, one_day_ms * 0.5)
 
     def test_retention_event_purged_without_state_event(self):
         """Tests that expired events are correctly purged when the room's retention policy
@@ -126,7 +139,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
 
         events.append(self.get_success(store.get_event(valid_event_id)))
 
-        # Advance the time by anothe 2 days. After this, the first event should be
+        # Advance the time by another 2 days. After this, the first event should be
         # outdated but not the second one.
         self.reactor.advance(one_day_ms * 2 / 1000)
 
@@ -140,11 +153,33 @@ class RetentionTestCase(unittest.HomeserverTestCase):
         # That event should be the second, not outdated event.
         self.assertEqual(filtered_events[0].event_id, valid_event_id, filtered_events)
 
-    def _test_retention_event_purged(self, room_id, increment):
+    def _test_retention_event_purged(self, room_id: str, increment: float):
+        """Run the following test scenario to test the message retention policy support:
+
+        1. Send event 1
+        2. Increment time by `increment`
+        3. Send event 2
+        4. Increment time by `increment`
+        5. Check that event 1 has been purged
+        6. Check that event 2 has not been purged
+        7. Check that state events that were sent before event 1 aren't purged.
+        The main reason for sending a second event is because currently Synapse won't
+        purge the latest message in a room because it would otherwise result in a lack of
+        forward extremities for this room. It's also a good thing to ensure the purge jobs
+        aren't too greedy and purge messages they shouldn't.
+
+        Args:
+            room_id: The ID of the room to test retention in.
+            increment: The number of milliseconds to advance the clock each time. Must be
+                defined so that events in the room aren't purged if they are `increment`
+                old but are purged if they are `increment * 2` old.
+        """
         # Get the create event to, later, check that we can still access it.
         message_handler = self.hs.get_message_handler()
         create_event = self.get_success(
-            message_handler.get_room_data(self.user_id, room_id, EventTypes.Create)
+            message_handler.get_room_data(
+                self.user_id, room_id, EventTypes.Create, state_key=""
+            )
         )
 
         # Send a first event to the room. This is the event we'll want to be purged at the
@@ -154,7 +189,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
         expired_event_id = resp.get("event_id")
 
         # Check that we can retrieve the event.
-        expired_event = self.get_event(room_id, expired_event_id)
+        expired_event = self.get_event(expired_event_id)
         self.assertEqual(
             expired_event.get("content", {}).get("body"), "1", expired_event
         )
@@ -172,26 +207,31 @@ class RetentionTestCase(unittest.HomeserverTestCase):
         # one should still be kept.
         self.reactor.advance(increment / 1000)
 
-        # Check that the event has been purged from the database.
-        self.get_event(room_id, expired_event_id, expected_code=404)
+        # Check that the first event has been purged from the database, i.e. that we
+        # can't retrieve it anymore, because it has expired.
+        self.get_event(expired_event_id, expect_none=True)
 
-        # Check that the event that hasn't been purged can still be retrieved.
-        valid_event = self.get_event(room_id, valid_event_id)
+        # Check that the event that hasn't expired can still be retrieved.
+        valid_event = self.get_event(valid_event_id)
         self.assertEqual(valid_event.get("content", {}).get("body"), "2", valid_event)
 
         # Check that we can still access state events that were sent before the event that
         # has been purged.
         self.get_event(room_id, create_event.event_id)
 
-    def get_event(self, room_id, event_id, expected_code=200):
-        url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id)
+    def get_event(self, event_id, expect_none=False):
+        event = self.get_success(self.store.get_event(event_id, allow_none=True))
 
-        request, channel = self.make_request("GET", url, access_token=self.token)
-        self.render(request)
+        if expect_none:
+            self.assertIsNone(event)
+            return {}
 
-        self.assertEqual(channel.code, expected_code, channel.result)
+        self.assertIsNotNone(event)
 
-        return channel.json_body
+        time_now = self.clock.time_msec()
+        serialized = self.get_success(self.serializer.serialize_event(event, time_now))
+
+        return serialized
 
 
 class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py
new file mode 100644
index 0000000000..dfe4bf7762
--- /dev/null
+++ b/tests/rest/client/test_shadow_banned.py
@@ -0,0 +1,312 @@
+#  Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+#  Licensed under the Apache License, Version 2.0 (the "License");
+#  you may not use this file except in compliance with the License.
+#  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+#  limitations under the License.
+
+from mock import Mock, patch
+
+import synapse.rest.admin
+from synapse.api.constants import EventTypes
+from synapse.rest.client.v1 import directory, login, profile, room
+from synapse.rest.client.v2_alpha import room_upgrade_rest_servlet
+
+from tests import unittest
+
+
+class _ShadowBannedBase(unittest.HomeserverTestCase):
+    def prepare(self, reactor, clock, homeserver):
+        # Create two users, one of which is shadow-banned.
+        self.banned_user_id = self.register_user("banned", "test")
+        self.banned_access_token = self.login("banned", "test")
+
+        self.store = self.hs.get_datastore()
+
+        self.get_success(
+            self.store.db_pool.simple_update(
+                table="users",
+                keyvalues={"name": self.banned_user_id},
+                updatevalues={"shadow_banned": True},
+                desc="shadow_ban",
+            )
+        )
+
+        self.other_user_id = self.register_user("otheruser", "pass")
+        self.other_access_token = self.login("otheruser", "pass")
+
+
+# To avoid the tests timing out don't add a delay to "annoy the requester".
+@patch("random.randint", new=lambda a, b: 0)
+class RoomTestCase(_ShadowBannedBase):
+    servlets = [
+        synapse.rest.admin.register_servlets_for_client_rest_resource,
+        directory.register_servlets,
+        login.register_servlets,
+        room.register_servlets,
+        room_upgrade_rest_servlet.register_servlets,
+    ]
+
+    def test_invite(self):
+        """Invites from shadow-banned users don't actually get sent."""
+
+        # The create works fine.
+        room_id = self.helper.create_room_as(
+            self.banned_user_id, tok=self.banned_access_token
+        )
+
+        # Inviting the user completes successfully.
+        self.helper.invite(
+            room=room_id,
+            src=self.banned_user_id,
+            tok=self.banned_access_token,
+            targ=self.other_user_id,
+        )
+
+        # But the user wasn't actually invited.
+        invited_rooms = self.get_success(
+            self.store.get_invited_rooms_for_local_user(self.other_user_id)
+        )
+        self.assertEqual(invited_rooms, [])
+
+    def test_invite_3pid(self):
+        """Ensure that a 3PID invite does not attempt to contact the identity server."""
+        identity_handler = self.hs.get_handlers().identity_handler
+        identity_handler.lookup_3pid = Mock(
+            side_effect=AssertionError("This should not get called")
+        )
+
+        # The create works fine.
+        room_id = self.helper.create_room_as(
+            self.banned_user_id, tok=self.banned_access_token
+        )
+
+        # Inviting the user completes successfully.
+        request, channel = self.make_request(
+            "POST",
+            "/rooms/%s/invite" % (room_id,),
+            {"id_server": "test", "medium": "email", "address": "test@test.test"},
+            access_token=self.banned_access_token,
+        )
+        self.render(request)
+        self.assertEquals(200, channel.code, channel.result)
+
+        # This should have raised an error earlier, but double check this wasn't called.
+        identity_handler.lookup_3pid.assert_not_called()
+
+    def test_create_room(self):
+        """Invitations during a room creation should be discarded, but the room still gets created."""
+        # The room creation is successful.
+        request, channel = self.make_request(
+            "POST",
+            "/_matrix/client/r0/createRoom",
+            {"visibility": "public", "invite": [self.other_user_id]},
+            access_token=self.banned_access_token,
+        )
+        self.render(request)
+        self.assertEquals(200, channel.code, channel.result)
+        room_id = channel.json_body["room_id"]
+
+        # But the user wasn't actually invited.
+        invited_rooms = self.get_success(
+            self.store.get_invited_rooms_for_local_user(self.other_user_id)
+        )
+        self.assertEqual(invited_rooms, [])
+
+        # Since a real room was created, the other user should be able to join it.
+        self.helper.join(room_id, self.other_user_id, tok=self.other_access_token)
+
+        # Both users should be in the room.
+        users = self.get_success(self.store.get_users_in_room(room_id))
+        self.assertCountEqual(users, ["@banned:test", "@otheruser:test"])
+
+    def test_message(self):
+        """Messages from shadow-banned users don't actually get sent."""
+
+        room_id = self.helper.create_room_as(
+            self.other_user_id, tok=self.other_access_token
+        )
+
+        # The user should be in the room.
+        self.helper.join(room_id, self.banned_user_id, tok=self.banned_access_token)
+
+        # Sending a message should complete successfully.
+        result = self.helper.send_event(
+            room_id=room_id,
+            type=EventTypes.Message,
+            content={"msgtype": "m.text", "body": "with right label"},
+            tok=self.banned_access_token,
+        )
+        self.assertIn("event_id", result)
+        event_id = result["event_id"]
+
+        latest_events = self.get_success(
+            self.store.get_latest_event_ids_in_room(room_id)
+        )
+        self.assertNotIn(event_id, latest_events)
+
+    def test_upgrade(self):
+        """A room upgrade should fail, but look like it succeeded."""
+
+        # The create works fine.
+        room_id = self.helper.create_room_as(
+            self.banned_user_id, tok=self.banned_access_token
+        )
+
+        request, channel = self.make_request(
+            "POST",
+            "/_matrix/client/r0/rooms/%s/upgrade" % (room_id,),
+            {"new_version": "6"},
+            access_token=self.banned_access_token,
+        )
+        self.render(request)
+        self.assertEquals(200, channel.code, channel.result)
+        # A new room_id should be returned.
+        self.assertIn("replacement_room", channel.json_body)
+
+        new_room_id = channel.json_body["replacement_room"]
+
+        # It doesn't really matter what API we use here, we just want to assert
+        # that the room doesn't exist.
+        summary = self.get_success(self.store.get_room_summary(new_room_id))
+        # The summary should be empty since the room doesn't exist.
+        self.assertEqual(summary, {})
+
+    def test_typing(self):
+        """Typing notifications should not be propagated into the room."""
+        # The create works fine.
+        room_id = self.helper.create_room_as(
+            self.banned_user_id, tok=self.banned_access_token
+        )
+
+        request, channel = self.make_request(
+            "PUT",
+            "/rooms/%s/typing/%s" % (room_id, self.banned_user_id),
+            {"typing": True, "timeout": 30000},
+            access_token=self.banned_access_token,
+        )
+        self.render(request)
+        self.assertEquals(200, channel.code)
+
+        # There should be no typing events.
+        event_source = self.hs.get_event_sources().sources["typing"]
+        self.assertEquals(event_source.get_current_key(), 0)
+
+        # The other user can join and send typing events.
+        self.helper.join(room_id, self.other_user_id, tok=self.other_access_token)
+
+        request, channel = self.make_request(
+            "PUT",
+            "/rooms/%s/typing/%s" % (room_id, self.other_user_id),
+            {"typing": True, "timeout": 30000},
+            access_token=self.other_access_token,
+        )
+        self.render(request)
+        self.assertEquals(200, channel.code)
+
+        # These appear in the room.
+        self.assertEquals(event_source.get_current_key(), 1)
+        events = self.get_success(
+            event_source.get_new_events(from_key=0, room_ids=[room_id])
+        )
+        self.assertEquals(
+            events[0],
+            [
+                {
+                    "type": "m.typing",
+                    "room_id": room_id,
+                    "content": {"user_ids": [self.other_user_id]},
+                }
+            ],
+        )
+
+
+# To avoid the tests timing out don't add a delay to "annoy the requester".
+@patch("random.randint", new=lambda a, b: 0)
+class ProfileTestCase(_ShadowBannedBase):
+    servlets = [
+        synapse.rest.admin.register_servlets_for_client_rest_resource,
+        login.register_servlets,
+        profile.register_servlets,
+        room.register_servlets,
+    ]
+
+    def test_displayname(self):
+        """Profile changes should succeed, but don't end up in a room."""
+        original_display_name = "banned"
+        new_display_name = "new name"
+
+        # Join a room.
+        room_id = self.helper.create_room_as(
+            self.banned_user_id, tok=self.banned_access_token
+        )
+
+        # The update should succeed.
+        request, channel = self.make_request(
+            "PUT",
+            "/_matrix/client/r0/profile/%s/displayname" % (self.banned_user_id,),
+            {"displayname": new_display_name},
+            access_token=self.banned_access_token,
+        )
+        self.render(request)
+        self.assertEquals(200, channel.code, channel.result)
+        self.assertEqual(channel.json_body, {})
+
+        # The user's display name should be updated.
+        request, channel = self.make_request(
+            "GET", "/profile/%s/displayname" % (self.banned_user_id,)
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200, channel.result)
+        self.assertEqual(channel.json_body["displayname"], new_display_name)
+
+        # But the display name in the room should not be.
+        message_handler = self.hs.get_message_handler()
+        event = self.get_success(
+            message_handler.get_room_data(
+                self.banned_user_id, room_id, "m.room.member", self.banned_user_id,
+            )
+        )
+        self.assertEqual(
+            event.content, {"membership": "join", "displayname": original_display_name}
+        )
+
+    def test_room_displayname(self):
+        """Changes to state events for a room should be processed, but not end up in the room."""
+        original_display_name = "banned"
+        new_display_name = "new name"
+
+        # Join a room.
+        room_id = self.helper.create_room_as(
+            self.banned_user_id, tok=self.banned_access_token
+        )
+
+        # The update should succeed.
+        request, channel = self.make_request(
+            "PUT",
+            "/_matrix/client/r0/rooms/%s/state/m.room.member/%s"
+            % (room_id, self.banned_user_id),
+            {"membership": "join", "displayname": new_display_name},
+            access_token=self.banned_access_token,
+        )
+        self.render(request)
+        self.assertEquals(200, channel.code, channel.result)
+        self.assertIn("event_id", channel.json_body)
+
+        # The display name in the room should not be changed.
+        message_handler = self.hs.get_message_handler()
+        event = self.get_success(
+            message_handler.get_room_data(
+                self.banned_user_id, room_id, "m.room.member", self.banned_user_id,
+            )
+        )
+        self.assertEqual(
+            event.content, {"membership": "join", "displayname": original_display_name}
+        )
diff --git a/tests/rest/client/third_party_rules.py b/tests/rest/client/third_party_rules.py
index 7167fc56b6..8c24add530 100644
--- a/tests/rest/client/third_party_rules.py
+++ b/tests/rest/client/third_party_rules.py
@@ -19,7 +19,7 @@ from synapse.rest.client.v1 import login, room
 from tests import unittest
 
 
-class ThirdPartyRulesTestModule(object):
+class ThirdPartyRulesTestModule:
     def __init__(self, config):
         pass
 
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 9033f09fd2..2668662c9e 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -62,8 +62,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
                 "identifier": {"type": "m.id.user", "user": "kermit" + str(i)},
                 "password": "monkey",
             }
-            request_data = json.dumps(params)
-            request, channel = self.make_request(b"POST", LOGIN_URL, request_data)
+            request, channel = self.make_request(b"POST", LOGIN_URL, params)
             self.render(request)
 
             if i == 5:
@@ -76,14 +75,13 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
         # than 1min.
         self.assertTrue(retry_after_ms < 6000)
 
-        self.reactor.advance(retry_after_ms / 1000.0)
+        self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
 
         params = {
             "type": "m.login.password",
             "identifier": {"type": "m.id.user", "user": "kermit" + str(i)},
             "password": "monkey",
         }
-        request_data = json.dumps(params)
         request, channel = self.make_request(b"POST", LOGIN_URL, params)
         self.render(request)
 
@@ -111,8 +109,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
                 "identifier": {"type": "m.id.user", "user": "kermit"},
                 "password": "monkey",
             }
-            request_data = json.dumps(params)
-            request, channel = self.make_request(b"POST", LOGIN_URL, request_data)
+            request, channel = self.make_request(b"POST", LOGIN_URL, params)
             self.render(request)
 
             if i == 5:
@@ -132,7 +129,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
             "identifier": {"type": "m.id.user", "user": "kermit"},
             "password": "monkey",
         }
-        request_data = json.dumps(params)
         request, channel = self.make_request(b"POST", LOGIN_URL, params)
         self.render(request)
 
@@ -160,8 +156,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
                 "identifier": {"type": "m.id.user", "user": "kermit"},
                 "password": "notamonkey",
             }
-            request_data = json.dumps(params)
-            request, channel = self.make_request(b"POST", LOGIN_URL, request_data)
+            request, channel = self.make_request(b"POST", LOGIN_URL, params)
             self.render(request)
 
             if i == 5:
@@ -174,14 +169,13 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
         # than 1min.
         self.assertTrue(retry_after_ms < 6000)
 
-        self.reactor.advance(retry_after_ms / 1000.0)
+        self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
 
         params = {
             "type": "m.login.password",
             "identifier": {"type": "m.id.user", "user": "kermit"},
             "password": "notamonkey",
         }
-        request_data = json.dumps(params)
         request, channel = self.make_request(b"POST", LOGIN_URL, params)
         self.render(request)
 
@@ -398,7 +392,7 @@ class CASTestCase(unittest.HomeserverTestCase):
                 </cas:serviceResponse>
             """
                 % cas_user_id
-            )
+            ).encode("utf-8")
 
         mocked_http_client = Mock(spec=["get_raw"])
         mocked_http_client.get_raw.side_effect = get_raw
@@ -514,19 +508,22 @@ class JWTTestCase(unittest.HomeserverTestCase):
     ]
 
     jwt_secret = "secret"
+    jwt_algorithm = "HS256"
 
     def make_homeserver(self, reactor, clock):
         self.hs = self.setup_test_homeserver()
         self.hs.config.jwt_enabled = True
         self.hs.config.jwt_secret = self.jwt_secret
-        self.hs.config.jwt_algorithm = "HS256"
+        self.hs.config.jwt_algorithm = self.jwt_algorithm
         return self.hs
 
     def jwt_encode(self, token, secret=jwt_secret):
-        return jwt.encode(token, secret, "HS256").decode("ascii")
+        return jwt.encode(token, secret, self.jwt_algorithm).decode("ascii")
 
     def jwt_login(self, *args):
-        params = json.dumps({"type": "m.login.jwt", "token": self.jwt_encode(*args)})
+        params = json.dumps(
+            {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
+        )
         request, channel = self.make_request(b"POST", LOGIN_URL, params)
         self.render(request)
         return channel
@@ -544,35 +541,126 @@ class JWTTestCase(unittest.HomeserverTestCase):
 
     def test_login_jwt_invalid_signature(self):
         channel = self.jwt_login({"sub": "frog"}, "notsecret")
-        self.assertEqual(channel.result["code"], b"401", channel.result)
-        self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
-        self.assertEqual(channel.json_body["error"], "Invalid JWT")
+        self.assertEqual(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+        self.assertEqual(
+            channel.json_body["error"],
+            "JWT validation failed: Signature verification failed",
+        )
 
     def test_login_jwt_expired(self):
         channel = self.jwt_login({"sub": "frog", "exp": 864000})
-        self.assertEqual(channel.result["code"], b"401", channel.result)
-        self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
-        self.assertEqual(channel.json_body["error"], "JWT expired")
+        self.assertEqual(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+        self.assertEqual(
+            channel.json_body["error"], "JWT validation failed: Signature has expired"
+        )
 
     def test_login_jwt_not_before(self):
         now = int(time.time())
         channel = self.jwt_login({"sub": "frog", "nbf": now + 3600})
-        self.assertEqual(channel.result["code"], b"401", channel.result)
-        self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
-        self.assertEqual(channel.json_body["error"], "Invalid JWT")
+        self.assertEqual(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+        self.assertEqual(
+            channel.json_body["error"],
+            "JWT validation failed: The token is not yet valid (nbf)",
+        )
 
     def test_login_no_sub(self):
         channel = self.jwt_login({"username": "root"})
-        self.assertEqual(channel.result["code"], b"401", channel.result)
-        self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+        self.assertEqual(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
         self.assertEqual(channel.json_body["error"], "Invalid JWT")
 
+    @override_config(
+        {
+            "jwt_config": {
+                "jwt_enabled": True,
+                "secret": jwt_secret,
+                "algorithm": jwt_algorithm,
+                "issuer": "test-issuer",
+            }
+        }
+    )
+    def test_login_iss(self):
+        """Test validating the issuer claim."""
+        # A valid issuer.
+        channel = self.jwt_login({"sub": "kermit", "iss": "test-issuer"})
+        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.json_body["user_id"], "@kermit:test")
+
+        # An invalid issuer.
+        channel = self.jwt_login({"sub": "kermit", "iss": "invalid"})
+        self.assertEqual(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+        self.assertEqual(
+            channel.json_body["error"], "JWT validation failed: Invalid issuer"
+        )
+
+        # Not providing an issuer.
+        channel = self.jwt_login({"sub": "kermit"})
+        self.assertEqual(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+        self.assertEqual(
+            channel.json_body["error"],
+            'JWT validation failed: Token is missing the "iss" claim',
+        )
+
+    def test_login_iss_no_config(self):
+        """Test providing an issuer claim without requiring it in the configuration."""
+        channel = self.jwt_login({"sub": "kermit", "iss": "invalid"})
+        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.json_body["user_id"], "@kermit:test")
+
+    @override_config(
+        {
+            "jwt_config": {
+                "jwt_enabled": True,
+                "secret": jwt_secret,
+                "algorithm": jwt_algorithm,
+                "audiences": ["test-audience"],
+            }
+        }
+    )
+    def test_login_aud(self):
+        """Test validating the audience claim."""
+        # A valid audience.
+        channel = self.jwt_login({"sub": "kermit", "aud": "test-audience"})
+        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.json_body["user_id"], "@kermit:test")
+
+        # An invalid audience.
+        channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
+        self.assertEqual(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+        self.assertEqual(
+            channel.json_body["error"], "JWT validation failed: Invalid audience"
+        )
+
+        # Not providing an audience.
+        channel = self.jwt_login({"sub": "kermit"})
+        self.assertEqual(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+        self.assertEqual(
+            channel.json_body["error"],
+            'JWT validation failed: Token is missing the "aud" claim',
+        )
+
+    def test_login_aud_no_config(self):
+        """Test providing an audience without requiring it in the configuration."""
+        channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
+        self.assertEqual(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+        self.assertEqual(
+            channel.json_body["error"], "JWT validation failed: Invalid audience"
+        )
+
     def test_login_no_token(self):
-        params = json.dumps({"type": "m.login.jwt"})
+        params = json.dumps({"type": "org.matrix.login.jwt"})
         request, channel = self.make_request(b"POST", LOGIN_URL, params)
         self.render(request)
-        self.assertEqual(channel.result["code"], b"401", channel.result)
-        self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+        self.assertEqual(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
         self.assertEqual(channel.json_body["error"], "Token field for JWT is missing")
 
 
@@ -640,7 +728,9 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
         return jwt.encode(token, secret, "RS256").decode("ascii")
 
     def jwt_login(self, *args):
-        params = json.dumps({"type": "m.login.jwt", "token": self.jwt_encode(*args)})
+        params = json.dumps(
+            {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
+        )
         request, channel = self.make_request(b"POST", LOGIN_URL, params)
         self.render(request)
         return channel
@@ -652,6 +742,9 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
 
     def test_login_jwt_invalid_signature(self):
         channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey)
-        self.assertEqual(channel.result["code"], b"401", channel.result)
-        self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
-        self.assertEqual(channel.json_body["error"], "Invalid JWT")
+        self.assertEqual(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+        self.assertEqual(
+            channel.json_body["error"],
+            "JWT validation failed: Signature verification failed",
+        )
diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py
index 0fdff79aa7..3c66255dac 100644
--- a/tests/rest/client/v1/test_presence.py
+++ b/tests/rest/client/v1/test_presence.py
@@ -60,7 +60,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
 
     def test_put_presence_disabled(self):
         """
-        PUT to the status endpoint with use_presence disbled will NOT call
+        PUT to the status endpoint with use_presence disabled will NOT call
         set_state on the presence handler.
         """
         self.hs.config.use_presence = False
diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py
index 8df58b4a63..ace0a3c08d 100644
--- a/tests/rest/client/v1/test_profile.py
+++ b/tests/rest/client/v1/test_profile.py
@@ -70,8 +70,8 @@ class MockHandlerProfileTestCase(unittest.TestCase):
             profile_handler=self.mock_handler,
         )
 
-        def _get_user_by_req(request=None, allow_guest=False):
-            return defer.succeed(synapse.types.create_requester(myid))
+        async def _get_user_by_req(request=None, allow_guest=False):
+            return synapse.types.create_requester(myid)
 
         hs.get_auth().get_user_by_req = _get_user_by_req
 
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 4886bbb401..0a567b032f 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -19,18 +19,16 @@
 """Tests REST events for /rooms paths."""
 
 import json
+from urllib import parse as urlparse
 
 from mock import Mock
-from six.moves.urllib import parse as urlparse
-
-from twisted.internet import defer
 
 import synapse.rest.admin
 from synapse.api.constants import EventContentFields, EventTypes, Membership
 from synapse.handlers.pagination import PurgeStatus
 from synapse.rest.client.v1 import directory, login, profile, room
 from synapse.rest.client.v2_alpha import account
-from synapse.types import JsonDict, RoomAlias
+from synapse.types import JsonDict, RoomAlias, UserID
 from synapse.util.stringutils import random_string
 
 from tests import unittest
@@ -51,8 +49,8 @@ class RoomBase(unittest.HomeserverTestCase):
 
         self.hs.get_federation_handler = Mock(return_value=Mock())
 
-        def _insert_client_ip(*args, **kwargs):
-            return defer.succeed(None)
+        async def _insert_client_ip(*args, **kwargs):
+            return None
 
         self.hs.get_datastore().insert_client_ip = _insert_client_ip
 
@@ -677,6 +675,92 @@ class RoomMemberStateTestCase(RoomBase):
         self.assertEquals(json.loads(content), channel.json_body)
 
 
+class RoomJoinRatelimitTestCase(RoomBase):
+    user_id = "@sid1:red"
+
+    servlets = [
+        profile.register_servlets,
+        room.register_servlets,
+    ]
+
+    @unittest.override_config(
+        {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
+    )
+    def test_join_local_ratelimit(self):
+        """Tests that local joins are actually rate-limited."""
+        for i in range(3):
+            self.helper.create_room_as(self.user_id)
+
+        self.helper.create_room_as(self.user_id, expect_code=429)
+
+    @unittest.override_config(
+        {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
+    )
+    def test_join_local_ratelimit_profile_change(self):
+        """Tests that sending a profile update into all of the user's joined rooms isn't
+        rate-limited by the rate-limiter on joins."""
+
+        # Create and join as many rooms as the rate-limiting config allows in a second.
+        room_ids = [
+            self.helper.create_room_as(self.user_id),
+            self.helper.create_room_as(self.user_id),
+            self.helper.create_room_as(self.user_id),
+        ]
+        # Let some time for the rate-limiter to forget about our multi-join.
+        self.reactor.advance(2)
+        # Add one to make sure we're joined to more rooms than the config allows us to
+        # join in a second.
+        room_ids.append(self.helper.create_room_as(self.user_id))
+
+        # Create a profile for the user, since it hasn't been done on registration.
+        store = self.hs.get_datastore()
+        self.get_success(
+            store.create_profile(UserID.from_string(self.user_id).localpart)
+        )
+
+        # Update the display name for the user.
+        path = "/_matrix/client/r0/profile/%s/displayname" % self.user_id
+        request, channel = self.make_request("PUT", path, {"displayname": "John Doe"})
+        self.render(request)
+        self.assertEquals(channel.code, 200, channel.json_body)
+
+        # Check that all the rooms have been sent a profile update into.
+        for room_id in room_ids:
+            path = "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % (
+                room_id,
+                self.user_id,
+            )
+
+            request, channel = self.make_request("GET", path)
+            self.render(request)
+            self.assertEquals(channel.code, 200)
+
+            self.assertIn("displayname", channel.json_body)
+            self.assertEquals(channel.json_body["displayname"], "John Doe")
+
+    @unittest.override_config(
+        {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
+    )
+    def test_join_local_ratelimit_idempotent(self):
+        """Tests that the room join endpoints remain idempotent despite rate-limiting
+        on room joins."""
+        room_id = self.helper.create_room_as(self.user_id)
+
+        # Let's test both paths to be sure.
+        paths_to_test = [
+            "/_matrix/client/r0/rooms/%s/join",
+            "/_matrix/client/r0/join/%s",
+        ]
+
+        for path in paths_to_test:
+            # Make sure we send more requests than the rate-limiting config would allow
+            # if all of these requests ended up joining the user to a room.
+            for i in range(4):
+                request, channel = self.make_request("POST", path % room_id, {})
+                self.render(request)
+                self.assertEquals(channel.code, 200)
+
+
 class RoomMessagesTestCase(RoomBase):
     """ Tests /rooms/$room_id/messages/$user_id/$msg_id REST events. """
 
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index 18260bb90e..94d2bf2eb1 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -46,7 +46,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
 
         hs.get_handlers().federation_handler = Mock()
 
-        def get_user_by_access_token(token=None, allow_guest=False):
+        async def get_user_by_access_token(token=None, allow_guest=False):
             return {
                 "user": UserID.from_string(self.auth_user_id),
                 "token_id": 1,
@@ -55,8 +55,8 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
 
         hs.get_auth().get_user_by_access_token = get_user_by_access_token
 
-        def _insert_client_ip(*args, **kwargs):
-            return defer.succeed(None)
+        async def _insert_client_ip(*args, **kwargs):
+            return None
 
         hs.get_datastore().insert_client_ip = _insert_client_ip
 
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 22d734e763..afaf9f7b85 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -30,7 +30,7 @@ from tests.server import make_request, render
 
 
 @attr.s
-class RestHelper(object):
+class RestHelper:
     """Contains extra helper functions to quickly and clearly perform a given
     REST action, which isn't the focus of the test.
     """
@@ -39,7 +39,9 @@ class RestHelper(object):
     resource = attr.ib()
     auth_user_id = attr.ib()
 
-    def create_room_as(self, room_creator=None, is_public=True, tok=None):
+    def create_room_as(
+        self, room_creator=None, is_public=True, tok=None, expect_code=200,
+    ):
         temp_id = self.auth_user_id
         self.auth_user_id = room_creator
         path = "/_matrix/client/r0/createRoom"
@@ -54,9 +56,11 @@ class RestHelper(object):
         )
         render(request, self.resource, self.hs.get_reactor())
 
-        assert channel.result["code"] == b"200", channel.result
+        assert channel.result["code"] == b"%d" % expect_code, channel.result
         self.auth_user_id = temp_id
-        return channel.json_body["room_id"]
+
+        if expect_code == 200:
+            return channel.json_body["room_id"]
 
     def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None):
         self.change_membership(
@@ -88,7 +92,28 @@ class RestHelper(object):
             expect_code=expect_code,
         )
 
-    def change_membership(self, room, src, targ, membership, tok=None, expect_code=200):
+    def change_membership(
+        self,
+        room: str,
+        src: str,
+        targ: str,
+        membership: str,
+        extra_data: dict = {},
+        tok: Optional[str] = None,
+        expect_code: int = 200,
+    ) -> None:
+        """
+        Send a membership state event into a room.
+
+        Args:
+            room: The ID of the room to send to
+            src: The mxid of the event sender
+            targ: The mxid of the event's target. The state key
+            membership: The type of membership event
+            extra_data: Extra information to include in the content of the event
+            tok: The user access token to use
+            expect_code: The expected HTTP response code
+        """
         temp_id = self.auth_user_id
         self.auth_user_id = src
 
@@ -97,6 +122,7 @@ class RestHelper(object):
             path = path + "?access_token=%s" % tok
 
         data = {"membership": membership}
+        data.update(extra_data)
 
         request, channel = make_request(
             self.hs.get_reactor(), "PUT", path, json.dumps(data).encode("utf8")
diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py
index 3ab611f618..152a5182fa 100644
--- a/tests/rest/client/v2_alpha/test_account.py
+++ b/tests/rest/client/v2_alpha/test_account.py
@@ -108,6 +108,46 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
         # Assert we can't log in with the old password
         self.attempt_wrong_password_login("kermit", old_password)
 
+    def test_basic_password_reset_canonicalise_email(self):
+        """Test basic password reset flow
+        Request password reset with different spelling
+        """
+        old_password = "monkey"
+        new_password = "kangeroo"
+
+        user_id = self.register_user("kermit", old_password)
+        self.login("kermit", old_password)
+
+        email_profile = "test@example.com"
+        email_passwort_reset = "TEST@EXAMPLE.COM"
+
+        # Add a threepid
+        self.get_success(
+            self.store.user_add_threepid(
+                user_id=user_id,
+                medium="email",
+                address=email_profile,
+                validated_at=0,
+                added_at=0,
+            )
+        )
+
+        client_secret = "foobar"
+        session_id = self._request_token(email_passwort_reset, client_secret)
+
+        self.assertEquals(len(self.email_attempts), 1)
+        link = self._get_link_from_email()
+
+        self._validate_token(link)
+
+        self._reset_password(new_password, session_id, client_secret)
+
+        # Assert we can log in with the new password
+        self.login("kermit", new_password)
+
+        # Assert we can't log in with the old password
+        self.attempt_wrong_password_login("kermit", old_password)
+
     def test_cant_reset_password_without_clicking_link(self):
         """Test that we do actually need to click the link in the email
         """
@@ -386,44 +426,67 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
         self.email = "test@example.com"
         self.url_3pid = b"account/3pid"
 
-    def test_add_email(self):
-        """Test adding an email to profile
-        """
-        client_secret = "foobar"
-        session_id = self._request_token(self.email, client_secret)
+    def test_add_valid_email(self):
+        self.get_success(self._add_email(self.email, self.email))
 
-        self.assertEquals(len(self.email_attempts), 1)
-        link = self._get_link_from_email()
+    def test_add_valid_email_second_time(self):
+        self.get_success(self._add_email(self.email, self.email))
+        self.get_success(
+            self._request_token_invalid_email(
+                self.email,
+                expected_errcode=Codes.THREEPID_IN_USE,
+                expected_error="Email is already in use",
+            )
+        )
 
-        self._validate_token(link)
+    def test_add_valid_email_second_time_canonicalise(self):
+        self.get_success(self._add_email(self.email, self.email))
+        self.get_success(
+            self._request_token_invalid_email(
+                "TEST@EXAMPLE.COM",
+                expected_errcode=Codes.THREEPID_IN_USE,
+                expected_error="Email is already in use",
+            )
+        )
 
-        request, channel = self.make_request(
-            "POST",
-            b"/_matrix/client/unstable/account/3pid/add",
-            {
-                "client_secret": client_secret,
-                "sid": session_id,
-                "auth": {
-                    "type": "m.login.password",
-                    "user": self.user_id,
-                    "password": "test",
-                },
-            },
-            access_token=self.user_id_tok,
+    def test_add_email_no_at(self):
+        self.get_success(
+            self._request_token_invalid_email(
+                "address-without-at.bar",
+                expected_errcode=Codes.UNKNOWN,
+                expected_error="Unable to parse email address",
+            )
         )
 
-        self.render(request)
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+    def test_add_email_two_at(self):
+        self.get_success(
+            self._request_token_invalid_email(
+                "foo@foo@test.bar",
+                expected_errcode=Codes.UNKNOWN,
+                expected_error="Unable to parse email address",
+            )
+        )
 
-        # Get user
-        request, channel = self.make_request(
-            "GET", self.url_3pid, access_token=self.user_id_tok,
+    def test_add_email_bad_format(self):
+        self.get_success(
+            self._request_token_invalid_email(
+                "user@bad.example.net@good.example.com",
+                expected_errcode=Codes.UNKNOWN,
+                expected_error="Unable to parse email address",
+            )
         )
-        self.render(request)
 
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-        self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
-        self.assertEqual(self.email, channel.json_body["threepids"][0]["address"])
+    def test_add_email_domain_to_lower(self):
+        self.get_success(self._add_email("foo@TEST.BAR", "foo@test.bar"))
+
+    def test_add_email_domain_with_umlaut(self):
+        self.get_success(self._add_email("foo@Öumlaut.com", "foo@öumlaut.com"))
+
+    def test_add_email_address_casefold(self):
+        self.get_success(self._add_email("Strauß@Example.com", "strauss@example.com"))
+
+    def test_address_trim(self):
+        self.get_success(self._add_email(" foo@test.bar ", "foo@test.bar"))
 
     def test_add_email_if_disabled(self):
         """Test adding email to profile when doing so is disallowed
@@ -616,6 +679,19 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
 
         return channel.json_body["sid"]
 
+    def _request_token_invalid_email(
+        self, email, expected_errcode, expected_error, client_secret="foobar",
+    ):
+        request, channel = self.make_request(
+            "POST",
+            b"account/3pid/email/requestToken",
+            {"client_secret": client_secret, "email": email, "send_attempt": 1},
+        )
+        self.render(request)
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(expected_errcode, channel.json_body["errcode"])
+        self.assertEqual(expected_error, channel.json_body["error"])
+
     def _validate_token(self, link):
         # Remove the host
         path = link.replace("https://example.com", "")
@@ -643,3 +719,42 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
         assert match, "Could not find link in email"
 
         return match.group(0)
+
+    def _add_email(self, request_email, expected_email):
+        """Test adding an email to profile
+        """
+        client_secret = "foobar"
+        session_id = self._request_token(request_email, client_secret)
+
+        self.assertEquals(len(self.email_attempts), 1)
+        link = self._get_link_from_email()
+
+        self._validate_token(link)
+
+        request, channel = self.make_request(
+            "POST",
+            b"/_matrix/client/unstable/account/3pid/add",
+            {
+                "client_secret": client_secret,
+                "sid": session_id,
+                "auth": {
+                    "type": "m.login.password",
+                    "user": self.user_id,
+                    "password": "test",
+                },
+            },
+            access_token=self.user_id_tok,
+        )
+
+        self.render(request)
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+        # Get user
+        request, channel = self.make_request(
+            "GET", self.url_3pid, access_token=self.user_id_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
+        self.assertEqual(expected_email, channel.json_body["threepids"][0]["address"])
diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py
index e0e9e94fbf..de00350580 100644
--- a/tests/rest/client/v2_alpha/test_filter.py
+++ b/tests/rest/client/v2_alpha/test_filter.py
@@ -13,6 +13,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from twisted.internet import defer
+
 from synapse.api.errors import Codes
 from synapse.rest.client.v2_alpha import filter
 
@@ -73,8 +75,10 @@ class FilterTestCase(unittest.HomeserverTestCase):
         self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN)
 
     def test_get_filter(self):
-        filter_id = self.filtering.add_user_filter(
-            user_localpart="apple", user_filter=self.EXAMPLE_FILTER
+        filter_id = defer.ensureDeferred(
+            self.filtering.add_user_filter(
+                user_localpart="apple", user_filter=self.EXAMPLE_FILTER
+            )
         )
         self.reactor.advance(1)
         filter_id = filter_id.result
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 7deaf5b24a..2fc3a60fc5 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -116,8 +116,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
         self.assertEquals(channel.result["code"], b"200", channel.result)
         self.assertDictContainsSubset(det_data, channel.json_body)
 
+    @override_config({"enable_registration": False})
     def test_POST_disabled_registration(self):
-        self.hs.config.enable_registration = False
         request_data = json.dumps({"username": "kermit", "password": "monkey"})
         self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
 
@@ -160,7 +160,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
             else:
                 self.assertEquals(channel.result["code"], b"200", channel.result)
 
-        self.reactor.advance(retry_after_ms / 1000.0)
+        self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
 
         request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
         self.render(request)
@@ -186,7 +186,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
             else:
                 self.assertEquals(channel.result["code"], b"200", channel.result)
 
-        self.reactor.advance(retry_after_ms / 1000.0)
+        self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
 
         request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
         self.render(request)
diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py
index c7e5859970..99c9f4e928 100644
--- a/tests/rest/client/v2_alpha/test_relations.py
+++ b/tests/rest/client/v2_alpha/test_relations.py
@@ -15,8 +15,7 @@
 
 import itertools
 import json
-
-import six
+import urllib
 
 from synapse.api.constants import EventTypes, RelationTypes
 from synapse.rest import admin
@@ -100,7 +99,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         self.assertEquals(400, channel.code, channel.json_body)
 
     def test_basic_paginate_relations(self):
-        """Tests that calling pagination API corectly the latest relations.
+        """Tests that calling pagination API correctly the latest relations.
         """
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction")
         self.assertEquals(200, channel.code, channel.json_body)
@@ -134,7 +133,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         # Make sure next_batch has something in it that looks like it could be a
         # valid token.
         self.assertIsInstance(
-            channel.json_body.get("next_batch"), six.string_types, channel.json_body
+            channel.json_body.get("next_batch"), str, channel.json_body
         )
 
     def test_repeated_paginate_relations(self):
@@ -278,7 +277,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
 
         prev_token = None
         found_event_ids = []
-        encoded_key = six.moves.urllib.parse.quote_plus("👍".encode("utf-8"))
+        encoded_key = urllib.parse.quote_plus("👍".encode("utf-8"))
         for _ in range(20):
             from_token = ""
             if prev_token:
@@ -670,7 +669,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
 
         query = ""
         if key:
-            query = "?key=" + six.moves.urllib.parse.quote_plus(key.encode("utf-8"))
+            query = "?key=" + urllib.parse.quote_plus(key.encode("utf-8"))
 
         original_id = parent_id if parent_id else self.parent_id
 
diff --git a/tests/rest/client/v2_alpha/test_shared_rooms.py b/tests/rest/client/v2_alpha/test_shared_rooms.py
new file mode 100644
index 0000000000..5ae72fd008
--- /dev/null
+++ b/tests/rest/client/v2_alpha/test_shared_rooms.py
@@ -0,0 +1,138 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Half-Shot
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import synapse.rest.admin
+from synapse.rest.client.v1 import login, room
+from synapse.rest.client.v2_alpha import shared_rooms
+
+from tests import unittest
+
+
+class UserSharedRoomsTest(unittest.HomeserverTestCase):
+    """
+    Tests the UserSharedRoomsServlet.
+    """
+
+    servlets = [
+        login.register_servlets,
+        synapse.rest.admin.register_servlets_for_client_rest_resource,
+        room.register_servlets,
+        shared_rooms.register_servlets,
+    ]
+
+    def make_homeserver(self, reactor, clock):
+        config = self.default_config()
+        config["update_user_directory"] = True
+        return self.setup_test_homeserver(config=config)
+
+    def prepare(self, reactor, clock, hs):
+        self.store = hs.get_datastore()
+        self.handler = hs.get_user_directory_handler()
+
+    def _get_shared_rooms(self, token, other_user):
+        request, channel = self.make_request(
+            "GET",
+            "/_matrix/client/unstable/uk.half-shot.msc2666/user/shared_rooms/%s"
+            % other_user,
+            access_token=token,
+        )
+        self.render(request)
+        return request, channel
+
+    def test_shared_room_list_public(self):
+        """
+        A room should show up in the shared list of rooms between two users
+        if it is public.
+        """
+        u1 = self.register_user("user1", "pass")
+        u1_token = self.login(u1, "pass")
+        u2 = self.register_user("user2", "pass")
+        u2_token = self.login(u2, "pass")
+
+        room = self.helper.create_room_as(u1, is_public=True, tok=u1_token)
+        self.helper.invite(room, src=u1, targ=u2, tok=u1_token)
+        self.helper.join(room, user=u2, tok=u2_token)
+
+        request, channel = self._get_shared_rooms(u1_token, u2)
+        self.assertEquals(200, channel.code, channel.result)
+        self.assertEquals(len(channel.json_body["joined"]), 1)
+        self.assertEquals(channel.json_body["joined"][0], room)
+
+    def test_shared_room_list_private(self):
+        """
+        A room should show up in the shared list of rooms between two users
+        if it is private.
+        """
+        u1 = self.register_user("user1", "pass")
+        u1_token = self.login(u1, "pass")
+        u2 = self.register_user("user2", "pass")
+        u2_token = self.login(u2, "pass")
+
+        room = self.helper.create_room_as(u1, is_public=False, tok=u1_token)
+        self.helper.invite(room, src=u1, targ=u2, tok=u1_token)
+        self.helper.join(room, user=u2, tok=u2_token)
+
+        request, channel = self._get_shared_rooms(u1_token, u2)
+        self.assertEquals(200, channel.code, channel.result)
+        self.assertEquals(len(channel.json_body["joined"]), 1)
+        self.assertEquals(channel.json_body["joined"][0], room)
+
+    def test_shared_room_list_mixed(self):
+        """
+        The shared room list between two users should contain both public and private
+        rooms.
+        """
+        u1 = self.register_user("user1", "pass")
+        u1_token = self.login(u1, "pass")
+        u2 = self.register_user("user2", "pass")
+        u2_token = self.login(u2, "pass")
+
+        room_public = self.helper.create_room_as(u1, is_public=True, tok=u1_token)
+        room_private = self.helper.create_room_as(u2, is_public=False, tok=u2_token)
+        self.helper.invite(room_public, src=u1, targ=u2, tok=u1_token)
+        self.helper.invite(room_private, src=u2, targ=u1, tok=u2_token)
+        self.helper.join(room_public, user=u2, tok=u2_token)
+        self.helper.join(room_private, user=u1, tok=u1_token)
+
+        request, channel = self._get_shared_rooms(u1_token, u2)
+        self.assertEquals(200, channel.code, channel.result)
+        self.assertEquals(len(channel.json_body["joined"]), 2)
+        self.assertTrue(room_public in channel.json_body["joined"])
+        self.assertTrue(room_private in channel.json_body["joined"])
+
+    def test_shared_room_list_after_leave(self):
+        """
+        A room should no longer be considered shared if the other
+        user has left it.
+        """
+        u1 = self.register_user("user1", "pass")
+        u1_token = self.login(u1, "pass")
+        u2 = self.register_user("user2", "pass")
+        u2_token = self.login(u2, "pass")
+
+        room = self.helper.create_room_as(u1, is_public=True, tok=u1_token)
+        self.helper.invite(room, src=u1, targ=u2, tok=u1_token)
+        self.helper.join(room, user=u2, tok=u2_token)
+
+        # Assert user directory is not empty
+        request, channel = self._get_shared_rooms(u1_token, u2)
+        self.assertEquals(200, channel.code, channel.result)
+        self.assertEquals(len(channel.json_body["joined"]), 1)
+        self.assertEquals(channel.json_body["joined"][0], room)
+
+        self.helper.leave(room, user=u1, tok=u1_token)
+
+        request, channel = self._get_shared_rooms(u2_token, u1)
+        self.assertEquals(200, channel.code, channel.result)
+        self.assertEquals(len(channel.json_body["joined"]), 0)
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py
index fa3a3ec1bd..a31e44c97e 100644
--- a/tests/rest/client/v2_alpha/test_sync.py
+++ b/tests/rest/client/v2_alpha/test_sync.py
@@ -16,9 +16,9 @@
 import json
 
 import synapse.rest.admin
-from synapse.api.constants import EventContentFields, EventTypes
+from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
 from synapse.rest.client.v1 import login, room
-from synapse.rest.client.v2_alpha import sync
+from synapse.rest.client.v2_alpha import read_marker, sync
 
 from tests import unittest
 from tests.server import TimedOutException
@@ -324,3 +324,156 @@ class SyncTypingTests(unittest.HomeserverTestCase):
             "GET", sync_url % (access_token, next_batch)
         )
         self.assertRaises(TimedOutException, self.render, request)
+
+
+class UnreadMessagesTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        synapse.rest.admin.register_servlets,
+        login.register_servlets,
+        read_marker.register_servlets,
+        room.register_servlets,
+        sync.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, hs):
+        self.url = "/sync?since=%s"
+        self.next_batch = "s0"
+
+        # Register the first user (used to check the unread counts).
+        self.user_id = self.register_user("kermit", "monkey")
+        self.tok = self.login("kermit", "monkey")
+
+        # Create the room we'll check unread counts for.
+        self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
+
+        # Register the second user (used to send events to the room).
+        self.user2 = self.register_user("kermit2", "monkey")
+        self.tok2 = self.login("kermit2", "monkey")
+
+        # Change the power levels of the room so that the second user can send state
+        # events.
+        self.helper.send_state(
+            self.room_id,
+            EventTypes.PowerLevels,
+            {
+                "users": {self.user_id: 100, self.user2: 100},
+                "users_default": 0,
+                "events": {
+                    "m.room.name": 50,
+                    "m.room.power_levels": 100,
+                    "m.room.history_visibility": 100,
+                    "m.room.canonical_alias": 50,
+                    "m.room.avatar": 50,
+                    "m.room.tombstone": 100,
+                    "m.room.server_acl": 100,
+                    "m.room.encryption": 100,
+                },
+                "events_default": 0,
+                "state_default": 50,
+                "ban": 50,
+                "kick": 50,
+                "redact": 50,
+                "invite": 0,
+            },
+            tok=self.tok,
+        )
+
+    def test_unread_counts(self):
+        """Tests that /sync returns the right value for the unread count (MSC2654)."""
+
+        # Check that our own messages don't increase the unread count.
+        self.helper.send(self.room_id, "hello", tok=self.tok)
+        self._check_unread_count(0)
+
+        # Join the new user and check that this doesn't increase the unread count.
+        self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2)
+        self._check_unread_count(0)
+
+        # Check that the new user sending a message increases our unread count.
+        res = self.helper.send(self.room_id, "hello", tok=self.tok2)
+        self._check_unread_count(1)
+
+        # Send a read receipt to tell the server we've read the latest event.
+        body = json.dumps({"m.read": res["event_id"]}).encode("utf8")
+        request, channel = self.make_request(
+            "POST",
+            "/rooms/%s/read_markers" % self.room_id,
+            body,
+            access_token=self.tok,
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        # Check that the unread counter is back to 0.
+        self._check_unread_count(0)
+
+        # Check that room name changes increase the unread counter.
+        self.helper.send_state(
+            self.room_id, "m.room.name", {"name": "my super room"}, tok=self.tok2,
+        )
+        self._check_unread_count(1)
+
+        # Check that room topic changes increase the unread counter.
+        self.helper.send_state(
+            self.room_id, "m.room.topic", {"topic": "welcome!!!"}, tok=self.tok2,
+        )
+        self._check_unread_count(2)
+
+        # Check that encrypted messages increase the unread counter.
+        self.helper.send_event(self.room_id, EventTypes.Encrypted, {}, tok=self.tok2)
+        self._check_unread_count(3)
+
+        # Check that custom events with a body increase the unread counter.
+        self.helper.send_event(
+            self.room_id, "org.matrix.custom_type", {"body": "hello"}, tok=self.tok2,
+        )
+        self._check_unread_count(4)
+
+        # Check that edits don't increase the unread counter.
+        self.helper.send_event(
+            room_id=self.room_id,
+            type=EventTypes.Message,
+            content={
+                "body": "hello",
+                "msgtype": "m.text",
+                "m.relates_to": {"rel_type": RelationTypes.REPLACE},
+            },
+            tok=self.tok2,
+        )
+        self._check_unread_count(4)
+
+        # Check that notices don't increase the unread counter.
+        self.helper.send_event(
+            room_id=self.room_id,
+            type=EventTypes.Message,
+            content={"body": "hello", "msgtype": "m.notice"},
+            tok=self.tok2,
+        )
+        self._check_unread_count(4)
+
+        # Check that tombstone events changes increase the unread counter.
+        self.helper.send_state(
+            self.room_id,
+            EventTypes.Tombstone,
+            {"replacement_room": "!someroom:test"},
+            tok=self.tok2,
+        )
+        self._check_unread_count(5)
+
+    def _check_unread_count(self, expected_count: True):
+        """Syncs and compares the unread count with the expected value."""
+
+        request, channel = self.make_request(
+            "GET", self.url % self.next_batch, access_token=self.tok,
+        )
+        self.render(request)
+
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        room_entry = channel.json_body["rooms"]["join"][self.room_id]
+        self.assertEqual(
+            room_entry["org.matrix.msc2654.unread_count"], expected_count, room_entry,
+        )
+
+        # Store the next batch for the next request.
+        self.next_batch = channel.json_body["next_batch"]
diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py
index 99eb477149..6850c666be 100644
--- a/tests/rest/key/v2/test_remote_key_resource.py
+++ b/tests/rest/key/v2/test_remote_key_resource.py
@@ -53,7 +53,7 @@ class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase):
         Tell the mock http client to expect an outgoing GET request for the given key
         """
 
-        def get_json(destination, path, ignore_backoff=False, **kwargs):
+        async def get_json(destination, path, ignore_backoff=False, **kwargs):
             self.assertTrue(ignore_backoff)
             self.assertEqual(destination, server_name)
             key_id = "%s:%s" % (signing_key.alg, signing_key.version)
@@ -177,7 +177,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
 
         # wire up outbound POST /key/v2/query requests from hs2 so that they
         # will be forwarded to hs1
-        def post_json(destination, path, data):
+        async def post_json(destination, path, data):
             self.assertEqual(destination, self.hs.hostname)
             self.assertEqual(
                 path, "/_matrix/key/v2/query",
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index 1ca648ef2b..f4f3e56777 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -12,22 +12,21 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
-
 import os
 import shutil
 import tempfile
 from binascii import unhexlify
 from io import BytesIO
 from typing import Optional
+from urllib import parse
 
 from mock import Mock
-from six.moves.urllib import parse
 
 import attr
-import PIL.Image as Image
 from parameterized import parameterized_class
+from PIL import Image as Image
 
+from twisted.internet import defer
 from twisted.internet.defer import Deferred
 
 from synapse.logging.context import make_deferred_yieldable
@@ -79,7 +78,9 @@ class MediaStorageTests(unittest.HomeserverTestCase):
 
         # This uses a real blocking threadpool so we have to wait for it to be
         # actually done :/
-        x = self.media_storage.ensure_media_is_in_local_cache(file_info)
+        x = defer.ensureDeferred(
+            self.media_storage.ensure_media_is_in_local_cache(file_info)
+        )
 
         # Hotloop until the threadpool does its job...
         self.wait_on_thread(x)
@@ -232,7 +233,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
         self.assertEqual(len(self.fetches), 1)
         self.assertEqual(self.fetches[0][1], "example.com")
         self.assertEqual(
-            self.fetches[0][2], "/_matrix/media/v1/download/" + self.media_id
+            self.fetches[0][2], "/_matrix/media/r0/download/" + self.media_id
         )
         self.assertEqual(self.fetches[0][3], {"allow_remote": "false"})
 
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index 2826211f32..c00a7b9114 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -12,8 +12,11 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
+import json
 import os
+import re
+
+from mock import patch
 
 import attr
 
@@ -29,7 +32,7 @@ from tests.server import FakeTransport
 
 
 @attr.s
-class FakeResponse(object):
+class FakeResponse:
     version = attr.ib()
     code = attr.ib()
     phrase = attr.ib()
@@ -40,7 +43,7 @@ class FakeResponse(object):
     @property
     def request(self):
         @attr.s
-        class FakeTransport(object):
+        class FakeTransport:
             absoluteURI = self.absoluteURI
 
         return FakeTransport()
@@ -108,7 +111,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
 
         self.lookups = {}
 
-        class Resolver(object):
+        class Resolver:
             def resolveHostName(
                 _self,
                 resolutionReceiver,
@@ -131,7 +134,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         self.reactor.nameResolver = Resolver()
 
     def test_cache_returns_correct_type(self):
-        self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")]
+        self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
 
         request, channel = self.make_request(
             "GET", "url_preview?url=http://matrix.org", shorthand=False
@@ -187,7 +190,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         )
 
     def test_non_ascii_preview_httpequiv(self):
-        self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")]
+        self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
 
         end_content = (
             b"<html><head>"
@@ -221,7 +224,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430")
 
     def test_non_ascii_preview_content_type(self):
-        self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")]
+        self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
 
         end_content = (
             b"<html><head>"
@@ -254,7 +257,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430")
 
     def test_overlong_title(self):
-        self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")]
+        self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
 
         end_content = (
             b"<html><head>"
@@ -292,7 +295,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         """
         IP addresses can be previewed directly.
         """
-        self.lookups["example.com"] = [(IPv4Address, "8.8.8.8")]
+        self.lookups["example.com"] = [(IPv4Address, "10.1.2.3")]
 
         request, channel = self.make_request(
             "GET", "url_preview?url=http://example.com", shorthand=False
@@ -439,7 +442,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         # Hardcode the URL resolving to the IP we want.
         self.lookups["example.com"] = [
             (IPv4Address, "1.1.1.2"),
-            (IPv4Address, "8.8.8.8"),
+            (IPv4Address, "10.1.2.3"),
         ]
 
         request, channel = self.make_request(
@@ -518,7 +521,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
         """
         Accept-Language header is sent to the remote server
         """
-        self.lookups["example.com"] = [(IPv4Address, "8.8.8.8")]
+        self.lookups["example.com"] = [(IPv4Address, "10.1.2.3")]
 
         # Build and make a request to the server
         request, channel = self.make_request(
@@ -562,3 +565,126 @@ class URLPreviewTests(unittest.HomeserverTestCase):
             ),
             server.data,
         )
+
+    def test_oembed_photo(self):
+        """Test an oEmbed endpoint which returns a 'photo' type which redirects the preview to a new URL."""
+        # Route the HTTP version to an HTTP endpoint so that the tests work.
+        with patch.dict(
+            "synapse.rest.media.v1.preview_url_resource._oembed_patterns",
+            {
+                re.compile(
+                    r"http://twitter\.com/.+/status/.+"
+                ): "http://publish.twitter.com/oembed",
+            },
+            clear=True,
+        ):
+
+            self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")]
+            self.lookups["cdn.twitter.com"] = [(IPv4Address, "10.1.2.3")]
+
+            result = {
+                "version": "1.0",
+                "type": "photo",
+                "url": "http://cdn.twitter.com/matrixdotorg",
+            }
+            oembed_content = json.dumps(result).encode("utf-8")
+
+            end_content = (
+                b"<html><head>"
+                b"<title>Some Title</title>"
+                b'<meta property="og:description" content="hi" />'
+                b"</head></html>"
+            )
+
+            request, channel = self.make_request(
+                "GET",
+                "url_preview?url=http://twitter.com/matrixdotorg/status/12345",
+                shorthand=False,
+            )
+            request.render(self.preview_url)
+            self.pump()
+
+            client = self.reactor.tcpClients[0][2].buildProtocol(None)
+            server = AccumulatingProtocol()
+            server.makeConnection(FakeTransport(client, self.reactor))
+            client.makeConnection(FakeTransport(server, self.reactor))
+            client.dataReceived(
+                (
+                    b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+                    b'Content-Type: application/json; charset="utf8"\r\n\r\n'
+                )
+                % (len(oembed_content),)
+                + oembed_content
+            )
+
+            self.pump()
+
+            client = self.reactor.tcpClients[1][2].buildProtocol(None)
+            server = AccumulatingProtocol()
+            server.makeConnection(FakeTransport(client, self.reactor))
+            client.makeConnection(FakeTransport(server, self.reactor))
+            client.dataReceived(
+                (
+                    b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+                    b'Content-Type: text/html; charset="utf8"\r\n\r\n'
+                )
+                % (len(end_content),)
+                + end_content
+            )
+
+            self.pump()
+
+            self.assertEqual(channel.code, 200)
+            self.assertEqual(
+                channel.json_body, {"og:title": "Some Title", "og:description": "hi"}
+            )
+
+    def test_oembed_rich(self):
+        """Test an oEmbed endpoint which returns HTML content via the 'rich' type."""
+        # Route the HTTP version to an HTTP endpoint so that the tests work.
+        with patch.dict(
+            "synapse.rest.media.v1.preview_url_resource._oembed_patterns",
+            {
+                re.compile(
+                    r"http://twitter\.com/.+/status/.+"
+                ): "http://publish.twitter.com/oembed",
+            },
+            clear=True,
+        ):
+
+            self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")]
+
+            result = {
+                "version": "1.0",
+                "type": "rich",
+                "html": "<div>Content Preview</div>",
+            }
+            end_content = json.dumps(result).encode("utf-8")
+
+            request, channel = self.make_request(
+                "GET",
+                "url_preview?url=http://twitter.com/matrixdotorg/status/12345",
+                shorthand=False,
+            )
+            request.render(self.preview_url)
+            self.pump()
+
+            client = self.reactor.tcpClients[0][2].buildProtocol(None)
+            server = AccumulatingProtocol()
+            server.makeConnection(FakeTransport(client, self.reactor))
+            client.makeConnection(FakeTransport(server, self.reactor))
+            client.dataReceived(
+                (
+                    b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
+                    b'Content-Type: application/json; charset="utf8"\r\n\r\n'
+                )
+                % (len(end_content),)
+                + end_content
+            )
+
+            self.pump()
+            self.assertEqual(channel.code, 200)
+            self.assertEqual(
+                channel.json_body,
+                {"og:title": None, "og:description": "Content Preview"},
+            )
diff --git a/tests/rest/test_health.py b/tests/rest/test_health.py
new file mode 100644
index 0000000000..2d021f6565
--- /dev/null
+++ b/tests/rest/test_health.py
@@ -0,0 +1,34 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from synapse.rest.health import HealthResource
+
+from tests import unittest
+
+
+class HealthCheckTests(unittest.HomeserverTestCase):
+    def setUp(self):
+        super().setUp()
+
+        # replace the JsonResource with a HealthResource.
+        self.resource = HealthResource()
+
+    def test_health(self):
+        request, channel = self.make_request("GET", "/health", shorthand=False)
+        self.render(request)
+
+        self.assertEqual(request.code, 200)
+        self.assertEqual(channel.result["body"], b"OK")
diff --git a/tests/server.py b/tests/server.py
index 1644710aa0..48e45c6c8b 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -2,8 +2,6 @@ import json
 import logging
 from io import BytesIO
 
-from six import text_type
-
 import attr
 from zope.interface import implementer
 
@@ -37,7 +35,7 @@ class TimedOutException(Exception):
 
 
 @attr.s
-class FakeChannel(object):
+class FakeChannel:
     """
     A fake Twisted Web Channel (the part that interfaces with the
     wire).
@@ -174,7 +172,7 @@ def make_request(
     if not path.startswith(b"/"):
         path = b"/" + path
 
-    if isinstance(content, text_type):
+    if isinstance(content, str):
         content = content.encode("utf8")
 
     site = FakeSite()
@@ -239,11 +237,12 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
     def __init__(self):
         self.threadpool = ThreadPool(self)
 
+        self._tcp_callbacks = {}
         self._udp = []
         lookups = self.lookups = {}
 
         @implementer(IResolverSimple)
-        class FakeResolver(object):
+        class FakeResolver:
             def getHostByName(self, name, timeout=None):
                 if name not in lookups:
                     return fail(DNSLookupError("OH NO: unknown %s" % (name,)))
@@ -270,6 +269,29 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
     def getThreadPool(self):
         return self.threadpool
 
+    def add_tcp_client_callback(self, host, port, callback):
+        """Add a callback that will be invoked when we receive a connection
+        attempt to the given IP/port using `connectTCP`.
+
+        Note that the callback gets run before we return the connection to the
+        client, which means callbacks cannot block while waiting for writes.
+        """
+        self._tcp_callbacks[(host, port)] = callback
+
+    def connectTCP(self, host, port, factory, timeout=30, bindAddress=None):
+        """Fake L{IReactorTCP.connectTCP}.
+        """
+
+        conn = super().connectTCP(
+            host, port, factory, timeout=timeout, bindAddress=None
+        )
+
+        callback = self._tcp_callbacks.get((host, port))
+        if callback:
+            callback()
+
+        return conn
+
 
 class ThreadPool:
     """
@@ -349,7 +371,7 @@ def get_clock():
 
 
 @attr.s(cmp=False)
-class FakeTransport(object):
+class FakeTransport:
     """
     A twisted.internet.interfaces.ITransport implementation which sends all its data
     straight into an IProtocol object: it exists to connect two IProtocols together.
@@ -488,7 +510,7 @@ class FakeTransport(object):
         try:
             self.other.dataReceived(to_write)
         except Exception as e:
-            logger.warning("Exception writing to protocol: %s", e)
+            logger.exception("Exception writing to protocol: %s", e)
             return
 
         self.buffer = self.buffer[len(to_write) :]
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index 99908edba3..973338ea71 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -27,6 +27,7 @@ from synapse.server_notices.resource_limits_server_notices import (
 )
 
 from tests import unittest
+from tests.test_utils import make_awaitable
 from tests.unittest import override_config
 from tests.utils import default_config
 
@@ -66,7 +67,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
             raise Exception("Failed to find reference to ResourceLimitsServerNotices")
 
         self._rlsn._store.user_last_seen_monthly_active = Mock(
-            return_value=defer.succeed(1000)
+            side_effect=lambda user_id: make_awaitable(1000)
         )
         self._rlsn._server_notices_manager.send_notice = Mock(
             return_value=defer.succeed(Mock())
@@ -79,7 +80,9 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
             return_value=defer.succeed("!something:localhost")
         )
         self._rlsn._store.add_tag_to_room = Mock(return_value=defer.succeed(None))
-        self._rlsn._store.get_tags_for_room = Mock(return_value=defer.succeed({}))
+        self._rlsn._store.get_tags_for_room = Mock(
+            side_effect=lambda user_id, room_id: make_awaitable({})
+        )
 
     @override_config({"hs_disabled": True})
     def test_maybe_send_server_notice_disabled_hs(self):
@@ -101,7 +104,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
             type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
         )
         self._rlsn._store.get_events = Mock(
-            return_value=defer.succeed({"123": mock_event})
+            return_value=make_awaitable({"123": mock_event})
         )
         self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
         # Would be better to check the content, but once == remove blocking event
@@ -119,7 +122,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
             type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
         )
         self._rlsn._store.get_events = Mock(
-            return_value=defer.succeed({"123": mock_event})
+            return_value=make_awaitable({"123": mock_event})
         )
 
         self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@@ -155,7 +158,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         """
         self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None))
         self._rlsn._store.user_last_seen_monthly_active = Mock(
-            return_value=defer.succeed(None)
+            side_effect=lambda user_id: make_awaitable(None)
         )
         self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
 
@@ -214,7 +217,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
             type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
         )
         self._rlsn._store.get_events = Mock(
-            return_value=defer.succeed({"123": mock_event})
+            return_value=make_awaitable({"123": mock_event})
         )
         self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
 
@@ -258,10 +261,12 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
         self.user_id = "@user_id:test"
 
     def test_server_notice_only_sent_once(self):
-        self.store.get_monthly_active_count = Mock(return_value=1000)
+        self.store.get_monthly_active_count = Mock(
+            side_effect=lambda: make_awaitable(1000)
+        )
 
         self.store.user_last_seen_monthly_active = Mock(
-            return_value=defer.succeed(1000)
+            side_effect=lambda user_id: make_awaitable(1000)
         )
 
         # Call the function multiple times to ensure we only send the notice once
@@ -275,7 +280,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
             self.server_notices_manager.get_or_create_notice_room_for_user(self.user_id)
         )
 
-        token = self.get_success(self.event_source.get_current_token())
+        token = self.event_source.get_current_token()
         events, _ = self.get_success(
             self.store.get_recent_events_for_room(
                 room_id, limit=100, end_token=token.room_key
diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py
index a44960203e..ad9bbef9d2 100644
--- a/tests/state/test_v2.py
+++ b/tests/state/test_v2.py
@@ -14,11 +14,12 @@
 # limitations under the License.
 
 import itertools
-
-from six.moves import zip
+from typing import List
 
 import attr
 
+from twisted.internet import defer
+
 from synapse.api.constants import EventTypes, JoinRules, Membership
 from synapse.api.room_versions import RoomVersions
 from synapse.event_auth import auth_types_for_event
@@ -43,7 +44,12 @@ MEMBERSHIP_CONTENT_BAN = {"membership": Membership.BAN}
 ORIGIN_SERVER_TS = 0
 
 
-class FakeEvent(object):
+class FakeClock:
+    def sleep(self, msec):
+        return defer.succeed(None)
+
+
+class FakeEvent:
     """A fake event we use as a convenience.
 
     NOTE: Again as a convenience we use "node_ids" rather than event_ids to
@@ -419,6 +425,7 @@ class StateTestCase(unittest.TestCase):
                 state_before = dict(state_at_event[prev_events[0]])
             else:
                 state_d = resolve_events_with_store(
+                    FakeClock(),
                     ROOM_ID,
                     RoomVersions.V2.identifier,
                     [state_at_event[n] for n in prev_events],
@@ -426,7 +433,7 @@ class StateTestCase(unittest.TestCase):
                     state_res_store=TestStateResolutionStore(event_map),
                 )
 
-                state_before = self.successResultOf(state_d)
+                state_before = self.successResultOf(defer.ensureDeferred(state_d))
 
             state_after = dict(state_before)
             if fake_event.state_key is not None:
@@ -567,6 +574,7 @@ class SimpleParamStateTestCase(unittest.TestCase):
         # Test that we correctly handle passing `None` as the event_map
 
         state_d = resolve_events_with_store(
+            FakeClock(),
             ROOM_ID,
             RoomVersions.V2.identifier,
             [self.state_at_bob, self.state_at_charlie],
@@ -574,7 +582,7 @@ class SimpleParamStateTestCase(unittest.TestCase):
             state_res_store=TestStateResolutionStore(self.event_map),
         )
 
-        state = self.successResultOf(state_d)
+        state = self.successResultOf(defer.ensureDeferred(state_d))
 
         self.assert_dict(self.expected_combined_state, state)
 
@@ -587,7 +595,7 @@ def pairwise(iterable):
 
 
 @attr.s
-class TestStateResolutionStore(object):
+class TestStateResolutionStore:
     event_map = attr.ib()
 
     def get_events(self, event_ids, allow_rejected=False):
@@ -601,9 +609,11 @@ class TestStateResolutionStore(object):
             Deferred[dict[str, FrozenEvent]]: Dict from event_id to event.
         """
 
-        return {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map}
+        return defer.succeed(
+            {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map}
+        )
 
-    def _get_auth_chain(self, event_ids):
+    def _get_auth_chain(self, event_ids: List[str]) -> List[str]:
         """Gets the full auth chain for a set of events (including rejected
         events).
 
@@ -615,10 +625,10 @@ class TestStateResolutionStore(object):
                presence of rejected events
 
         Args:
-            event_ids (list): The event IDs of the events to fetch the auth
+            event_ids: The event IDs of the events to fetch the auth
                 chain for. Must be state events.
         Returns:
-            Deferred[list[str]]: List of event IDs of the auth chain.
+            List of event IDs of the auth chain.
         """
 
         # Simple DFS for auth chain
@@ -641,4 +651,4 @@ class TestStateResolutionStore(object):
         chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets]
 
         common = set(chains[0]).intersection(*chains[1:])
-        return set(chains[0]).union(*chains[1:]) - common
+        return defer.succeed(set(chains[0]).union(*chains[1:]) - common)
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 5a50e4fdd4..f5afed017c 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -99,7 +99,7 @@ class CacheTestCase(unittest.HomeserverTestCase):
 class CacheDecoratorTestCase(unittest.HomeserverTestCase):
     @defer.inlineCallbacks
     def test_passthrough(self):
-        class A(object):
+        class A:
             @cached()
             def func(self, key):
                 return key
@@ -113,7 +113,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
     def test_hit(self):
         callcount = [0]
 
-        class A(object):
+        class A:
             @cached()
             def func(self, key):
                 callcount[0] += 1
@@ -131,7 +131,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
     def test_invalidate(self):
         callcount = [0]
 
-        class A(object):
+        class A:
             @cached()
             def func(self, key):
                 callcount[0] += 1
@@ -149,7 +149,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
         self.assertEquals(callcount[0], 2)
 
     def test_invalidate_missing(self):
-        class A(object):
+        class A:
             @cached()
             def func(self, key):
                 return key
@@ -160,7 +160,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
     def test_max_entries(self):
         callcount = [0]
 
-        class A(object):
+        class A:
             @cached(max_entries=10)
             def func(self, key):
                 callcount[0] += 1
@@ -187,7 +187,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
 
         d = defer.succeed(123)
 
-        class A(object):
+        class A:
             @cached()
             def func(self, key):
                 callcount[0] += 1
@@ -205,7 +205,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
         callcount = [0]
         callcount2 = [0]
 
-        class A(object):
+        class A:
             @cached()
             def func(self, key):
                 callcount[0] += 1
@@ -238,7 +238,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
         callcount = [0]
         callcount2 = [0]
 
-        class A(object):
+        class A:
             @cached(max_entries=2)
             def func(self, key):
                 callcount[0] += 1
@@ -275,7 +275,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
         callcount = [0]
         callcount2 = [0]
 
-        class A(object):
+        class A:
             @cached()
             def func(self, key):
                 callcount[0] += 1
@@ -323,7 +323,7 @@ class UpsertManyTests(unittest.HomeserverTestCase):
 
         self.table_name = "table_" + hs.get_secrets().token_hex(6)
         self.get_success(
-            self.storage.db.runInteraction(
+            self.storage.db_pool.runInteraction(
                 "create",
                 lambda x, *a: x.execute(*a),
                 "CREATE TABLE %s (id INTEGER, username TEXT, value TEXT)"
@@ -331,7 +331,7 @@ class UpsertManyTests(unittest.HomeserverTestCase):
             )
         )
         self.get_success(
-            self.storage.db.runInteraction(
+            self.storage.db_pool.runInteraction(
                 "index",
                 lambda x, *a: x.execute(*a),
                 "CREATE UNIQUE INDEX %sindex ON %s(id, username)"
@@ -354,9 +354,9 @@ class UpsertManyTests(unittest.HomeserverTestCase):
         value_values = [["hello"], ["there"]]
 
         self.get_success(
-            self.storage.db.runInteraction(
+            self.storage.db_pool.runInteraction(
                 "test",
-                self.storage.db.simple_upsert_many_txn,
+                self.storage.db_pool.simple_upsert_many_txn,
                 self.table_name,
                 key_names,
                 key_values,
@@ -367,7 +367,7 @@ class UpsertManyTests(unittest.HomeserverTestCase):
 
         # Check results are what we expect
         res = self.get_success(
-            self.storage.db.simple_select_list(
+            self.storage.db_pool.simple_select_list(
                 self.table_name, None, ["id, username, value"]
             )
         )
@@ -381,9 +381,9 @@ class UpsertManyTests(unittest.HomeserverTestCase):
         value_values = [["bleb"]]
 
         self.get_success(
-            self.storage.db.runInteraction(
+            self.storage.db_pool.runInteraction(
                 "test",
-                self.storage.db.simple_upsert_many_txn,
+                self.storage.db_pool.simple_upsert_many_txn,
                 self.table_name,
                 key_names,
                 key_values,
@@ -394,7 +394,7 @@ class UpsertManyTests(unittest.HomeserverTestCase):
 
         # Check results are what we expect
         res = self.get_success(
-            self.storage.db.simple_select_list(
+            self.storage.db_pool.simple_select_list(
                 self.table_name, None, ["id, username, value"]
             )
         )
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index ef296e7dab..cb808d4de4 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -24,13 +24,14 @@ from twisted.internet import defer
 
 from synapse.appservice import ApplicationService, ApplicationServiceState
 from synapse.config._base import ConfigError
-from synapse.storage.data_stores.main.appservice import (
+from synapse.storage.database import DatabasePool, make_conn
+from synapse.storage.databases.main.appservice import (
     ApplicationServiceStore,
     ApplicationServiceTransactionStore,
 )
-from synapse.storage.database import Database, make_conn
 
 from tests import unittest
+from tests.test_utils import make_awaitable
 from tests.utils import setup_test_homeserver
 
 
@@ -178,14 +179,14 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_get_appservice_state_none(self):
         service = Mock(id="999")
-        state = yield self.store.get_appservice_state(service)
+        state = yield defer.ensureDeferred(self.store.get_appservice_state(service))
         self.assertEquals(None, state)
 
     @defer.inlineCallbacks
     def test_get_appservice_state_up(self):
         yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.UP)
         service = Mock(id=self.as_list[0]["id"])
-        state = yield self.store.get_appservice_state(service)
+        state = yield defer.ensureDeferred(self.store.get_appservice_state(service))
         self.assertEquals(ApplicationServiceState.UP, state)
 
     @defer.inlineCallbacks
@@ -194,20 +195,22 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.DOWN)
         yield self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN)
         service = Mock(id=self.as_list[1]["id"])
-        state = yield self.store.get_appservice_state(service)
+        state = yield defer.ensureDeferred(self.store.get_appservice_state(service))
         self.assertEquals(ApplicationServiceState.DOWN, state)
 
     @defer.inlineCallbacks
     def test_get_appservices_by_state_none(self):
-        services = yield self.store.get_appservices_by_state(
-            ApplicationServiceState.DOWN
+        services = yield defer.ensureDeferred(
+            self.store.get_appservices_by_state(ApplicationServiceState.DOWN)
         )
         self.assertEquals(0, len(services))
 
     @defer.inlineCallbacks
     def test_set_appservices_state_down(self):
         service = Mock(id=self.as_list[1]["id"])
-        yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
+        yield defer.ensureDeferred(
+            self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
+        )
         rows = yield self.db_pool.runQuery(
             self.engine.convert_param_style(
                 "SELECT as_id FROM application_services_state WHERE state=?"
@@ -219,9 +222,15 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_set_appservices_state_multiple_up(self):
         service = Mock(id=self.as_list[1]["id"])
-        yield self.store.set_appservice_state(service, ApplicationServiceState.UP)
-        yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
-        yield self.store.set_appservice_state(service, ApplicationServiceState.UP)
+        yield defer.ensureDeferred(
+            self.store.set_appservice_state(service, ApplicationServiceState.UP)
+        )
+        yield defer.ensureDeferred(
+            self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
+        )
+        yield defer.ensureDeferred(
+            self.store.set_appservice_state(service, ApplicationServiceState.UP)
+        )
         rows = yield self.db_pool.runQuery(
             self.engine.convert_param_style(
                 "SELECT as_id FROM application_services_state WHERE state=?"
@@ -234,7 +243,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
     def test_create_appservice_txn_first(self):
         service = Mock(id=self.as_list[0]["id"])
         events = [Mock(event_id="e1"), Mock(event_id="e2")]
-        txn = yield self.store.create_appservice_txn(service, events)
+        txn = yield defer.ensureDeferred(
+            self.store.create_appservice_txn(service, events)
+        )
         self.assertEquals(txn.id, 1)
         self.assertEquals(txn.events, events)
         self.assertEquals(txn.service, service)
@@ -246,7 +257,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         yield self._set_last_txn(service.id, 9643)  # AS is falling behind
         yield self._insert_txn(service.id, 9644, events)
         yield self._insert_txn(service.id, 9645, events)
-        txn = yield self.store.create_appservice_txn(service, events)
+        txn = yield defer.ensureDeferred(
+            self.store.create_appservice_txn(service, events)
+        )
         self.assertEquals(txn.id, 9646)
         self.assertEquals(txn.events, events)
         self.assertEquals(txn.service, service)
@@ -256,7 +269,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         service = Mock(id=self.as_list[0]["id"])
         events = [Mock(event_id="e1"), Mock(event_id="e2")]
         yield self._set_last_txn(service.id, 9643)
-        txn = yield self.store.create_appservice_txn(service, events)
+        txn = yield defer.ensureDeferred(
+            self.store.create_appservice_txn(service, events)
+        )
         self.assertEquals(txn.id, 9644)
         self.assertEquals(txn.events, events)
         self.assertEquals(txn.service, service)
@@ -277,7 +292,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         yield self._insert_txn(self.as_list[2]["id"], 10, events)
         yield self._insert_txn(self.as_list[3]["id"], 9643, events)
 
-        txn = yield self.store.create_appservice_txn(service, events)
+        txn = yield defer.ensureDeferred(
+            self.store.create_appservice_txn(service, events)
+        )
         self.assertEquals(txn.id, 9644)
         self.assertEquals(txn.events, events)
         self.assertEquals(txn.service, service)
@@ -289,7 +306,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         txn_id = 1
 
         yield self._insert_txn(service.id, txn_id, events)
-        yield self.store.complete_appservice_txn(txn_id=txn_id, service=service)
+        yield defer.ensureDeferred(
+            self.store.complete_appservice_txn(txn_id=txn_id, service=service)
+        )
 
         res = yield self.db_pool.runQuery(
             self.engine.convert_param_style(
@@ -315,7 +334,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         txn_id = 5
         yield self._set_last_txn(service.id, 4)
         yield self._insert_txn(service.id, txn_id, events)
-        yield self.store.complete_appservice_txn(txn_id=txn_id, service=service)
+        yield defer.ensureDeferred(
+            self.store.complete_appservice_txn(txn_id=txn_id, service=service)
+        )
 
         res = yield self.db_pool.runQuery(
             self.engine.convert_param_style(
@@ -339,7 +360,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
     def test_get_oldest_unsent_txn_none(self):
         service = Mock(id=self.as_list[0]["id"])
 
-        txn = yield self.store.get_oldest_unsent_txn(service)
+        txn = yield defer.ensureDeferred(self.store.get_oldest_unsent_txn(service))
         self.assertEquals(None, txn)
 
     @defer.inlineCallbacks
@@ -349,14 +370,14 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         other_events = [Mock(event_id="e5"), Mock(event_id="e6")]
 
         # we aren't testing store._base stuff here, so mock this out
-        self.store.get_events_as_list = Mock(return_value=events)
+        self.store.get_events_as_list = Mock(return_value=make_awaitable(events))
 
         yield self._insert_txn(self.as_list[1]["id"], 9, other_events)
         yield self._insert_txn(service.id, 10, events)
         yield self._insert_txn(service.id, 11, other_events)
         yield self._insert_txn(service.id, 12, other_events)
 
-        txn = yield self.store.get_oldest_unsent_txn(service)
+        txn = yield defer.ensureDeferred(self.store.get_oldest_unsent_txn(service))
         self.assertEquals(service, txn.service)
         self.assertEquals(10, txn.id)
         self.assertEquals(events, txn.events)
@@ -366,8 +387,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.DOWN)
         yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.UP)
 
-        services = yield self.store.get_appservices_by_state(
-            ApplicationServiceState.DOWN
+        services = yield defer.ensureDeferred(
+            self.store.get_appservices_by_state(ApplicationServiceState.DOWN)
         )
         self.assertEquals(1, len(services))
         self.assertEquals(self.as_list[0]["id"], services[0].id)
@@ -379,8 +400,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
         yield self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN)
         yield self._set_state(self.as_list[3]["id"], ApplicationServiceState.UP)
 
-        services = yield self.store.get_appservices_by_state(
-            ApplicationServiceState.DOWN
+        services = yield defer.ensureDeferred(
+            self.store.get_appservices_by_state(ApplicationServiceState.DOWN)
         )
         self.assertEquals(2, len(services))
         self.assertEquals(
@@ -391,7 +412,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
 
 # required for ApplicationServiceTransactionStoreTestCase tests
 class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore):
-    def __init__(self, database: Database, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs):
         super(TestTransactionStore, self).__init__(database, db_conn, hs)
 
 
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index 940b166129..02aae1c13d 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -1,7 +1,5 @@
 from mock import Mock
 
-from twisted.internet import defer
-
 from synapse.storage.background_updates import BackgroundUpdater
 
 from tests import unittest
@@ -9,7 +7,9 @@ from tests import unittest
 
 class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
     def prepare(self, reactor, clock, homeserver):
-        self.updates = self.hs.get_datastore().db.updates  # type: BackgroundUpdater
+        self.updates = (
+            self.hs.get_datastore().db_pool.updates
+        )  # type: BackgroundUpdater
         # the base test class should have run the real bg updates for us
         self.assertTrue(
             self.get_success(self.updates.has_completed_background_updates())
@@ -29,18 +29,17 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
 
         store = self.hs.get_datastore()
         self.get_success(
-            store.db.simple_insert(
+            store.db_pool.simple_insert(
                 "background_updates",
                 values={"update_name": "test_update", "progress_json": '{"my_key": 1}'},
             )
         )
 
         # first step: make a bit of progress
-        @defer.inlineCallbacks
-        def update(progress, count):
-            yield self.clock.sleep((count * duration_ms) / 1000)
+        async def update(progress, count):
+            await self.clock.sleep((count * duration_ms) / 1000)
             progress = {"my_key": progress["my_key"] + 1}
-            yield store.db.runInteraction(
+            await store.db_pool.runInteraction(
                 "update_progress",
                 self.updates._background_update_progress_txn,
                 "test_update",
@@ -65,13 +64,12 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
 
         # second step: complete the update
         # we should now get run with a much bigger number of items to update
-        @defer.inlineCallbacks
-        def update(progress, count):
+        async def update(progress, count):
             self.assertEqual(progress, {"my_key": 2})
             self.assertAlmostEqual(
                 count, target_background_update_duration_ms / duration_ms, places=0,
             )
-            yield self.updates._end_background_update("test_update")
+            await self.updates._end_background_update("test_update")
             return count
 
         self.update_handler.side_effect = update
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index 278961c331..40ba652248 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -21,11 +21,11 @@ from mock import Mock
 from twisted.internet import defer
 
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
 from synapse.storage.engines import create_engine
 
 from tests import unittest
-from tests.utils import TestHomeServer
+from tests.utils import TestHomeServer, default_config
 
 
 class SQLBaseStoreTestCase(unittest.TestCase):
@@ -49,10 +49,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
 
         self.db_pool.runWithConnection = runWithConnection
 
-        config = Mock()
-        config._disable_native_upserts = True
-        config.caches = Mock()
-        config.caches.event_cache_size = 1
+        config = default_config(name="test", parse=True)
         hs = TestHomeServer("test", config=config)
 
         sqlite_config = {"name": "sqlite3"}
@@ -60,7 +57,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
         fake_engine = Mock(wraps=engine)
         fake_engine.can_native_upsert = False
 
-        db = Database(Mock(), Mock(config=sqlite_config), fake_engine)
+        db = DatabasePool(Mock(), Mock(config=sqlite_config), fake_engine)
         db._db_pool = self.db_pool
 
         self.datastore = SQLBaseStore(db, None, hs)
@@ -69,8 +66,10 @@ class SQLBaseStoreTestCase(unittest.TestCase):
     def test_insert_1col(self):
         self.mock_txn.rowcount = 1
 
-        yield self.datastore.db.simple_insert(
-            table="tablename", values={"columname": "Value"}
+        yield defer.ensureDeferred(
+            self.datastore.db_pool.simple_insert(
+                table="tablename", values={"columname": "Value"}
+            )
         )
 
         self.mock_txn.execute.assert_called_with(
@@ -81,10 +80,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
     def test_insert_3cols(self):
         self.mock_txn.rowcount = 1
 
-        yield self.datastore.db.simple_insert(
-            table="tablename",
-            # Use OrderedDict() so we can assert on the SQL generated
-            values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]),
+        yield defer.ensureDeferred(
+            self.datastore.db_pool.simple_insert(
+                table="tablename",
+                # Use OrderedDict() so we can assert on the SQL generated
+                values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]),
+            )
         )
 
         self.mock_txn.execute.assert_called_with(
@@ -96,8 +97,10 @@ class SQLBaseStoreTestCase(unittest.TestCase):
         self.mock_txn.rowcount = 1
         self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)]))
 
-        value = yield self.datastore.db.simple_select_one_onecol(
-            table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol"
+        value = yield defer.ensureDeferred(
+            self.datastore.db_pool.simple_select_one_onecol(
+                table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol"
+            )
         )
 
         self.assertEquals("Value", value)
@@ -110,10 +113,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
         self.mock_txn.rowcount = 1
         self.mock_txn.fetchone.return_value = (1, 2, 3)
 
-        ret = yield self.datastore.db.simple_select_one(
-            table="tablename",
-            keyvalues={"keycol": "TheKey"},
-            retcols=["colA", "colB", "colC"],
+        ret = yield defer.ensureDeferred(
+            self.datastore.db_pool.simple_select_one(
+                table="tablename",
+                keyvalues={"keycol": "TheKey"},
+                retcols=["colA", "colB", "colC"],
+            )
         )
 
         self.assertEquals({"colA": 1, "colB": 2, "colC": 3}, ret)
@@ -126,11 +131,13 @@ class SQLBaseStoreTestCase(unittest.TestCase):
         self.mock_txn.rowcount = 0
         self.mock_txn.fetchone.return_value = None
 
-        ret = yield self.datastore.db.simple_select_one(
-            table="tablename",
-            keyvalues={"keycol": "Not here"},
-            retcols=["colA"],
-            allow_none=True,
+        ret = yield defer.ensureDeferred(
+            self.datastore.db_pool.simple_select_one(
+                table="tablename",
+                keyvalues={"keycol": "Not here"},
+                retcols=["colA"],
+                allow_none=True,
+            )
         )
 
         self.assertFalse(ret)
@@ -141,8 +148,10 @@ class SQLBaseStoreTestCase(unittest.TestCase):
         self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)]))
         self.mock_txn.description = (("colA", None, None, None, None, None, None),)
 
-        ret = yield self.datastore.db.simple_select_list(
-            table="tablename", keyvalues={"keycol": "A set"}, retcols=["colA"]
+        ret = yield defer.ensureDeferred(
+            self.datastore.db_pool.simple_select_list(
+                table="tablename", keyvalues={"keycol": "A set"}, retcols=["colA"]
+            )
         )
 
         self.assertEquals([{"colA": 1}, {"colA": 2}, {"colA": 3}], ret)
@@ -154,10 +163,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
     def test_update_one_1col(self):
         self.mock_txn.rowcount = 1
 
-        yield self.datastore.db.simple_update_one(
-            table="tablename",
-            keyvalues={"keycol": "TheKey"},
-            updatevalues={"columnname": "New Value"},
+        yield defer.ensureDeferred(
+            self.datastore.db_pool.simple_update_one(
+                table="tablename",
+                keyvalues={"keycol": "TheKey"},
+                updatevalues={"columnname": "New Value"},
+            )
         )
 
         self.mock_txn.execute.assert_called_with(
@@ -169,10 +180,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
     def test_update_one_4cols(self):
         self.mock_txn.rowcount = 1
 
-        yield self.datastore.db.simple_update_one(
-            table="tablename",
-            keyvalues=OrderedDict([("colA", 1), ("colB", 2)]),
-            updatevalues=OrderedDict([("colC", 3), ("colD", 4)]),
+        yield defer.ensureDeferred(
+            self.datastore.db_pool.simple_update_one(
+                table="tablename",
+                keyvalues=OrderedDict([("colA", 1), ("colB", 2)]),
+                updatevalues=OrderedDict([("colC", 3), ("colD", 4)]),
+            )
         )
 
         self.mock_txn.execute.assert_called_with(
@@ -184,8 +197,10 @@ class SQLBaseStoreTestCase(unittest.TestCase):
     def test_delete_one(self):
         self.mock_txn.rowcount = 1
 
-        yield self.datastore.db.simple_delete_one(
-            table="tablename", keyvalues={"keycol": "Go away"}
+        yield defer.ensureDeferred(
+            self.datastore.db_pool.simple_delete_one(
+                table="tablename", keyvalues={"keycol": "Go away"}
+            )
         )
 
         self.mock_txn.execute.assert_called_with(
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index 43425c969a..080761d1d2 100644
--- a/tests/storage/test_cleanup_extrems.py
+++ b/tests/storage/test_cleanup_extrems.py
@@ -38,7 +38,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
 
         # Create a test user and room
         self.user = UserID("alice", "test")
-        self.requester = Requester(self.user, None, False, None, None)
+        self.requester = Requester(self.user, None, False, False, None, None)
         info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
         self.room_id = info["room_id"]
 
@@ -47,12 +47,12 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
         """
         # Make sure we don't clash with in progress updates.
         self.assertTrue(
-            self.store.db.updates._all_done, "Background updates are still ongoing"
+            self.store.db_pool.updates._all_done, "Background updates are still ongoing"
         )
 
         schema_path = os.path.join(
             prepare_database.dir_path,
-            "data_stores",
+            "databases",
             "main",
             "schema",
             "delta",
@@ -64,19 +64,19 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
             prepare_database.executescript(txn, schema_path)
 
         self.get_success(
-            self.store.db.runInteraction(
+            self.store.db_pool.runInteraction(
                 "test_delete_forward_extremities", run_delta_file
             )
         )
 
         # Ugh, have to reset this flag
-        self.store.db.updates._all_done = False
+        self.store.db_pool.updates._all_done = False
 
         while not self.get_success(
-            self.store.db.updates.has_completed_background_updates()
+            self.store.db_pool.updates.has_completed_background_updates()
         ):
             self.get_success(
-                self.store.db.updates.do_next_background_update(100), by=0.1
+                self.store.db_pool.updates.do_next_background_update(100), by=0.1
             )
 
     def test_soft_failed_extremities_handled_correctly(self):
@@ -260,7 +260,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
         # Create a test user and room
         self.user = UserID.from_string(self.register_user("user1", "password"))
         self.token1 = self.login("user1", "password")
-        self.requester = Requester(self.user, None, False, None, None)
+        self.requester = Requester(self.user, None, False, False, None, None)
         info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
         self.room_id = info["room_id"]
         self.event_creator = homeserver.get_event_creation_handler()
@@ -271,7 +271,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
 
         # Pump the reactor repeatedly so that the background updates have a
         # chance to run.
-        self.pump(10 * 60)
+        self.pump(20)
 
         latest_event_ids = self.get_success(
             self.store.get_latest_event_ids_in_room(self.room_id)
@@ -353,6 +353,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
         self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion[
             "3"
         ] = 300000
+
         self.event_creator_handler._expire_rooms_to_exclude_from_dummy_event_insertion()
         # All entries within time frame
         self.assertEqual(
@@ -362,7 +363,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
             3,
         )
         # Oldest room to expire
-        self.pump(1)
+        self.pump(1.01)
         self.event_creator_handler._expire_rooms_to_exclude_from_dummy_event_insertion()
         self.assertEqual(
             len(
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 3b483bc7f0..370c247e16 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -16,13 +16,12 @@
 
 from mock import Mock
 
-from twisted.internet import defer
-
 import synapse.rest.admin
 from synapse.http.site import XForwardedForRequest
 from synapse.rest.client.v1 import login
 
 from tests import unittest
+from tests.test_utils import make_awaitable
 from tests.unittest import override_config
 
 
@@ -86,7 +85,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         self.pump(0)
 
         result = self.get_success(
-            self.store.db.simple_select_list(
+            self.store.db_pool.simple_select_list(
                 table="user_ips",
                 keyvalues={"user_id": user_id},
                 retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
@@ -117,7 +116,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         self.pump(0)
 
         result = self.get_success(
-            self.store.db.simple_select_list(
+            self.store.db_pool.simple_select_list(
                 table="user_ips",
                 keyvalues={"user_id": user_id},
                 retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
@@ -155,7 +154,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         user_id = "@user:server"
 
         self.store.get_monthly_active_count = Mock(
-            return_value=defer.succeed(lots_of_users)
+            side_effect=lambda: make_awaitable(lots_of_users)
         )
         self.get_success(
             self.store.insert_client_ip(
@@ -204,10 +203,10 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
     def test_devices_last_seen_bg_update(self):
         # First make sure we have completed all updates.
         while not self.get_success(
-            self.store.db.updates.has_completed_background_updates()
+            self.store.db_pool.updates.has_completed_background_updates()
         ):
             self.get_success(
-                self.store.db.updates.do_next_background_update(100), by=0.1
+                self.store.db_pool.updates.do_next_background_update(100), by=0.1
             )
 
         user_id = "@user:id"
@@ -225,7 +224,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
 
         # But clear the associated entry in devices table
         self.get_success(
-            self.store.db.simple_update(
+            self.store.db_pool.simple_update(
                 table="devices",
                 keyvalues={"user_id": user_id, "device_id": device_id},
                 updatevalues={"last_seen": None, "ip": None, "user_agent": None},
@@ -252,7 +251,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
 
         # Register the background update to run again.
         self.get_success(
-            self.store.db.simple_insert(
+            self.store.db_pool.simple_insert(
                 table="background_updates",
                 values={
                     "update_name": "devices_last_seen",
@@ -263,14 +262,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         )
 
         # ... and tell the DataStore that it hasn't finished all updates yet
-        self.store.db.updates._all_done = False
+        self.store.db_pool.updates._all_done = False
 
         # Now let's actually drive the updates to completion
         while not self.get_success(
-            self.store.db.updates.has_completed_background_updates()
+            self.store.db_pool.updates.has_completed_background_updates()
         ):
             self.get_success(
-                self.store.db.updates.do_next_background_update(100), by=0.1
+                self.store.db_pool.updates.do_next_background_update(100), by=0.1
             )
 
         # We should now get the correct result again
@@ -293,10 +292,10 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
     def test_old_user_ips_pruned(self):
         # First make sure we have completed all updates.
         while not self.get_success(
-            self.store.db.updates.has_completed_background_updates()
+            self.store.db_pool.updates.has_completed_background_updates()
         ):
             self.get_success(
-                self.store.db.updates.do_next_background_update(100), by=0.1
+                self.store.db_pool.updates.do_next_background_update(100), by=0.1
             )
 
         user_id = "@user:id"
@@ -315,7 +314,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
 
         # We should see that in the DB
         result = self.get_success(
-            self.store.db.simple_select_list(
+            self.store.db_pool.simple_select_list(
                 table="user_ips",
                 keyvalues={"user_id": user_id},
                 retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
@@ -341,7 +340,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
 
         # We should get no results.
         result = self.get_success(
-            self.store.db.simple_select_list(
+            self.store.db_pool.simple_select_list(
                 table="user_ips",
                 keyvalues={"user_id": user_id},
                 retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index c2539b353a..34ae8c9da7 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -34,9 +34,11 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_store_new_device(self):
-        yield self.store.store_device("user_id", "device_id", "display_name")
+        yield defer.ensureDeferred(
+            self.store.store_device("user_id", "device_id", "display_name")
+        )
 
-        res = yield self.store.get_device("user_id", "device_id")
+        res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
         self.assertDictContainsSubset(
             {
                 "user_id": "user_id",
@@ -48,11 +50,17 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_get_devices_by_user(self):
-        yield self.store.store_device("user_id", "device1", "display_name 1")
-        yield self.store.store_device("user_id", "device2", "display_name 2")
-        yield self.store.store_device("user_id2", "device3", "display_name 3")
+        yield defer.ensureDeferred(
+            self.store.store_device("user_id", "device1", "display_name 1")
+        )
+        yield defer.ensureDeferred(
+            self.store.store_device("user_id", "device2", "display_name 2")
+        )
+        yield defer.ensureDeferred(
+            self.store.store_device("user_id2", "device3", "display_name 3")
+        )
 
-        res = yield self.store.get_devices_by_user("user_id")
+        res = yield defer.ensureDeferred(self.store.get_devices_by_user("user_id"))
         self.assertEqual(2, len(res.keys()))
         self.assertDictContainsSubset(
             {
@@ -76,13 +84,13 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
         device_ids = ["device_id1", "device_id2"]
 
         # Add two device updates with a single stream_id
-        yield self.store.add_device_change_to_streams(
-            "user_id", device_ids, ["somehost"]
+        yield defer.ensureDeferred(
+            self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
         )
 
         # Get all device updates ever meant for this remote
-        now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
-            "somehost", -1, limit=100
+        now_stream_id, device_updates = yield defer.ensureDeferred(
+            self.store.get_device_updates_by_remote("somehost", -1, limit=100)
         )
 
         # Check original device_ids are contained within these updates
@@ -99,29 +107,35 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_update_device(self):
-        yield self.store.store_device("user_id", "device_id", "display_name 1")
+        yield defer.ensureDeferred(
+            self.store.store_device("user_id", "device_id", "display_name 1")
+        )
 
-        res = yield self.store.get_device("user_id", "device_id")
+        res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
         self.assertEqual("display_name 1", res["display_name"])
 
         # do a no-op first
-        yield self.store.update_device("user_id", "device_id")
-        res = yield self.store.get_device("user_id", "device_id")
+        yield defer.ensureDeferred(self.store.update_device("user_id", "device_id"))
+        res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
         self.assertEqual("display_name 1", res["display_name"])
 
         # do the update
-        yield self.store.update_device(
-            "user_id", "device_id", new_display_name="display_name 2"
+        yield defer.ensureDeferred(
+            self.store.update_device(
+                "user_id", "device_id", new_display_name="display_name 2"
+            )
         )
 
         # check it worked
-        res = yield self.store.get_device("user_id", "device_id")
+        res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
         self.assertEqual("display_name 2", res["display_name"])
 
     @defer.inlineCallbacks
     def test_update_unknown_device(self):
         with self.assertRaises(synapse.api.errors.StoreError) as cm:
-            yield self.store.update_device(
-                "user_id", "unknown_device_id", new_display_name="display_name 2"
+            yield defer.ensureDeferred(
+                self.store.update_device(
+                    "user_id", "unknown_device_id", new_display_name="display_name 2"
+                )
             )
         self.assertEqual(404, cm.exception.code)
diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py
index 4e128e1047..da93ca3980 100644
--- a/tests/storage/test_directory.py
+++ b/tests/storage/test_directory.py
@@ -34,35 +34,53 @@ class DirectoryStoreTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_room_to_alias(self):
-        yield self.store.create_room_alias_association(
-            room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
+        yield defer.ensureDeferred(
+            self.store.create_room_alias_association(
+                room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
+            )
         )
 
         self.assertEquals(
             ["#my-room:test"],
-            (yield self.store.get_aliases_for_room(self.room.to_string())),
+            (
+                yield defer.ensureDeferred(
+                    self.store.get_aliases_for_room(self.room.to_string())
+                )
+            ),
         )
 
     @defer.inlineCallbacks
     def test_alias_to_room(self):
-        yield self.store.create_room_alias_association(
-            room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
+        yield defer.ensureDeferred(
+            self.store.create_room_alias_association(
+                room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
+            )
         )
 
         self.assertObjectHasAttributes(
             {"room_id": self.room.to_string(), "servers": ["test"]},
-            (yield self.store.get_association_from_room_alias(self.alias)),
+            (
+                yield defer.ensureDeferred(
+                    self.store.get_association_from_room_alias(self.alias)
+                )
+            ),
         )
 
     @defer.inlineCallbacks
     def test_delete_alias(self):
-        yield self.store.create_room_alias_association(
-            room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
+        yield defer.ensureDeferred(
+            self.store.create_room_alias_association(
+                room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
+            )
         )
 
-        room_id = yield self.store.delete_room_alias(self.alias)
+        room_id = yield defer.ensureDeferred(self.store.delete_room_alias(self.alias))
         self.assertEqual(self.room.to_string(), room_id)
 
         self.assertIsNone(
-            (yield self.store.get_association_from_room_alias(self.alias))
+            (
+                yield defer.ensureDeferred(
+                    self.store.get_association_from_room_alias(self.alias)
+                )
+            )
         )
diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py
index 398d546280..3fc4bb13b6 100644
--- a/tests/storage/test_end_to_end_keys.py
+++ b/tests/storage/test_end_to_end_keys.py
@@ -30,11 +30,15 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
         now = 1470174257070
         json = {"key": "value"}
 
-        yield self.store.store_device("user", "device", None)
+        yield defer.ensureDeferred(self.store.store_device("user", "device", None))
 
-        yield self.store.set_e2e_device_keys("user", "device", now, json)
+        yield defer.ensureDeferred(
+            self.store.set_e2e_device_keys("user", "device", now, json)
+        )
 
-        res = yield self.store.get_e2e_device_keys((("user", "device"),))
+        res = yield defer.ensureDeferred(
+            self.store.get_e2e_device_keys_for_cs_api((("user", "device"),))
+        )
         self.assertIn("user", res)
         self.assertIn("device", res["user"])
         dev = res["user"]["device"]
@@ -45,14 +49,18 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
         now = 1470174257070
         json = {"key": "value"}
 
-        yield self.store.store_device("user", "device", None)
+        yield defer.ensureDeferred(self.store.store_device("user", "device", None))
 
-        changed = yield self.store.set_e2e_device_keys("user", "device", now, json)
+        changed = yield defer.ensureDeferred(
+            self.store.set_e2e_device_keys("user", "device", now, json)
+        )
         self.assertTrue(changed)
 
         # If we try to upload the same key then we should be told nothing
         # changed
-        changed = yield self.store.set_e2e_device_keys("user", "device", now, json)
+        changed = yield defer.ensureDeferred(
+            self.store.set_e2e_device_keys("user", "device", now, json)
+        )
         self.assertFalse(changed)
 
     @defer.inlineCallbacks
@@ -60,10 +68,16 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
         now = 1470174257070
         json = {"key": "value"}
 
-        yield self.store.set_e2e_device_keys("user", "device", now, json)
-        yield self.store.store_device("user", "device", "display_name")
+        yield defer.ensureDeferred(
+            self.store.set_e2e_device_keys("user", "device", now, json)
+        )
+        yield defer.ensureDeferred(
+            self.store.store_device("user", "device", "display_name")
+        )
 
-        res = yield self.store.get_e2e_device_keys((("user", "device"),))
+        res = yield defer.ensureDeferred(
+            self.store.get_e2e_device_keys_for_cs_api((("user", "device"),))
+        )
         self.assertIn("user", res)
         self.assertIn("device", res["user"])
         dev = res["user"]["device"]
@@ -75,18 +89,28 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
     def test_multiple_devices(self):
         now = 1470174257070
 
-        yield self.store.store_device("user1", "device1", None)
-        yield self.store.store_device("user1", "device2", None)
-        yield self.store.store_device("user2", "device1", None)
-        yield self.store.store_device("user2", "device2", None)
+        yield defer.ensureDeferred(self.store.store_device("user1", "device1", None))
+        yield defer.ensureDeferred(self.store.store_device("user1", "device2", None))
+        yield defer.ensureDeferred(self.store.store_device("user2", "device1", None))
+        yield defer.ensureDeferred(self.store.store_device("user2", "device2", None))
 
-        yield self.store.set_e2e_device_keys("user1", "device1", now, {"key": "json11"})
-        yield self.store.set_e2e_device_keys("user1", "device2", now, {"key": "json12"})
-        yield self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"})
-        yield self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"})
+        yield defer.ensureDeferred(
+            self.store.set_e2e_device_keys("user1", "device1", now, {"key": "json11"})
+        )
+        yield defer.ensureDeferred(
+            self.store.set_e2e_device_keys("user1", "device2", now, {"key": "json12"})
+        )
+        yield defer.ensureDeferred(
+            self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"})
+        )
+        yield defer.ensureDeferred(
+            self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"})
+        )
 
-        res = yield self.store.get_e2e_device_keys(
-            (("user1", "device1"), ("user2", "device2"))
+        res = yield defer.ensureDeferred(
+            self.store.get_e2e_device_keys_for_cs_api(
+                (("user1", "device1"), ("user2", "device2"))
+            )
         )
         self.assertIn("user1", res)
         self.assertIn("device1", res["user1"])
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 3aeec0dc0f..d4c3b867e3 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -56,7 +56,9 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
             )
 
         for i in range(0, 20):
-            self.get_success(self.store.db.runInteraction("insert", insert_event, i))
+            self.get_success(
+                self.store.db_pool.runInteraction("insert", insert_event, i)
+            )
 
         # this should get the last ten
         r = self.get_success(self.store.get_prev_events_for_room(room_id))
@@ -81,13 +83,13 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
 
         for i in range(0, 20):
             self.get_success(
-                self.store.db.runInteraction("insert", insert_event, i, room1)
+                self.store.db_pool.runInteraction("insert", insert_event, i, room1)
             )
             self.get_success(
-                self.store.db.runInteraction("insert", insert_event, i, room2)
+                self.store.db_pool.runInteraction("insert", insert_event, i, room2)
             )
             self.get_success(
-                self.store.db.runInteraction("insert", insert_event, i, room3)
+                self.store.db_pool.runInteraction("insert", insert_event, i, room3)
             )
 
         # Test simple case
@@ -164,7 +166,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
 
             depth = depth_map[event_id]
 
-            self.store.db.simple_insert_txn(
+            self.store.db_pool.simple_insert_txn(
                 txn,
                 table="events",
                 values={
@@ -179,7 +181,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
                 },
             )
 
-            self.store.db.simple_insert_many_txn(
+            self.store.db_pool.simple_insert_many_txn(
                 txn,
                 table="event_auth",
                 values=[
@@ -192,7 +194,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
         for event_id in auth_graph:
             next_stream_ordering += 1
             self.get_success(
-                self.store.db.runInteraction(
+                self.store.db_pool.runInteraction(
                     "insert", insert_event, event_id, next_stream_ordering
                 )
             )
diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py
index a7b85004e5..949846fe33 100644
--- a/tests/storage/test_event_metrics.py
+++ b/tests/storage/test_event_metrics.py
@@ -27,7 +27,7 @@ class ExtremStatisticsTestCase(HomeserverTestCase):
         room_creator = self.hs.get_room_creation_handler()
 
         user = UserID("alice", "test")
-        requester = Requester(user, None, False, None, None)
+        requester = Requester(user, None, False, False, None, None)
 
         # Real events, forward extremities
         events = [(3, 2), (6, 2), (4, 6)]
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index b45bc9c115..c0595963dd 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -39,14 +39,18 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_get_unread_push_actions_for_user_in_range_for_http(self):
-        yield self.store.get_unread_push_actions_for_user_in_range_for_http(
-            USER_ID, 0, 1000, 20
+        yield defer.ensureDeferred(
+            self.store.get_unread_push_actions_for_user_in_range_for_http(
+                USER_ID, 0, 1000, 20
+            )
         )
 
     @defer.inlineCallbacks
     def test_get_unread_push_actions_for_user_in_range_for_email(self):
-        yield self.store.get_unread_push_actions_for_user_in_range_for_email(
-            USER_ID, 0, 1000, 20
+        yield defer.ensureDeferred(
+            self.store.get_unread_push_actions_for_user_in_range_for_email(
+                USER_ID, 0, 1000, 20
+            )
         )
 
     @defer.inlineCallbacks
@@ -56,12 +60,18 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
 
         @defer.inlineCallbacks
         def _assert_counts(noitf_count, highlight_count):
-            counts = yield self.store.db.runInteraction(
-                "", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0
+            counts = yield defer.ensureDeferred(
+                self.store.db_pool.runInteraction(
+                    "", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0
+                )
             )
             self.assertEquals(
                 counts,
-                {"notify_count": noitf_count, "highlight_count": highlight_count},
+                {
+                    "notify_count": noitf_count,
+                    "unread_count": 0,  # Unread counts are tested in the sync tests.
+                    "highlight_count": highlight_count,
+                },
             )
 
         @defer.inlineCallbacks
@@ -72,28 +82,36 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
             event.internal_metadata.stream_ordering = stream
             event.depth = stream
 
-            yield self.store.add_push_actions_to_staging(
-                event.event_id, {user_id: action}
+            yield defer.ensureDeferred(
+                self.store.add_push_actions_to_staging(
+                    event.event_id, {user_id: action}, False,
+                )
             )
-            yield self.store.db.runInteraction(
-                "",
-                self.persist_events_store._set_push_actions_for_event_and_users_txn,
-                [(event, None)],
-                [(event, None)],
+            yield defer.ensureDeferred(
+                self.store.db_pool.runInteraction(
+                    "",
+                    self.persist_events_store._set_push_actions_for_event_and_users_txn,
+                    [(event, None)],
+                    [(event, None)],
+                )
             )
 
         def _rotate(stream):
-            return self.store.db.runInteraction(
-                "", self.store._rotate_notifs_before_txn, stream
+            return defer.ensureDeferred(
+                self.store.db_pool.runInteraction(
+                    "", self.store._rotate_notifs_before_txn, stream
+                )
             )
 
         def _mark_read(stream, depth):
-            return self.store.db.runInteraction(
-                "",
-                self.store._remove_old_push_actions_before_txn,
-                room_id,
-                user_id,
-                stream,
+            return defer.ensureDeferred(
+                self.store.db_pool.runInteraction(
+                    "",
+                    self.store._remove_old_push_actions_before_txn,
+                    room_id,
+                    user_id,
+                    stream,
+                )
             )
 
         yield _assert_counts(0, 0)
@@ -117,8 +135,10 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
         yield _inject_actions(6, PlAIN_NOTIF)
         yield _rotate(7)
 
-        yield self.store.db.simple_delete(
-            table="event_push_actions", keyvalues={"1": 1}, desc=""
+        yield defer.ensureDeferred(
+            self.store.db_pool.simple_delete(
+                table="event_push_actions", keyvalues={"1": 1}, desc=""
+            )
         )
 
         yield _assert_counts(1, 0)
@@ -136,33 +156,43 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
     @defer.inlineCallbacks
     def test_find_first_stream_ordering_after_ts(self):
         def add_event(so, ts):
-            return self.store.db.simple_insert(
-                "events",
-                {
-                    "stream_ordering": so,
-                    "received_ts": ts,
-                    "event_id": "event%i" % so,
-                    "type": "",
-                    "room_id": "",
-                    "content": "",
-                    "processed": True,
-                    "outlier": False,
-                    "topological_ordering": 0,
-                    "depth": 0,
-                },
+            return defer.ensureDeferred(
+                self.store.db_pool.simple_insert(
+                    "events",
+                    {
+                        "stream_ordering": so,
+                        "received_ts": ts,
+                        "event_id": "event%i" % so,
+                        "type": "",
+                        "room_id": "",
+                        "content": "",
+                        "processed": True,
+                        "outlier": False,
+                        "topological_ordering": 0,
+                        "depth": 0,
+                    },
+                )
             )
 
         # start with the base case where there are no events in the table
-        r = yield self.store.find_first_stream_ordering_after_ts(11)
+        r = yield defer.ensureDeferred(
+            self.store.find_first_stream_ordering_after_ts(11)
+        )
         self.assertEqual(r, 0)
 
         # now with one event
         yield add_event(2, 10)
-        r = yield self.store.find_first_stream_ordering_after_ts(9)
+        r = yield defer.ensureDeferred(
+            self.store.find_first_stream_ordering_after_ts(9)
+        )
         self.assertEqual(r, 2)
-        r = yield self.store.find_first_stream_ordering_after_ts(10)
+        r = yield defer.ensureDeferred(
+            self.store.find_first_stream_ordering_after_ts(10)
+        )
         self.assertEqual(r, 2)
-        r = yield self.store.find_first_stream_ordering_after_ts(11)
+        r = yield defer.ensureDeferred(
+            self.store.find_first_stream_ordering_after_ts(11)
+        )
         self.assertEqual(r, 3)
 
         # add a bunch of dummy events to the events table
@@ -175,25 +205,37 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
         ):
             yield add_event(stream_ordering, ts)
 
-        r = yield self.store.find_first_stream_ordering_after_ts(110)
+        r = yield defer.ensureDeferred(
+            self.store.find_first_stream_ordering_after_ts(110)
+        )
         self.assertEqual(r, 3, "First event after 110ms should be 3, was %i" % r)
 
         # 4 and 5 are both after 120: we want 4 rather than 5
-        r = yield self.store.find_first_stream_ordering_after_ts(120)
+        r = yield defer.ensureDeferred(
+            self.store.find_first_stream_ordering_after_ts(120)
+        )
         self.assertEqual(r, 4, "First event after 120ms should be 4, was %i" % r)
 
-        r = yield self.store.find_first_stream_ordering_after_ts(129)
+        r = yield defer.ensureDeferred(
+            self.store.find_first_stream_ordering_after_ts(129)
+        )
         self.assertEqual(r, 10, "First event after 129ms should be 10, was %i" % r)
 
         # check we can get the last event
-        r = yield self.store.find_first_stream_ordering_after_ts(140)
+        r = yield defer.ensureDeferred(
+            self.store.find_first_stream_ordering_after_ts(140)
+        )
         self.assertEqual(r, 20, "First event after 14ms should be 20, was %i" % r)
 
         # off the end
-        r = yield self.store.find_first_stream_ordering_after_ts(160)
+        r = yield defer.ensureDeferred(
+            self.store.find_first_stream_ordering_after_ts(160)
+        )
         self.assertEqual(r, 21)
 
         # check we can find an event at ordering zero
         yield add_event(0, 5)
-        r = yield self.store.find_first_stream_ordering_after_ts(1)
+        r = yield defer.ensureDeferred(
+            self.store.find_first_stream_ordering_after_ts(1)
+        )
         self.assertEqual(r, 0)
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 55e9ecf264..f0a8e32f1e 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
 from synapse.storage.util.id_generators import MultiWriterIdGenerator
 
 from tests.unittest import HomeserverTestCase
@@ -27,9 +27,9 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
 
     def prepare(self, reactor, clock, hs):
         self.store = hs.get_datastore()
-        self.db = self.store.db  # type: Database
+        self.db_pool = self.store.db_pool  # type: DatabasePool
 
-        self.get_success(self.db.runInteraction("_setup_db", self._setup_db))
+        self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
 
     def _setup_db(self, txn):
         txn.execute("CREATE SEQUENCE foobar_seq")
@@ -47,7 +47,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         def _create(conn):
             return MultiWriterIdGenerator(
                 conn,
-                self.db,
+                self.db_pool,
                 instance_name=instance_name,
                 table="foobar",
                 instance_column="instance_name",
@@ -55,9 +55,13 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
                 sequence_name="foobar_seq",
             )
 
-        return self.get_success(self.db.runWithConnection(_create))
+        return self.get_success(self.db_pool.runWithConnection(_create))
 
     def _insert_rows(self, instance_name: str, number: int):
+        """Insert N rows as the given instance, inserting with stream IDs pulled
+        from the postgres sequence.
+        """
+
         def _insert(txn):
             for _ in range(number):
                 txn.execute(
@@ -65,7 +69,20 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
                     (instance_name,),
                 )
 
-        self.get_success(self.db.runInteraction("test_single_instance", _insert))
+        self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
+
+    def _insert_row_with_id(self, instance_name: str, stream_id: int):
+        """Insert one row as the given instance with given stream_id, updating
+        the postgres sequence position to match.
+        """
+
+        def _insert(txn):
+            txn.execute(
+                "INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,),
+            )
+            txn.execute("SELECT setval('foobar_seq', ?)", (stream_id,))
+
+        self.get_success(self.db_pool.runInteraction("_insert_row_with_id", _insert))
 
     def test_empty(self):
         """Test an ID generator against an empty database gives sensible
@@ -88,7 +105,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         id_gen = self._create_id_generator()
 
         self.assertEqual(id_gen.get_positions(), {"master": 7})
-        self.assertEqual(id_gen.get_current_token("master"), 7)
+        self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
 
         # Try allocating a new ID gen and check that we only see position
         # advanced after we leave the context manager.
@@ -98,12 +115,12 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
                 self.assertEqual(stream_id, 8)
 
                 self.assertEqual(id_gen.get_positions(), {"master": 7})
-                self.assertEqual(id_gen.get_current_token("master"), 7)
+                self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
 
         self.get_success(_get_next_async())
 
         self.assertEqual(id_gen.get_positions(), {"master": 8})
-        self.assertEqual(id_gen.get_current_token("master"), 8)
+        self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
 
     def test_multi_instance(self):
         """Test that reads and writes from multiple processes are handled
@@ -116,8 +133,8 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         second_id_gen = self._create_id_generator("second")
 
         self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
-        self.assertEqual(first_id_gen.get_current_token("first"), 3)
-        self.assertEqual(first_id_gen.get_current_token("second"), 7)
+        self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3)
+        self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
 
         # Try allocating a new ID gen and check that we only see position
         # advanced after we leave the context manager.
@@ -166,7 +183,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         id_gen = self._create_id_generator()
 
         self.assertEqual(id_gen.get_positions(), {"master": 7})
-        self.assertEqual(id_gen.get_current_token("master"), 7)
+        self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
 
         # Try allocating a new ID gen and check that we only see position
         # advanced after we leave the context manager.
@@ -176,9 +193,179 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
             self.assertEqual(stream_id, 8)
 
             self.assertEqual(id_gen.get_positions(), {"master": 7})
-            self.assertEqual(id_gen.get_current_token("master"), 7)
+            self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
 
-        self.get_success(self.db.runInteraction("test", _get_next_txn))
+        self.get_success(self.db_pool.runInteraction("test", _get_next_txn))
 
         self.assertEqual(id_gen.get_positions(), {"master": 8})
-        self.assertEqual(id_gen.get_current_token("master"), 8)
+        self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
+
+    def test_get_persisted_upto_position(self):
+        """Test that `get_persisted_upto_position` correctly tracks updates to
+        positions.
+        """
+
+        # The following tests are a bit cheeky in that we notify about new
+        # positions via `advance` without *actually* advancing the postgres
+        # sequence.
+
+        self._insert_row_with_id("first", 3)
+        self._insert_row_with_id("second", 5)
+
+        id_gen = self._create_id_generator("first")
+
+        self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
+
+        # Min is 3 and there is a gap between 5, so we expect it to be 3.
+        self.assertEqual(id_gen.get_persisted_upto_position(), 3)
+
+        # We advance "first" straight to 6. Min is now 5 but there is no gap so
+        # we expect it to be 6
+        id_gen.advance("first", 6)
+        self.assertEqual(id_gen.get_persisted_upto_position(), 6)
+
+        # No gap, so we expect 7.
+        id_gen.advance("second", 7)
+        self.assertEqual(id_gen.get_persisted_upto_position(), 7)
+
+        # We haven't seen 8 yet, so we expect 7 still.
+        id_gen.advance("second", 9)
+        self.assertEqual(id_gen.get_persisted_upto_position(), 7)
+
+        # Now that we've seen 7, 8 and 9 we can got straight to 9.
+        id_gen.advance("first", 8)
+        self.assertEqual(id_gen.get_persisted_upto_position(), 9)
+
+        # Jump forward with gaps. The minimum is 11, even though we haven't seen
+        # 10 we know that everything before 11 must be persisted.
+        id_gen.advance("first", 11)
+        id_gen.advance("second", 15)
+        self.assertEqual(id_gen.get_persisted_upto_position(), 11)
+
+    def test_get_persisted_upto_position_get_next(self):
+        """Test that `get_persisted_upto_position` correctly tracks updates to
+        positions when `get_next` is called.
+        """
+
+        self._insert_row_with_id("first", 3)
+        self._insert_row_with_id("second", 5)
+
+        id_gen = self._create_id_generator("first")
+
+        self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
+
+        self.assertEqual(id_gen.get_persisted_upto_position(), 3)
+        with self.get_success(id_gen.get_next()) as stream_id:
+            self.assertEqual(stream_id, 6)
+            self.assertEqual(id_gen.get_persisted_upto_position(), 3)
+
+        self.assertEqual(id_gen.get_persisted_upto_position(), 6)
+
+        # We assume that so long as `get_next` does correctly advance the
+        # `persisted_upto_position` in this case, then it will be correct in the
+        # other cases that are tested above (since they'll hit the same code).
+
+
+class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
+    """Tests MultiWriterIdGenerator that produce *negative* stream IDs.
+    """
+
+    if not USE_POSTGRES_FOR_TESTS:
+        skip = "Requires Postgres"
+
+    def prepare(self, reactor, clock, hs):
+        self.store = hs.get_datastore()
+        self.db_pool = self.store.db_pool  # type: DatabasePool
+
+        self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
+
+    def _setup_db(self, txn):
+        txn.execute("CREATE SEQUENCE foobar_seq")
+        txn.execute(
+            """
+            CREATE TABLE foobar (
+                stream_id BIGINT NOT NULL,
+                instance_name TEXT NOT NULL,
+                data TEXT
+            );
+            """
+        )
+
+    def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator:
+        def _create(conn):
+            return MultiWriterIdGenerator(
+                conn,
+                self.db_pool,
+                instance_name=instance_name,
+                table="foobar",
+                instance_column="instance_name",
+                id_column="stream_id",
+                sequence_name="foobar_seq",
+                positive=False,
+            )
+
+        return self.get_success(self.db_pool.runWithConnection(_create))
+
+    def _insert_row(self, instance_name: str, stream_id: int):
+        """Insert one row as the given instance with given stream_id.
+        """
+
+        def _insert(txn):
+            txn.execute(
+                "INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,),
+            )
+
+        self.get_success(self.db_pool.runInteraction("_insert_row", _insert))
+
+    def test_single_instance(self):
+        """Test that reads and writes from a single process are handled
+        correctly.
+        """
+        id_gen = self._create_id_generator()
+
+        with self.get_success(id_gen.get_next()) as stream_id:
+            self._insert_row("master", stream_id)
+
+        self.assertEqual(id_gen.get_positions(), {"master": -1})
+        self.assertEqual(id_gen.get_current_token_for_writer("master"), -1)
+        self.assertEqual(id_gen.get_persisted_upto_position(), -1)
+
+        with self.get_success(id_gen.get_next_mult(3)) as stream_ids:
+            for stream_id in stream_ids:
+                self._insert_row("master", stream_id)
+
+        self.assertEqual(id_gen.get_positions(), {"master": -4})
+        self.assertEqual(id_gen.get_current_token_for_writer("master"), -4)
+        self.assertEqual(id_gen.get_persisted_upto_position(), -4)
+
+        # Test loading from DB by creating a second ID gen
+        second_id_gen = self._create_id_generator()
+
+        self.assertEqual(second_id_gen.get_positions(), {"master": -4})
+        self.assertEqual(second_id_gen.get_current_token_for_writer("master"), -4)
+        self.assertEqual(second_id_gen.get_persisted_upto_position(), -4)
+
+    def test_multiple_instance(self):
+        """Tests that having multiple instances that get advanced over
+        federation works corretly.
+        """
+        id_gen_1 = self._create_id_generator("first")
+        id_gen_2 = self._create_id_generator("second")
+
+        with self.get_success(id_gen_1.get_next()) as stream_id:
+            self._insert_row("first", stream_id)
+            id_gen_2.advance("first", stream_id)
+
+        self.assertEqual(id_gen_1.get_positions(), {"first": -1})
+        self.assertEqual(id_gen_2.get_positions(), {"first": -1})
+        self.assertEqual(id_gen_1.get_persisted_upto_position(), -1)
+        self.assertEqual(id_gen_2.get_persisted_upto_position(), -1)
+
+        with self.get_success(id_gen_2.get_next()) as stream_id:
+            self._insert_row("second", stream_id)
+            id_gen_1.advance("second", stream_id)
+
+        self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -2})
+        self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2})
+        self.assertEqual(id_gen_1.get_persisted_upto_position(), -2)
+        self.assertEqual(id_gen_2.get_persisted_upto_position(), -2)
diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py
index ab0df5ea93..7e7f1286d9 100644
--- a/tests/storage/test_main.py
+++ b/tests/storage/test_main.py
@@ -34,12 +34,16 @@ class DataStoreTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_get_users_paginate(self):
-        yield self.store.register_user(self.user.to_string(), "pass")
-        yield self.store.create_profile(self.user.localpart)
-        yield self.store.set_profile_displayname(self.user.localpart, self.displayname)
+        yield defer.ensureDeferred(
+            self.store.register_user(self.user.to_string(), "pass")
+        )
+        yield defer.ensureDeferred(self.store.create_profile(self.user.localpart))
+        yield defer.ensureDeferred(
+            self.store.set_profile_displayname(self.user.localpart, self.displayname)
+        )
 
-        users, total = yield self.store.get_users_paginate(
-            0, 10, name="bc", guests=False
+        users, total = yield defer.ensureDeferred(
+            self.store.get_users_paginate(0, 10, name="bc", guests=False)
         )
 
         self.assertEquals(1, total)
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index 9c04e92577..9870c74883 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -19,6 +19,7 @@ from twisted.internet import defer
 from synapse.api.constants import UserTypes
 
 from tests import unittest
+from tests.test_utils import make_awaitable
 from tests.unittest import default_config, override_config
 
 FORTY_DAYS = 40 * 24 * 60 * 60
@@ -78,7 +79,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
         # XXX why are we doing this here? this function is only run at startup
         # so it is odd to re-run it here.
         self.get_success(
-            self.store.db.runInteraction(
+            self.store.db_pool.runInteraction(
                 "initialise", self.store._initialise_reserved_users, threepids
             )
         )
@@ -204,7 +205,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
                     self.store.user_add_threepid(user, "email", email, now, now)
                 )
 
-        d = self.store.db.runInteraction(
+        d = self.store.db_pool.runInteraction(
             "initialise", self.store._initialise_reserved_users, threepids
         )
         self.get_success(d)
@@ -230,7 +231,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
         )
         self.get_success(d)
 
-        self.store.upsert_monthly_active_user = Mock()
+        self.store.upsert_monthly_active_user = Mock(
+            side_effect=lambda user_id: make_awaitable(None)
+        )
 
         d = self.store.populate_monthly_active_users(user_id)
         self.get_success(d)
@@ -238,7 +241,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
         self.store.upsert_monthly_active_user.assert_not_called()
 
     def test_populate_monthly_users_should_update(self):
-        self.store.upsert_monthly_active_user = Mock()
+        self.store.upsert_monthly_active_user = Mock(
+            side_effect=lambda user_id: make_awaitable(None)
+        )
 
         self.store.is_trial_user = Mock(return_value=defer.succeed(False))
 
@@ -251,7 +256,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
         self.store.upsert_monthly_active_user.assert_called_once()
 
     def test_populate_monthly_users_should_not_update(self):
-        self.store.upsert_monthly_active_user = Mock()
+        self.store.upsert_monthly_active_user = Mock(
+            side_effect=lambda user_id: make_awaitable(None)
+        )
 
         self.store.is_trial_user = Mock(return_value=defer.succeed(False))
         self.store.user_last_seen_monthly_active = Mock(
@@ -280,7 +287,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
         ]
 
         self.hs.config.mau_limits_reserved_threepids = threepids
-        d = self.store.db.runInteraction(
+        d = self.store.db_pool.runInteraction(
             "initialise", self.store._initialise_reserved_users, threepids
         )
         self.get_success(d)
@@ -293,8 +300,12 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
         self.get_success(self.store.register_user(user_id=user2, password_hash=None))
 
         now = int(self.hs.get_clock().time_msec())
-        self.store.user_add_threepid(user1, "email", user1_email, now, now)
-        self.store.user_add_threepid(user2, "email", user2_email, now, now)
+        self.get_success(
+            self.store.user_add_threepid(user1, "email", user1_email, now, now)
+        )
+        self.get_success(
+            self.store.user_add_threepid(user2, "email", user2_email, now, now)
+        )
 
         users = self.get_success(self.store.get_registered_reserved_users())
         self.assertEqual(len(users), len(threepids))
@@ -333,7 +344,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
 
     @override_config({"limit_usage_by_mau": False, "mau_stats_only": False})
     def test_no_users_when_not_tracking(self):
-        self.store.upsert_monthly_active_user = Mock()
+        self.store.upsert_monthly_active_user = Mock(
+            side_effect=lambda user_id: make_awaitable(None)
+        )
 
         self.get_success(self.store.populate_monthly_active_users("@user:sever"))
 
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index 9b6f7211ae..3fd0a38cf5 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -33,23 +33,36 @@ class ProfileStoreTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_displayname(self):
-        yield self.store.create_profile(self.u_frank.localpart)
+        yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
 
-        yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
+        yield defer.ensureDeferred(
+            self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
+        )
 
         self.assertEquals(
-            "Frank", (yield self.store.get_profile_displayname(self.u_frank.localpart))
+            "Frank",
+            (
+                yield defer.ensureDeferred(
+                    self.store.get_profile_displayname(self.u_frank.localpart)
+                )
+            ),
         )
 
     @defer.inlineCallbacks
     def test_avatar_url(self):
-        yield self.store.create_profile(self.u_frank.localpart)
+        yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
 
-        yield self.store.set_profile_avatar_url(
-            self.u_frank.localpart, "http://my.site/here"
+        yield defer.ensureDeferred(
+            self.store.set_profile_avatar_url(
+                self.u_frank.localpart, "http://my.site/here"
+            )
         )
 
         self.assertEquals(
             "http://my.site/here",
-            (yield self.store.get_profile_avatar_url(self.u_frank.localpart)),
+            (
+                yield defer.ensureDeferred(
+                    self.store.get_profile_avatar_url(self.u_frank.localpart)
+                )
+            ),
         )
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index b9fafaa1a6..918387733b 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -13,6 +13,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from twisted.internet import defer
+
+from synapse.api.errors import NotFoundError
 from synapse.rest.client.v1 import room
 
 from tests.unittest import HomeserverTestCase
@@ -44,28 +47,19 @@ class PurgeTests(HomeserverTestCase):
         storage = self.hs.get_storage()
 
         # Get the topological token
-        event = store.get_topological_token_for_event(last["event_id"])
-        self.pump()
-        event = self.successResultOf(event)
+        event = self.get_success(
+            store.get_topological_token_for_event(last["event_id"])
+        )
 
         # Purge everything before this topological token
-        purge = storage.purge_events.purge_history(self.room_id, event, True)
-        self.pump()
-        self.assertEqual(self.successResultOf(purge), None)
-
-        # Try and get the events
-        get_first = store.get_event(first["event_id"])
-        get_second = store.get_event(second["event_id"])
-        get_third = store.get_event(third["event_id"])
-        get_last = store.get_event(last["event_id"])
-        self.pump()
+        self.get_success(storage.purge_events.purge_history(self.room_id, event, True))
 
         # 1-3 should fail and last will succeed, meaning that 1-3 are deleted
         # and last is not.
-        self.failureResultOf(get_first)
-        self.failureResultOf(get_second)
-        self.failureResultOf(get_third)
-        self.successResultOf(get_last)
+        self.get_failure(store.get_event(first["event_id"]), NotFoundError)
+        self.get_failure(store.get_event(second["event_id"]), NotFoundError)
+        self.get_failure(store.get_event(third["event_id"]), NotFoundError)
+        self.get_success(store.get_event(last["event_id"]))
 
     def test_purge_wont_delete_extrems(self):
         """
@@ -80,28 +74,21 @@ class PurgeTests(HomeserverTestCase):
         storage = self.hs.get_datastore()
 
         # Set the topological token higher than it should be
-        event = storage.get_topological_token_for_event(last["event_id"])
-        self.pump()
-        event = self.successResultOf(event)
+        event = self.get_success(
+            storage.get_topological_token_for_event(last["event_id"])
+        )
         event = "t{}-{}".format(
             *list(map(lambda x: x + 1, map(int, event[1:].split("-"))))
         )
 
         # Purge everything before this topological token
-        purge = storage.purge_history(self.room_id, event, True)
+        purge = defer.ensureDeferred(storage.purge_history(self.room_id, event, True))
         self.pump()
         f = self.failureResultOf(purge)
         self.assertIn("greater than forward", f.value.args[0])
 
         # Try and get the events
-        get_first = storage.get_event(first["event_id"])
-        get_second = storage.get_event(second["event_id"])
-        get_third = storage.get_event(third["event_id"])
-        get_last = storage.get_event(last["event_id"])
-        self.pump()
-
-        # Nothing is deleted.
-        self.successResultOf(get_first)
-        self.successResultOf(get_second)
-        self.successResultOf(get_third)
-        self.successResultOf(get_last)
+        self.get_success(storage.get_event(first["event_id"]))
+        self.get_success(storage.get_event(second["event_id"]))
+        self.get_success(storage.get_event(third["event_id"]))
+        self.get_success(storage.get_event(last["event_id"]))
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index db3667dc43..1ea35d60c1 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -237,7 +237,9 @@ class RedactionTestCase(unittest.HomeserverTestCase):
 
             @defer.inlineCallbacks
             def build(self, prev_event_ids):
-                built_event = yield self._base_builder.build(prev_event_ids)
+                built_event = yield defer.ensureDeferred(
+                    self._base_builder.build(prev_event_ids)
+                )
 
                 built_event._event_id = self._event_id
                 built_event._dict["event_id"] = self._event_id
@@ -249,6 +251,10 @@ class RedactionTestCase(unittest.HomeserverTestCase):
             def room_id(self):
                 return self._base_builder.room_id
 
+            @property
+            def type(self):
+                return self._base_builder.type
+
         event_1, context_1 = self.get_success(
             self.event_creation_handler.create_new_client_event(
                 EventIdManglingBuilder(
@@ -341,7 +347,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
         )
 
         event_json = self.get_success(
-            self.store.db.simple_select_one_onecol(
+            self.store.db_pool.simple_select_one_onecol(
                 table="event_json",
                 keyvalues={"event_id": msg_event.event_id},
                 retcol="json",
@@ -359,7 +365,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
         self.reactor.advance(60 * 60 * 2)
 
         event_json = self.get_success(
-            self.store.db.simple_select_one_onecol(
+            self.store.db_pool.simple_select_one_onecol(
                 table="event_json",
                 keyvalues={"event_id": msg_event.event_id},
                 retcol="json",
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 71a40a0a49..6b582771fe 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -17,6 +17,7 @@
 from twisted.internet import defer
 
 from synapse.api.constants import UserTypes
+from synapse.api.errors import ThreepidValidationError
 
 from tests import unittest
 from tests.utils import setup_test_homeserver
@@ -36,7 +37,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_register(self):
-        yield self.store.register_user(self.user_id, self.pwhash)
+        yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash))
 
         self.assertEquals(
             {
@@ -52,17 +53,21 @@ class RegistrationStoreTestCase(unittest.TestCase):
                 "user_type": None,
                 "deactivated": 0,
             },
-            (yield self.store.get_user_by_id(self.user_id)),
+            (yield defer.ensureDeferred(self.store.get_user_by_id(self.user_id))),
         )
 
     @defer.inlineCallbacks
     def test_add_tokens(self):
-        yield self.store.register_user(self.user_id, self.pwhash)
-        yield self.store.add_access_token_to_user(
-            self.user_id, self.tokens[1], self.device_id, valid_until_ms=None
+        yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash))
+        yield defer.ensureDeferred(
+            self.store.add_access_token_to_user(
+                self.user_id, self.tokens[1], self.device_id, valid_until_ms=None
+            )
         )
 
-        result = yield self.store.get_user_by_access_token(self.tokens[1])
+        result = yield defer.ensureDeferred(
+            self.store.get_user_by_access_token(self.tokens[1])
+        )
 
         self.assertDictContainsSubset(
             {"name": self.user_id, "device_id": self.device_id}, result
@@ -73,31 +78,41 @@ class RegistrationStoreTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_user_delete_access_tokens(self):
         # add some tokens
-        yield self.store.register_user(self.user_id, self.pwhash)
-        yield self.store.add_access_token_to_user(
-            self.user_id, self.tokens[0], device_id=None, valid_until_ms=None
+        yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash))
+        yield defer.ensureDeferred(
+            self.store.add_access_token_to_user(
+                self.user_id, self.tokens[0], device_id=None, valid_until_ms=None
+            )
         )
-        yield self.store.add_access_token_to_user(
-            self.user_id, self.tokens[1], self.device_id, valid_until_ms=None
+        yield defer.ensureDeferred(
+            self.store.add_access_token_to_user(
+                self.user_id, self.tokens[1], self.device_id, valid_until_ms=None
+            )
         )
 
         # now delete some
-        yield self.store.user_delete_access_tokens(
-            self.user_id, device_id=self.device_id
+        yield defer.ensureDeferred(
+            self.store.user_delete_access_tokens(self.user_id, device_id=self.device_id)
         )
 
         # check they were deleted
-        user = yield self.store.get_user_by_access_token(self.tokens[1])
+        user = yield defer.ensureDeferred(
+            self.store.get_user_by_access_token(self.tokens[1])
+        )
         self.assertIsNone(user, "access token was not deleted by device_id")
 
         # check the one not associated with the device was not deleted
-        user = yield self.store.get_user_by_access_token(self.tokens[0])
+        user = yield defer.ensureDeferred(
+            self.store.get_user_by_access_token(self.tokens[0])
+        )
         self.assertEqual(self.user_id, user["name"])
 
         # now delete the rest
-        yield self.store.user_delete_access_tokens(self.user_id)
+        yield defer.ensureDeferred(self.store.user_delete_access_tokens(self.user_id))
 
-        user = yield self.store.get_user_by_access_token(self.tokens[0])
+        user = yield defer.ensureDeferred(
+            self.store.get_user_by_access_token(self.tokens[0])
+        )
         self.assertIsNone(user, "access token was not deleted without device_id")
 
     @defer.inlineCallbacks
@@ -105,14 +120,48 @@ class RegistrationStoreTestCase(unittest.TestCase):
         TEST_USER = "@test:test"
         SUPPORT_USER = "@support:test"
 
-        res = yield self.store.is_support_user(None)
+        res = yield defer.ensureDeferred(self.store.is_support_user(None))
         self.assertFalse(res)
-        yield self.store.register_user(user_id=TEST_USER, password_hash=None)
-        res = yield self.store.is_support_user(TEST_USER)
+        yield defer.ensureDeferred(
+            self.store.register_user(user_id=TEST_USER, password_hash=None)
+        )
+        res = yield defer.ensureDeferred(self.store.is_support_user(TEST_USER))
         self.assertFalse(res)
 
-        yield self.store.register_user(
-            user_id=SUPPORT_USER, password_hash=None, user_type=UserTypes.SUPPORT
+        yield defer.ensureDeferred(
+            self.store.register_user(
+                user_id=SUPPORT_USER, password_hash=None, user_type=UserTypes.SUPPORT
+            )
         )
-        res = yield self.store.is_support_user(SUPPORT_USER)
+        res = yield defer.ensureDeferred(self.store.is_support_user(SUPPORT_USER))
         self.assertTrue(res)
+
+    @defer.inlineCallbacks
+    def test_3pid_inhibit_invalid_validation_session_error(self):
+        """Tests that enabling the configuration option to inhibit 3PID errors on
+        /requestToken also inhibits validation errors caused by an unknown session ID.
+        """
+
+        # Check that, with the config setting set to false (the default value), a
+        # validation error is caused by the unknown session ID.
+        try:
+            yield defer.ensureDeferred(
+                self.store.validate_threepid_session(
+                    "fake_sid", "fake_client_secret", "fake_token", 0,
+                )
+            )
+        except ThreepidValidationError as e:
+            self.assertEquals(e.msg, "Unknown session_id", e)
+
+        # Set the config setting to true.
+        self.store._ignore_unknown_session_error = True
+
+        # Check that now the validation error is caused by the token not matching.
+        try:
+            yield defer.ensureDeferred(
+                self.store.validate_threepid_session(
+                    "fake_sid", "fake_client_secret", "fake_token", 0,
+                )
+            )
+        except ThreepidValidationError as e:
+            self.assertEquals(e.msg, "Validation token not found or has expired", e)
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index 3b78d48896..bc8400f240 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -37,11 +37,13 @@ class RoomStoreTestCase(unittest.TestCase):
         self.alias = RoomAlias.from_string("#a-room-name:test")
         self.u_creator = UserID.from_string("@creator:test")
 
-        yield self.store.store_room(
-            self.room.to_string(),
-            room_creator_user_id=self.u_creator.to_string(),
-            is_public=True,
-            room_version=RoomVersions.V1,
+        yield defer.ensureDeferred(
+            self.store.store_room(
+                self.room.to_string(),
+                room_creator_user_id=self.u_creator.to_string(),
+                is_public=True,
+                room_version=RoomVersions.V1,
+            )
         )
 
     @defer.inlineCallbacks
@@ -52,7 +54,13 @@ class RoomStoreTestCase(unittest.TestCase):
                 "creator": self.u_creator.to_string(),
                 "is_public": True,
             },
-            (yield self.store.get_room(self.room.to_string())),
+            (yield defer.ensureDeferred(self.store.get_room(self.room.to_string()))),
+        )
+
+    @defer.inlineCallbacks
+    def test_get_room_unknown_room(self):
+        self.assertIsNone(
+            (yield defer.ensureDeferred(self.store.get_room("!uknown:test")))
         )
 
     @defer.inlineCallbacks
@@ -63,7 +71,21 @@ class RoomStoreTestCase(unittest.TestCase):
                 "creator": self.u_creator.to_string(),
                 "public": True,
             },
-            (yield self.store.get_room_with_stats(self.room.to_string())),
+            (
+                yield defer.ensureDeferred(
+                    self.store.get_room_with_stats(self.room.to_string())
+                )
+            ),
+        )
+
+    @defer.inlineCallbacks
+    def test_get_room_with_stats_unknown_room(self):
+        self.assertIsNone(
+            (
+                yield defer.ensureDeferred(
+                    self.store.get_room_with_stats("!uknown:test")
+                )
+            ),
         )
 
 
@@ -80,17 +102,21 @@ class RoomEventsStoreTestCase(unittest.TestCase):
 
         self.room = RoomID.from_string("!abcde:test")
 
-        yield self.store.store_room(
-            self.room.to_string(),
-            room_creator_user_id="@creator:text",
-            is_public=True,
-            room_version=RoomVersions.V1,
+        yield defer.ensureDeferred(
+            self.store.store_room(
+                self.room.to_string(),
+                room_creator_user_id="@creator:text",
+                is_public=True,
+                room_version=RoomVersions.V1,
+            )
         )
 
     @defer.inlineCallbacks
     def inject_room_event(self, **kwargs):
-        yield self.storage.persistence.persist_event(
-            self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
+        yield defer.ensureDeferred(
+            self.storage.persistence.persist_event(
+                self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
+            )
         )
 
     @defer.inlineCallbacks
@@ -101,7 +127,9 @@ class RoomEventsStoreTestCase(unittest.TestCase):
             etype=EventTypes.Name, name=name, content={"name": name}, depth=1
         )
 
-        state = yield self.store.get_current_state(room_id=self.room.to_string())
+        state = yield defer.ensureDeferred(
+            self.store.get_current_state(room_id=self.room.to_string())
+        )
 
         self.assertEquals(1, len(state))
         self.assertObjectHasAttributes(
@@ -117,7 +145,9 @@ class RoomEventsStoreTestCase(unittest.TestCase):
             etype=EventTypes.Topic, topic=topic, content={"topic": topic}, depth=1
         )
 
-        state = yield self.store.get_current_state(room_id=self.room.to_string())
+        state = yield defer.ensureDeferred(
+            self.store.get_current_state(room_id=self.room.to_string())
+        )
 
         self.assertEquals(1, len(state))
         self.assertObjectHasAttributes(
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 5dd46005e6..12ccc1f53e 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -87,7 +87,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
         self.inject_room_member(self.room, self.u_bob, Membership.JOIN)
         self.inject_room_member(self.room, self.u_charlie.to_string(), Membership.JOIN)
 
-        self.pump(20)
+        self.pump()
 
         self.assertTrue("_known_servers_count" not in self.store.__dict__.keys())
 
@@ -101,7 +101,7 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
         # Initialises to 1 -- itself
         self.assertEqual(self.store._known_servers_count, 1)
 
-        self.pump(20)
+        self.pump()
 
         # No rooms have been joined, so technically the SQL returns 0, but it
         # will still say it knows about itself.
@@ -111,25 +111,29 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
         self.inject_room_member(self.room, self.u_bob, Membership.JOIN)
         self.inject_room_member(self.room, self.u_charlie.to_string(), Membership.JOIN)
 
-        self.pump(20)
+        self.pump(1)
 
         # It now knows about Charlie's server.
         self.assertEqual(self.store._known_servers_count, 2)
 
     def test_get_joined_users_from_context(self):
         room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
-        bob_event = event_injection.inject_member_event(
-            self.hs, room, self.u_bob, Membership.JOIN
+        bob_event = self.get_success(
+            event_injection.inject_member_event(
+                self.hs, room, self.u_bob, Membership.JOIN
+            )
         )
 
         # first, create a regular event
-        event, context = event_injection.create_event(
-            self.hs,
-            room_id=room,
-            sender=self.u_alice,
-            prev_event_ids=[bob_event.event_id],
-            type="m.test.1",
-            content={},
+        event, context = self.get_success(
+            event_injection.create_event(
+                self.hs,
+                room_id=room,
+                sender=self.u_alice,
+                prev_event_ids=[bob_event.event_id],
+                type="m.test.1",
+                content={},
+            )
         )
 
         users = self.get_success(
@@ -140,22 +144,26 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
         # Regression test for #7376: create a state event whose key matches bob's
         # user_id, but which is *not* a membership event, and persist that; then check
         # that `get_joined_users_from_context` returns the correct users for the next event.
-        non_member_event = event_injection.inject_event(
-            self.hs,
-            room_id=room,
-            sender=self.u_bob,
-            prev_event_ids=[bob_event.event_id],
-            type="m.test.2",
-            state_key=self.u_bob,
-            content={},
+        non_member_event = self.get_success(
+            event_injection.inject_event(
+                self.hs,
+                room_id=room,
+                sender=self.u_bob,
+                prev_event_ids=[bob_event.event_id],
+                type="m.test.2",
+                state_key=self.u_bob,
+                content={},
+            )
         )
-        event, context = event_injection.create_event(
-            self.hs,
-            room_id=room,
-            sender=self.u_alice,
-            prev_event_ids=[non_member_event.event_id],
-            type="m.test.3",
-            content={},
+        event, context = self.get_success(
+            event_injection.create_event(
+                self.hs,
+                room_id=room,
+                sender=self.u_alice,
+                prev_event_ids=[non_member_event.event_id],
+                type="m.test.3",
+                content={},
+            )
         )
         users = self.get_success(
             self.store.get_joined_users_from_context(event, context)
@@ -171,20 +179,20 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
     def test_can_rerun_update(self):
         # First make sure we have completed all updates.
         while not self.get_success(
-            self.store.db.updates.has_completed_background_updates()
+            self.store.db_pool.updates.has_completed_background_updates()
         ):
             self.get_success(
-                self.store.db.updates.do_next_background_update(100), by=0.1
+                self.store.db_pool.updates.do_next_background_update(100), by=0.1
             )
 
         # Now let's create a room, which will insert a membership
         user = UserID("alice", "test")
-        requester = Requester(user, None, False, None, None)
+        requester = Requester(user, None, False, False, None, None)
         self.get_success(self.room_creator.create_room(requester, {}))
 
         # Register the background update to run again.
         self.get_success(
-            self.store.db.simple_insert(
+            self.store.db_pool.simple_insert(
                 table="background_updates",
                 values={
                     "update_name": "current_state_events_membership",
@@ -195,12 +203,12 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
         )
 
         # ... and tell the DataStore that it hasn't finished all updates yet
-        self.store.db.updates._all_done = False
+        self.store.db_pool.updates._all_done = False
 
         # Now let's actually drive the updates to completion
         while not self.get_success(
-            self.store.db.updates.has_completed_background_updates()
+            self.store.db_pool.updates.has_completed_background_updates()
         ):
             self.get_success(
-                self.store.db.updates.do_next_background_update(100), by=0.1
+                self.store.db_pool.updates.do_next_background_update(100), by=0.1
             )
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 0b88308ff4..8bd12fa847 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -44,11 +44,13 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         self.room = RoomID.from_string("!abc123:test")
 
-        yield self.store.store_room(
-            self.room.to_string(),
-            room_creator_user_id="@creator:text",
-            is_public=True,
-            room_version=RoomVersions.V1,
+        yield defer.ensureDeferred(
+            self.store.store_room(
+                self.room.to_string(),
+                room_creator_user_id="@creator:text",
+                is_public=True,
+                room_version=RoomVersions.V1,
+            )
         )
 
     @defer.inlineCallbacks
@@ -64,11 +66,13 @@ class StateStoreTestCase(tests.unittest.TestCase):
             },
         )
 
-        event, context = yield self.event_creation_handler.create_new_client_event(
-            builder
+        event, context = yield defer.ensureDeferred(
+            self.event_creation_handler.create_new_client_event(builder)
         )
 
-        yield self.storage.persistence.persist_event(event, context)
+        yield defer.ensureDeferred(
+            self.storage.persistence.persist_event(event, context)
+        )
 
         return event
 
@@ -87,8 +91,8 @@ class StateStoreTestCase(tests.unittest.TestCase):
             self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
         )
 
-        state_group_map = yield self.storage.state.get_state_groups_ids(
-            self.room, [e2.event_id]
+        state_group_map = yield defer.ensureDeferred(
+            self.storage.state.get_state_groups_ids(self.room, [e2.event_id])
         )
         self.assertEqual(len(state_group_map), 1)
         state_map = list(state_group_map.values())[0]
@@ -106,8 +110,8 @@ class StateStoreTestCase(tests.unittest.TestCase):
             self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
         )
 
-        state_group_map = yield self.storage.state.get_state_groups(
-            self.room, [e2.event_id]
+        state_group_map = yield defer.ensureDeferred(
+            self.storage.state.get_state_groups(self.room, [e2.event_id])
         )
         self.assertEqual(len(state_group_map), 1)
         state_list = list(state_group_map.values())[0]
@@ -148,7 +152,9 @@ class StateStoreTestCase(tests.unittest.TestCase):
         )
 
         # check we get the full state as of the final event
-        state = yield self.storage.state.get_state_for_event(e5.event_id)
+        state = yield defer.ensureDeferred(
+            self.storage.state.get_state_for_event(e5.event_id)
+        )
 
         self.assertIsNotNone(e4)
 
@@ -164,22 +170,28 @@ class StateStoreTestCase(tests.unittest.TestCase):
         )
 
         # check we can filter to the m.room.name event (with a '' state key)
-        state = yield self.storage.state.get_state_for_event(
-            e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
+        state = yield defer.ensureDeferred(
+            self.storage.state.get_state_for_event(
+                e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
+            )
         )
 
         self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
 
         # check we can filter to the m.room.name event (with a wildcard None state key)
-        state = yield self.storage.state.get_state_for_event(
-            e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
+        state = yield defer.ensureDeferred(
+            self.storage.state.get_state_for_event(
+                e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
+            )
         )
 
         self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
 
         # check we can grab the m.room.member events (with a wildcard None state key)
-        state = yield self.storage.state.get_state_for_event(
-            e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
+        state = yield defer.ensureDeferred(
+            self.storage.state.get_state_for_event(
+                e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
+            )
         )
 
         self.assertStateMapEqual(
@@ -188,12 +200,14 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         # check we can grab a specific room member without filtering out the
         # other event types
-        state = yield self.storage.state.get_state_for_event(
-            e5.event_id,
-            state_filter=StateFilter(
-                types={EventTypes.Member: {self.u_alice.to_string()}},
-                include_others=True,
-            ),
+        state = yield defer.ensureDeferred(
+            self.storage.state.get_state_for_event(
+                e5.event_id,
+                state_filter=StateFilter(
+                    types={EventTypes.Member: {self.u_alice.to_string()}},
+                    include_others=True,
+                ),
+            )
         )
 
         self.assertStateMapEqual(
@@ -206,11 +220,13 @@ class StateStoreTestCase(tests.unittest.TestCase):
         )
 
         # check that we can grab everything except members
-        state = yield self.storage.state.get_state_for_event(
-            e5.event_id,
-            state_filter=StateFilter(
-                types={EventTypes.Member: set()}, include_others=True
-            ),
+        state = yield defer.ensureDeferred(
+            self.storage.state.get_state_for_event(
+                e5.event_id,
+                state_filter=StateFilter(
+                    types={EventTypes.Member: set()}, include_others=True
+                ),
+            )
         )
 
         self.assertStateMapEqual(
@@ -222,8 +238,8 @@ class StateStoreTestCase(tests.unittest.TestCase):
         #######################################################
 
         room_id = self.room.to_string()
-        group_ids = yield self.storage.state.get_state_groups_ids(
-            room_id, [e5.event_id]
+        group_ids = yield defer.ensureDeferred(
+            self.storage.state.get_state_groups_ids(room_id, [e5.event_id])
         )
         group = list(group_ids.keys())[0]
 
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index 6a545d2eb0..738e912468 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -31,16 +31,24 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
 
         # alice and bob are both in !room_id. bobby is not but shares
         # a homeserver with alice.
-        yield self.store.update_profile_in_user_dir(ALICE, "alice", None)
-        yield self.store.update_profile_in_user_dir(BOB, "bob", None)
-        yield self.store.update_profile_in_user_dir(BOBBY, "bobby", None)
-        yield self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB))
+        yield defer.ensureDeferred(
+            self.store.update_profile_in_user_dir(ALICE, "alice", None)
+        )
+        yield defer.ensureDeferred(
+            self.store.update_profile_in_user_dir(BOB, "bob", None)
+        )
+        yield defer.ensureDeferred(
+            self.store.update_profile_in_user_dir(BOBBY, "bobby", None)
+        )
+        yield defer.ensureDeferred(
+            self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB))
+        )
 
     @defer.inlineCallbacks
     def test_search_user_dir(self):
         # normally when alice searches the directory she should just find
         # bob because bobby doesn't share a room with her.
-        r = yield self.store.search_user_dir(ALICE, "bob", 10)
+        r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "bob", 10))
         self.assertFalse(r["limited"])
         self.assertEqual(1, len(r["results"]))
         self.assertDictEqual(
@@ -51,7 +59,7 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
     def test_search_user_dir_all_users(self):
         self.hs.config.user_directory_search_all_users = True
         try:
-            r = yield self.store.search_user_dir(ALICE, "bob", 10)
+            r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "bob", 10))
             self.assertFalse(r["limited"])
             self.assertEqual(2, len(r["results"]))
             self.assertDictEqual(
diff --git a/tests/test_federation.py b/tests/test_federation.py
index c662195eec..27a7fc9ed7 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -1,7 +1,23 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
 from mock import Mock
 
-from twisted.internet.defer import ensureDeferred, maybeDeferred, succeed
+from twisted.internet.defer import succeed
 
+from synapse.api.errors import FederationError
 from synapse.events import make_event_from_dict
 from synapse.logging.context import LoggingContext
 from synapse.types import Requester, UserID
@@ -10,6 +26,7 @@ from synapse.util.retryutils import NotRetryingDestination
 
 from tests import unittest
 from tests.server import ThreadedMemoryReactorClock, setup_test_homeserver
+from tests.test_utils import make_awaitable
 
 
 class MessageAcceptTests(unittest.HomeserverTestCase):
@@ -26,24 +43,19 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
         )
 
         user_id = UserID("us", "test")
-        our_user = Requester(user_id, None, False, None, None)
+        our_user = Requester(user_id, None, False, False, None, None)
         room_creator = self.homeserver.get_room_creation_handler()
-        room_deferred = ensureDeferred(
+        self.room_id = self.get_success(
             room_creator.create_room(
-                our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False
+                our_user, room_creator._presets_dict["public_chat"], ratelimit=False
             )
-        )
-        self.reactor.advance(0.1)
-        self.room_id = self.successResultOf(room_deferred)[0]["room_id"]
+        )[0]["room_id"]
 
         self.store = self.homeserver.get_datastore()
 
         # Figure out what the most recent event is
-        most_recent = self.successResultOf(
-            maybeDeferred(
-                self.homeserver.get_datastore().get_latest_event_ids_in_room,
-                self.room_id,
-            )
+        most_recent = self.get_success(
+            self.homeserver.get_datastore().get_latest_event_ids_in_room(self.room_id)
         )[0]
 
         join_event = make_event_from_dict(
@@ -73,19 +85,18 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
         )
 
         # Send the join, it should return None (which is not an error)
-        d = ensureDeferred(
-            self.handler.on_receive_pdu(
-                "test.serv", join_event, sent_to_us_directly=True
-            )
+        self.assertEqual(
+            self.get_success(
+                self.handler.on_receive_pdu(
+                    "test.serv", join_event, sent_to_us_directly=True
+                )
+            ),
+            None,
         )
-        self.reactor.advance(1)
-        self.assertEqual(self.successResultOf(d), None)
 
         # Make sure we actually joined the room
         self.assertEqual(
-            self.successResultOf(
-                maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id)
-            )[0],
+            self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))[0],
             "$join:test.serv",
         )
 
@@ -95,7 +106,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
         prev_events that said event references.
         """
 
-        def post_json(destination, path, data, headers=None, timeout=0):
+        async def post_json(destination, path, data, headers=None, timeout=0):
             # If it asks us for new missing events, give them NOTHING
             if path.startswith("/_matrix/federation/v1/get_missing_events/"):
                 return {"events": []}
@@ -103,8 +114,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
         self.http_client.post_json = post_json
 
         # Figure out what the most recent event is
-        most_recent = self.successResultOf(
-            maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id)
+        most_recent = self.get_success(
+            self.store.get_latest_event_ids_in_room(self.room_id)
         )[0]
 
         # Now lie about an event
@@ -124,17 +135,14 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
         )
 
         with LoggingContext(request="lying_event"):
-            d = ensureDeferred(
+            failure = self.get_failure(
                 self.handler.on_receive_pdu(
                     "test.serv", lying_event, sent_to_us_directly=True
-                )
+                ),
+                FederationError,
             )
 
-            # Step the reactor, so the database fetches come back
-            self.reactor.advance(1)
-
         # on_receive_pdu should throw an error
-        failure = self.failureResultOf(d)
         self.assertEqual(
             failure.value.args[0],
             (
@@ -144,8 +152,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
         )
 
         # Make sure the invalid event isn't there
-        extrem = maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id)
-        self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv")
+        extrem = self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))
+        self.assertEqual(extrem[0], "$join:test.serv")
 
     def test_retry_device_list_resync(self):
         """Tests that device lists are marked as stale if they couldn't be synced, and
@@ -173,7 +181,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
         # Register a mock on the store so that the incoming update doesn't fail because
         # we don't share a room with the user.
         store = self.homeserver.get_datastore()
-        store.get_rooms_for_user = Mock(return_value=["!someroom:test"])
+        store.get_rooms_for_user = Mock(return_value=make_awaitable(["!someroom:test"]))
 
         # Manually inject a fake device list update. We need this update to include at
         # least one prev_id so that the user's device list will need to be retried.
@@ -218,23 +226,26 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
         # Register mock device list retrieval on the federation client.
         federation_client = self.homeserver.get_federation_client()
         federation_client.query_user_devices = Mock(
-            return_value={
-                "user_id": remote_user_id,
-                "stream_id": 1,
-                "devices": [],
-                "master_key": {
-                    "user_id": remote_user_id,
-                    "usage": ["master"],
-                    "keys": {"ed25519:" + remote_master_key: remote_master_key},
-                },
-                "self_signing_key": {
+            return_value=succeed(
+                {
                     "user_id": remote_user_id,
-                    "usage": ["self_signing"],
-                    "keys": {
-                        "ed25519:" + remote_self_signing_key: remote_self_signing_key
+                    "stream_id": 1,
+                    "devices": [],
+                    "master_key": {
+                        "user_id": remote_user_id,
+                        "usage": ["master"],
+                        "keys": {"ed25519:" + remote_master_key: remote_master_key},
                     },
-                },
-            }
+                    "self_signing_key": {
+                        "user_id": remote_user_id,
+                        "usage": ["self_signing"],
+                        "keys": {
+                            "ed25519:"
+                            + remote_self_signing_key: remote_self_signing_key
+                        },
+                    },
+                }
+            )
         )
 
         # Resync the device list.
diff --git a/tests/test_mau.py b/tests/test_mau.py
index 49667ed7f4..654a6fa42d 100644
--- a/tests/test_mau.py
+++ b/tests/test_mau.py
@@ -166,7 +166,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
         self.do_sync_for_user(token5)
         self.do_sync_for_user(token6)
 
-        # But old user cant
+        # But old user can't
         with self.assertRaises(SynapseError) as cm:
             self.do_sync_for_user(token1)
 
diff --git a/tests/test_server.py b/tests/test_server.py
index e9a43b1e45..655c918a15 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -12,31 +12,20 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import logging
 import re
 
-from six import StringIO
-
 from twisted.internet.defer import Deferred
-from twisted.python.failure import Failure
-from twisted.test.proto_helpers import AccumulatingProtocol
 from twisted.web.resource import Resource
-from twisted.web.server import NOT_DONE_YET
 
 from synapse.api.errors import Codes, RedirectException, SynapseError
-from synapse.http.server import (
-    DirectServeResource,
-    JsonResource,
-    OptionsResource,
-    wrap_html_request_handler,
-)
-from synapse.http.site import SynapseSite, logger
+from synapse.config.server import parse_listener_def
+from synapse.http.server import DirectServeHtmlResource, JsonResource, OptionsResource
+from synapse.http.site import SynapseSite
 from synapse.logging.context import make_deferred_yieldable
 from synapse.util import Clock
 
 from tests import unittest
 from tests.server import (
-    FakeTransport,
     ThreadedMemoryReactorClock,
     make_request,
     render,
@@ -168,6 +157,28 @@ class JsonResourceTests(unittest.TestCase):
         self.assertEqual(channel.json_body["error"], "Unrecognized request")
         self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
 
+    def test_head_request(self):
+        """
+        JsonResource.handler_for_request gives correctly decoded URL args to
+        the callback, while Twisted will give the raw bytes of URL query
+        arguments.
+        """
+
+        def _callback(request, **kwargs):
+            return 200, {"result": True}
+
+        res = JsonResource(self.homeserver)
+        res.register_paths(
+            "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet",
+        )
+
+        # The path was registered as GET, but this is a HEAD request.
+        request, channel = make_request(self.reactor, b"HEAD", b"/_matrix/foo")
+        render(request, res, self.reactor)
+
+        self.assertEqual(channel.result["code"], b"200")
+        self.assertNotIn("body", channel.result)
+
 
 class OptionsResourceTests(unittest.TestCase):
     def setUp(self):
@@ -189,7 +200,13 @@ class OptionsResourceTests(unittest.TestCase):
         request.prepath = []  # This doesn't get set properly by make_request.
 
         # Create a site and query for the resource.
-        site = SynapseSite("test", "site_tag", {}, self.resource, "1.0")
+        site = SynapseSite(
+            "test",
+            "site_tag",
+            parse_listener_def({"type": "http", "port": 0}),
+            self.resource,
+            "1.0",
+        )
         request.site = site
         resource = site.getResourceFor(request)
 
@@ -198,10 +215,10 @@ class OptionsResourceTests(unittest.TestCase):
         return channel
 
     def test_unknown_options_request(self):
-        """An OPTIONS requests to an unknown URL still returns 200 OK."""
+        """An OPTIONS requests to an unknown URL still returns 204 No Content."""
         channel = self._make_request(b"OPTIONS", b"/foo/")
-        self.assertEqual(channel.result["code"], b"200")
-        self.assertEqual(channel.result["body"], b"{}")
+        self.assertEqual(channel.result["code"], b"204")
+        self.assertNotIn("body", channel.result)
 
         # Ensure the correct CORS headers have been added
         self.assertTrue(
@@ -218,10 +235,10 @@ class OptionsResourceTests(unittest.TestCase):
         )
 
     def test_known_options_request(self):
-        """An OPTIONS requests to an known URL still returns 200 OK."""
+        """An OPTIONS requests to an known URL still returns 204 No Content."""
         channel = self._make_request(b"OPTIONS", b"/res/")
-        self.assertEqual(channel.result["code"], b"200")
-        self.assertEqual(channel.result["body"], b"{}")
+        self.assertEqual(channel.result["code"], b"204")
+        self.assertNotIn("body", channel.result)
 
         # Ensure the correct CORS headers have been added
         self.assertTrue(
@@ -250,18 +267,17 @@ class OptionsResourceTests(unittest.TestCase):
 
 
 class WrapHtmlRequestHandlerTests(unittest.TestCase):
-    class TestResource(DirectServeResource):
+    class TestResource(DirectServeHtmlResource):
         callback = None
 
-        @wrap_html_request_handler
         async def _async_render_GET(self, request):
-            return await self.callback(request)
+            await self.callback(request)
 
     def setUp(self):
         self.reactor = ThreadedMemoryReactorClock()
 
     def test_good_response(self):
-        def callback(request):
+        async def callback(request):
             request.write(b"response")
             request.finish()
 
@@ -281,7 +297,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
         with the right location.
         """
 
-        def callback(request, **kwargs):
+        async def callback(request, **kwargs):
             raise RedirectException(b"/look/an/eagle", 301)
 
         res = WrapHtmlRequestHandlerTests.TestResource()
@@ -301,7 +317,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
         returned too
         """
 
-        def callback(request, **kwargs):
+        async def callback(request, **kwargs):
             e = RedirectException(b"/no/over/there", 304)
             e.cookies.append(b"session=yespls")
             raise e
@@ -319,51 +335,18 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
         cookies_headers = [v for k, v in headers if k == b"Set-Cookie"]
         self.assertEqual(cookies_headers, [b"session=yespls"])
 
+    def test_head_request(self):
+        """A head request should work by being turned into a GET request."""
 
-class SiteTestCase(unittest.HomeserverTestCase):
-    def test_lose_connection(self):
-        """
-        We log the URI correctly redacted when we lose the connection.
-        """
+        async def callback(request):
+            request.write(b"response")
+            request.finish()
 
-        class HangingResource(Resource):
-            """
-            A Resource that strategically hangs, as if it were processing an
-            answer.
-            """
+        res = WrapHtmlRequestHandlerTests.TestResource()
+        res.callback = callback
 
-            def render(self, request):
-                return NOT_DONE_YET
-
-        # Set up a logging handler that we can inspect afterwards
-        output = StringIO()
-        handler = logging.StreamHandler(output)
-        logger.addHandler(handler)
-        old_level = logger.level
-        logger.setLevel(10)
-        self.addCleanup(logger.setLevel, old_level)
-        self.addCleanup(logger.removeHandler, handler)
-
-        # Make a resource and a Site, the resource will hang and allow us to
-        # time out the request while it's 'processing'
-        base_resource = Resource()
-        base_resource.putChild(b"", HangingResource())
-        site = SynapseSite("test", "site_tag", {}, base_resource, "1.0")
-
-        server = site.buildProtocol(None)
-        client = AccumulatingProtocol()
-        client.makeConnection(FakeTransport(server, self.reactor))
-        server.makeConnection(FakeTransport(client, self.reactor))
-
-        # Send a request with an access token that will get redacted
-        server.dataReceived(b"GET /?access_token=bar HTTP/1.0\r\n\r\n")
-        self.pump()
-
-        # Lose the connection
-        e = Failure(Exception("Failed123"))
-        server.connectionLost(e)
-        handler.flush()
-
-        # Our access token is redacted and the failure reason is logged.
-        self.assertIn("/?access_token=<redacted>", output.getvalue())
-        self.assertIn("Failed123", output.getvalue())
+        request, channel = make_request(self.reactor, b"HEAD", b"/path")
+        render(request, res, self.reactor)
+
+        self.assertEqual(channel.result["code"], b"200")
+        self.assertNotIn("body", channel.result)
diff --git a/tests/test_state.py b/tests/test_state.py
index 66f22f6813..2d58467932 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -71,7 +71,7 @@ def create_event(
     return event
 
 
-class StateGroupStore(object):
+class StateGroupStore:
     def __init__(self):
         self._event_to_state_group = {}
         self._group_to_state = {}
@@ -80,16 +80,16 @@ class StateGroupStore(object):
 
         self._next_group = 1
 
-    def get_state_groups_ids(self, room_id, event_ids):
+    async def get_state_groups_ids(self, room_id, event_ids):
         groups = {}
         for event_id in event_ids:
             group = self._event_to_state_group.get(event_id)
             if group:
                 groups[group] = self._group_to_state[group]
 
-        return defer.succeed(groups)
+        return groups
 
-    def store_state_group(
+    async def store_state_group(
         self, event_id, room_id, prev_group, delta_ids, current_state_ids
     ):
         state_group = self._next_group
@@ -99,15 +99,15 @@ class StateGroupStore(object):
 
         return state_group
 
-    def get_events(self, event_ids, **kwargs):
+    async def get_events(self, event_ids, **kwargs):
         return {
             e_id: self._event_id_to_event[e_id]
             for e_id in event_ids
             if e_id in self._event_id_to_event
         }
 
-    def get_state_group_delta(self, name):
-        return None, None
+    async def get_state_group_delta(self, name):
+        return (None, None)
 
     def register_events(self, events):
         for e in events:
@@ -119,7 +119,7 @@ class StateGroupStore(object):
     def register_event_id_state_group(self, event_id, state_group):
         self._event_to_state_group[event_id] = state_group
 
-    def get_room_version_id(self, room_id):
+    async def get_room_version_id(self, room_id):
         return RoomVersions.V1.identifier
 
 
@@ -129,7 +129,7 @@ class DictObj(dict):
         self.__dict__ = self
 
 
-class Graph(object):
+class Graph:
     def __init__(self, nodes, edges):
         events = {}
         clobbered = set(events.keys())
@@ -202,14 +202,16 @@ class StateTestCase(unittest.TestCase):
         context_store = {}  # type: dict[str, EventContext]
 
         for event in graph.walk():
-            context = yield self.state.compute_event_context(event)
+            context = yield defer.ensureDeferred(
+                self.state.compute_event_context(event)
+            )
             self.store.register_event_context(event, context)
             context_store[event.event_id] = context
 
         ctx_c = context_store["C"]
         ctx_d = context_store["D"]
 
-        prev_state_ids = yield ctx_d.get_prev_state_ids()
+        prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
         self.assertEqual(2, len(prev_state_ids))
 
         self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
@@ -244,7 +246,9 @@ class StateTestCase(unittest.TestCase):
         context_store = {}
 
         for event in graph.walk():
-            context = yield self.state.compute_event_context(event)
+            context = yield defer.ensureDeferred(
+                self.state.compute_event_context(event)
+            )
             self.store.register_event_context(event, context)
             context_store[event.event_id] = context
 
@@ -253,7 +257,7 @@ class StateTestCase(unittest.TestCase):
         ctx_c = context_store["C"]
         ctx_d = context_store["D"]
 
-        prev_state_ids = yield ctx_d.get_prev_state_ids()
+        prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
         self.assertSetEqual({"START", "A", "C"}, set(prev_state_ids.values()))
 
         self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
@@ -300,7 +304,9 @@ class StateTestCase(unittest.TestCase):
         context_store = {}
 
         for event in graph.walk():
-            context = yield self.state.compute_event_context(event)
+            context = yield defer.ensureDeferred(
+                self.state.compute_event_context(event)
+            )
             self.store.register_event_context(event, context)
             context_store[event.event_id] = context
 
@@ -310,7 +316,7 @@ class StateTestCase(unittest.TestCase):
         ctx_c = context_store["C"]
         ctx_e = context_store["E"]
 
-        prev_state_ids = yield ctx_e.get_prev_state_ids()
+        prev_state_ids = yield defer.ensureDeferred(ctx_e.get_prev_state_ids())
         self.assertSetEqual({"START", "A", "B", "C"}, set(prev_state_ids.values()))
         self.assertEqual(ctx_c.state_group, ctx_e.state_group_before_event)
         self.assertEqual(ctx_e.state_group_before_event, ctx_e.state_group)
@@ -373,7 +379,9 @@ class StateTestCase(unittest.TestCase):
         context_store = {}
 
         for event in graph.walk():
-            context = yield self.state.compute_event_context(event)
+            context = yield defer.ensureDeferred(
+                self.state.compute_event_context(event)
+            )
             self.store.register_event_context(event, context)
             context_store[event.event_id] = context
 
@@ -383,7 +391,7 @@ class StateTestCase(unittest.TestCase):
         ctx_b = context_store["B"]
         ctx_d = context_store["D"]
 
-        prev_state_ids = yield ctx_d.get_prev_state_ids()
+        prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
         self.assertSetEqual({"A1", "A2", "A3", "A5", "B"}, set(prev_state_ids.values()))
 
         self.assertEqual(ctx_b.state_group, ctx_d.state_group_before_event)
@@ -411,12 +419,14 @@ class StateTestCase(unittest.TestCase):
             create_event(type="test2", state_key=""),
         ]
 
-        context = yield self.state.compute_event_context(event, old_state=old_state)
+        context = yield defer.ensureDeferred(
+            self.state.compute_event_context(event, old_state=old_state)
+        )
 
-        prev_state_ids = yield context.get_prev_state_ids()
+        prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
         self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
 
-        current_state_ids = yield context.get_current_state_ids()
+        current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
         self.assertCountEqual(
             (e.event_id for e in old_state), current_state_ids.values()
         )
@@ -434,12 +444,14 @@ class StateTestCase(unittest.TestCase):
             create_event(type="test2", state_key=""),
         ]
 
-        context = yield self.state.compute_event_context(event, old_state=old_state)
+        context = yield defer.ensureDeferred(
+            self.state.compute_event_context(event, old_state=old_state)
+        )
 
-        prev_state_ids = yield context.get_prev_state_ids()
+        prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
         self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
 
-        current_state_ids = yield context.get_current_state_ids()
+        current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
         self.assertCountEqual(
             (e.event_id for e in old_state + [event]), current_state_ids.values()
         )
@@ -462,18 +474,20 @@ class StateTestCase(unittest.TestCase):
             create_event(type="test2", state_key=""),
         ]
 
-        group_name = self.store.store_state_group(
-            prev_event_id,
-            event.room_id,
-            None,
-            None,
-            {(e.type, e.state_key): e.event_id for e in old_state},
+        group_name = yield defer.ensureDeferred(
+            self.store.store_state_group(
+                prev_event_id,
+                event.room_id,
+                None,
+                None,
+                {(e.type, e.state_key): e.event_id for e in old_state},
+            )
         )
         self.store.register_event_id_state_group(prev_event_id, group_name)
 
-        context = yield self.state.compute_event_context(event)
+        context = yield defer.ensureDeferred(self.state.compute_event_context(event))
 
-        current_state_ids = yield context.get_current_state_ids()
+        current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
 
         self.assertEqual(
             {e.event_id for e in old_state}, set(current_state_ids.values())
@@ -494,18 +508,20 @@ class StateTestCase(unittest.TestCase):
             create_event(type="test2", state_key=""),
         ]
 
-        group_name = self.store.store_state_group(
-            prev_event_id,
-            event.room_id,
-            None,
-            None,
-            {(e.type, e.state_key): e.event_id for e in old_state},
+        group_name = yield defer.ensureDeferred(
+            self.store.store_state_group(
+                prev_event_id,
+                event.room_id,
+                None,
+                None,
+                {(e.type, e.state_key): e.event_id for e in old_state},
+            )
         )
         self.store.register_event_id_state_group(prev_event_id, group_name)
 
-        context = yield self.state.compute_event_context(event)
+        context = yield defer.ensureDeferred(self.state.compute_event_context(event))
 
-        prev_state_ids = yield context.get_prev_state_ids()
+        prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
 
         self.assertEqual({e.event_id for e in old_state}, set(prev_state_ids.values()))
 
@@ -544,7 +560,7 @@ class StateTestCase(unittest.TestCase):
             event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
         )
 
-        current_state_ids = yield context.get_current_state_ids()
+        current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
 
         self.assertEqual(len(current_state_ids), 6)
 
@@ -586,7 +602,7 @@ class StateTestCase(unittest.TestCase):
             event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
         )
 
-        current_state_ids = yield context.get_current_state_ids()
+        current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
 
         self.assertEqual(len(current_state_ids), 6)
 
@@ -641,7 +657,7 @@ class StateTestCase(unittest.TestCase):
             event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
         )
 
-        current_state_ids = yield context.get_current_state_ids()
+        current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
 
         self.assertEqual(old_state_2[3].event_id, current_state_ids[("test1", "1")])
 
@@ -669,29 +685,35 @@ class StateTestCase(unittest.TestCase):
             event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
         )
 
-        current_state_ids = yield context.get_current_state_ids()
+        current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
 
         self.assertEqual(old_state_1[3].event_id, current_state_ids[("test1", "1")])
 
+    @defer.inlineCallbacks
     def _get_context(
         self, event, prev_event_id_1, old_state_1, prev_event_id_2, old_state_2
     ):
-        sg1 = self.store.store_state_group(
-            prev_event_id_1,
-            event.room_id,
-            None,
-            None,
-            {(e.type, e.state_key): e.event_id for e in old_state_1},
+        sg1 = yield defer.ensureDeferred(
+            self.store.store_state_group(
+                prev_event_id_1,
+                event.room_id,
+                None,
+                None,
+                {(e.type, e.state_key): e.event_id for e in old_state_1},
+            )
         )
         self.store.register_event_id_state_group(prev_event_id_1, sg1)
 
-        sg2 = self.store.store_state_group(
-            prev_event_id_2,
-            event.room_id,
-            None,
-            None,
-            {(e.type, e.state_key): e.event_id for e in old_state_2},
+        sg2 = yield defer.ensureDeferred(
+            self.store.store_state_group(
+                prev_event_id_2,
+                event.room_id,
+                None,
+                None,
+                {(e.type, e.state_key): e.event_id for e in old_state_2},
+            )
         )
         self.store.register_event_id_state_group(prev_event_id_2, sg2)
 
-        return self.state.compute_event_context(event)
+        result = yield defer.ensureDeferred(self.state.compute_event_context(event))
+        return result
diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py
index 5c2817cf28..b89798336c 100644
--- a/tests/test_terms_auth.py
+++ b/tests/test_terms_auth.py
@@ -14,7 +14,6 @@
 
 import json
 
-import six
 from mock import Mock
 
 from twisted.test.proto_helpers import MemoryReactorClock
@@ -60,7 +59,7 @@ class TermsTestCase(unittest.HomeserverTestCase):
         self.assertEquals(channel.result["code"], b"401", channel.result)
 
         self.assertTrue(channel.json_body is not None)
-        self.assertIsInstance(channel.json_body["session"], six.text_type)
+        self.assertIsInstance(channel.json_body["session"], str)
 
         self.assertIsInstance(channel.json_body["flows"], list)
         for flow in channel.json_body["flows"]:
@@ -125,6 +124,6 @@ class TermsTestCase(unittest.HomeserverTestCase):
         self.assertEquals(channel.result["code"], b"200", channel.result)
 
         self.assertTrue(channel.json_body is not None)
-        self.assertIsInstance(channel.json_body["user_id"], six.text_type)
-        self.assertIsInstance(channel.json_body["access_token"], six.text_type)
-        self.assertIsInstance(channel.json_body["device_id"], six.text_type)
+        self.assertIsInstance(channel.json_body["user_id"], str)
+        self.assertIsInstance(channel.json_body["access_token"], str)
+        self.assertIsInstance(channel.json_body["device_id"], str)
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index 7b345b03bb..508aeba078 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -17,7 +17,7 @@
 """
 Utilities for running the unit tests
 """
-from typing import Awaitable, TypeVar
+from typing import Any, Awaitable, TypeVar
 
 TV = TypeVar("TV")
 
@@ -36,3 +36,8 @@ def get_awaitable_result(awaitable: Awaitable[TV]) -> TV:
 
     # if next didn't raise, the awaitable hasn't completed.
     raise Exception("awaitable has not yet completed")
+
+
+async def make_awaitable(result: Any):
+    """Create an awaitable that just returns a result."""
+    return result
diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py
index 431e9f8e5e..fb1ca90336 100644
--- a/tests/test_utils/event_injection.py
+++ b/tests/test_utils/event_injection.py
@@ -13,25 +13,20 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
-from typing import Optional, Tuple
+from typing import List, Optional, Tuple
 
 import synapse.server
 from synapse.api.constants import EventTypes
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
-from synapse.types import Collection
-
-from tests.test_utils import get_awaitable_result
-
 
 """
 Utility functions for poking events into the storage of the server under test.
 """
 
 
-def inject_member_event(
+async def inject_member_event(
     hs: synapse.server.HomeServer,
     room_id: str,
     sender: str,
@@ -48,7 +43,7 @@ def inject_member_event(
     if extra_content:
         content.update(extra_content)
 
-    return inject_event(
+    return await inject_event(
         hs,
         room_id=room_id,
         type=EventTypes.Member,
@@ -59,10 +54,10 @@ def inject_member_event(
     )
 
 
-def inject_event(
+async def inject_event(
     hs: synapse.server.HomeServer,
     room_version: Optional[str] = None,
-    prev_event_ids: Optional[Collection[str]] = None,
+    prev_event_ids: Optional[List[str]] = None,
     **kwargs
 ) -> EventBase:
     """Inject a generic event into a room
@@ -74,37 +69,27 @@ def inject_event(
         prev_event_ids: prev_events for the event. If not specified, will be looked up
         kwargs: fields for the event to be created
     """
-    test_reactor = hs.get_reactor()
+    event, context = await create_event(hs, room_version, prev_event_ids, **kwargs)
 
-    event, context = create_event(hs, room_version, prev_event_ids, **kwargs)
-
-    d = hs.get_storage().persistence.persist_event(event, context)
-    test_reactor.advance(0)
-    get_awaitable_result(d)
+    await hs.get_storage().persistence.persist_event(event, context)
 
     return event
 
 
-def create_event(
+async def create_event(
     hs: synapse.server.HomeServer,
     room_version: Optional[str] = None,
-    prev_event_ids: Optional[Collection[str]] = None,
+    prev_event_ids: Optional[List[str]] = None,
     **kwargs
 ) -> Tuple[EventBase, EventContext]:
-    test_reactor = hs.get_reactor()
-
     if room_version is None:
-        d = hs.get_datastore().get_room_version_id(kwargs["room_id"])
-        test_reactor.advance(0)
-        room_version = get_awaitable_result(d)
+        room_version = await hs.get_datastore().get_room_version_id(kwargs["room_id"])
 
     builder = hs.get_event_builder_factory().for_room_version(
         KNOWN_ROOM_VERSIONS[room_version], kwargs
     )
-    d = hs.get_event_creation_handler().create_new_client_event(
+    event, context = await hs.get_event_creation_handler().create_new_client_event(
         builder, prev_event_ids=prev_event_ids
     )
-    test_reactor.advance(0)
-    event, context = get_awaitable_result(d)
 
     return event, context
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index f7381b2885..510b630114 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -37,10 +37,9 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
         self.hs = yield setup_test_homeserver(self.addCleanup)
         self.event_creation_handler = self.hs.get_event_creation_handler()
         self.event_builder_factory = self.hs.get_event_builder_factory()
-        self.store = self.hs.get_datastore()
         self.storage = self.hs.get_storage()
 
-        yield create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM")
+        yield defer.ensureDeferred(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM"))
 
     @defer.inlineCallbacks
     def test_filtering(self):
@@ -53,7 +52,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
         #
 
         # before we do that, we persist some other events to act as state.
-        self.inject_visibility("@admin:hs", "joined")
+        yield self.inject_visibility("@admin:hs", "joined")
         for i in range(0, 10):
             yield self.inject_room_member("@resident%i:hs" % i)
 
@@ -64,8 +63,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
             evt = yield self.inject_room_member(user, extra_content={"a": "b"})
             events_to_filter.append(evt)
 
-        filtered = yield filter_events_for_server(
-            self.storage, "test_server", events_to_filter
+        filtered = yield defer.ensureDeferred(
+            filter_events_for_server(self.storage, "test_server", events_to_filter)
         )
 
         # the result should be 5 redacted events, and 5 unredacted events.
@@ -99,11 +98,13 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
         events_to_filter.append(evt)
 
         # the erasey user gets erased
-        yield self.hs.get_datastore().mark_user_erased("@erased:local_hs")
+        yield defer.ensureDeferred(
+            self.hs.get_datastore().mark_user_erased("@erased:local_hs")
+        )
 
         # ... and the filtering happens.
-        filtered = yield filter_events_for_server(
-            self.storage, "test_server", events_to_filter
+        filtered = yield defer.ensureDeferred(
+            filter_events_for_server(self.storage, "test_server", events_to_filter)
         )
 
         for i in range(0, len(events_to_filter)):
@@ -137,10 +138,12 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
             },
         )
 
-        event, context = yield self.event_creation_handler.create_new_client_event(
-            builder
+        event, context = yield defer.ensureDeferred(
+            self.event_creation_handler.create_new_client_event(builder)
+        )
+        yield defer.ensureDeferred(
+            self.storage.persistence.persist_event(event, context)
         )
-        yield self.storage.persistence.persist_event(event, context)
         return event
 
     @defer.inlineCallbacks
@@ -158,11 +161,13 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
             },
         )
 
-        event, context = yield self.event_creation_handler.create_new_client_event(
-            builder
+        event, context = yield defer.ensureDeferred(
+            self.event_creation_handler.create_new_client_event(builder)
         )
 
-        yield self.storage.persistence.persist_event(event, context)
+        yield defer.ensureDeferred(
+            self.storage.persistence.persist_event(event, context)
+        )
         return event
 
     @defer.inlineCallbacks
@@ -179,11 +184,13 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
             },
         )
 
-        event, context = yield self.event_creation_handler.create_new_client_event(
-            builder
+        event, context = yield defer.ensureDeferred(
+            self.event_creation_handler.create_new_client_event(builder)
         )
 
-        yield self.storage.persistence.persist_event(event, context)
+        yield defer.ensureDeferred(
+            self.storage.persistence.persist_event(event, context)
+        )
         return event
 
     @defer.inlineCallbacks
@@ -265,8 +272,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
         storage.main = test_store
         storage.state = test_store
 
-        filtered = yield filter_events_for_server(
-            test_store, "test_server", events_to_filter
+        filtered = yield defer.ensureDeferred(
+            filter_events_for_server(test_store, "test_server", events_to_filter)
         )
         logger.info("Filtering took %f seconds", time.time() - start)
 
@@ -287,7 +294,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
     test_large_room.skip = "Disabled by default because it's slow"
 
 
-class _TestStore(object):
+class _TestStore:
     """Implements a few methods of the DataStore, so that we can test
     filter_events_for_server
 
diff --git a/tests/unittest.py b/tests/unittest.py
index 6b6f224e9c..3cb55a7e96 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -229,7 +229,7 @@ class HomeserverTestCase(TestCase):
         self.site = SynapseSite(
             logger_name="synapse.access.http.fake",
             site_tag="test",
-            config={},
+            config=self.hs.config.server.listeners[0],
             resource=self.resource,
             server_version_string="1",
         )
@@ -241,20 +241,20 @@ class HomeserverTestCase(TestCase):
         if hasattr(self, "user_id"):
             if self.hijack_auth:
 
-                def get_user_by_access_token(token=None, allow_guest=False):
-                    return succeed(
-                        {
-                            "user": UserID.from_string(self.helper.auth_user_id),
-                            "token_id": 1,
-                            "is_guest": False,
-                        }
-                    )
-
-                def get_user_by_req(request, allow_guest=False, rights="access"):
-                    return succeed(
-                        create_requester(
-                            UserID.from_string(self.helper.auth_user_id), 1, False, None
-                        )
+                async def get_user_by_access_token(token=None, allow_guest=False):
+                    return {
+                        "user": UserID.from_string(self.helper.auth_user_id),
+                        "token_id": 1,
+                        "is_guest": False,
+                    }
+
+                async def get_user_by_req(request, allow_guest=False, rights="access"):
+                    return create_requester(
+                        UserID.from_string(self.helper.auth_user_id),
+                        1,
+                        False,
+                        False,
+                        None,
                     )
 
                 self.hs.get_auth().get_user_by_req = get_user_by_req
@@ -422,8 +422,8 @@ class HomeserverTestCase(TestCase):
 
         async def run_bg_updates():
             with LoggingContext("run_bg_updates", request="run_bg_updates-1"):
-                while not await stor.db.updates.has_completed_background_updates():
-                    await stor.db.updates.do_next_background_update(1)
+                while not await stor.db_pool.updates.has_completed_background_updates():
+                    await stor.db_pool.updates.do_next_background_update(1)
 
         hs = setup_test_homeserver(self.addCleanup, *args, **kwargs)
         stor = hs.get_datastore()
@@ -544,7 +544,7 @@ class HomeserverTestCase(TestCase):
         """
         event_creator = self.hs.get_event_creation_handler()
         secrets = self.hs.get_secrets()
-        requester = Requester(user, None, False, None, None)
+        requester = Requester(user, None, False, False, None, None)
 
         event, context = self.get_success(
             event_creator.create_event(
@@ -571,7 +571,7 @@ class HomeserverTestCase(TestCase):
         Add the given event as an extremity to the room.
         """
         self.get_success(
-            self.hs.get_datastore().db.simple_insert(
+            self.hs.get_datastore().db_pool.simple_insert(
                 table="event_forward_extremities",
                 values={"room_id": room_id, "event_id": event_id},
                 desc="test_add_extremity",
@@ -603,7 +603,9 @@ class HomeserverTestCase(TestCase):
             user: MXID of the user to inject the membership for.
             membership: The membership type.
         """
-        event_injection.inject_member_event(self.hs, room, user, membership)
+        self.get_success(
+            event_injection.inject_member_event(self.hs, room, user, membership)
+        )
 
 
 class FederatingHomeserverTestCase(HomeserverTestCase):
@@ -612,7 +614,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
     """
 
     def prepare(self, reactor, clock, homeserver):
-        class Authenticator(object):
+        class Authenticator:
             def authenticate_request(self, request, content):
                 return succeed("other.example.com")
 
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 4d2b9e0d64..677e925477 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -88,7 +88,7 @@ class CacheTestCase(unittest.TestCase):
 class DescriptorTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_cache(self):
-        class Cls(object):
+        class Cls:
             def __init__(self):
                 self.mock = mock.Mock()
 
@@ -122,7 +122,7 @@ class DescriptorTestCase(unittest.TestCase):
     def test_cache_num_args(self):
         """Only the first num_args arguments should matter to the cache"""
 
-        class Cls(object):
+        class Cls:
             def __init__(self):
                 self.mock = mock.Mock()
 
@@ -156,7 +156,7 @@ class DescriptorTestCase(unittest.TestCase):
         """If the wrapped function throws synchronously, things should continue to work
         """
 
-        class Cls(object):
+        class Cls:
             @cached()
             def fn(self, arg1):
                 raise SynapseError(100, "mai spoon iz too big!!1")
@@ -180,7 +180,7 @@ class DescriptorTestCase(unittest.TestCase):
 
         complete_lookup = defer.Deferred()
 
-        class Cls(object):
+        class Cls:
             @descriptors.cached()
             def fn(self, arg1):
                 @defer.inlineCallbacks
@@ -223,7 +223,7 @@ class DescriptorTestCase(unittest.TestCase):
         """Check that the cache sets and restores logcontexts correctly when
         the lookup function throws an exception"""
 
-        class Cls(object):
+        class Cls:
             @descriptors.cached()
             def fn(self, arg1):
                 @defer.inlineCallbacks
@@ -263,7 +263,7 @@ class DescriptorTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_cache_default_args(self):
-        class Cls(object):
+        class Cls:
             def __init__(self):
                 self.mock = mock.Mock()
 
@@ -300,7 +300,7 @@ class DescriptorTestCase(unittest.TestCase):
         obj.mock.assert_not_called()
 
     def test_cache_iterable(self):
-        class Cls(object):
+        class Cls:
             def __init__(self):
                 self.mock = mock.Mock()
 
@@ -336,7 +336,7 @@ class DescriptorTestCase(unittest.TestCase):
         """If the wrapped function throws synchronously, things should continue to work
         """
 
-        class Cls(object):
+        class Cls:
             @descriptors.cached(iterable=True)
             def fn(self, arg1):
                 raise SynapseError(100, "mai spoon iz too big!!1")
@@ -358,7 +358,7 @@ class DescriptorTestCase(unittest.TestCase):
 class CachedListDescriptorTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_cache(self):
-        class Cls(object):
+        class Cls:
             def __init__(self):
                 self.mock = mock.Mock()
 
@@ -366,11 +366,11 @@ class CachedListDescriptorTestCase(unittest.TestCase):
             def fn(self, arg1, arg2):
                 pass
 
-            @descriptors.cachedList("fn", "args1", inlineCallbacks=True)
-            def list_fn(self, args1, arg2):
+            @descriptors.cachedList("fn", "args1")
+            async def list_fn(self, args1, arg2):
                 assert current_context().request == "c1"
                 # we want this to behave like an asynchronous function
-                yield run_on_reactor()
+                await run_on_reactor()
                 assert current_context().request == "c1"
                 return self.mock(args1, arg2)
 
@@ -408,7 +408,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
     def test_invalidate(self):
         """Make sure that invalidation callbacks are called."""
 
-        class Cls(object):
+        class Cls:
             def __init__(self):
                 self.mock = mock.Mock()
 
@@ -416,10 +416,10 @@ class CachedListDescriptorTestCase(unittest.TestCase):
             def fn(self, arg1, arg2):
                 pass
 
-            @descriptors.cachedList("fn", "args1", inlineCallbacks=True)
-            def list_fn(self, args1, arg2):
+            @descriptors.cachedList("fn", "args1")
+            async def list_fn(self, args1, arg2):
                 # we want this to behave like an asynchronous function
-                yield run_on_reactor()
+                await run_on_reactor()
                 return self.mock(args1, arg2)
 
         obj = Cls()
diff --git a/tests/util/test_file_consumer.py b/tests/util/test_file_consumer.py
index e90e08d1c0..2012263184 100644
--- a/tests/util/test_file_consumer.py
+++ b/tests/util/test_file_consumer.py
@@ -15,9 +15,9 @@
 
 
 import threading
+from io import StringIO
 
 from mock import NonCallableMock
-from six import StringIO
 
 from twisted.internet import defer, reactor
 
@@ -112,7 +112,7 @@ class FileConsumerTests(unittest.TestCase):
         self.assertTrue(string_file.closed)
 
 
-class DummyPullProducer(object):
+class DummyPullProducer:
     def __init__(self):
         self.consumer = None
         self.deferred = defer.Deferred()
@@ -134,7 +134,7 @@ class DummyPullProducer(object):
         return d
 
 
-class BlockingStringWrite(object):
+class BlockingStringWrite:
     def __init__(self):
         self.buffer = ""
         self.closed = False
diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py
index ca3858b184..0e52811948 100644
--- a/tests/util/test_linearizer.py
+++ b/tests/util/test_linearizer.py
@@ -14,8 +14,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from six.moves import range
-
 from twisted.internet import defer, reactor
 from twisted.internet.defer import CancelledError
 
diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py
index 95301c013c..58ee918f65 100644
--- a/tests/util/test_logcontext.py
+++ b/tests/util/test_logcontext.py
@@ -124,7 +124,7 @@ class LoggingContextTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_make_deferred_yieldable(self):
-        # a function which retuns an incomplete deferred, but doesn't follow
+        # a function which returns an incomplete deferred, but doesn't follow
         # the synapse rules.
         def blocking_function():
             d = defer.Deferred()
@@ -183,7 +183,7 @@ class LoggingContextTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_make_deferred_yieldable_with_await(self):
-        # an async function which retuns an incomplete coroutine, but doesn't
+        # an async function which returns an incomplete coroutine, but doesn't
         # follow the synapse rules.
 
         async def blocking_function():
diff --git a/tests/util/test_retryutils.py b/tests/util/test_retryutils.py
index 9e348694ad..5f46ed0cef 100644
--- a/tests/util/test_retryutils.py
+++ b/tests/util/test_retryutils.py
@@ -26,9 +26,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
     def test_new_destination(self):
         """A happy-path case with a new destination and a successful operation"""
         store = self.hs.get_datastore()
-        d = get_retry_limiter("test_dest", self.clock, store)
-        self.pump()
-        limiter = self.successResultOf(d)
+        limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))
 
         # advance the clock a bit before making the request
         self.pump(1)
@@ -36,18 +34,14 @@ class RetryLimiterTestCase(HomeserverTestCase):
         with limiter:
             pass
 
-        d = store.get_destination_retry_timings("test_dest")
-        self.pump()
-        new_timings = self.successResultOf(d)
+        new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
         self.assertIsNone(new_timings)
 
     def test_limiter(self):
         """General test case which walks through the process of a failing request"""
         store = self.hs.get_datastore()
 
-        d = get_retry_limiter("test_dest", self.clock, store)
-        self.pump()
-        limiter = self.successResultOf(d)
+        limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))
 
         self.pump(1)
         try:
@@ -58,29 +52,22 @@ class RetryLimiterTestCase(HomeserverTestCase):
         except AssertionError:
             pass
 
-        # wait for the update to land
-        self.pump()
-
-        d = store.get_destination_retry_timings("test_dest")
-        self.pump()
-        new_timings = self.successResultOf(d)
+        new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
         self.assertEqual(new_timings["failure_ts"], failure_ts)
         self.assertEqual(new_timings["retry_last_ts"], failure_ts)
         self.assertEqual(new_timings["retry_interval"], MIN_RETRY_INTERVAL)
 
         # now if we try again we should get a failure
-        d = get_retry_limiter("test_dest", self.clock, store)
-        self.pump()
-        self.failureResultOf(d, NotRetryingDestination)
+        self.get_failure(
+            get_retry_limiter("test_dest", self.clock, store), NotRetryingDestination
+        )
 
         #
         # advance the clock and try again
         #
 
         self.pump(MIN_RETRY_INTERVAL)
-        d = get_retry_limiter("test_dest", self.clock, store)
-        self.pump()
-        limiter = self.successResultOf(d)
+        limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))
 
         self.pump(1)
         try:
@@ -91,12 +78,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
         except AssertionError:
             pass
 
-        # wait for the update to land
-        self.pump()
-
-        d = store.get_destination_retry_timings("test_dest")
-        self.pump()
-        new_timings = self.successResultOf(d)
+        new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
         self.assertEqual(new_timings["failure_ts"], failure_ts)
         self.assertEqual(new_timings["retry_last_ts"], retry_ts)
         self.assertGreaterEqual(
@@ -109,10 +91,8 @@ class RetryLimiterTestCase(HomeserverTestCase):
         #
         # one more go, with success
         #
-        self.pump(MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 2.0)
-        d = get_retry_limiter("test_dest", self.clock, store)
-        self.pump()
-        limiter = self.successResultOf(d)
+        self.reactor.advance(MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 2.0)
+        limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))
 
         self.pump(1)
         with limiter:
@@ -121,7 +101,5 @@ class RetryLimiterTestCase(HomeserverTestCase):
         # wait for the update to land
         self.pump()
 
-        d = store.get_destination_retry_timings("test_dest")
-        self.pump()
-        new_timings = self.successResultOf(d)
+        new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
         self.assertIsNone(new_timings)
diff --git a/tests/util/test_rwlock.py b/tests/util/test_rwlock.py
index bd32e2cee7..d3dea3b52a 100644
--- a/tests/util/test_rwlock.py
+++ b/tests/util/test_rwlock.py
@@ -13,6 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from twisted.internet import defer
 
 from synapse.util.async_helpers import ReadWriteLock
 
@@ -43,6 +44,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
             rwlock.read(key),  # 5
             rwlock.write(key),  # 6
         ]
+        ds = [defer.ensureDeferred(d) for d in ds]
 
         self._assert_called_before_not_after(ds, 2)
 
@@ -73,12 +75,12 @@ class ReadWriteLockTestCase(unittest.TestCase):
         with ds[6].result:
             pass
 
-        d = rwlock.write(key)
+        d = defer.ensureDeferred(rwlock.write(key))
         self.assertTrue(d.called)
         with d.result:
             pass
 
-        d = rwlock.read(key)
+        d = defer.ensureDeferred(rwlock.read(key))
         self.assertTrue(d.called)
         with d.result:
             pass
diff --git a/tests/util/test_stringutils.py b/tests/util/test_stringutils.py
index 4f4da29a98..8491f7cc83 100644
--- a/tests/util/test_stringutils.py
+++ b/tests/util/test_stringutils.py
@@ -28,9 +28,6 @@ class StringUtilsTestCase(unittest.TestCase):
             "_--something==_",
             "...--==-18913",
             "8Dj2odd-e9asd.cd==_--ddas-secret-",
-            # We temporarily allow : characters: https://github.com/matrix-org/synapse/issues/6766
-            # To be removed in a future release
-            "SECRET:1234567890",
         ]
 
         bad = [
diff --git a/tests/util/test_threepids.py b/tests/util/test_threepids.py
new file mode 100644
index 0000000000..5513724d87
--- /dev/null
+++ b/tests/util/test_threepids.py
@@ -0,0 +1,49 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Dirk Klimpel
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.util.threepids import canonicalise_email
+
+from tests.unittest import HomeserverTestCase
+
+
+class CanonicaliseEmailTests(HomeserverTestCase):
+    def test_no_at(self):
+        with self.assertRaises(ValueError):
+            canonicalise_email("address-without-at.bar")
+
+    def test_two_at(self):
+        with self.assertRaises(ValueError):
+            canonicalise_email("foo@foo@test.bar")
+
+    def test_bad_format(self):
+        with self.assertRaises(ValueError):
+            canonicalise_email("user@bad.example.net@good.example.com")
+
+    def test_valid_format(self):
+        self.assertEqual(canonicalise_email("foo@test.bar"), "foo@test.bar")
+
+    def test_domain_to_lower(self):
+        self.assertEqual(canonicalise_email("foo@TEST.BAR"), "foo@test.bar")
+
+    def test_domain_with_umlaut(self):
+        self.assertEqual(canonicalise_email("foo@Öumlaut.com"), "foo@öumlaut.com")
+
+    def test_address_casefold(self):
+        self.assertEqual(
+            canonicalise_email("Strauß@Example.com"), "strauss@example.com"
+        )
+
+    def test_address_trim(self):
+        self.assertEqual(canonicalise_email(" foo@test.bar "), "foo@test.bar")
diff --git a/tests/utils.py b/tests/utils.py
index 59c020a051..4673872f88 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -21,9 +21,9 @@ import time
 import uuid
 import warnings
 from inspect import getcallargs
+from urllib import parse as urlparse
 
 from mock import Mock, patch
-from six.moves.urllib import parse as urlparse
 
 from twisted.internet import defer, reactor
 
@@ -154,6 +154,10 @@ def default_config(name, parse=False):
             "account": {"per_second": 10000, "burst_count": 10000},
             "failed_attempts": {"per_second": 10000, "burst_count": 10000},
         },
+        "rc_joins": {
+            "local": {"per_second": 10000, "burst_count": 10000},
+            "remote": {"per_second": 10000, "burst_count": 10000},
+        },
         "saml2_enabled": False,
         "public_baseurl": None,
         "default_identity_server": None,
@@ -168,6 +172,7 @@ def default_config(name, parse=False):
         # background, which upsets the test runner.
         "update_user_directory": False,
         "caches": {"global_factor": 1},
+        "listeners": [{"port": 0, "type": "http"}],
     }
 
     if parse:
@@ -467,7 +472,7 @@ class MockHttpResource(HttpServer):
             self.callbacks.append((method, path_pattern, callback))
 
 
-class MockKey(object):
+class MockKey:
     alg = "mock_alg"
     version = "mock_version"
     signature = b"\x9a\x87$"
@@ -486,7 +491,7 @@ class MockKey(object):
         return b"<fake_encoded_key>"
 
 
-class MockClock(object):
+class MockClock:
     now = 1000
 
     def __init__(self):
@@ -563,7 +568,7 @@ def _format_call(args, kwargs):
     )
 
 
-class DeferredMockCallable(object):
+class DeferredMockCallable:
     """A callable instance that stores a set of pending call expectations and
     return values for them. It allows a unit test to assert that the given set
     of function calls are eventually made, by awaiting on them to be called.
@@ -637,14 +642,8 @@ class DeferredMockCallable(object):
             )
 
 
-@defer.inlineCallbacks
-def create_room(hs, room_id, creator_id):
+async def create_room(hs, room_id: str, creator_id: str):
     """Creates and persist a creation event for the given room
-
-    Args:
-        hs
-        room_id (str)
-        creator_id (str)
     """
 
     persistence_store = hs.get_storage().persistence
@@ -652,7 +651,7 @@ def create_room(hs, room_id, creator_id):
     event_builder_factory = hs.get_event_builder_factory()
     event_creation_handler = hs.get_event_creation_handler()
 
-    yield store.store_room(
+    await store.store_room(
         room_id=room_id,
         room_creator_user_id=creator_id,
         is_public=False,
@@ -670,6 +669,6 @@ def create_room(hs, room_id, creator_id):
         },
     )
 
-    event, context = yield event_creation_handler.create_new_client_event(builder)
+    event, context = await event_creation_handler.create_new_client_event(builder)
 
-    yield persistence_store.persist_event(event, context)
+    await persistence_store.persist_event(event, context)