summary refs log tree commit diff
path: root/tests/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'tests/handlers')
-rw-r--r--tests/handlers/test_admin.py9
-rw-r--r--tests/handlers/test_appservice.py33
-rw-r--r--tests/handlers/test_auth.py133
-rw-r--r--tests/handlers/test_cas.py60
-rw-r--r--tests/handlers/test_device.py4
-rw-r--r--tests/handlers/test_directory.py20
-rw-r--r--tests/handlers/test_e2e_keys.py238
-rw-r--r--tests/handlers/test_e2e_room_keys.py347
-rw-r--r--tests/handlers/test_federation.py54
-rw-r--r--tests/handlers/test_message.py13
-rw-r--r--tests/handlers/test_oidc.py158
-rw-r--r--tests/handlers/test_password_providers.py11
-rw-r--r--tests/handlers/test_presence.py3
-rw-r--r--tests/handlers/test_profile.py125
-rw-r--r--tests/handlers/test_saml.py64
-rw-r--r--tests/handlers/test_typing.py12
-rw-r--r--tests/handlers/test_user_directory.py12
17 files changed, 649 insertions, 647 deletions
diff --git a/tests/handlers/test_admin.py b/tests/handlers/test_admin.py
index 5c2b4de1a6..a01fdd0839 100644
--- a/tests/handlers/test_admin.py
+++ b/tests/handlers/test_admin.py
@@ -44,8 +44,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
         self.token2 = self.login("user2", "password")
 
     def test_single_public_joined_room(self):
-        """Test that we write *all* events for a public room
-        """
+        """Test that we write *all* events for a public room"""
         room_id = self.helper.create_room_as(
             self.user1, tok=self.token1, is_public=True
         )
@@ -116,8 +115,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
         self.assertEqual(counter[(EventTypes.Member, self.user2)], 1)
 
     def test_single_left_room(self):
-        """Tests that we don't see events in the room after we leave.
-        """
+        """Tests that we don't see events in the room after we leave."""
         room_id = self.helper.create_room_as(self.user1, tok=self.token1)
         self.helper.send(room_id, body="Hello!", tok=self.token1)
         self.helper.join(room_id, self.user2, tok=self.token2)
@@ -190,8 +188,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
         self.assertEqual(counter[(EventTypes.Member, self.user2)], 3)
 
     def test_invite(self):
-        """Tests that pending invites get handled correctly.
-        """
+        """Tests that pending invites get handled correctly."""
         room_id = self.helper.create_room_as(self.user1, tok=self.token1)
         self.helper.send(room_id, body="Hello!", tok=self.token1)
         self.helper.invite(room_id, self.user1, self.user2, tok=self.token1)
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index 53763cd0f9..d5d3fdd99a 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -35,8 +35,8 @@ class AppServiceHandlerTestCase(unittest.TestCase):
         self.mock_scheduler = Mock()
         hs = Mock()
         hs.get_datastore.return_value = self.mock_store
-        self.mock_store.get_received_ts.return_value = defer.succeed(0)
-        self.mock_store.set_appservice_last_pos.return_value = defer.succeed(None)
+        self.mock_store.get_received_ts.return_value = make_awaitable(0)
+        self.mock_store.set_appservice_last_pos.return_value = make_awaitable(None)
         hs.get_application_service_api.return_value = self.mock_as_api
         hs.get_application_service_scheduler.return_value = self.mock_scheduler
         hs.get_clock.return_value = MockClock()
@@ -50,16 +50,16 @@ class AppServiceHandlerTestCase(unittest.TestCase):
             self._mkservice(is_interested=False),
         ]
 
-        self.mock_as_api.query_user.return_value = defer.succeed(True)
+        self.mock_as_api.query_user.return_value = make_awaitable(True)
         self.mock_store.get_app_services.return_value = services
-        self.mock_store.get_user_by_id.return_value = defer.succeed([])
+        self.mock_store.get_user_by_id.return_value = make_awaitable([])
 
         event = Mock(
             sender="@someone:anywhere", type="m.room.message", room_id="!foo:bar"
         )
         self.mock_store.get_new_events_for_appservice.side_effect = [
-            defer.succeed((0, [event])),
-            defer.succeed((0, [])),
+            make_awaitable((0, [event])),
+            make_awaitable((0, [])),
         ]
         self.handler.notify_interested_services(RoomStreamToken(None, 0))
 
@@ -72,13 +72,13 @@ class AppServiceHandlerTestCase(unittest.TestCase):
         services = [self._mkservice(is_interested=True)]
         services[0].is_interested_in_user.return_value = True
         self.mock_store.get_app_services.return_value = services
-        self.mock_store.get_user_by_id.return_value = defer.succeed(None)
+        self.mock_store.get_user_by_id.return_value = make_awaitable(None)
 
         event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
-        self.mock_as_api.query_user.return_value = defer.succeed(True)
+        self.mock_as_api.query_user.return_value = make_awaitable(True)
         self.mock_store.get_new_events_for_appservice.side_effect = [
-            defer.succeed((0, [event])),
-            defer.succeed((0, [])),
+            make_awaitable((0, [event])),
+            make_awaitable((0, [])),
         ]
 
         self.handler.notify_interested_services(RoomStreamToken(None, 0))
@@ -90,13 +90,13 @@ class AppServiceHandlerTestCase(unittest.TestCase):
         services = [self._mkservice(is_interested=True)]
         services[0].is_interested_in_user.return_value = True
         self.mock_store.get_app_services.return_value = services
-        self.mock_store.get_user_by_id.return_value = defer.succeed({"name": user_id})
+        self.mock_store.get_user_by_id.return_value = make_awaitable({"name": user_id})
 
         event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
-        self.mock_as_api.query_user.return_value = defer.succeed(True)
+        self.mock_as_api.query_user.return_value = make_awaitable(True)
         self.mock_store.get_new_events_for_appservice.side_effect = [
-            defer.succeed((0, [event])),
-            defer.succeed((0, [])),
+            make_awaitable((0, [event])),
+            make_awaitable((0, [])),
         ]
 
         self.handler.notify_interested_services(RoomStreamToken(None, 0))
@@ -106,7 +106,6 @@ class AppServiceHandlerTestCase(unittest.TestCase):
             "query_user called when it shouldn't have been.",
         )
 
-    @defer.inlineCallbacks
     def test_query_room_alias_exists(self):
         room_alias_str = "#foo:bar"
         room_alias = Mock()
@@ -127,8 +126,8 @@ class AppServiceHandlerTestCase(unittest.TestCase):
             Mock(room_id=room_id, servers=servers)
         )
 
-        result = yield defer.ensureDeferred(
-            self.handler.query_room_alias_exists(room_alias)
+        result = self.successResultOf(
+            defer.ensureDeferred(self.handler.query_room_alias_exists(room_alias))
         )
 
         self.mock_as_api.query_alias.assert_called_once_with(
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index e24ce81284..0e42013bb9 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -16,28 +16,21 @@ from mock import Mock
 
 import pymacaroons
 
-from twisted.internet import defer
-
-import synapse
-import synapse.api.errors
-from synapse.api.errors import ResourceLimitError
+from synapse.api.errors import AuthError, ResourceLimitError
 
 from tests import unittest
 from tests.test_utils import make_awaitable
-from tests.utils import setup_test_homeserver
 
 
-class AuthTestCase(unittest.TestCase):
-    @defer.inlineCallbacks
-    def setUp(self):
-        self.hs = yield setup_test_homeserver(self.addCleanup)
-        self.auth_handler = self.hs.get_auth_handler()
-        self.macaroon_generator = self.hs.get_macaroon_generator()
+class AuthTestCase(unittest.HomeserverTestCase):
+    def prepare(self, reactor, clock, hs):
+        self.auth_handler = hs.get_auth_handler()
+        self.macaroon_generator = hs.get_macaroon_generator()
 
         # MAU tests
         # AuthBlocking reads from the hs' config on initialization. We need to
         # modify its config instead of the hs'
-        self.auth_blocking = self.hs.get_auth()._auth_blocking
+        self.auth_blocking = hs.get_auth()._auth_blocking
         self.auth_blocking._max_mau_value = 50
 
         self.small_number_of_users = 1
@@ -52,8 +45,6 @@ class AuthTestCase(unittest.TestCase):
             self.fail("some_user was not in %s" % macaroon.inspect())
 
     def test_macaroon_caveats(self):
-        self.hs.get_clock().now = 5000
-
         token = self.macaroon_generator.generate_access_token("a_user")
         macaroon = pymacaroons.Macaroon.deserialize(token)
 
@@ -76,29 +67,25 @@ class AuthTestCase(unittest.TestCase):
         v.satisfy_general(verify_nonce)
         v.verify(macaroon, self.hs.config.macaroon_secret_key)
 
-    @defer.inlineCallbacks
     def test_short_term_login_token_gives_user_id(self):
-        self.hs.get_clock().now = 1000
-
         token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
-        user_id = yield defer.ensureDeferred(
+        user_id = self.get_success(
             self.auth_handler.validate_short_term_login_token_and_get_user_id(token)
         )
         self.assertEqual("a_user", user_id)
 
         # when we advance the clock, the token should be rejected
-        self.hs.get_clock().now = 6000
-        with self.assertRaises(synapse.api.errors.AuthError):
-            yield defer.ensureDeferred(
-                self.auth_handler.validate_short_term_login_token_and_get_user_id(token)
-            )
+        self.reactor.advance(6)
+        self.get_failure(
+            self.auth_handler.validate_short_term_login_token_and_get_user_id(token),
+            AuthError,
+        )
 
-    @defer.inlineCallbacks
     def test_short_term_login_token_cannot_replace_user_id(self):
         token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
         macaroon = pymacaroons.Macaroon.deserialize(token)
 
-        user_id = yield defer.ensureDeferred(
+        user_id = self.get_success(
             self.auth_handler.validate_short_term_login_token_and_get_user_id(
                 macaroon.serialize()
             )
@@ -109,102 +96,90 @@ class AuthTestCase(unittest.TestCase):
         # user_id.
         macaroon.add_first_party_caveat("user_id = b_user")
 
-        with self.assertRaises(synapse.api.errors.AuthError):
-            yield defer.ensureDeferred(
-                self.auth_handler.validate_short_term_login_token_and_get_user_id(
-                    macaroon.serialize()
-                )
-            )
+        self.get_failure(
+            self.auth_handler.validate_short_term_login_token_and_get_user_id(
+                macaroon.serialize()
+            ),
+            AuthError,
+        )
 
-    @defer.inlineCallbacks
     def test_mau_limits_disabled(self):
         self.auth_blocking._limit_usage_by_mau = False
         # Ensure does not throw exception
-        yield defer.ensureDeferred(
+        self.get_success(
             self.auth_handler.get_access_token_for_user_id(
                 "user_a", device_id=None, valid_until_ms=None
             )
         )
 
-        yield defer.ensureDeferred(
+        self.get_success(
             self.auth_handler.validate_short_term_login_token_and_get_user_id(
                 self._get_macaroon().serialize()
             )
         )
 
-    @defer.inlineCallbacks
     def test_mau_limits_exceeded_large(self):
         self.auth_blocking._limit_usage_by_mau = True
         self.hs.get_datastore().get_monthly_active_count = Mock(
             return_value=make_awaitable(self.large_number_of_users)
         )
 
-        with self.assertRaises(ResourceLimitError):
-            yield defer.ensureDeferred(
-                self.auth_handler.get_access_token_for_user_id(
-                    "user_a", device_id=None, valid_until_ms=None
-                )
-            )
+        self.get_failure(
+            self.auth_handler.get_access_token_for_user_id(
+                "user_a", device_id=None, valid_until_ms=None
+            ),
+            ResourceLimitError,
+        )
 
         self.hs.get_datastore().get_monthly_active_count = Mock(
             return_value=make_awaitable(self.large_number_of_users)
         )
-        with self.assertRaises(ResourceLimitError):
-            yield defer.ensureDeferred(
-                self.auth_handler.validate_short_term_login_token_and_get_user_id(
-                    self._get_macaroon().serialize()
-                )
-            )
+        self.get_failure(
+            self.auth_handler.validate_short_term_login_token_and_get_user_id(
+                self._get_macaroon().serialize()
+            ),
+            ResourceLimitError,
+        )
 
-    @defer.inlineCallbacks
     def test_mau_limits_parity(self):
+        # Ensure we're not at the unix epoch.
+        self.reactor.advance(1)
         self.auth_blocking._limit_usage_by_mau = True
 
-        # If not in monthly active cohort
+        # Set the server to be at the edge of too many users.
         self.hs.get_datastore().get_monthly_active_count = Mock(
             return_value=make_awaitable(self.auth_blocking._max_mau_value)
         )
-        with self.assertRaises(ResourceLimitError):
-            yield defer.ensureDeferred(
-                self.auth_handler.get_access_token_for_user_id(
-                    "user_a", device_id=None, valid_until_ms=None
-                )
-            )
 
-        self.hs.get_datastore().get_monthly_active_count = Mock(
-            return_value=make_awaitable(self.auth_blocking._max_mau_value)
+        # If not in monthly active cohort
+        self.get_failure(
+            self.auth_handler.get_access_token_for_user_id(
+                "user_a", device_id=None, valid_until_ms=None
+            ),
+            ResourceLimitError,
         )
-        with self.assertRaises(ResourceLimitError):
-            yield defer.ensureDeferred(
-                self.auth_handler.validate_short_term_login_token_and_get_user_id(
-                    self._get_macaroon().serialize()
-                )
-            )
+        self.get_failure(
+            self.auth_handler.validate_short_term_login_token_and_get_user_id(
+                self._get_macaroon().serialize()
+            ),
+            ResourceLimitError,
+        )
+
         # If in monthly active cohort
         self.hs.get_datastore().user_last_seen_monthly_active = Mock(
-            return_value=make_awaitable(self.hs.get_clock().time_msec())
+            return_value=make_awaitable(self.clock.time_msec())
         )
-        self.hs.get_datastore().get_monthly_active_count = Mock(
-            return_value=make_awaitable(self.auth_blocking._max_mau_value)
-        )
-        yield defer.ensureDeferred(
+        self.get_success(
             self.auth_handler.get_access_token_for_user_id(
                 "user_a", device_id=None, valid_until_ms=None
             )
         )
-        self.hs.get_datastore().user_last_seen_monthly_active = Mock(
-            return_value=make_awaitable(self.hs.get_clock().time_msec())
-        )
-        self.hs.get_datastore().get_monthly_active_count = Mock(
-            return_value=make_awaitable(self.auth_blocking._max_mau_value)
-        )
-        yield defer.ensureDeferred(
+        self.get_success(
             self.auth_handler.validate_short_term_login_token_and_get_user_id(
                 self._get_macaroon().serialize()
             )
         )
 
-    @defer.inlineCallbacks
     def test_mau_limits_not_exceeded(self):
         self.auth_blocking._limit_usage_by_mau = True
 
@@ -212,7 +187,7 @@ class AuthTestCase(unittest.TestCase):
             return_value=make_awaitable(self.small_number_of_users)
         )
         # Ensure does not raise exception
-        yield defer.ensureDeferred(
+        self.get_success(
             self.auth_handler.get_access_token_for_user_id(
                 "user_a", device_id=None, valid_until_ms=None
             )
@@ -221,7 +196,7 @@ class AuthTestCase(unittest.TestCase):
         self.hs.get_datastore().get_monthly_active_count = Mock(
             return_value=make_awaitable(self.small_number_of_users)
         )
-        yield defer.ensureDeferred(
+        self.get_success(
             self.auth_handler.validate_short_term_login_token_and_get_user_id(
                 self._get_macaroon().serialize()
             )
diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
index c37bb6440e..6f992291b8 100644
--- a/tests/handlers/test_cas.py
+++ b/tests/handlers/test_cas.py
@@ -16,7 +16,7 @@ from mock import Mock
 from synapse.handlers.cas_handler import CasResponse
 
 from tests.test_utils import simple_async_mock
-from tests.unittest import HomeserverTestCase
+from tests.unittest import HomeserverTestCase, override_config
 
 # These are a few constants that are used as config parameters in the tests.
 BASE_URL = "https://synapse/"
@@ -32,6 +32,10 @@ class CasHandlerTestCase(HomeserverTestCase):
             "server_url": SERVER_URL,
             "service_url": BASE_URL,
         }
+
+        # Update this config with what's in the default config so that
+        # override_config works as expected.
+        cas_config.update(config.get("cas_config", {}))
         config["cas_config"] = cas_config
 
         return config
@@ -62,7 +66,7 @@ class CasHandlerTestCase(HomeserverTestCase):
 
         # check that the auth handler got called as expected
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", request, "redirect_uri", None
+            "@test_user:test", request, "redirect_uri", None, new_user=True
         )
 
     def test_map_cas_user_to_existing_user(self):
@@ -85,7 +89,7 @@ class CasHandlerTestCase(HomeserverTestCase):
 
         # check that the auth handler got called as expected
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", request, "redirect_uri", None
+            "@test_user:test", request, "redirect_uri", None, new_user=False
         )
 
         # Subsequent calls should map to the same mxid.
@@ -94,7 +98,7 @@ class CasHandlerTestCase(HomeserverTestCase):
             self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
         )
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", request, "redirect_uri", None
+            "@test_user:test", request, "redirect_uri", None, new_user=False
         )
 
     def test_map_cas_user_to_invalid_localpart(self):
@@ -112,10 +116,54 @@ class CasHandlerTestCase(HomeserverTestCase):
 
         # check that the auth handler got called as expected
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@f=c3=b6=c3=b6:test", request, "redirect_uri", None
+            "@f=c3=b6=c3=b6:test", request, "redirect_uri", None, new_user=True
+        )
+
+    @override_config(
+        {
+            "cas_config": {
+                "required_attributes": {"userGroup": "staff", "department": None}
+            }
+        }
+    )
+    def test_required_attributes(self):
+        """The required attributes must be met from the CAS response."""
+
+        # stub out the auth handler
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
+
+        # The response doesn't have the proper userGroup or department.
+        cas_response = CasResponse("test_user", {})
+        request = _mock_request()
+        self.get_success(
+            self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
+        )
+        auth_handler.complete_sso_login.assert_not_called()
+
+        # The response doesn't have any department.
+        cas_response = CasResponse("test_user", {"userGroup": "staff"})
+        request.reset_mock()
+        self.get_success(
+            self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
+        )
+        auth_handler.complete_sso_login.assert_not_called()
+
+        # Add the proper attributes and it should succeed.
+        cas_response = CasResponse(
+            "test_user", {"userGroup": ["staff", "admin"], "department": ["sales"]}
+        )
+        request.reset_mock()
+        self.get_success(
+            self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
+        )
+
+        # check that the auth handler got called as expected
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@test_user:test", request, "redirect_uri", None, new_user=True
         )
 
 
 def _mock_request():
     """Returns a mock which will stand in as a SynapseRequest"""
-    return Mock(spec=["getClientIP", "getHeader"])
+    return Mock(spec=["getClientIP", "getHeader", "_disconnected"])
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 5dfeccfeb6..821629bc38 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -260,7 +260,9 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
         # Create a new login for the user and dehydrated the device
         device_id, access_token = self.get_success(
             self.registration.register_device(
-                user_id=user_id, device_id=None, initial_display_name="new device",
+                user_id=user_id,
+                device_id=None,
+                initial_display_name="new device",
             )
         )
 
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index a39f898608..863d8737b2 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -131,7 +131,9 @@ class TestCreateAlias(unittest.HomeserverTestCase):
         """A user can create an alias for a room they're in."""
         self.get_success(
             self.handler.create_association(
-                create_requester(self.test_user), self.room_alias, self.room_id,
+                create_requester(self.test_user),
+                self.room_alias,
+                self.room_id,
             )
         )
 
@@ -143,7 +145,9 @@ class TestCreateAlias(unittest.HomeserverTestCase):
 
         self.get_failure(
             self.handler.create_association(
-                create_requester(self.test_user), self.room_alias, other_room_id,
+                create_requester(self.test_user),
+                self.room_alias,
+                other_room_id,
             ),
             synapse.api.errors.SynapseError,
         )
@@ -156,7 +160,9 @@ class TestCreateAlias(unittest.HomeserverTestCase):
 
         self.get_success(
             self.handler.create_association(
-                create_requester(self.admin_user), self.room_alias, other_room_id,
+                create_requester(self.admin_user),
+                self.room_alias,
+                other_room_id,
             )
         )
 
@@ -275,8 +281,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
 
 
 class CanonicalAliasTestCase(unittest.HomeserverTestCase):
-    """Test modifications of the canonical alias when delete aliases.
-    """
+    """Test modifications of the canonical alias when delete aliases."""
 
     servlets = [
         synapse.rest.admin.register_servlets,
@@ -317,7 +322,10 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
     def _set_canonical_alias(self, content):
         """Configure the canonical alias state on the room."""
         self.helper.send_state(
-            self.room_id, "m.room.canonical_alias", content, tok=self.admin_user_tok,
+            self.room_id,
+            "m.room.canonical_alias",
+            content,
+            tok=self.admin_user_tok,
         )
 
     def _get_canonical_alias(self):
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 924f29f051..5e86c5e56b 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -18,42 +18,26 @@ import mock
 
 from signedjson import key as key, sign as sign
 
-from twisted.internet import defer
-
-import synapse.handlers.e2e_keys
-import synapse.storage
-from synapse.api import errors
 from synapse.api.constants import RoomEncryptionAlgorithms
+from synapse.api.errors import Codes, SynapseError
 
-from tests import unittest, utils
+from tests import unittest
 
 
-class E2eKeysHandlerTestCase(unittest.TestCase):
-    def __init__(self, *args, **kwargs):
-        super().__init__(*args, **kwargs)
-        self.hs = None  # type: synapse.server.HomeServer
-        self.handler = None  # type: synapse.handlers.e2e_keys.E2eKeysHandler
-        self.store = None  # type: synapse.storage.Storage
+class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
+    def make_homeserver(self, reactor, clock):
+        return self.setup_test_homeserver(federation_client=mock.Mock())
 
-    @defer.inlineCallbacks
-    def setUp(self):
-        self.hs = yield utils.setup_test_homeserver(
-            self.addCleanup, federation_client=mock.Mock()
-        )
-        self.handler = synapse.handlers.e2e_keys.E2eKeysHandler(self.hs)
+    def prepare(self, reactor, clock, hs):
+        self.handler = hs.get_e2e_keys_handler()
         self.store = self.hs.get_datastore()
 
-    @defer.inlineCallbacks
     def test_query_local_devices_no_devices(self):
-        """If the user has no devices, we expect an empty list.
-        """
+        """If the user has no devices, we expect an empty list."""
         local_user = "@boris:" + self.hs.hostname
-        res = yield defer.ensureDeferred(
-            self.handler.query_local_devices({local_user: None})
-        )
+        res = self.get_success(self.handler.query_local_devices({local_user: None}))
         self.assertDictEqual(res, {local_user: {}})
 
-    @defer.inlineCallbacks
     def test_reupload_one_time_keys(self):
         """we should be able to re-upload the same keys"""
         local_user = "@boris:" + self.hs.hostname
@@ -64,7 +48,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
             "alg2:k3": {"key": "key3"},
         }
 
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.upload_keys_for_user(
                 local_user, device_id, {"one_time_keys": keys}
             )
@@ -73,14 +57,13 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
 
         # we should be able to change the signature without a problem
         keys["alg2:k2"]["signatures"]["k1"] = "sig2"
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.upload_keys_for_user(
                 local_user, device_id, {"one_time_keys": keys}
             )
         )
         self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
 
-    @defer.inlineCallbacks
     def test_change_one_time_keys(self):
         """attempts to change one-time-keys should be rejected"""
 
@@ -92,75 +75,66 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
             "alg2:k3": {"key": "key3"},
         }
 
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.upload_keys_for_user(
                 local_user, device_id, {"one_time_keys": keys}
             )
         )
         self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
 
-        try:
-            yield defer.ensureDeferred(
-                self.handler.upload_keys_for_user(
-                    local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}}
-                )
-            )
-            self.fail("No error when changing string key")
-        except errors.SynapseError:
-            pass
-
-        try:
-            yield defer.ensureDeferred(
-                self.handler.upload_keys_for_user(
-                    local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}}
-                )
-            )
-            self.fail("No error when replacing dict key with string")
-        except errors.SynapseError:
-            pass
-
-        try:
-            yield defer.ensureDeferred(
-                self.handler.upload_keys_for_user(
-                    local_user,
-                    device_id,
-                    {"one_time_keys": {"alg1:k1": {"key": "key"}}},
-                )
-            )
-            self.fail("No error when replacing string key with dict")
-        except errors.SynapseError:
-            pass
-
-        try:
-            yield defer.ensureDeferred(
-                self.handler.upload_keys_for_user(
-                    local_user,
-                    device_id,
-                    {
-                        "one_time_keys": {
-                            "alg2:k2": {"key": "key3", "signatures": {"k1": "sig1"}}
-                        }
-                    },
-                )
-            )
-            self.fail("No error when replacing dict key")
-        except errors.SynapseError:
-            pass
+        # Error when changing string key
+        self.get_failure(
+            self.handler.upload_keys_for_user(
+                local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}}
+            ),
+            SynapseError,
+        )
+
+        # Error when replacing dict key with strin
+        self.get_failure(
+            self.handler.upload_keys_for_user(
+                local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}}
+            ),
+            SynapseError,
+        )
+
+        # Error when replacing string key with dict
+        self.get_failure(
+            self.handler.upload_keys_for_user(
+                local_user,
+                device_id,
+                {"one_time_keys": {"alg1:k1": {"key": "key"}}},
+            ),
+            SynapseError,
+        )
+
+        # Error when replacing dict key
+        self.get_failure(
+            self.handler.upload_keys_for_user(
+                local_user,
+                device_id,
+                {
+                    "one_time_keys": {
+                        "alg2:k2": {"key": "key3", "signatures": {"k1": "sig1"}}
+                    }
+                },
+            ),
+            SynapseError,
+        )
 
-    @defer.inlineCallbacks
     def test_claim_one_time_key(self):
         local_user = "@boris:" + self.hs.hostname
         device_id = "xyz"
         keys = {"alg1:k1": "key1"}
 
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.upload_keys_for_user(
                 local_user, device_id, {"one_time_keys": keys}
             )
         )
         self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1}})
 
-        res2 = yield defer.ensureDeferred(
+        res2 = self.get_success(
             self.handler.claim_one_time_keys(
                 {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
             )
@@ -173,7 +147,6 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
             },
         )
 
-    @defer.inlineCallbacks
     def test_fallback_key(self):
         local_user = "@boris:" + self.hs.hostname
         device_id = "xyz"
@@ -181,12 +154,12 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         otk = {"alg1:k2": "key2"}
 
         # we shouldn't have any unused fallback keys yet
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
         )
         self.assertEqual(res, [])
 
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_keys_for_user(
                 local_user,
                 device_id,
@@ -195,14 +168,14 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         )
 
         # we should now have an unused alg1 key
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
         )
         self.assertEqual(res, ["alg1"])
 
         # claiming an OTK when no OTKs are available should return the fallback
         # key
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.claim_one_time_keys(
                 {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
             )
@@ -213,13 +186,13 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         )
 
         # we shouldn't have any unused fallback keys again
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
         )
         self.assertEqual(res, [])
 
         # claiming an OTK again should return the same fallback key
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.claim_one_time_keys(
                 {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
             )
@@ -231,22 +204,23 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
 
         # if the user uploads a one-time key, the next claim should fetch the
         # one-time key, and then go back to the fallback
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_keys_for_user(
                 local_user, device_id, {"one_time_keys": otk}
             )
         )
 
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.claim_one_time_keys(
                 {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
             )
         )
         self.assertEqual(
-            res, {"failures": {}, "one_time_keys": {local_user: {device_id: otk}}},
+            res,
+            {"failures": {}, "one_time_keys": {local_user: {device_id: otk}}},
         )
 
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.claim_one_time_keys(
                 {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
             )
@@ -256,7 +230,6 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
             {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
         )
 
-    @defer.inlineCallbacks
     def test_replace_master_key(self):
         """uploading a new signing key should make the old signing key unavailable"""
         local_user = "@boris:" + self.hs.hostname
@@ -270,9 +243,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
                 },
             }
         }
-        yield defer.ensureDeferred(
-            self.handler.upload_signing_keys_for_user(local_user, keys1)
-        )
+        self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1))
 
         keys2 = {
             "master_key": {
@@ -284,16 +255,13 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
                 },
             }
         }
-        yield defer.ensureDeferred(
-            self.handler.upload_signing_keys_for_user(local_user, keys2)
-        )
+        self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys2))
 
-        devices = yield defer.ensureDeferred(
+        devices = self.get_success(
             self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user)
         )
         self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]})
 
-    @defer.inlineCallbacks
     def test_reupload_signatures(self):
         """re-uploading a signature should not fail"""
         local_user = "@boris:" + self.hs.hostname
@@ -326,9 +294,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
             "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
             "2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0",
         )
-        yield defer.ensureDeferred(
-            self.handler.upload_signing_keys_for_user(local_user, keys1)
-        )
+        self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1))
 
         # upload two device keys, which will be signed later by the self-signing key
         device_key_1 = {
@@ -358,12 +324,12 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
             "signatures": {local_user: {"ed25519:def": "base64+signature"}},
         }
 
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_keys_for_user(
                 local_user, "abc", {"device_keys": device_key_1}
             )
         )
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_keys_for_user(
                 local_user, "def", {"device_keys": device_key_2}
             )
@@ -372,7 +338,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         # sign the first device key and upload it
         del device_key_1["signatures"]
         sign.sign_json(device_key_1, local_user, signing_key)
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_signatures_for_device_keys(
                 local_user, {local_user: {"abc": device_key_1}}
             )
@@ -383,7 +349,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         # signature for it
         del device_key_2["signatures"]
         sign.sign_json(device_key_2, local_user, signing_key)
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_signatures_for_device_keys(
                 local_user, {local_user: {"abc": device_key_1, "def": device_key_2}}
             )
@@ -391,7 +357,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
 
         device_key_1["signatures"][local_user]["ed25519:abc"] = "base64+signature"
         device_key_2["signatures"][local_user]["ed25519:def"] = "base64+signature"
-        devices = yield defer.ensureDeferred(
+        devices = self.get_success(
             self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user)
         )
         del devices["device_keys"][local_user]["abc"]["unsigned"]
@@ -399,7 +365,6 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         self.assertDictEqual(devices["device_keys"][local_user]["abc"], device_key_1)
         self.assertDictEqual(devices["device_keys"][local_user]["def"], device_key_2)
 
-    @defer.inlineCallbacks
     def test_self_signing_key_doesnt_show_up_as_device(self):
         """signing keys should be hidden when fetching a user's devices"""
         local_user = "@boris:" + self.hs.hostname
@@ -413,29 +378,22 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
                 },
             }
         }
-        yield defer.ensureDeferred(
-            self.handler.upload_signing_keys_for_user(local_user, keys1)
-        )
-
-        res = None
-        try:
-            yield defer.ensureDeferred(
-                self.hs.get_device_handler().check_device_registered(
-                    user_id=local_user,
-                    device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
-                    initial_device_display_name="new display name",
-                )
-            )
-        except errors.SynapseError as e:
-            res = e.code
-        self.assertEqual(res, 400)
+        self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1))
 
-        res = yield defer.ensureDeferred(
-            self.handler.query_local_devices({local_user: None})
+        e = self.get_failure(
+            self.hs.get_device_handler().check_device_registered(
+                user_id=local_user,
+                device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
+                initial_device_display_name="new display name",
+            ),
+            SynapseError,
         )
+        res = e.value.code
+        self.assertEqual(res, 400)
+
+        res = self.get_success(self.handler.query_local_devices({local_user: None}))
         self.assertDictEqual(res, {local_user: {}})
 
-    @defer.inlineCallbacks
     def test_upload_signatures(self):
         """should check signatures that are uploaded"""
         # set up a user with cross-signing keys and a device.  This user will
@@ -458,7 +416,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
             "ed25519", "xyz", "OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA"
         )
 
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_keys_for_user(
                 local_user, device_id, {"device_keys": device_key}
             )
@@ -501,7 +459,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
             "user_signing_key": usersigning_key,
             "self_signing_key": selfsigning_key,
         }
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_signing_keys_for_user(local_user, cross_signing_keys)
         )
 
@@ -515,14 +473,14 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
             "usage": ["master"],
             "keys": {"ed25519:" + other_master_pubkey: other_master_pubkey},
         }
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_signing_keys_for_user(
                 other_user, {"master_key": other_master_key}
             )
         )
 
         # test various signature failures (see below)
-        ret = yield defer.ensureDeferred(
+        ret = self.get_success(
             self.handler.upload_signatures_for_device_keys(
                 local_user,
                 {
@@ -602,20 +560,16 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         )
 
         user_failures = ret["failures"][local_user]
+        self.assertEqual(user_failures[device_id]["errcode"], Codes.INVALID_SIGNATURE)
         self.assertEqual(
-            user_failures[device_id]["errcode"], errors.Codes.INVALID_SIGNATURE
+            user_failures[master_pubkey]["errcode"], Codes.INVALID_SIGNATURE
         )
-        self.assertEqual(
-            user_failures[master_pubkey]["errcode"], errors.Codes.INVALID_SIGNATURE
-        )
-        self.assertEqual(user_failures["unknown"]["errcode"], errors.Codes.NOT_FOUND)
+        self.assertEqual(user_failures["unknown"]["errcode"], Codes.NOT_FOUND)
 
         other_user_failures = ret["failures"][other_user]
+        self.assertEqual(other_user_failures["unknown"]["errcode"], Codes.NOT_FOUND)
         self.assertEqual(
-            other_user_failures["unknown"]["errcode"], errors.Codes.NOT_FOUND
-        )
-        self.assertEqual(
-            other_user_failures[other_master_pubkey]["errcode"], errors.Codes.UNKNOWN
+            other_user_failures[other_master_pubkey]["errcode"], Codes.UNKNOWN
         )
 
         # test successful signatures
@@ -623,7 +577,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         sign.sign_json(device_key, local_user, selfsigning_signing_key)
         sign.sign_json(master_key, local_user, device_signing_key)
         sign.sign_json(other_master_key, local_user, usersigning_signing_key)
-        ret = yield defer.ensureDeferred(
+        ret = self.get_success(
             self.handler.upload_signatures_for_device_keys(
                 local_user,
                 {
@@ -636,7 +590,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         self.assertEqual(ret["failures"], {})
 
         # fetch the signed keys/devices and make sure that the signatures are there
-        ret = yield defer.ensureDeferred(
+        ret = self.get_success(
             self.handler.query_devices(
                 {"device_keys": {local_user: [], other_user: []}}, 0, local_user
             )
diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py
index 45f201a399..d7498aa51a 100644
--- a/tests/handlers/test_e2e_room_keys.py
+++ b/tests/handlers/test_e2e_room_keys.py
@@ -19,14 +19,9 @@ import copy
 
 import mock
 
-from twisted.internet import defer
+from synapse.api.errors import SynapseError
 
-import synapse.api.errors
-import synapse.handlers.e2e_room_keys
-import synapse.storage
-from synapse.api import errors
-
-from tests import unittest, utils
+from tests import unittest
 
 # sample room_key data for use in the tests
 room_keys = {
@@ -45,51 +40,38 @@ room_keys = {
 }
 
 
-class E2eRoomKeysHandlerTestCase(unittest.TestCase):
-    def __init__(self, *args, **kwargs):
-        super().__init__(*args, **kwargs)
-        self.hs = None  # type: synapse.server.HomeServer
-        self.handler = None  # type: synapse.handlers.e2e_keys.E2eRoomKeysHandler
+class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
+    def make_homeserver(self, reactor, clock):
+        return self.setup_test_homeserver(replication_layer=mock.Mock())
 
-    @defer.inlineCallbacks
-    def setUp(self):
-        self.hs = yield utils.setup_test_homeserver(
-            self.addCleanup, replication_layer=mock.Mock()
-        )
-        self.handler = synapse.handlers.e2e_room_keys.E2eRoomKeysHandler(self.hs)
-        self.local_user = "@boris:" + self.hs.hostname
+    def prepare(self, reactor, clock, hs):
+        self.handler = hs.get_e2e_room_keys_handler()
+        self.local_user = "@boris:" + hs.hostname
 
-    @defer.inlineCallbacks
     def test_get_missing_current_version_info(self):
         """Check that we get a 404 if we ask for info about the current version
         if there is no version.
         """
-        res = None
-        try:
-            yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
-        except errors.SynapseError as e:
-            res = e.code
+        e = self.get_failure(
+            self.handler.get_version_info(self.local_user), SynapseError
+        )
+        res = e.value.code
         self.assertEqual(res, 404)
 
-    @defer.inlineCallbacks
     def test_get_missing_version_info(self):
         """Check that we get a 404 if we ask for info about a specific version
         if it doesn't exist.
         """
-        res = None
-        try:
-            yield defer.ensureDeferred(
-                self.handler.get_version_info(self.local_user, "bogus_version")
-            )
-        except errors.SynapseError as e:
-            res = e.code
+        e = self.get_failure(
+            self.handler.get_version_info(self.local_user, "bogus_version"),
+            SynapseError,
+        )
+        res = e.value.code
         self.assertEqual(res, 404)
 
-    @defer.inlineCallbacks
     def test_create_version(self):
-        """Check that we can create and then retrieve versions.
-        """
-        res = yield defer.ensureDeferred(
+        """Check that we can create and then retrieve versions."""
+        res = self.get_success(
             self.handler.create_version(
                 self.local_user,
                 {
@@ -101,7 +83,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         self.assertEqual(res, "1")
 
         # check we can retrieve it as the current version
-        res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
+        res = self.get_success(self.handler.get_version_info(self.local_user))
         version_etag = res["etag"]
         self.assertIsInstance(version_etag, str)
         del res["etag"]
@@ -116,9 +98,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         )
 
         # check we can retrieve it as a specific version
-        res = yield defer.ensureDeferred(
-            self.handler.get_version_info(self.local_user, "1")
-        )
+        res = self.get_success(self.handler.get_version_info(self.local_user, "1"))
         self.assertEqual(res["etag"], version_etag)
         del res["etag"]
         self.assertDictEqual(
@@ -132,7 +112,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         )
 
         # upload a new one...
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.create_version(
                 self.local_user,
                 {
@@ -144,7 +124,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         self.assertEqual(res, "2")
 
         # check we can retrieve it as the current version
-        res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
+        res = self.get_success(self.handler.get_version_info(self.local_user))
         del res["etag"]
         self.assertDictEqual(
             res,
@@ -156,11 +136,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
             },
         )
 
-    @defer.inlineCallbacks
     def test_update_version(self):
-        """Check that we can update versions.
-        """
-        version = yield defer.ensureDeferred(
+        """Check that we can update versions."""
+        version = self.get_success(
             self.handler.create_version(
                 self.local_user,
                 {
@@ -171,7 +149,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         )
         self.assertEqual(version, "1")
 
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.update_version(
                 self.local_user,
                 version,
@@ -185,7 +163,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         self.assertDictEqual(res, {})
 
         # check we can retrieve it as the current version
-        res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
+        res = self.get_success(self.handler.get_version_info(self.local_user))
         del res["etag"]
         self.assertDictEqual(
             res,
@@ -197,32 +175,26 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
             },
         )
 
-    @defer.inlineCallbacks
     def test_update_missing_version(self):
-        """Check that we get a 404 on updating nonexistent versions
-        """
-        res = None
-        try:
-            yield defer.ensureDeferred(
-                self.handler.update_version(
-                    self.local_user,
-                    "1",
-                    {
-                        "algorithm": "m.megolm_backup.v1",
-                        "auth_data": "revised_first_version_auth_data",
-                        "version": "1",
-                    },
-                )
-            )
-        except errors.SynapseError as e:
-            res = e.code
+        """Check that we get a 404 on updating nonexistent versions"""
+        e = self.get_failure(
+            self.handler.update_version(
+                self.local_user,
+                "1",
+                {
+                    "algorithm": "m.megolm_backup.v1",
+                    "auth_data": "revised_first_version_auth_data",
+                    "version": "1",
+                },
+            ),
+            SynapseError,
+        )
+        res = e.value.code
         self.assertEqual(res, 404)
 
-    @defer.inlineCallbacks
     def test_update_omitted_version(self):
-        """Check that the update succeeds if the version is missing from the body
-        """
-        version = yield defer.ensureDeferred(
+        """Check that the update succeeds if the version is missing from the body"""
+        version = self.get_success(
             self.handler.create_version(
                 self.local_user,
                 {
@@ -233,7 +205,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         )
         self.assertEqual(version, "1")
 
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.update_version(
                 self.local_user,
                 version,
@@ -245,7 +217,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         )
 
         # check we can retrieve it as the current version
-        res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
+        res = self.get_success(self.handler.get_version_info(self.local_user))
         del res["etag"]  # etag is opaque, so don't test its contents
         self.assertDictEqual(
             res,
@@ -257,11 +229,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
             },
         )
 
-    @defer.inlineCallbacks
     def test_update_bad_version(self):
-        """Check that we get a 400 if the version in the body doesn't match
-        """
-        version = yield defer.ensureDeferred(
+        """Check that we get a 400 if the version in the body doesn't match"""
+        version = self.get_success(
             self.handler.create_version(
                 self.local_user,
                 {
@@ -272,52 +242,38 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         )
         self.assertEqual(version, "1")
 
-        res = None
-        try:
-            yield defer.ensureDeferred(
-                self.handler.update_version(
-                    self.local_user,
-                    version,
-                    {
-                        "algorithm": "m.megolm_backup.v1",
-                        "auth_data": "revised_first_version_auth_data",
-                        "version": "incorrect",
-                    },
-                )
-            )
-        except errors.SynapseError as e:
-            res = e.code
+        e = self.get_failure(
+            self.handler.update_version(
+                self.local_user,
+                version,
+                {
+                    "algorithm": "m.megolm_backup.v1",
+                    "auth_data": "revised_first_version_auth_data",
+                    "version": "incorrect",
+                },
+            ),
+            SynapseError,
+        )
+        res = e.value.code
         self.assertEqual(res, 400)
 
-    @defer.inlineCallbacks
     def test_delete_missing_version(self):
-        """Check that we get a 404 on deleting nonexistent versions
-        """
-        res = None
-        try:
-            yield defer.ensureDeferred(
-                self.handler.delete_version(self.local_user, "1")
-            )
-        except errors.SynapseError as e:
-            res = e.code
+        """Check that we get a 404 on deleting nonexistent versions"""
+        e = self.get_failure(
+            self.handler.delete_version(self.local_user, "1"), SynapseError
+        )
+        res = e.value.code
         self.assertEqual(res, 404)
 
-    @defer.inlineCallbacks
     def test_delete_missing_current_version(self):
-        """Check that we get a 404 on deleting nonexistent current version
-        """
-        res = None
-        try:
-            yield defer.ensureDeferred(self.handler.delete_version(self.local_user))
-        except errors.SynapseError as e:
-            res = e.code
+        """Check that we get a 404 on deleting nonexistent current version"""
+        e = self.get_failure(self.handler.delete_version(self.local_user), SynapseError)
+        res = e.value.code
         self.assertEqual(res, 404)
 
-    @defer.inlineCallbacks
     def test_delete_version(self):
-        """Check that we can create and then delete versions.
-        """
-        res = yield defer.ensureDeferred(
+        """Check that we can create and then delete versions."""
+        res = self.get_success(
             self.handler.create_version(
                 self.local_user,
                 {
@@ -329,36 +285,26 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         self.assertEqual(res, "1")
 
         # check we can delete it
-        yield defer.ensureDeferred(self.handler.delete_version(self.local_user, "1"))
+        self.get_success(self.handler.delete_version(self.local_user, "1"))
 
         # check that it's gone
-        res = None
-        try:
-            yield defer.ensureDeferred(
-                self.handler.get_version_info(self.local_user, "1")
-            )
-        except errors.SynapseError as e:
-            res = e.code
+        e = self.get_failure(
+            self.handler.get_version_info(self.local_user, "1"), SynapseError
+        )
+        res = e.value.code
         self.assertEqual(res, 404)
 
-    @defer.inlineCallbacks
     def test_get_missing_backup(self):
-        """Check that we get a 404 on querying missing backup
-        """
-        res = None
-        try:
-            yield defer.ensureDeferred(
-                self.handler.get_room_keys(self.local_user, "bogus_version")
-            )
-        except errors.SynapseError as e:
-            res = e.code
+        """Check that we get a 404 on querying missing backup"""
+        e = self.get_failure(
+            self.handler.get_room_keys(self.local_user, "bogus_version"), SynapseError
+        )
+        res = e.value.code
         self.assertEqual(res, 404)
 
-    @defer.inlineCallbacks
     def test_get_missing_room_keys(self):
-        """Check we get an empty response from an empty backup
-        """
-        version = yield defer.ensureDeferred(
+        """Check we get an empty response from an empty backup"""
+        version = self.get_success(
             self.handler.create_version(
                 self.local_user,
                 {
@@ -369,33 +315,26 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         )
         self.assertEqual(version, "1")
 
-        res = yield defer.ensureDeferred(
-            self.handler.get_room_keys(self.local_user, version)
-        )
+        res = self.get_success(self.handler.get_room_keys(self.local_user, version))
         self.assertDictEqual(res, {"rooms": {}})
 
     # TODO: test the locking semantics when uploading room_keys,
     # although this is probably best done in sytest
 
-    @defer.inlineCallbacks
     def test_upload_room_keys_no_versions(self):
-        """Check that we get a 404 on uploading keys when no versions are defined
-        """
-        res = None
-        try:
-            yield defer.ensureDeferred(
-                self.handler.upload_room_keys(self.local_user, "no_version", room_keys)
-            )
-        except errors.SynapseError as e:
-            res = e.code
+        """Check that we get a 404 on uploading keys when no versions are defined"""
+        e = self.get_failure(
+            self.handler.upload_room_keys(self.local_user, "no_version", room_keys),
+            SynapseError,
+        )
+        res = e.value.code
         self.assertEqual(res, 404)
 
-    @defer.inlineCallbacks
     def test_upload_room_keys_bogus_version(self):
         """Check that we get a 404 on uploading keys when an nonexistent version
         is specified
         """
-        version = yield defer.ensureDeferred(
+        version = self.get_success(
             self.handler.create_version(
                 self.local_user,
                 {
@@ -406,22 +345,16 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         )
         self.assertEqual(version, "1")
 
-        res = None
-        try:
-            yield defer.ensureDeferred(
-                self.handler.upload_room_keys(
-                    self.local_user, "bogus_version", room_keys
-                )
-            )
-        except errors.SynapseError as e:
-            res = e.code
+        e = self.get_failure(
+            self.handler.upload_room_keys(self.local_user, "bogus_version", room_keys),
+            SynapseError,
+        )
+        res = e.value.code
         self.assertEqual(res, 404)
 
-    @defer.inlineCallbacks
     def test_upload_room_keys_wrong_version(self):
-        """Check that we get a 403 on uploading keys for an old version
-        """
-        version = yield defer.ensureDeferred(
+        """Check that we get a 403 on uploading keys for an old version"""
+        version = self.get_success(
             self.handler.create_version(
                 self.local_user,
                 {
@@ -432,7 +365,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         )
         self.assertEqual(version, "1")
 
-        version = yield defer.ensureDeferred(
+        version = self.get_success(
             self.handler.create_version(
                 self.local_user,
                 {
@@ -443,20 +376,15 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         )
         self.assertEqual(version, "2")
 
-        res = None
-        try:
-            yield defer.ensureDeferred(
-                self.handler.upload_room_keys(self.local_user, "1", room_keys)
-            )
-        except errors.SynapseError as e:
-            res = e.code
+        e = self.get_failure(
+            self.handler.upload_room_keys(self.local_user, "1", room_keys), SynapseError
+        )
+        res = e.value.code
         self.assertEqual(res, 403)
 
-    @defer.inlineCallbacks
     def test_upload_room_keys_insert(self):
-        """Check that we can insert and retrieve keys for a session
-        """
-        version = yield defer.ensureDeferred(
+        """Check that we can insert and retrieve keys for a session"""
+        version = self.get_success(
             self.handler.create_version(
                 self.local_user,
                 {
@@ -467,17 +395,15 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         )
         self.assertEqual(version, "1")
 
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_room_keys(self.local_user, version, room_keys)
         )
 
-        res = yield defer.ensureDeferred(
-            self.handler.get_room_keys(self.local_user, version)
-        )
+        res = self.get_success(self.handler.get_room_keys(self.local_user, version))
         self.assertDictEqual(res, room_keys)
 
         # check getting room_keys for a given room
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.get_room_keys(
                 self.local_user, version, room_id="!abc:matrix.org"
             )
@@ -485,18 +411,17 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         self.assertDictEqual(res, room_keys)
 
         # check getting room_keys for a given session_id
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.get_room_keys(
                 self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
             )
         )
         self.assertDictEqual(res, room_keys)
 
-    @defer.inlineCallbacks
     def test_upload_room_keys_merge(self):
         """Check that we can upload a new room_key for an existing session and
         have it correctly merged"""
-        version = yield defer.ensureDeferred(
+        version = self.get_success(
             self.handler.create_version(
                 self.local_user,
                 {
@@ -507,12 +432,12 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         )
         self.assertEqual(version, "1")
 
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_room_keys(self.local_user, version, room_keys)
         )
 
         # get the etag to compare to future versions
-        res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
+        res = self.get_success(self.handler.get_version_info(self.local_user))
         backup_etag = res["etag"]
         self.assertEqual(res["count"], 1)
 
@@ -522,37 +447,33 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         # test that increasing the message_index doesn't replace the existing session
         new_room_key["first_message_index"] = 2
         new_room_key["session_data"] = "new"
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_room_keys(self.local_user, version, new_room_keys)
         )
 
-        res = yield defer.ensureDeferred(
-            self.handler.get_room_keys(self.local_user, version)
-        )
+        res = self.get_success(self.handler.get_room_keys(self.local_user, version))
         self.assertEqual(
             res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"],
             "SSBBTSBBIEZJU0gK",
         )
 
         # the etag should be the same since the session did not change
-        res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
+        res = self.get_success(self.handler.get_version_info(self.local_user))
         self.assertEqual(res["etag"], backup_etag)
 
         # test that marking the session as verified however /does/ replace it
         new_room_key["is_verified"] = True
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_room_keys(self.local_user, version, new_room_keys)
         )
 
-        res = yield defer.ensureDeferred(
-            self.handler.get_room_keys(self.local_user, version)
-        )
+        res = self.get_success(self.handler.get_room_keys(self.local_user, version))
         self.assertEqual(
             res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new"
         )
 
         # the etag should NOT be equal now, since the key changed
-        res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
+        res = self.get_success(self.handler.get_version_info(self.local_user))
         self.assertNotEqual(res["etag"], backup_etag)
         backup_etag = res["etag"]
 
@@ -560,28 +481,24 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         # with a lower forwarding count
         new_room_key["forwarded_count"] = 2
         new_room_key["session_data"] = "other"
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_room_keys(self.local_user, version, new_room_keys)
         )
 
-        res = yield defer.ensureDeferred(
-            self.handler.get_room_keys(self.local_user, version)
-        )
+        res = self.get_success(self.handler.get_room_keys(self.local_user, version))
         self.assertEqual(
             res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new"
         )
 
         # the etag should be the same since the session did not change
-        res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
+        res = self.get_success(self.handler.get_version_info(self.local_user))
         self.assertEqual(res["etag"], backup_etag)
 
         # TODO: check edge cases as well as the common variations here
 
-    @defer.inlineCallbacks
     def test_delete_room_keys(self):
-        """Check that we can insert and delete keys for a session
-        """
-        version = yield defer.ensureDeferred(
+        """Check that we can insert and delete keys for a session"""
+        version = self.get_success(
             self.handler.create_version(
                 self.local_user,
                 {
@@ -593,13 +510,11 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         self.assertEqual(version, "1")
 
         # check for bulk-delete
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_room_keys(self.local_user, version, room_keys)
         )
-        yield defer.ensureDeferred(
-            self.handler.delete_room_keys(self.local_user, version)
-        )
-        res = yield defer.ensureDeferred(
+        self.get_success(self.handler.delete_room_keys(self.local_user, version))
+        res = self.get_success(
             self.handler.get_room_keys(
                 self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
             )
@@ -607,15 +522,15 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         self.assertDictEqual(res, {"rooms": {}})
 
         # check for bulk-delete per room
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_room_keys(self.local_user, version, room_keys)
         )
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.delete_room_keys(
                 self.local_user, version, room_id="!abc:matrix.org"
             )
         )
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.get_room_keys(
                 self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
             )
@@ -623,15 +538,15 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         self.assertDictEqual(res, {"rooms": {}})
 
         # check for bulk-delete per session
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_room_keys(self.local_user, version, room_keys)
         )
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.delete_room_keys(
                 self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
             )
         )
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.get_room_keys(
                 self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
             )
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index 0b24b89a2e..3af361195b 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -16,7 +16,7 @@ import logging
 from unittest import TestCase
 
 from synapse.api.constants import EventTypes
-from synapse.api.errors import AuthError, Codes, SynapseError
+from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError
 from synapse.api.room_versions import RoomVersions
 from synapse.events import EventBase
 from synapse.federation.federation_base import event_from_pdu_json
@@ -191,6 +191,58 @@ class FederationTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(sg, sg2)
 
+    @unittest.override_config(
+        {"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}}
+    )
+    def test_invite_by_user_ratelimit(self):
+        """Tests that invites from federation to a particular user are
+        actually rate-limited.
+        """
+        other_server = "otherserver"
+        other_user = "@otheruser:" + other_server
+
+        # create the room
+        user_id = self.register_user("kermit", "test")
+        tok = self.login("kermit", "test")
+
+        def create_invite():
+            room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+            room_version = self.get_success(self.store.get_room_version(room_id))
+            return event_from_pdu_json(
+                {
+                    "type": EventTypes.Member,
+                    "content": {"membership": "invite"},
+                    "room_id": room_id,
+                    "sender": other_user,
+                    "state_key": "@user:test",
+                    "depth": 32,
+                    "prev_events": [],
+                    "auth_events": [],
+                    "origin_server_ts": self.clock.time_msec(),
+                },
+                room_version,
+            )
+
+        for i in range(3):
+            event = create_invite()
+            self.get_success(
+                self.handler.on_invite_request(
+                    other_server,
+                    event,
+                    event.room_version,
+                )
+            )
+
+        event = create_invite()
+        self.get_failure(
+            self.handler.on_invite_request(
+                other_server,
+                event,
+                event.room_version,
+            ),
+            exc=LimitExceededError,
+        )
+
     def _build_and_send_join_event(self, other_server, other_user, room_id):
         join_event = self.get_success(
             self.handler.on_make_join_request(other_server, room_id, other_user)
diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py
index f955dfa490..a0d1ebdbe3 100644
--- a/tests/handlers/test_message.py
+++ b/tests/handlers/test_message.py
@@ -44,7 +44,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
         self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token)
 
         self.info = self.get_success(
-            self.hs.get_datastore().get_user_by_access_token(self.access_token,)
+            self.hs.get_datastore().get_user_by_access_token(
+                self.access_token,
+            )
         )
         self.token_id = self.info.token_id
 
@@ -169,8 +171,7 @@ class ServerAclValidationTestCase(unittest.HomeserverTestCase):
         self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token)
 
     def test_allow_server_acl(self):
-        """Test that sending an ACL that blocks everyone but ourselves works.
-        """
+        """Test that sending an ACL that blocks everyone but ourselves works."""
 
         self.helper.send_state(
             self.room_id,
@@ -181,8 +182,7 @@ class ServerAclValidationTestCase(unittest.HomeserverTestCase):
         )
 
     def test_deny_server_acl_block_outselves(self):
-        """Test that sending an ACL that blocks ourselves does not work.
-        """
+        """Test that sending an ACL that blocks ourselves does not work."""
         self.helper.send_state(
             self.room_id,
             EventTypes.ServerACL,
@@ -192,8 +192,7 @@ class ServerAclValidationTestCase(unittest.HomeserverTestCase):
         )
 
     def test_deny_redact_server_acl(self):
-        """Test that attempting to redact an ACL is blocked.
-        """
+        """Test that attempting to redact an ACL is blocked."""
 
         body = self.helper.send_state(
             self.room_id,
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index b3dfa40d25..cf1de28fa9 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -24,7 +24,7 @@ from synapse.handlers.sso import MappingException
 from synapse.server import HomeServer
 from synapse.types import UserID
 
-from tests.test_utils import FakeResponse, simple_async_mock
+from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
 from tests.unittest import HomeserverTestCase, override_config
 
 try:
@@ -40,7 +40,7 @@ ISSUER = "https://issuer/"
 CLIENT_ID = "test-client-id"
 CLIENT_SECRET = "test-client-secret"
 BASE_URL = "https://synapse/"
-CALLBACK_URL = BASE_URL + "_synapse/oidc/callback"
+CALLBACK_URL = BASE_URL + "_synapse/client/oidc/callback"
 SCOPES = ["openid"]
 
 AUTHORIZATION_ENDPOINT = ISSUER + "authorize"
@@ -58,12 +58,6 @@ COMMON_CONFIG = {
 }
 
 
-# The cookie name and path don't really matter, just that it has to be coherent
-# between the callback & redirect handlers.
-COOKIE_NAME = b"oidc_session"
-COOKIE_PATH = "/_synapse/oidc"
-
-
 class TestMappingProvider:
     @staticmethod
     def parse_config(config):
@@ -137,7 +131,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
         return config
 
     def make_homeserver(self, reactor, clock):
-
         self.http_client = Mock(spec=["get_json"])
         self.http_client.get_json.side_effect = get_json
         self.http_client.user_agent = "Synapse Test"
@@ -157,7 +150,15 @@ class OidcHandlerTestCase(HomeserverTestCase):
         return hs
 
     def metadata_edit(self, values):
-        return patch.dict(self.provider._provider_metadata, values)
+        """Modify the result that will be returned by the well-known query"""
+
+        async def patched_get_json(uri):
+            res = await get_json(uri)
+            if uri == WELL_KNOWN:
+                res.update(values)
+            return res
+
+        return patch.object(self.http_client, "get_json", patched_get_json)
 
     def assertRenderedError(self, error, error_description=None):
         self.render_error.assert_called_once()
@@ -218,7 +219,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.http_client.get_json.assert_called_once_with(JWKS_URI)
 
         # Throw if the JWKS uri is missing
-        with self.metadata_edit({"jwks_uri": None}):
+        original = self.provider.load_metadata
+
+        async def patched_load_metadata():
+            m = (await original()).copy()
+            m.update({"jwks_uri": None})
+            return m
+
+        with patch.object(self.provider, "load_metadata", patched_load_metadata):
             self.get_failure(self.provider.load_jwks(force=True), RuntimeError)
 
         # Return empty key set if JWKS are not used
@@ -228,55 +236,60 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.http_client.get_json.assert_not_called()
         self.assertEqual(jwks, {"keys": []})
 
-    @override_config({"oidc_config": COMMON_CONFIG})
     def test_validate_config(self):
         """Provider metadatas are extensively validated."""
         h = self.provider
 
+        def force_load_metadata():
+            async def force_load():
+                return await h.load_metadata(force=True)
+
+            return get_awaitable_result(force_load())
+
         # Default test config does not throw
-        h._validate_metadata()
+        force_load_metadata()
 
         with self.metadata_edit({"issuer": None}):
-            self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata)
+            self.assertRaisesRegex(ValueError, "issuer", force_load_metadata)
 
         with self.metadata_edit({"issuer": "http://insecure/"}):
-            self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata)
+            self.assertRaisesRegex(ValueError, "issuer", force_load_metadata)
 
         with self.metadata_edit({"issuer": "https://invalid/?because=query"}):
-            self.assertRaisesRegex(ValueError, "issuer", h._validate_metadata)
+            self.assertRaisesRegex(ValueError, "issuer", force_load_metadata)
 
         with self.metadata_edit({"authorization_endpoint": None}):
             self.assertRaisesRegex(
-                ValueError, "authorization_endpoint", h._validate_metadata
+                ValueError, "authorization_endpoint", force_load_metadata
             )
 
         with self.metadata_edit({"authorization_endpoint": "http://insecure/auth"}):
             self.assertRaisesRegex(
-                ValueError, "authorization_endpoint", h._validate_metadata
+                ValueError, "authorization_endpoint", force_load_metadata
             )
 
         with self.metadata_edit({"token_endpoint": None}):
-            self.assertRaisesRegex(ValueError, "token_endpoint", h._validate_metadata)
+            self.assertRaisesRegex(ValueError, "token_endpoint", force_load_metadata)
 
         with self.metadata_edit({"token_endpoint": "http://insecure/token"}):
-            self.assertRaisesRegex(ValueError, "token_endpoint", h._validate_metadata)
+            self.assertRaisesRegex(ValueError, "token_endpoint", force_load_metadata)
 
         with self.metadata_edit({"jwks_uri": None}):
-            self.assertRaisesRegex(ValueError, "jwks_uri", h._validate_metadata)
+            self.assertRaisesRegex(ValueError, "jwks_uri", force_load_metadata)
 
         with self.metadata_edit({"jwks_uri": "http://insecure/jwks.json"}):
-            self.assertRaisesRegex(ValueError, "jwks_uri", h._validate_metadata)
+            self.assertRaisesRegex(ValueError, "jwks_uri", force_load_metadata)
 
         with self.metadata_edit({"response_types_supported": ["id_token"]}):
             self.assertRaisesRegex(
-                ValueError, "response_types_supported", h._validate_metadata
+                ValueError, "response_types_supported", force_load_metadata
             )
 
         with self.metadata_edit(
             {"token_endpoint_auth_methods_supported": ["client_secret_basic"]}
         ):
             # should not throw, as client_secret_basic is the default auth method
-            h._validate_metadata()
+            force_load_metadata()
 
         with self.metadata_edit(
             {"token_endpoint_auth_methods_supported": ["client_secret_post"]}
@@ -284,7 +297,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             self.assertRaisesRegex(
                 ValueError,
                 "token_endpoint_auth_methods_supported",
-                h._validate_metadata,
+                force_load_metadata,
             )
 
         # Tests for configs that require the userinfo endpoint
@@ -293,28 +306,30 @@ class OidcHandlerTestCase(HomeserverTestCase):
         h._user_profile_method = "userinfo_endpoint"
         self.assertTrue(h._uses_userinfo)
 
-        # Revert the profile method and do not request the "openid" scope.
+        # Revert the profile method and do not request the "openid" scope: this should
+        # mean that we check for a userinfo endpoint
         h._user_profile_method = "auto"
         h._scopes = []
         self.assertTrue(h._uses_userinfo)
-        self.assertRaisesRegex(ValueError, "userinfo_endpoint", h._validate_metadata)
+        with self.metadata_edit({"userinfo_endpoint": None}):
+            self.assertRaisesRegex(ValueError, "userinfo_endpoint", force_load_metadata)
 
-        with self.metadata_edit(
-            {"userinfo_endpoint": USERINFO_ENDPOINT, "jwks_uri": None}
-        ):
-            # Shouldn't raise with a valid userinfo, even without
-            h._validate_metadata()
+        with self.metadata_edit({"jwks_uri": None}):
+            # Shouldn't raise with a valid userinfo, even without jwks
+            force_load_metadata()
 
     @override_config({"oidc_config": {"skip_verification": True}})
     def test_skip_verification(self):
         """Provider metadata validation can be disabled by config."""
         with self.metadata_edit({"issuer": "http://insecure"}):
             # This should not throw
-            self.provider._validate_metadata()
+            get_awaitable_result(self.provider.load_metadata())
 
     def test_redirect_request(self):
         """The redirect request has the right arguments & generates a valid session cookie."""
-        req = Mock(spec=["addCookie"])
+        req = Mock(spec=["cookies"])
+        req.cookies = []
+
         url = self.get_success(
             self.provider.handle_redirect_request(req, b"http://client/redirect")
         )
@@ -333,16 +348,16 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.assertEqual(len(params["state"]), 1)
         self.assertEqual(len(params["nonce"]), 1)
 
-        # Check what is in the cookie
-        # note: python3.5 mock does not have the .called_once() method
-        calls = req.addCookie.call_args_list
-        self.assertEqual(len(calls), 1)  # called once
-        # For some reason, call.args does not work with python3.5
-        args = calls[0][0]
-        kwargs = calls[0][1]
-        self.assertEqual(args[0], COOKIE_NAME)
-        self.assertEqual(kwargs["path"], COOKIE_PATH)
-        cookie = args[1]
+        # Check what is in the cookies
+        self.assertEqual(len(req.cookies), 2)  # two cookies
+        cookie_header = req.cookies[0]
+
+        # The cookie name and path don't really matter, just that it has to be coherent
+        # between the callback & redirect handlers.
+        parts = [p.strip() for p in cookie_header.split(b";")]
+        self.assertIn(b"Path=/_synapse/client/oidc", parts)
+        name, cookie = parts[0].split(b"=")
+        self.assertEqual(name, b"oidc_session")
 
         macaroon = pymacaroons.Macaroon.deserialize(cookie)
         state = self.handler._token_generator._get_value_from_macaroon(
@@ -419,7 +434,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.get_success(self.handler.handle_oidc_callback(request))
 
         auth_handler.complete_sso_login.assert_called_once_with(
-            expected_user_id, request, client_redirect_url, None,
+            expected_user_id, request, client_redirect_url, None, new_user=True
         )
         self.provider._exchange_code.assert_called_once_with(code)
         self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
@@ -450,7 +465,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.get_success(self.handler.handle_oidc_callback(request))
 
         auth_handler.complete_sso_login.assert_called_once_with(
-            expected_user_id, request, client_redirect_url, None,
+            expected_user_id, request, client_redirect_url, None, new_user=False
         )
         self.provider._exchange_code.assert_called_once_with(code)
         self.provider._parse_id_token.assert_not_called()
@@ -473,7 +488,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
 
     def test_callback_session(self):
         """The callback verifies the session presence and validity"""
-        request = Mock(spec=["args", "getCookie", "addCookie"])
+        request = Mock(spec=["args", "getCookie", "cookies"])
 
         # Missing cookie
         request.args = {}
@@ -496,7 +511,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
 
         # Mismatching session
         session = self._generate_oidc_session_token(
-            state="state", nonce="nonce", client_redirect_url="http://client/redirect",
+            state="state",
+            nonce="nonce",
+            client_redirect_url="http://client/redirect",
         )
         request.args = {}
         request.args[b"state"] = [b"mismatching state"]
@@ -551,7 +568,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
         # Internal server error with no JSON body
         self.http_client.request = simple_async_mock(
             return_value=FakeResponse(
-                code=500, phrase=b"Internal Server Error", body=b"Not JSON",
+                code=500,
+                phrase=b"Internal Server Error",
+                body=b"Not JSON",
             )
         )
         exc = self.get_failure(self.provider._exchange_code(code), OidcError)
@@ -571,7 +590,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
 
         # 4xx error without "error" field
         self.http_client.request = simple_async_mock(
-            return_value=FakeResponse(code=400, phrase=b"Bad request", body=b"{}",)
+            return_value=FakeResponse(
+                code=400,
+                phrase=b"Bad request",
+                body=b"{}",
+            )
         )
         exc = self.get_failure(self.provider._exchange_code(code), OidcError)
         self.assertEqual(exc.value.error, "server_error")
@@ -579,7 +602,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
         # 2xx error with "error" field
         self.http_client.request = simple_async_mock(
             return_value=FakeResponse(
-                code=200, phrase=b"OK", body=b'{"error": "some_error"}',
+                code=200,
+                phrase=b"OK",
+                body=b'{"error": "some_error"}',
             )
         )
         exc = self.get_failure(self.provider._exchange_code(code), OidcError)
@@ -616,14 +641,20 @@ class OidcHandlerTestCase(HomeserverTestCase):
         state = "state"
         client_redirect_url = "http://client/redirect"
         session = self._generate_oidc_session_token(
-            state=state, nonce="nonce", client_redirect_url=client_redirect_url,
+            state=state,
+            nonce="nonce",
+            client_redirect_url=client_redirect_url,
         )
         request = _build_callback_request("code", state, session)
 
         self.get_success(self.handler.handle_oidc_callback(request))
 
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@foo:test", request, client_redirect_url, {"phone": "1234567"},
+            "@foo:test",
+            request,
+            client_redirect_url,
+            {"phone": "1234567"},
+            new_user=True,
         )
 
     def test_map_userinfo_to_user(self):
@@ -637,7 +668,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         }
         self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", ANY, ANY, None,
+            "@test_user:test", ANY, ANY, None, new_user=True
         )
         auth_handler.complete_sso_login.reset_mock()
 
@@ -648,7 +679,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         }
         self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user_2:test", ANY, ANY, None,
+            "@test_user_2:test", ANY, ANY, None, new_user=True
         )
         auth_handler.complete_sso_login.reset_mock()
 
@@ -685,14 +716,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
         }
         self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            user.to_string(), ANY, ANY, None,
+            user.to_string(), ANY, ANY, None, new_user=False
         )
         auth_handler.complete_sso_login.reset_mock()
 
         # Subsequent calls should map to the same mxid.
         self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            user.to_string(), ANY, ANY, None,
+            user.to_string(), ANY, ANY, None, new_user=False
         )
         auth_handler.complete_sso_login.reset_mock()
 
@@ -707,7 +738,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         }
         self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            user.to_string(), ANY, ANY, None,
+            user.to_string(), ANY, ANY, None, new_user=False
         )
         auth_handler.complete_sso_login.reset_mock()
 
@@ -743,7 +774,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
 
         self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@TEST_USER_2:test", ANY, ANY, None,
+            "@TEST_USER_2:test", ANY, ANY, None, new_user=False
         )
 
     def test_map_userinfo_to_invalid_localpart(self):
@@ -779,7 +810,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
 
         # test_user is already taken, so test_user1 gets registered instead.
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user1:test", ANY, ANY, None,
+            "@test_user1:test", ANY, ANY, None, new_user=True
         )
         auth_handler.complete_sso_login.reset_mock()
 
@@ -875,7 +906,9 @@ async def _make_callback_with_userinfo(
     session = handler._token_generator.generate_oidc_session_token(
         state=state,
         session_data=OidcSessionData(
-            idp_id="oidc", nonce="nonce", client_redirect_url=client_redirect_url,
+            idp_id="oidc",
+            nonce="nonce",
+            client_redirect_url=client_redirect_url,
         ),
     )
     request = _build_callback_request("code", state, session)
@@ -909,13 +942,14 @@ def _build_callback_request(
         spec=[
             "args",
             "getCookie",
-            "addCookie",
+            "cookies",
             "requestHeaders",
             "getClientIP",
             "getHeader",
         ]
     )
 
+    request.cookies = []
     request.getCookie.return_value = session
     request.args = {}
     request.args[b"code"] = [code.encode("utf-8")]
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index f816594ee4..a98a65ae67 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -231,8 +231,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         }
     )
     def test_no_local_user_fallback_login(self):
-        """localdb_enabled can block login with the local password
-        """
+        """localdb_enabled can block login with the local password"""
         self.register_user("localuser", "localpass")
 
         # check_password must return an awaitable
@@ -251,8 +250,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         }
     )
     def test_no_local_user_fallback_ui_auth(self):
-        """localdb_enabled can block ui auth with the local password
-        """
+        """localdb_enabled can block ui auth with the local password"""
         self.register_user("localuser", "localpass")
 
         # allow login via the auth provider
@@ -594,7 +592,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         )
 
     def _delete_device(
-        self, access_token: str, device: str, body: Union[JsonDict, bytes] = b"",
+        self,
+        access_token: str,
+        device: str,
+        body: Union[JsonDict, bytes] = b"",
     ) -> FakeChannel:
         """Delete an individual device."""
         channel = self.make_request(
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 0794b32c9c..be2ee26f07 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -589,8 +589,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
         )
 
     def _add_new_user(self, room_id, user_id):
-        """Add new user to the room by creating an event and poking the federation API.
-        """
+        """Add new user to the room by creating an event and poking the federation API."""
 
         hostname = get_domain_from_id(user_id)
 
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 022943a10a..18ca8b84f5 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -13,25 +13,20 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
 from mock import Mock
 
-from twisted.internet import defer
-
 import synapse.types
 from synapse.api.errors import AuthError, SynapseError
 from synapse.types import UserID
 
 from tests import unittest
 from tests.test_utils import make_awaitable
-from tests.utils import setup_test_homeserver
 
 
-class ProfileTestCase(unittest.TestCase):
+class ProfileTestCase(unittest.HomeserverTestCase):
     """ Tests profile management. """
 
-    @defer.inlineCallbacks
-    def setUp(self):
+    def make_homeserver(self, reactor, clock):
         self.mock_federation = Mock()
         self.mock_registry = Mock()
 
@@ -42,39 +37,35 @@ class ProfileTestCase(unittest.TestCase):
 
         self.mock_registry.register_query_handler = register_query_handler
 
-        hs = yield setup_test_homeserver(
-            self.addCleanup,
+        hs = self.setup_test_homeserver(
             federation_client=self.mock_federation,
             federation_server=Mock(),
             federation_registry=self.mock_registry,
         )
+        return hs
 
+    def prepare(self, reactor, clock, hs):
         self.store = hs.get_datastore()
 
         self.frank = UserID.from_string("@1234ABCD:test")
         self.bob = UserID.from_string("@4567:test")
         self.alice = UserID.from_string("@alice:remote")
 
-        yield defer.ensureDeferred(self.store.create_profile(self.frank.localpart))
+        self.get_success(self.store.create_profile(self.frank.localpart))
 
         self.handler = hs.get_profile_handler()
-        self.hs = hs
 
-    @defer.inlineCallbacks
     def test_get_my_name(self):
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.set_profile_displayname(self.frank.localpart, "Frank")
         )
 
-        displayname = yield defer.ensureDeferred(
-            self.handler.get_displayname(self.frank)
-        )
+        displayname = self.get_success(self.handler.get_displayname(self.frank))
 
         self.assertEquals("Frank", displayname)
 
-    @defer.inlineCallbacks
     def test_set_my_name(self):
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.set_displayname(
                 self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
             )
@@ -82,7 +73,7 @@ class ProfileTestCase(unittest.TestCase):
 
         self.assertEquals(
             (
-                yield defer.ensureDeferred(
+                self.get_success(
                     self.store.get_profile_displayname(self.frank.localpart)
                 )
             ),
@@ -90,7 +81,7 @@ class ProfileTestCase(unittest.TestCase):
         )
 
         # Set displayname again
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.set_displayname(
                 self.frank, synapse.types.create_requester(self.frank), "Frank"
             )
@@ -98,7 +89,7 @@ class ProfileTestCase(unittest.TestCase):
 
         self.assertEquals(
             (
-                yield defer.ensureDeferred(
+                self.get_success(
                     self.store.get_profile_displayname(self.frank.localpart)
                 )
             ),
@@ -106,32 +97,27 @@ class ProfileTestCase(unittest.TestCase):
         )
 
         # Set displayname to an empty string
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.set_displayname(
                 self.frank, synapse.types.create_requester(self.frank), ""
             )
         )
 
         self.assertIsNone(
-            (
-                yield defer.ensureDeferred(
-                    self.store.get_profile_displayname(self.frank.localpart)
-                )
-            )
+            (self.get_success(self.store.get_profile_displayname(self.frank.localpart)))
         )
 
-    @defer.inlineCallbacks
     def test_set_my_name_if_disabled(self):
         self.hs.config.enable_set_displayname = False
 
         # Setting displayname for the first time is allowed
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.set_profile_displayname(self.frank.localpart, "Frank")
         )
 
         self.assertEquals(
             (
-                yield defer.ensureDeferred(
+                self.get_success(
                     self.store.get_profile_displayname(self.frank.localpart)
                 )
             ),
@@ -139,33 +125,27 @@ class ProfileTestCase(unittest.TestCase):
         )
 
         # Setting displayname a second time is forbidden
-        d = defer.ensureDeferred(
+        self.get_failure(
             self.handler.set_displayname(
                 self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
-            )
+            ),
+            SynapseError,
         )
 
-        yield self.assertFailure(d, SynapseError)
-
-    @defer.inlineCallbacks
     def test_set_my_name_noauth(self):
-        d = defer.ensureDeferred(
+        self.get_failure(
             self.handler.set_displayname(
                 self.frank, synapse.types.create_requester(self.bob), "Frank Jr."
-            )
+            ),
+            AuthError,
         )
 
-        yield self.assertFailure(d, AuthError)
-
-    @defer.inlineCallbacks
     def test_get_other_name(self):
         self.mock_federation.make_query.return_value = make_awaitable(
             {"displayname": "Alice"}
         )
 
-        displayname = yield defer.ensureDeferred(
-            self.handler.get_displayname(self.alice)
-        )
+        displayname = self.get_success(self.handler.get_displayname(self.alice))
 
         self.assertEquals(displayname, "Alice")
         self.mock_federation.make_query.assert_called_with(
@@ -175,14 +155,11 @@ class ProfileTestCase(unittest.TestCase):
             ignore_backoff=True,
         )
 
-    @defer.inlineCallbacks
     def test_incoming_fed_query(self):
-        yield defer.ensureDeferred(self.store.create_profile("caroline"))
-        yield defer.ensureDeferred(
-            self.store.set_profile_displayname("caroline", "Caroline")
-        )
+        self.get_success(self.store.create_profile("caroline"))
+        self.get_success(self.store.set_profile_displayname("caroline", "Caroline"))
 
-        response = yield defer.ensureDeferred(
+        response = self.get_success(
             self.query_handlers["profile"](
                 {"user_id": "@caroline:test", "field": "displayname"}
             )
@@ -190,20 +167,18 @@ class ProfileTestCase(unittest.TestCase):
 
         self.assertEquals({"displayname": "Caroline"}, response)
 
-    @defer.inlineCallbacks
     def test_get_my_avatar(self):
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.set_profile_avatar_url(
                 self.frank.localpart, "http://my.server/me.png"
             )
         )
-        avatar_url = yield defer.ensureDeferred(self.handler.get_avatar_url(self.frank))
+        avatar_url = self.get_success(self.handler.get_avatar_url(self.frank))
 
         self.assertEquals("http://my.server/me.png", avatar_url)
 
-    @defer.inlineCallbacks
     def test_set_my_avatar(self):
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.set_avatar_url(
                 self.frank,
                 synapse.types.create_requester(self.frank),
@@ -212,16 +187,12 @@ class ProfileTestCase(unittest.TestCase):
         )
 
         self.assertEquals(
-            (
-                yield defer.ensureDeferred(
-                    self.store.get_profile_avatar_url(self.frank.localpart)
-                )
-            ),
+            (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
             "http://my.server/pic.gif",
         )
 
         # Set avatar again
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.set_avatar_url(
                 self.frank,
                 synapse.types.create_requester(self.frank),
@@ -230,56 +201,44 @@ class ProfileTestCase(unittest.TestCase):
         )
 
         self.assertEquals(
-            (
-                yield defer.ensureDeferred(
-                    self.store.get_profile_avatar_url(self.frank.localpart)
-                )
-            ),
+            (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
             "http://my.server/me.png",
         )
 
         # Set avatar to an empty string
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.set_avatar_url(
-                self.frank, synapse.types.create_requester(self.frank), "",
+                self.frank,
+                synapse.types.create_requester(self.frank),
+                "",
             )
         )
 
         self.assertIsNone(
-            (
-                yield defer.ensureDeferred(
-                    self.store.get_profile_avatar_url(self.frank.localpart)
-                )
-            ),
+            (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
         )
 
-    @defer.inlineCallbacks
     def test_set_my_avatar_if_disabled(self):
         self.hs.config.enable_set_avatar_url = False
 
         # Setting displayname for the first time is allowed
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.set_profile_avatar_url(
                 self.frank.localpart, "http://my.server/me.png"
             )
         )
 
         self.assertEquals(
-            (
-                yield defer.ensureDeferred(
-                    self.store.get_profile_avatar_url(self.frank.localpart)
-                )
-            ),
+            (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
             "http://my.server/me.png",
         )
 
         # Set avatar a second time is forbidden
-        d = defer.ensureDeferred(
+        self.get_failure(
             self.handler.set_avatar_url(
                 self.frank,
                 synapse.types.create_requester(self.frank),
                 "http://my.server/pic.gif",
-            )
+            ),
+            SynapseError,
         )
-
-        yield self.assertFailure(d, SynapseError)
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index 261c7083d1..029af2853e 100644
--- a/tests/handlers/test_saml.py
+++ b/tests/handlers/test_saml.py
@@ -131,7 +131,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # check that the auth handler got called as expected
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", request, "redirect_uri", None
+            "@test_user:test", request, "redirect_uri", None, new_user=True
         )
 
     @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
@@ -157,7 +157,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # check that the auth handler got called as expected
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", request, "", None
+            "@test_user:test", request, "", None, new_user=False
         )
 
         # Subsequent calls should map to the same mxid.
@@ -166,7 +166,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
             self.handler._handle_authn_response(request, saml_response, "")
         )
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", request, "", None
+            "@test_user:test", request, "", None, new_user=False
         )
 
     def test_map_saml_response_to_invalid_localpart(self):
@@ -214,7 +214,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # test_user is already taken, so test_user1 gets registered instead.
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user1:test", request, "", None
+            "@test_user1:test", request, "", None, new_user=True
         )
         auth_handler.complete_sso_login.reset_mock()
 
@@ -259,7 +259,61 @@ class SamlHandlerTestCase(HomeserverTestCase):
         )
         self.assertEqual(e.value.location, b"https://custom-saml-redirect/")
 
+    @override_config(
+        {
+            "saml2_config": {
+                "attribute_requirements": [
+                    {"attribute": "userGroup", "value": "staff"},
+                    {"attribute": "department", "value": "sales"},
+                ],
+            },
+        }
+    )
+    def test_attribute_requirements(self):
+        """The required attributes must be met from the SAML response."""
+
+        # stub out the auth handler
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
+
+        # The response doesn't have the proper userGroup or department.
+        saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
+        request = _mock_request()
+        self.get_success(
+            self.handler._handle_authn_response(request, saml_response, "redirect_uri")
+        )
+        auth_handler.complete_sso_login.assert_not_called()
+
+        # The response doesn't have the proper department.
+        saml_response = FakeAuthnResponse(
+            {"uid": "test_user", "username": "test_user", "userGroup": ["staff"]}
+        )
+        request = _mock_request()
+        self.get_success(
+            self.handler._handle_authn_response(request, saml_response, "redirect_uri")
+        )
+        auth_handler.complete_sso_login.assert_not_called()
+
+        # Add the proper attributes and it should succeed.
+        saml_response = FakeAuthnResponse(
+            {
+                "uid": "test_user",
+                "username": "test_user",
+                "userGroup": ["staff", "admin"],
+                "department": ["sales"],
+            }
+        )
+        request.reset_mock()
+        self.get_success(
+            self.handler._handle_authn_response(request, saml_response, "redirect_uri")
+        )
+
+        # check that the auth handler got called as expected
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@test_user:test", request, "redirect_uri", None, new_user=True
+        )
+
 
 def _mock_request():
     """Returns a mock which will stand in as a SynapseRequest"""
-    return Mock(spec=["getClientIP", "getHeader"])
+    return Mock(spec=["getClientIP", "getHeader", "_disconnected"])
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 96e5bdac4a..24e7138196 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -143,14 +143,14 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
         self.datastore.get_current_state_deltas = Mock(return_value=(0, None))
 
         self.datastore.get_to_device_stream_token = lambda: 0
-        self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: make_awaitable(
-            ([], 0)
+        self.datastore.get_new_device_msgs_for_remote = (
+            lambda *args, **kargs: make_awaitable(([], 0))
         )
-        self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: make_awaitable(
-            None
+        self.datastore.delete_device_msgs_for_remote = (
+            lambda *args, **kargs: make_awaitable(None)
         )
-        self.datastore.set_received_txn_response = lambda *args, **kwargs: make_awaitable(
-            None
+        self.datastore.set_received_txn_response = (
+            lambda *args, **kwargs: make_awaitable(None)
         )
 
     def test_started_typing_local(self):
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 9c886d671a..3572e54c5d 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -200,7 +200,9 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
 
         # Check that the room has an encryption state event
         event_content = self.helper.get_state(
-            room_id=room_id, event_type=EventTypes.RoomEncryption, tok=user_token,
+            room_id=room_id,
+            event_type=EventTypes.RoomEncryption,
+            tok=user_token,
         )
         self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT})
 
@@ -209,7 +211,9 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
 
         # Check that the room has an encryption state event
         event_content = self.helper.get_state(
-            room_id=room_id, event_type=EventTypes.RoomEncryption, tok=user_token,
+            room_id=room_id,
+            event_type=EventTypes.RoomEncryption,
+            tok=user_token,
         )
         self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT})
 
@@ -227,7 +231,9 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
 
         # Check that the room has an encryption state event
         event_content = self.helper.get_state(
-            room_id=room_id, event_type=EventTypes.RoomEncryption, tok=user_token,
+            room_id=room_id,
+            event_type=EventTypes.RoomEncryption,
+            tok=user_token,
         )
         self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT})