summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/config/test_server.py61
-rw-r--r--tests/handlers/test_auth.py133
-rw-r--r--tests/handlers/test_cas.py52
-rw-r--r--tests/handlers/test_e2e_keys.py230
-rw-r--r--tests/handlers/test_e2e_room_keys.py305
-rw-r--r--tests/handlers/test_profile.py121
-rw-r--r--tests/handlers/test_saml.py56
-rw-r--r--tests/push/test_email.py51
-rw-r--r--tests/rest/admin/test_room.py84
-rw-r--r--tests/rest/client/v1/test_login.py16
-rw-r--r--tests/rest/client/v1/test_typing.py28
-rw-r--r--tests/rest/media/v1/test_media_storage.py94
-rw-r--r--tests/test_preview.py103
13 files changed, 797 insertions, 537 deletions
diff --git a/tests/config/test_server.py b/tests/config/test_server.py
index a10d017120..98af7aa675 100644
--- a/tests/config/test_server.py
+++ b/tests/config/test_server.py
@@ -15,7 +15,8 @@
 
 import yaml
 
-from synapse.config.server import ServerConfig, is_threepid_reserved
+from synapse.config._base import ConfigError
+from synapse.config.server import ServerConfig, generate_ip_set, is_threepid_reserved
 
 from tests import unittest
 
@@ -128,3 +129,61 @@ class ServerConfigTestCase(unittest.TestCase):
         )
 
         self.assertEqual(conf["listeners"], expected_listeners)
+
+
+class GenerateIpSetTestCase(unittest.TestCase):
+    def test_empty(self):
+        ip_set = generate_ip_set(())
+        self.assertFalse(ip_set)
+
+        ip_set = generate_ip_set((), ())
+        self.assertFalse(ip_set)
+
+    def test_generate(self):
+        """Check adding IPv4 and IPv6 addresses."""
+        # IPv4 address
+        ip_set = generate_ip_set(("1.2.3.4",))
+        self.assertEqual(len(ip_set.iter_cidrs()), 4)
+
+        # IPv4 CIDR
+        ip_set = generate_ip_set(("1.2.3.4/24",))
+        self.assertEqual(len(ip_set.iter_cidrs()), 4)
+
+        # IPv6 address
+        ip_set = generate_ip_set(("2001:db8::8a2e:370:7334",))
+        self.assertEqual(len(ip_set.iter_cidrs()), 1)
+
+        # IPv6 CIDR
+        ip_set = generate_ip_set(("2001:db8::/104",))
+        self.assertEqual(len(ip_set.iter_cidrs()), 1)
+
+        # The addresses can overlap OK.
+        ip_set = generate_ip_set(("1.2.3.4", "::1.2.3.4"))
+        self.assertEqual(len(ip_set.iter_cidrs()), 4)
+
+    def test_extra(self):
+        """Extra IP addresses are treated the same."""
+        ip_set = generate_ip_set((), ("1.2.3.4",))
+        self.assertEqual(len(ip_set.iter_cidrs()), 4)
+
+        ip_set = generate_ip_set(("1.1.1.1",), ("1.2.3.4",))
+        self.assertEqual(len(ip_set.iter_cidrs()), 8)
+
+        # They can duplicate without error.
+        ip_set = generate_ip_set(("1.2.3.4",), ("1.2.3.4",))
+        self.assertEqual(len(ip_set.iter_cidrs()), 4)
+
+    def test_bad_value(self):
+        """An error should be raised if a bad value is passed in."""
+        with self.assertRaises(ConfigError):
+            generate_ip_set(("not-an-ip",))
+
+        with self.assertRaises(ConfigError):
+            generate_ip_set(("1.2.3.4/128",))
+
+        with self.assertRaises(ConfigError):
+            generate_ip_set((":::",))
+
+        # The following get treated as empty data.
+        self.assertFalse(generate_ip_set(None))
+        self.assertFalse(generate_ip_set({}))
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index e24ce81284..0e42013bb9 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -16,28 +16,21 @@ from mock import Mock
 
 import pymacaroons
 
-from twisted.internet import defer
-
-import synapse
-import synapse.api.errors
-from synapse.api.errors import ResourceLimitError
+from synapse.api.errors import AuthError, ResourceLimitError
 
 from tests import unittest
 from tests.test_utils import make_awaitable
-from tests.utils import setup_test_homeserver
 
 
-class AuthTestCase(unittest.TestCase):
-    @defer.inlineCallbacks
-    def setUp(self):
-        self.hs = yield setup_test_homeserver(self.addCleanup)
-        self.auth_handler = self.hs.get_auth_handler()
-        self.macaroon_generator = self.hs.get_macaroon_generator()
+class AuthTestCase(unittest.HomeserverTestCase):
+    def prepare(self, reactor, clock, hs):
+        self.auth_handler = hs.get_auth_handler()
+        self.macaroon_generator = hs.get_macaroon_generator()
 
         # MAU tests
         # AuthBlocking reads from the hs' config on initialization. We need to
         # modify its config instead of the hs'
-        self.auth_blocking = self.hs.get_auth()._auth_blocking
+        self.auth_blocking = hs.get_auth()._auth_blocking
         self.auth_blocking._max_mau_value = 50
 
         self.small_number_of_users = 1
@@ -52,8 +45,6 @@ class AuthTestCase(unittest.TestCase):
             self.fail("some_user was not in %s" % macaroon.inspect())
 
     def test_macaroon_caveats(self):
-        self.hs.get_clock().now = 5000
-
         token = self.macaroon_generator.generate_access_token("a_user")
         macaroon = pymacaroons.Macaroon.deserialize(token)
 
@@ -76,29 +67,25 @@ class AuthTestCase(unittest.TestCase):
         v.satisfy_general(verify_nonce)
         v.verify(macaroon, self.hs.config.macaroon_secret_key)
 
-    @defer.inlineCallbacks
     def test_short_term_login_token_gives_user_id(self):
-        self.hs.get_clock().now = 1000
-
         token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
-        user_id = yield defer.ensureDeferred(
+        user_id = self.get_success(
             self.auth_handler.validate_short_term_login_token_and_get_user_id(token)
         )
         self.assertEqual("a_user", user_id)
 
         # when we advance the clock, the token should be rejected
-        self.hs.get_clock().now = 6000
-        with self.assertRaises(synapse.api.errors.AuthError):
-            yield defer.ensureDeferred(
-                self.auth_handler.validate_short_term_login_token_and_get_user_id(token)
-            )
+        self.reactor.advance(6)
+        self.get_failure(
+            self.auth_handler.validate_short_term_login_token_and_get_user_id(token),
+            AuthError,
+        )
 
-    @defer.inlineCallbacks
     def test_short_term_login_token_cannot_replace_user_id(self):
         token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
         macaroon = pymacaroons.Macaroon.deserialize(token)
 
-        user_id = yield defer.ensureDeferred(
+        user_id = self.get_success(
             self.auth_handler.validate_short_term_login_token_and_get_user_id(
                 macaroon.serialize()
             )
@@ -109,102 +96,90 @@ class AuthTestCase(unittest.TestCase):
         # user_id.
         macaroon.add_first_party_caveat("user_id = b_user")
 
-        with self.assertRaises(synapse.api.errors.AuthError):
-            yield defer.ensureDeferred(
-                self.auth_handler.validate_short_term_login_token_and_get_user_id(
-                    macaroon.serialize()
-                )
-            )
+        self.get_failure(
+            self.auth_handler.validate_short_term_login_token_and_get_user_id(
+                macaroon.serialize()
+            ),
+            AuthError,
+        )
 
-    @defer.inlineCallbacks
     def test_mau_limits_disabled(self):
         self.auth_blocking._limit_usage_by_mau = False
         # Ensure does not throw exception
-        yield defer.ensureDeferred(
+        self.get_success(
             self.auth_handler.get_access_token_for_user_id(
                 "user_a", device_id=None, valid_until_ms=None
             )
         )
 
-        yield defer.ensureDeferred(
+        self.get_success(
             self.auth_handler.validate_short_term_login_token_and_get_user_id(
                 self._get_macaroon().serialize()
             )
         )
 
-    @defer.inlineCallbacks
     def test_mau_limits_exceeded_large(self):
         self.auth_blocking._limit_usage_by_mau = True
         self.hs.get_datastore().get_monthly_active_count = Mock(
             return_value=make_awaitable(self.large_number_of_users)
         )
 
-        with self.assertRaises(ResourceLimitError):
-            yield defer.ensureDeferred(
-                self.auth_handler.get_access_token_for_user_id(
-                    "user_a", device_id=None, valid_until_ms=None
-                )
-            )
+        self.get_failure(
+            self.auth_handler.get_access_token_for_user_id(
+                "user_a", device_id=None, valid_until_ms=None
+            ),
+            ResourceLimitError,
+        )
 
         self.hs.get_datastore().get_monthly_active_count = Mock(
             return_value=make_awaitable(self.large_number_of_users)
         )
-        with self.assertRaises(ResourceLimitError):
-            yield defer.ensureDeferred(
-                self.auth_handler.validate_short_term_login_token_and_get_user_id(
-                    self._get_macaroon().serialize()
-                )
-            )
+        self.get_failure(
+            self.auth_handler.validate_short_term_login_token_and_get_user_id(
+                self._get_macaroon().serialize()
+            ),
+            ResourceLimitError,
+        )
 
-    @defer.inlineCallbacks
     def test_mau_limits_parity(self):
+        # Ensure we're not at the unix epoch.
+        self.reactor.advance(1)
         self.auth_blocking._limit_usage_by_mau = True
 
-        # If not in monthly active cohort
+        # Set the server to be at the edge of too many users.
         self.hs.get_datastore().get_monthly_active_count = Mock(
             return_value=make_awaitable(self.auth_blocking._max_mau_value)
         )
-        with self.assertRaises(ResourceLimitError):
-            yield defer.ensureDeferred(
-                self.auth_handler.get_access_token_for_user_id(
-                    "user_a", device_id=None, valid_until_ms=None
-                )
-            )
 
-        self.hs.get_datastore().get_monthly_active_count = Mock(
-            return_value=make_awaitable(self.auth_blocking._max_mau_value)
+        # If not in monthly active cohort
+        self.get_failure(
+            self.auth_handler.get_access_token_for_user_id(
+                "user_a", device_id=None, valid_until_ms=None
+            ),
+            ResourceLimitError,
         )
-        with self.assertRaises(ResourceLimitError):
-            yield defer.ensureDeferred(
-                self.auth_handler.validate_short_term_login_token_and_get_user_id(
-                    self._get_macaroon().serialize()
-                )
-            )
+        self.get_failure(
+            self.auth_handler.validate_short_term_login_token_and_get_user_id(
+                self._get_macaroon().serialize()
+            ),
+            ResourceLimitError,
+        )
+
         # If in monthly active cohort
         self.hs.get_datastore().user_last_seen_monthly_active = Mock(
-            return_value=make_awaitable(self.hs.get_clock().time_msec())
+            return_value=make_awaitable(self.clock.time_msec())
         )
-        self.hs.get_datastore().get_monthly_active_count = Mock(
-            return_value=make_awaitable(self.auth_blocking._max_mau_value)
-        )
-        yield defer.ensureDeferred(
+        self.get_success(
             self.auth_handler.get_access_token_for_user_id(
                 "user_a", device_id=None, valid_until_ms=None
             )
         )
-        self.hs.get_datastore().user_last_seen_monthly_active = Mock(
-            return_value=make_awaitable(self.hs.get_clock().time_msec())
-        )
-        self.hs.get_datastore().get_monthly_active_count = Mock(
-            return_value=make_awaitable(self.auth_blocking._max_mau_value)
-        )
-        yield defer.ensureDeferred(
+        self.get_success(
             self.auth_handler.validate_short_term_login_token_and_get_user_id(
                 self._get_macaroon().serialize()
             )
         )
 
-    @defer.inlineCallbacks
     def test_mau_limits_not_exceeded(self):
         self.auth_blocking._limit_usage_by_mau = True
 
@@ -212,7 +187,7 @@ class AuthTestCase(unittest.TestCase):
             return_value=make_awaitable(self.small_number_of_users)
         )
         # Ensure does not raise exception
-        yield defer.ensureDeferred(
+        self.get_success(
             self.auth_handler.get_access_token_for_user_id(
                 "user_a", device_id=None, valid_until_ms=None
             )
@@ -221,7 +196,7 @@ class AuthTestCase(unittest.TestCase):
         self.hs.get_datastore().get_monthly_active_count = Mock(
             return_value=make_awaitable(self.small_number_of_users)
         )
-        yield defer.ensureDeferred(
+        self.get_success(
             self.auth_handler.validate_short_term_login_token_and_get_user_id(
                 self._get_macaroon().serialize()
             )
diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
index 7baf224f7e..6f992291b8 100644
--- a/tests/handlers/test_cas.py
+++ b/tests/handlers/test_cas.py
@@ -16,7 +16,7 @@ from mock import Mock
 from synapse.handlers.cas_handler import CasResponse
 
 from tests.test_utils import simple_async_mock
-from tests.unittest import HomeserverTestCase
+from tests.unittest import HomeserverTestCase, override_config
 
 # These are a few constants that are used as config parameters in the tests.
 BASE_URL = "https://synapse/"
@@ -32,6 +32,10 @@ class CasHandlerTestCase(HomeserverTestCase):
             "server_url": SERVER_URL,
             "service_url": BASE_URL,
         }
+
+        # Update this config with what's in the default config so that
+        # override_config works as expected.
+        cas_config.update(config.get("cas_config", {}))
         config["cas_config"] = cas_config
 
         return config
@@ -115,7 +119,51 @@ class CasHandlerTestCase(HomeserverTestCase):
             "@f=c3=b6=c3=b6:test", request, "redirect_uri", None, new_user=True
         )
 
+    @override_config(
+        {
+            "cas_config": {
+                "required_attributes": {"userGroup": "staff", "department": None}
+            }
+        }
+    )
+    def test_required_attributes(self):
+        """The required attributes must be met from the CAS response."""
+
+        # stub out the auth handler
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
+
+        # The response doesn't have the proper userGroup or department.
+        cas_response = CasResponse("test_user", {})
+        request = _mock_request()
+        self.get_success(
+            self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
+        )
+        auth_handler.complete_sso_login.assert_not_called()
+
+        # The response doesn't have any department.
+        cas_response = CasResponse("test_user", {"userGroup": "staff"})
+        request.reset_mock()
+        self.get_success(
+            self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
+        )
+        auth_handler.complete_sso_login.assert_not_called()
+
+        # Add the proper attributes and it should succeed.
+        cas_response = CasResponse(
+            "test_user", {"userGroup": ["staff", "admin"], "department": ["sales"]}
+        )
+        request.reset_mock()
+        self.get_success(
+            self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
+        )
+
+        # check that the auth handler got called as expected
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@test_user:test", request, "redirect_uri", None, new_user=True
+        )
+
 
 def _mock_request():
     """Returns a mock which will stand in as a SynapseRequest"""
-    return Mock(spec=["getClientIP", "getHeader"])
+    return Mock(spec=["getClientIP", "getHeader", "_disconnected"])
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 924f29f051..c1a13aeb71 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -18,42 +18,27 @@ import mock
 
 from signedjson import key as key, sign as sign
 
-from twisted.internet import defer
-
-import synapse.handlers.e2e_keys
-import synapse.storage
-from synapse.api import errors
 from synapse.api.constants import RoomEncryptionAlgorithms
+from synapse.api.errors import Codes, SynapseError
 
-from tests import unittest, utils
+from tests import unittest
 
 
-class E2eKeysHandlerTestCase(unittest.TestCase):
-    def __init__(self, *args, **kwargs):
-        super().__init__(*args, **kwargs)
-        self.hs = None  # type: synapse.server.HomeServer
-        self.handler = None  # type: synapse.handlers.e2e_keys.E2eKeysHandler
-        self.store = None  # type: synapse.storage.Storage
+class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
+    def make_homeserver(self, reactor, clock):
+        return self.setup_test_homeserver(federation_client=mock.Mock())
 
-    @defer.inlineCallbacks
-    def setUp(self):
-        self.hs = yield utils.setup_test_homeserver(
-            self.addCleanup, federation_client=mock.Mock()
-        )
-        self.handler = synapse.handlers.e2e_keys.E2eKeysHandler(self.hs)
+    def prepare(self, reactor, clock, hs):
+        self.handler = hs.get_e2e_keys_handler()
         self.store = self.hs.get_datastore()
 
-    @defer.inlineCallbacks
     def test_query_local_devices_no_devices(self):
         """If the user has no devices, we expect an empty list.
         """
         local_user = "@boris:" + self.hs.hostname
-        res = yield defer.ensureDeferred(
-            self.handler.query_local_devices({local_user: None})
-        )
+        res = self.get_success(self.handler.query_local_devices({local_user: None}))
         self.assertDictEqual(res, {local_user: {}})
 
-    @defer.inlineCallbacks
     def test_reupload_one_time_keys(self):
         """we should be able to re-upload the same keys"""
         local_user = "@boris:" + self.hs.hostname
@@ -64,7 +49,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
             "alg2:k3": {"key": "key3"},
         }
 
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.upload_keys_for_user(
                 local_user, device_id, {"one_time_keys": keys}
             )
@@ -73,14 +58,13 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
 
         # we should be able to change the signature without a problem
         keys["alg2:k2"]["signatures"]["k1"] = "sig2"
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.upload_keys_for_user(
                 local_user, device_id, {"one_time_keys": keys}
             )
         )
         self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
 
-    @defer.inlineCallbacks
     def test_change_one_time_keys(self):
         """attempts to change one-time-keys should be rejected"""
 
@@ -92,75 +76,64 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
             "alg2:k3": {"key": "key3"},
         }
 
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.upload_keys_for_user(
                 local_user, device_id, {"one_time_keys": keys}
             )
         )
         self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
 
-        try:
-            yield defer.ensureDeferred(
-                self.handler.upload_keys_for_user(
-                    local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}}
-                )
-            )
-            self.fail("No error when changing string key")
-        except errors.SynapseError:
-            pass
-
-        try:
-            yield defer.ensureDeferred(
-                self.handler.upload_keys_for_user(
-                    local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}}
-                )
-            )
-            self.fail("No error when replacing dict key with string")
-        except errors.SynapseError:
-            pass
-
-        try:
-            yield defer.ensureDeferred(
-                self.handler.upload_keys_for_user(
-                    local_user,
-                    device_id,
-                    {"one_time_keys": {"alg1:k1": {"key": "key"}}},
-                )
-            )
-            self.fail("No error when replacing string key with dict")
-        except errors.SynapseError:
-            pass
-
-        try:
-            yield defer.ensureDeferred(
-                self.handler.upload_keys_for_user(
-                    local_user,
-                    device_id,
-                    {
-                        "one_time_keys": {
-                            "alg2:k2": {"key": "key3", "signatures": {"k1": "sig1"}}
-                        }
-                    },
-                )
-            )
-            self.fail("No error when replacing dict key")
-        except errors.SynapseError:
-            pass
+        # Error when changing string key
+        self.get_failure(
+            self.handler.upload_keys_for_user(
+                local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}}
+            ),
+            SynapseError,
+        )
+
+        # Error when replacing dict key with strin
+        self.get_failure(
+            self.handler.upload_keys_for_user(
+                local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}}
+            ),
+            SynapseError,
+        )
+
+        # Error when replacing string key with dict
+        self.get_failure(
+            self.handler.upload_keys_for_user(
+                local_user, device_id, {"one_time_keys": {"alg1:k1": {"key": "key"}}},
+            ),
+            SynapseError,
+        )
+
+        # Error when replacing dict key
+        self.get_failure(
+            self.handler.upload_keys_for_user(
+                local_user,
+                device_id,
+                {
+                    "one_time_keys": {
+                        "alg2:k2": {"key": "key3", "signatures": {"k1": "sig1"}}
+                    }
+                },
+            ),
+            SynapseError,
+        )
 
-    @defer.inlineCallbacks
     def test_claim_one_time_key(self):
         local_user = "@boris:" + self.hs.hostname
         device_id = "xyz"
         keys = {"alg1:k1": "key1"}
 
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.upload_keys_for_user(
                 local_user, device_id, {"one_time_keys": keys}
             )
         )
         self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1}})
 
-        res2 = yield defer.ensureDeferred(
+        res2 = self.get_success(
             self.handler.claim_one_time_keys(
                 {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
             )
@@ -173,7 +146,6 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
             },
         )
 
-    @defer.inlineCallbacks
     def test_fallback_key(self):
         local_user = "@boris:" + self.hs.hostname
         device_id = "xyz"
@@ -181,12 +153,12 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         otk = {"alg1:k2": "key2"}
 
         # we shouldn't have any unused fallback keys yet
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
         )
         self.assertEqual(res, [])
 
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_keys_for_user(
                 local_user,
                 device_id,
@@ -195,14 +167,14 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         )
 
         # we should now have an unused alg1 key
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
         )
         self.assertEqual(res, ["alg1"])
 
         # claiming an OTK when no OTKs are available should return the fallback
         # key
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.claim_one_time_keys(
                 {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
             )
@@ -213,13 +185,13 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         )
 
         # we shouldn't have any unused fallback keys again
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
         )
         self.assertEqual(res, [])
 
         # claiming an OTK again should return the same fallback key
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.claim_one_time_keys(
                 {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
             )
@@ -231,13 +203,13 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
 
         # if the user uploads a one-time key, the next claim should fetch the
         # one-time key, and then go back to the fallback
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_keys_for_user(
                 local_user, device_id, {"one_time_keys": otk}
             )
         )
 
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.claim_one_time_keys(
                 {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
             )
@@ -246,7 +218,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
             res, {"failures": {}, "one_time_keys": {local_user: {device_id: otk}}},
         )
 
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.claim_one_time_keys(
                 {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
             )
@@ -256,7 +228,6 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
             {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
         )
 
-    @defer.inlineCallbacks
     def test_replace_master_key(self):
         """uploading a new signing key should make the old signing key unavailable"""
         local_user = "@boris:" + self.hs.hostname
@@ -270,9 +241,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
                 },
             }
         }
-        yield defer.ensureDeferred(
-            self.handler.upload_signing_keys_for_user(local_user, keys1)
-        )
+        self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1))
 
         keys2 = {
             "master_key": {
@@ -284,16 +253,13 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
                 },
             }
         }
-        yield defer.ensureDeferred(
-            self.handler.upload_signing_keys_for_user(local_user, keys2)
-        )
+        self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys2))
 
-        devices = yield defer.ensureDeferred(
+        devices = self.get_success(
             self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user)
         )
         self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]})
 
-    @defer.inlineCallbacks
     def test_reupload_signatures(self):
         """re-uploading a signature should not fail"""
         local_user = "@boris:" + self.hs.hostname
@@ -326,9 +292,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
             "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
             "2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0",
         )
-        yield defer.ensureDeferred(
-            self.handler.upload_signing_keys_for_user(local_user, keys1)
-        )
+        self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1))
 
         # upload two device keys, which will be signed later by the self-signing key
         device_key_1 = {
@@ -358,12 +322,12 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
             "signatures": {local_user: {"ed25519:def": "base64+signature"}},
         }
 
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_keys_for_user(
                 local_user, "abc", {"device_keys": device_key_1}
             )
         )
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_keys_for_user(
                 local_user, "def", {"device_keys": device_key_2}
             )
@@ -372,7 +336,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         # sign the first device key and upload it
         del device_key_1["signatures"]
         sign.sign_json(device_key_1, local_user, signing_key)
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_signatures_for_device_keys(
                 local_user, {local_user: {"abc": device_key_1}}
             )
@@ -383,7 +347,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         # signature for it
         del device_key_2["signatures"]
         sign.sign_json(device_key_2, local_user, signing_key)
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_signatures_for_device_keys(
                 local_user, {local_user: {"abc": device_key_1, "def": device_key_2}}
             )
@@ -391,7 +355,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
 
         device_key_1["signatures"][local_user]["ed25519:abc"] = "base64+signature"
         device_key_2["signatures"][local_user]["ed25519:def"] = "base64+signature"
-        devices = yield defer.ensureDeferred(
+        devices = self.get_success(
             self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user)
         )
         del devices["device_keys"][local_user]["abc"]["unsigned"]
@@ -399,7 +363,6 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         self.assertDictEqual(devices["device_keys"][local_user]["abc"], device_key_1)
         self.assertDictEqual(devices["device_keys"][local_user]["def"], device_key_2)
 
-    @defer.inlineCallbacks
     def test_self_signing_key_doesnt_show_up_as_device(self):
         """signing keys should be hidden when fetching a user's devices"""
         local_user = "@boris:" + self.hs.hostname
@@ -413,29 +376,22 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
                 },
             }
         }
-        yield defer.ensureDeferred(
-            self.handler.upload_signing_keys_for_user(local_user, keys1)
-        )
-
-        res = None
-        try:
-            yield defer.ensureDeferred(
-                self.hs.get_device_handler().check_device_registered(
-                    user_id=local_user,
-                    device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
-                    initial_device_display_name="new display name",
-                )
-            )
-        except errors.SynapseError as e:
-            res = e.code
-        self.assertEqual(res, 400)
+        self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1))
 
-        res = yield defer.ensureDeferred(
-            self.handler.query_local_devices({local_user: None})
+        e = self.get_failure(
+            self.hs.get_device_handler().check_device_registered(
+                user_id=local_user,
+                device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
+                initial_device_display_name="new display name",
+            ),
+            SynapseError,
         )
+        res = e.value.code
+        self.assertEqual(res, 400)
+
+        res = self.get_success(self.handler.query_local_devices({local_user: None}))
         self.assertDictEqual(res, {local_user: {}})
 
-    @defer.inlineCallbacks
     def test_upload_signatures(self):
         """should check signatures that are uploaded"""
         # set up a user with cross-signing keys and a device.  This user will
@@ -458,7 +414,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
             "ed25519", "xyz", "OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA"
         )
 
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_keys_for_user(
                 local_user, device_id, {"device_keys": device_key}
             )
@@ -501,7 +457,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
             "user_signing_key": usersigning_key,
             "self_signing_key": selfsigning_key,
         }
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_signing_keys_for_user(local_user, cross_signing_keys)
         )
 
@@ -515,14 +471,14 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
             "usage": ["master"],
             "keys": {"ed25519:" + other_master_pubkey: other_master_pubkey},
         }
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_signing_keys_for_user(
                 other_user, {"master_key": other_master_key}
             )
         )
 
         # test various signature failures (see below)
-        ret = yield defer.ensureDeferred(
+        ret = self.get_success(
             self.handler.upload_signatures_for_device_keys(
                 local_user,
                 {
@@ -602,20 +558,16 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         )
 
         user_failures = ret["failures"][local_user]
+        self.assertEqual(user_failures[device_id]["errcode"], Codes.INVALID_SIGNATURE)
         self.assertEqual(
-            user_failures[device_id]["errcode"], errors.Codes.INVALID_SIGNATURE
+            user_failures[master_pubkey]["errcode"], Codes.INVALID_SIGNATURE
         )
-        self.assertEqual(
-            user_failures[master_pubkey]["errcode"], errors.Codes.INVALID_SIGNATURE
-        )
-        self.assertEqual(user_failures["unknown"]["errcode"], errors.Codes.NOT_FOUND)
+        self.assertEqual(user_failures["unknown"]["errcode"], Codes.NOT_FOUND)
 
         other_user_failures = ret["failures"][other_user]
+        self.assertEqual(other_user_failures["unknown"]["errcode"], Codes.NOT_FOUND)
         self.assertEqual(
-            other_user_failures["unknown"]["errcode"], errors.Codes.NOT_FOUND
-        )
-        self.assertEqual(
-            other_user_failures[other_master_pubkey]["errcode"], errors.Codes.UNKNOWN
+            other_user_failures[other_master_pubkey]["errcode"], Codes.UNKNOWN
         )
 
         # test successful signatures
@@ -623,7 +575,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         sign.sign_json(device_key, local_user, selfsigning_signing_key)
         sign.sign_json(master_key, local_user, device_signing_key)
         sign.sign_json(other_master_key, local_user, usersigning_signing_key)
-        ret = yield defer.ensureDeferred(
+        ret = self.get_success(
             self.handler.upload_signatures_for_device_keys(
                 local_user,
                 {
@@ -636,7 +588,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         self.assertEqual(ret["failures"], {})
 
         # fetch the signed keys/devices and make sure that the signatures are there
-        ret = yield defer.ensureDeferred(
+        ret = self.get_success(
             self.handler.query_devices(
                 {"device_keys": {local_user: [], other_user: []}}, 0, local_user
             )
diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py
index 45f201a399..58773a0c38 100644
--- a/tests/handlers/test_e2e_room_keys.py
+++ b/tests/handlers/test_e2e_room_keys.py
@@ -19,14 +19,9 @@ import copy
 
 import mock
 
-from twisted.internet import defer
+from synapse.api.errors import SynapseError
 
-import synapse.api.errors
-import synapse.handlers.e2e_room_keys
-import synapse.storage
-from synapse.api import errors
-
-from tests import unittest, utils
+from tests import unittest
 
 # sample room_key data for use in the tests
 room_keys = {
@@ -45,51 +40,39 @@ room_keys = {
 }
 
 
-class E2eRoomKeysHandlerTestCase(unittest.TestCase):
-    def __init__(self, *args, **kwargs):
-        super().__init__(*args, **kwargs)
-        self.hs = None  # type: synapse.server.HomeServer
-        self.handler = None  # type: synapse.handlers.e2e_keys.E2eRoomKeysHandler
+class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
+    def make_homeserver(self, reactor, clock):
+        return self.setup_test_homeserver(replication_layer=mock.Mock())
 
-    @defer.inlineCallbacks
-    def setUp(self):
-        self.hs = yield utils.setup_test_homeserver(
-            self.addCleanup, replication_layer=mock.Mock()
-        )
-        self.handler = synapse.handlers.e2e_room_keys.E2eRoomKeysHandler(self.hs)
-        self.local_user = "@boris:" + self.hs.hostname
+    def prepare(self, reactor, clock, hs):
+        self.handler = hs.get_e2e_room_keys_handler()
+        self.local_user = "@boris:" + hs.hostname
 
-    @defer.inlineCallbacks
     def test_get_missing_current_version_info(self):
         """Check that we get a 404 if we ask for info about the current version
         if there is no version.
         """
-        res = None
-        try:
-            yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
-        except errors.SynapseError as e:
-            res = e.code
+        e = self.get_failure(
+            self.handler.get_version_info(self.local_user), SynapseError
+        )
+        res = e.value.code
         self.assertEqual(res, 404)
 
-    @defer.inlineCallbacks
     def test_get_missing_version_info(self):
         """Check that we get a 404 if we ask for info about a specific version
         if it doesn't exist.
         """
-        res = None
-        try:
-            yield defer.ensureDeferred(
-                self.handler.get_version_info(self.local_user, "bogus_version")
-            )
-        except errors.SynapseError as e:
-            res = e.code
+        e = self.get_failure(
+            self.handler.get_version_info(self.local_user, "bogus_version"),
+            SynapseError,
+        )
+        res = e.value.code
         self.assertEqual(res, 404)
 
-    @defer.inlineCallbacks
     def test_create_version(self):
         """Check that we can create and then retrieve versions.
         """
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.create_version(
                 self.local_user,
                 {
@@ -101,7 +84,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         self.assertEqual(res, "1")
 
         # check we can retrieve it as the current version
-        res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
+        res = self.get_success(self.handler.get_version_info(self.local_user))
         version_etag = res["etag"]
         self.assertIsInstance(version_etag, str)
         del res["etag"]
@@ -116,9 +99,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         )
 
         # check we can retrieve it as a specific version
-        res = yield defer.ensureDeferred(
-            self.handler.get_version_info(self.local_user, "1")
-        )
+        res = self.get_success(self.handler.get_version_info(self.local_user, "1"))
         self.assertEqual(res["etag"], version_etag)
         del res["etag"]
         self.assertDictEqual(
@@ -132,7 +113,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         )
 
         # upload a new one...
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.create_version(
                 self.local_user,
                 {
@@ -144,7 +125,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         self.assertEqual(res, "2")
 
         # check we can retrieve it as the current version
-        res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
+        res = self.get_success(self.handler.get_version_info(self.local_user))
         del res["etag"]
         self.assertDictEqual(
             res,
@@ -156,11 +137,10 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
             },
         )
 
-    @defer.inlineCallbacks
     def test_update_version(self):
         """Check that we can update versions.
         """
-        version = yield defer.ensureDeferred(
+        version = self.get_success(
             self.handler.create_version(
                 self.local_user,
                 {
@@ -171,7 +151,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         )
         self.assertEqual(version, "1")
 
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.update_version(
                 self.local_user,
                 version,
@@ -185,7 +165,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         self.assertDictEqual(res, {})
 
         # check we can retrieve it as the current version
-        res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
+        res = self.get_success(self.handler.get_version_info(self.local_user))
         del res["etag"]
         self.assertDictEqual(
             res,
@@ -197,32 +177,28 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
             },
         )
 
-    @defer.inlineCallbacks
     def test_update_missing_version(self):
         """Check that we get a 404 on updating nonexistent versions
         """
-        res = None
-        try:
-            yield defer.ensureDeferred(
-                self.handler.update_version(
-                    self.local_user,
-                    "1",
-                    {
-                        "algorithm": "m.megolm_backup.v1",
-                        "auth_data": "revised_first_version_auth_data",
-                        "version": "1",
-                    },
-                )
-            )
-        except errors.SynapseError as e:
-            res = e.code
+        e = self.get_failure(
+            self.handler.update_version(
+                self.local_user,
+                "1",
+                {
+                    "algorithm": "m.megolm_backup.v1",
+                    "auth_data": "revised_first_version_auth_data",
+                    "version": "1",
+                },
+            ),
+            SynapseError,
+        )
+        res = e.value.code
         self.assertEqual(res, 404)
 
-    @defer.inlineCallbacks
     def test_update_omitted_version(self):
         """Check that the update succeeds if the version is missing from the body
         """
-        version = yield defer.ensureDeferred(
+        version = self.get_success(
             self.handler.create_version(
                 self.local_user,
                 {
@@ -233,7 +209,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         )
         self.assertEqual(version, "1")
 
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.update_version(
                 self.local_user,
                 version,
@@ -245,7 +221,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         )
 
         # check we can retrieve it as the current version
-        res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
+        res = self.get_success(self.handler.get_version_info(self.local_user))
         del res["etag"]  # etag is opaque, so don't test its contents
         self.assertDictEqual(
             res,
@@ -257,11 +233,10 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
             },
         )
 
-    @defer.inlineCallbacks
     def test_update_bad_version(self):
         """Check that we get a 400 if the version in the body doesn't match
         """
-        version = yield defer.ensureDeferred(
+        version = self.get_success(
             self.handler.create_version(
                 self.local_user,
                 {
@@ -272,52 +247,41 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         )
         self.assertEqual(version, "1")
 
-        res = None
-        try:
-            yield defer.ensureDeferred(
-                self.handler.update_version(
-                    self.local_user,
-                    version,
-                    {
-                        "algorithm": "m.megolm_backup.v1",
-                        "auth_data": "revised_first_version_auth_data",
-                        "version": "incorrect",
-                    },
-                )
-            )
-        except errors.SynapseError as e:
-            res = e.code
+        e = self.get_failure(
+            self.handler.update_version(
+                self.local_user,
+                version,
+                {
+                    "algorithm": "m.megolm_backup.v1",
+                    "auth_data": "revised_first_version_auth_data",
+                    "version": "incorrect",
+                },
+            ),
+            SynapseError,
+        )
+        res = e.value.code
         self.assertEqual(res, 400)
 
-    @defer.inlineCallbacks
     def test_delete_missing_version(self):
         """Check that we get a 404 on deleting nonexistent versions
         """
-        res = None
-        try:
-            yield defer.ensureDeferred(
-                self.handler.delete_version(self.local_user, "1")
-            )
-        except errors.SynapseError as e:
-            res = e.code
+        e = self.get_failure(
+            self.handler.delete_version(self.local_user, "1"), SynapseError
+        )
+        res = e.value.code
         self.assertEqual(res, 404)
 
-    @defer.inlineCallbacks
     def test_delete_missing_current_version(self):
         """Check that we get a 404 on deleting nonexistent current version
         """
-        res = None
-        try:
-            yield defer.ensureDeferred(self.handler.delete_version(self.local_user))
-        except errors.SynapseError as e:
-            res = e.code
+        e = self.get_failure(self.handler.delete_version(self.local_user), SynapseError)
+        res = e.value.code
         self.assertEqual(res, 404)
 
-    @defer.inlineCallbacks
     def test_delete_version(self):
         """Check that we can create and then delete versions.
         """
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.create_version(
                 self.local_user,
                 {
@@ -329,36 +293,28 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         self.assertEqual(res, "1")
 
         # check we can delete it
-        yield defer.ensureDeferred(self.handler.delete_version(self.local_user, "1"))
+        self.get_success(self.handler.delete_version(self.local_user, "1"))
 
         # check that it's gone
-        res = None
-        try:
-            yield defer.ensureDeferred(
-                self.handler.get_version_info(self.local_user, "1")
-            )
-        except errors.SynapseError as e:
-            res = e.code
+        e = self.get_failure(
+            self.handler.get_version_info(self.local_user, "1"), SynapseError
+        )
+        res = e.value.code
         self.assertEqual(res, 404)
 
-    @defer.inlineCallbacks
     def test_get_missing_backup(self):
         """Check that we get a 404 on querying missing backup
         """
-        res = None
-        try:
-            yield defer.ensureDeferred(
-                self.handler.get_room_keys(self.local_user, "bogus_version")
-            )
-        except errors.SynapseError as e:
-            res = e.code
+        e = self.get_failure(
+            self.handler.get_room_keys(self.local_user, "bogus_version"), SynapseError
+        )
+        res = e.value.code
         self.assertEqual(res, 404)
 
-    @defer.inlineCallbacks
     def test_get_missing_room_keys(self):
         """Check we get an empty response from an empty backup
         """
-        version = yield defer.ensureDeferred(
+        version = self.get_success(
             self.handler.create_version(
                 self.local_user,
                 {
@@ -369,33 +325,27 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         )
         self.assertEqual(version, "1")
 
-        res = yield defer.ensureDeferred(
-            self.handler.get_room_keys(self.local_user, version)
-        )
+        res = self.get_success(self.handler.get_room_keys(self.local_user, version))
         self.assertDictEqual(res, {"rooms": {}})
 
     # TODO: test the locking semantics when uploading room_keys,
     # although this is probably best done in sytest
 
-    @defer.inlineCallbacks
     def test_upload_room_keys_no_versions(self):
         """Check that we get a 404 on uploading keys when no versions are defined
         """
-        res = None
-        try:
-            yield defer.ensureDeferred(
-                self.handler.upload_room_keys(self.local_user, "no_version", room_keys)
-            )
-        except errors.SynapseError as e:
-            res = e.code
+        e = self.get_failure(
+            self.handler.upload_room_keys(self.local_user, "no_version", room_keys),
+            SynapseError,
+        )
+        res = e.value.code
         self.assertEqual(res, 404)
 
-    @defer.inlineCallbacks
     def test_upload_room_keys_bogus_version(self):
         """Check that we get a 404 on uploading keys when an nonexistent version
         is specified
         """
-        version = yield defer.ensureDeferred(
+        version = self.get_success(
             self.handler.create_version(
                 self.local_user,
                 {
@@ -406,22 +356,17 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         )
         self.assertEqual(version, "1")
 
-        res = None
-        try:
-            yield defer.ensureDeferred(
-                self.handler.upload_room_keys(
-                    self.local_user, "bogus_version", room_keys
-                )
-            )
-        except errors.SynapseError as e:
-            res = e.code
+        e = self.get_failure(
+            self.handler.upload_room_keys(self.local_user, "bogus_version", room_keys),
+            SynapseError,
+        )
+        res = e.value.code
         self.assertEqual(res, 404)
 
-    @defer.inlineCallbacks
     def test_upload_room_keys_wrong_version(self):
         """Check that we get a 403 on uploading keys for an old version
         """
-        version = yield defer.ensureDeferred(
+        version = self.get_success(
             self.handler.create_version(
                 self.local_user,
                 {
@@ -432,7 +377,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         )
         self.assertEqual(version, "1")
 
-        version = yield defer.ensureDeferred(
+        version = self.get_success(
             self.handler.create_version(
                 self.local_user,
                 {
@@ -443,20 +388,16 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         )
         self.assertEqual(version, "2")
 
-        res = None
-        try:
-            yield defer.ensureDeferred(
-                self.handler.upload_room_keys(self.local_user, "1", room_keys)
-            )
-        except errors.SynapseError as e:
-            res = e.code
+        e = self.get_failure(
+            self.handler.upload_room_keys(self.local_user, "1", room_keys), SynapseError
+        )
+        res = e.value.code
         self.assertEqual(res, 403)
 
-    @defer.inlineCallbacks
     def test_upload_room_keys_insert(self):
         """Check that we can insert and retrieve keys for a session
         """
-        version = yield defer.ensureDeferred(
+        version = self.get_success(
             self.handler.create_version(
                 self.local_user,
                 {
@@ -467,17 +408,15 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         )
         self.assertEqual(version, "1")
 
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_room_keys(self.local_user, version, room_keys)
         )
 
-        res = yield defer.ensureDeferred(
-            self.handler.get_room_keys(self.local_user, version)
-        )
+        res = self.get_success(self.handler.get_room_keys(self.local_user, version))
         self.assertDictEqual(res, room_keys)
 
         # check getting room_keys for a given room
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.get_room_keys(
                 self.local_user, version, room_id="!abc:matrix.org"
             )
@@ -485,18 +424,17 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         self.assertDictEqual(res, room_keys)
 
         # check getting room_keys for a given session_id
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.get_room_keys(
                 self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
             )
         )
         self.assertDictEqual(res, room_keys)
 
-    @defer.inlineCallbacks
     def test_upload_room_keys_merge(self):
         """Check that we can upload a new room_key for an existing session and
         have it correctly merged"""
-        version = yield defer.ensureDeferred(
+        version = self.get_success(
             self.handler.create_version(
                 self.local_user,
                 {
@@ -507,12 +445,12 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         )
         self.assertEqual(version, "1")
 
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_room_keys(self.local_user, version, room_keys)
         )
 
         # get the etag to compare to future versions
-        res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
+        res = self.get_success(self.handler.get_version_info(self.local_user))
         backup_etag = res["etag"]
         self.assertEqual(res["count"], 1)
 
@@ -522,37 +460,33 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         # test that increasing the message_index doesn't replace the existing session
         new_room_key["first_message_index"] = 2
         new_room_key["session_data"] = "new"
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_room_keys(self.local_user, version, new_room_keys)
         )
 
-        res = yield defer.ensureDeferred(
-            self.handler.get_room_keys(self.local_user, version)
-        )
+        res = self.get_success(self.handler.get_room_keys(self.local_user, version))
         self.assertEqual(
             res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"],
             "SSBBTSBBIEZJU0gK",
         )
 
         # the etag should be the same since the session did not change
-        res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
+        res = self.get_success(self.handler.get_version_info(self.local_user))
         self.assertEqual(res["etag"], backup_etag)
 
         # test that marking the session as verified however /does/ replace it
         new_room_key["is_verified"] = True
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_room_keys(self.local_user, version, new_room_keys)
         )
 
-        res = yield defer.ensureDeferred(
-            self.handler.get_room_keys(self.local_user, version)
-        )
+        res = self.get_success(self.handler.get_room_keys(self.local_user, version))
         self.assertEqual(
             res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new"
         )
 
         # the etag should NOT be equal now, since the key changed
-        res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
+        res = self.get_success(self.handler.get_version_info(self.local_user))
         self.assertNotEqual(res["etag"], backup_etag)
         backup_etag = res["etag"]
 
@@ -560,28 +494,25 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         # with a lower forwarding count
         new_room_key["forwarded_count"] = 2
         new_room_key["session_data"] = "other"
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_room_keys(self.local_user, version, new_room_keys)
         )
 
-        res = yield defer.ensureDeferred(
-            self.handler.get_room_keys(self.local_user, version)
-        )
+        res = self.get_success(self.handler.get_room_keys(self.local_user, version))
         self.assertEqual(
             res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new"
         )
 
         # the etag should be the same since the session did not change
-        res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
+        res = self.get_success(self.handler.get_version_info(self.local_user))
         self.assertEqual(res["etag"], backup_etag)
 
         # TODO: check edge cases as well as the common variations here
 
-    @defer.inlineCallbacks
     def test_delete_room_keys(self):
         """Check that we can insert and delete keys for a session
         """
-        version = yield defer.ensureDeferred(
+        version = self.get_success(
             self.handler.create_version(
                 self.local_user,
                 {
@@ -593,13 +524,11 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         self.assertEqual(version, "1")
 
         # check for bulk-delete
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_room_keys(self.local_user, version, room_keys)
         )
-        yield defer.ensureDeferred(
-            self.handler.delete_room_keys(self.local_user, version)
-        )
-        res = yield defer.ensureDeferred(
+        self.get_success(self.handler.delete_room_keys(self.local_user, version))
+        res = self.get_success(
             self.handler.get_room_keys(
                 self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
             )
@@ -607,15 +536,15 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         self.assertDictEqual(res, {"rooms": {}})
 
         # check for bulk-delete per room
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_room_keys(self.local_user, version, room_keys)
         )
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.delete_room_keys(
                 self.local_user, version, room_id="!abc:matrix.org"
             )
         )
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.get_room_keys(
                 self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
             )
@@ -623,15 +552,15 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         self.assertDictEqual(res, {"rooms": {}})
 
         # check for bulk-delete per session
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_room_keys(self.local_user, version, room_keys)
         )
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.delete_room_keys(
                 self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
             )
         )
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.get_room_keys(
                 self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
             )
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 022943a10a..787fab7875 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -13,25 +13,20 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
 from mock import Mock
 
-from twisted.internet import defer
-
 import synapse.types
 from synapse.api.errors import AuthError, SynapseError
 from synapse.types import UserID
 
 from tests import unittest
 from tests.test_utils import make_awaitable
-from tests.utils import setup_test_homeserver
 
 
-class ProfileTestCase(unittest.TestCase):
+class ProfileTestCase(unittest.HomeserverTestCase):
     """ Tests profile management. """
 
-    @defer.inlineCallbacks
-    def setUp(self):
+    def make_homeserver(self, reactor, clock):
         self.mock_federation = Mock()
         self.mock_registry = Mock()
 
@@ -42,39 +37,35 @@ class ProfileTestCase(unittest.TestCase):
 
         self.mock_registry.register_query_handler = register_query_handler
 
-        hs = yield setup_test_homeserver(
-            self.addCleanup,
+        hs = self.setup_test_homeserver(
             federation_client=self.mock_federation,
             federation_server=Mock(),
             federation_registry=self.mock_registry,
         )
+        return hs
 
+    def prepare(self, reactor, clock, hs):
         self.store = hs.get_datastore()
 
         self.frank = UserID.from_string("@1234ABCD:test")
         self.bob = UserID.from_string("@4567:test")
         self.alice = UserID.from_string("@alice:remote")
 
-        yield defer.ensureDeferred(self.store.create_profile(self.frank.localpart))
+        self.get_success(self.store.create_profile(self.frank.localpart))
 
         self.handler = hs.get_profile_handler()
-        self.hs = hs
 
-    @defer.inlineCallbacks
     def test_get_my_name(self):
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.set_profile_displayname(self.frank.localpart, "Frank")
         )
 
-        displayname = yield defer.ensureDeferred(
-            self.handler.get_displayname(self.frank)
-        )
+        displayname = self.get_success(self.handler.get_displayname(self.frank))
 
         self.assertEquals("Frank", displayname)
 
-    @defer.inlineCallbacks
     def test_set_my_name(self):
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.set_displayname(
                 self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
             )
@@ -82,7 +73,7 @@ class ProfileTestCase(unittest.TestCase):
 
         self.assertEquals(
             (
-                yield defer.ensureDeferred(
+                self.get_success(
                     self.store.get_profile_displayname(self.frank.localpart)
                 )
             ),
@@ -90,7 +81,7 @@ class ProfileTestCase(unittest.TestCase):
         )
 
         # Set displayname again
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.set_displayname(
                 self.frank, synapse.types.create_requester(self.frank), "Frank"
             )
@@ -98,7 +89,7 @@ class ProfileTestCase(unittest.TestCase):
 
         self.assertEquals(
             (
-                yield defer.ensureDeferred(
+                self.get_success(
                     self.store.get_profile_displayname(self.frank.localpart)
                 )
             ),
@@ -106,32 +97,27 @@ class ProfileTestCase(unittest.TestCase):
         )
 
         # Set displayname to an empty string
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.set_displayname(
                 self.frank, synapse.types.create_requester(self.frank), ""
             )
         )
 
         self.assertIsNone(
-            (
-                yield defer.ensureDeferred(
-                    self.store.get_profile_displayname(self.frank.localpart)
-                )
-            )
+            (self.get_success(self.store.get_profile_displayname(self.frank.localpart)))
         )
 
-    @defer.inlineCallbacks
     def test_set_my_name_if_disabled(self):
         self.hs.config.enable_set_displayname = False
 
         # Setting displayname for the first time is allowed
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.set_profile_displayname(self.frank.localpart, "Frank")
         )
 
         self.assertEquals(
             (
-                yield defer.ensureDeferred(
+                self.get_success(
                     self.store.get_profile_displayname(self.frank.localpart)
                 )
             ),
@@ -139,33 +125,27 @@ class ProfileTestCase(unittest.TestCase):
         )
 
         # Setting displayname a second time is forbidden
-        d = defer.ensureDeferred(
+        self.get_failure(
             self.handler.set_displayname(
                 self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
-            )
+            ),
+            SynapseError,
         )
 
-        yield self.assertFailure(d, SynapseError)
-
-    @defer.inlineCallbacks
     def test_set_my_name_noauth(self):
-        d = defer.ensureDeferred(
+        self.get_failure(
             self.handler.set_displayname(
                 self.frank, synapse.types.create_requester(self.bob), "Frank Jr."
-            )
+            ),
+            AuthError,
         )
 
-        yield self.assertFailure(d, AuthError)
-
-    @defer.inlineCallbacks
     def test_get_other_name(self):
         self.mock_federation.make_query.return_value = make_awaitable(
             {"displayname": "Alice"}
         )
 
-        displayname = yield defer.ensureDeferred(
-            self.handler.get_displayname(self.alice)
-        )
+        displayname = self.get_success(self.handler.get_displayname(self.alice))
 
         self.assertEquals(displayname, "Alice")
         self.mock_federation.make_query.assert_called_with(
@@ -175,14 +155,11 @@ class ProfileTestCase(unittest.TestCase):
             ignore_backoff=True,
         )
 
-    @defer.inlineCallbacks
     def test_incoming_fed_query(self):
-        yield defer.ensureDeferred(self.store.create_profile("caroline"))
-        yield defer.ensureDeferred(
-            self.store.set_profile_displayname("caroline", "Caroline")
-        )
+        self.get_success(self.store.create_profile("caroline"))
+        self.get_success(self.store.set_profile_displayname("caroline", "Caroline"))
 
-        response = yield defer.ensureDeferred(
+        response = self.get_success(
             self.query_handlers["profile"](
                 {"user_id": "@caroline:test", "field": "displayname"}
             )
@@ -190,20 +167,18 @@ class ProfileTestCase(unittest.TestCase):
 
         self.assertEquals({"displayname": "Caroline"}, response)
 
-    @defer.inlineCallbacks
     def test_get_my_avatar(self):
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.set_profile_avatar_url(
                 self.frank.localpart, "http://my.server/me.png"
             )
         )
-        avatar_url = yield defer.ensureDeferred(self.handler.get_avatar_url(self.frank))
+        avatar_url = self.get_success(self.handler.get_avatar_url(self.frank))
 
         self.assertEquals("http://my.server/me.png", avatar_url)
 
-    @defer.inlineCallbacks
     def test_set_my_avatar(self):
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.set_avatar_url(
                 self.frank,
                 synapse.types.create_requester(self.frank),
@@ -212,16 +187,12 @@ class ProfileTestCase(unittest.TestCase):
         )
 
         self.assertEquals(
-            (
-                yield defer.ensureDeferred(
-                    self.store.get_profile_avatar_url(self.frank.localpart)
-                )
-            ),
+            (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
             "http://my.server/pic.gif",
         )
 
         # Set avatar again
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.set_avatar_url(
                 self.frank,
                 synapse.types.create_requester(self.frank),
@@ -230,56 +201,42 @@ class ProfileTestCase(unittest.TestCase):
         )
 
         self.assertEquals(
-            (
-                yield defer.ensureDeferred(
-                    self.store.get_profile_avatar_url(self.frank.localpart)
-                )
-            ),
+            (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
             "http://my.server/me.png",
         )
 
         # Set avatar to an empty string
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.set_avatar_url(
                 self.frank, synapse.types.create_requester(self.frank), "",
             )
         )
 
         self.assertIsNone(
-            (
-                yield defer.ensureDeferred(
-                    self.store.get_profile_avatar_url(self.frank.localpart)
-                )
-            ),
+            (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
         )
 
-    @defer.inlineCallbacks
     def test_set_my_avatar_if_disabled(self):
         self.hs.config.enable_set_avatar_url = False
 
         # Setting displayname for the first time is allowed
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.set_profile_avatar_url(
                 self.frank.localpart, "http://my.server/me.png"
             )
         )
 
         self.assertEquals(
-            (
-                yield defer.ensureDeferred(
-                    self.store.get_profile_avatar_url(self.frank.localpart)
-                )
-            ),
+            (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
             "http://my.server/me.png",
         )
 
         # Set avatar a second time is forbidden
-        d = defer.ensureDeferred(
+        self.get_failure(
             self.handler.set_avatar_url(
                 self.frank,
                 synapse.types.create_requester(self.frank),
                 "http://my.server/pic.gif",
-            )
+            ),
+            SynapseError,
         )
-
-        yield self.assertFailure(d, SynapseError)
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index a8d6c0f617..029af2853e 100644
--- a/tests/handlers/test_saml.py
+++ b/tests/handlers/test_saml.py
@@ -259,7 +259,61 @@ class SamlHandlerTestCase(HomeserverTestCase):
         )
         self.assertEqual(e.value.location, b"https://custom-saml-redirect/")
 
+    @override_config(
+        {
+            "saml2_config": {
+                "attribute_requirements": [
+                    {"attribute": "userGroup", "value": "staff"},
+                    {"attribute": "department", "value": "sales"},
+                ],
+            },
+        }
+    )
+    def test_attribute_requirements(self):
+        """The required attributes must be met from the SAML response."""
+
+        # stub out the auth handler
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
+
+        # The response doesn't have the proper userGroup or department.
+        saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
+        request = _mock_request()
+        self.get_success(
+            self.handler._handle_authn_response(request, saml_response, "redirect_uri")
+        )
+        auth_handler.complete_sso_login.assert_not_called()
+
+        # The response doesn't have the proper department.
+        saml_response = FakeAuthnResponse(
+            {"uid": "test_user", "username": "test_user", "userGroup": ["staff"]}
+        )
+        request = _mock_request()
+        self.get_success(
+            self.handler._handle_authn_response(request, saml_response, "redirect_uri")
+        )
+        auth_handler.complete_sso_login.assert_not_called()
+
+        # Add the proper attributes and it should succeed.
+        saml_response = FakeAuthnResponse(
+            {
+                "uid": "test_user",
+                "username": "test_user",
+                "userGroup": ["staff", "admin"],
+                "department": ["sales"],
+            }
+        )
+        request.reset_mock()
+        self.get_success(
+            self.handler._handle_authn_response(request, saml_response, "redirect_uri")
+        )
+
+        # check that the auth handler got called as expected
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@test_user:test", request, "redirect_uri", None, new_user=True
+        )
+
 
 def _mock_request():
     """Returns a mock which will stand in as a SynapseRequest"""
-    return Mock(spec=["getClientIP", "getHeader"])
+    return Mock(spec=["getClientIP", "getHeader", "_disconnected"])
diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index c4e1e7ed85..22f452ec24 100644
--- a/tests/push/test_email.py
+++ b/tests/push/test_email.py
@@ -124,13 +124,18 @@ class EmailPusherTests(HomeserverTestCase):
         )
         self.helper.join(room=room, user=self.others[0].id, tok=self.others[0].token)
 
-        # The other user sends some messages
+        # The other user sends a single message.
         self.helper.send(room, body="Hi!", tok=self.others[0].token)
-        self.helper.send(room, body="There!", tok=self.others[0].token)
 
         # We should get emailed about that message
         self._check_for_mail()
 
+        # The other user sends multiple messages.
+        self.helper.send(room, body="Hi!", tok=self.others[0].token)
+        self.helper.send(room, body="There!", tok=self.others[0].token)
+
+        self._check_for_mail()
+
     def test_invite_sends_email(self):
         # Create a room and invite the user to it
         room = self.helper.create_room_as(self.others[0].id, tok=self.others[0].token)
@@ -217,6 +222,45 @@ class EmailPusherTests(HomeserverTestCase):
         # We should get emailed about those messages
         self._check_for_mail()
 
+    def test_empty_room(self):
+        """All users leaving a room shouldn't cause the pusher to break."""
+        # Create a simple room with two users
+        room = self.helper.create_room_as(self.user_id, tok=self.access_token)
+        self.helper.invite(
+            room=room, src=self.user_id, tok=self.access_token, targ=self.others[0].id
+        )
+        self.helper.join(room=room, user=self.others[0].id, tok=self.others[0].token)
+
+        # The other user sends a single message.
+        self.helper.send(room, body="Hi!", tok=self.others[0].token)
+
+        # Leave the room before the message is processed.
+        self.helper.leave(room, self.user_id, tok=self.access_token)
+        self.helper.leave(room, self.others[0].id, tok=self.others[0].token)
+
+        # We should get emailed about that message
+        self._check_for_mail()
+
+    def test_empty_room_multiple_messages(self):
+        """All users leaving a room shouldn't cause the pusher to break."""
+        # Create a simple room with two users
+        room = self.helper.create_room_as(self.user_id, tok=self.access_token)
+        self.helper.invite(
+            room=room, src=self.user_id, tok=self.access_token, targ=self.others[0].id
+        )
+        self.helper.join(room=room, user=self.others[0].id, tok=self.others[0].token)
+
+        # The other user sends a single message.
+        self.helper.send(room, body="Hi!", tok=self.others[0].token)
+        self.helper.send(room, body="There!", tok=self.others[0].token)
+
+        # Leave the room before the message is processed.
+        self.helper.leave(room, self.user_id, tok=self.access_token)
+        self.helper.leave(room, self.others[0].id, tok=self.others[0].token)
+
+        # We should get emailed about that message
+        self._check_for_mail()
+
     def test_encrypted_message(self):
         room = self.helper.create_room_as(self.user_id, tok=self.access_token)
         self.helper.invite(
@@ -269,3 +313,6 @@ class EmailPusherTests(HomeserverTestCase):
         pushers = list(pushers)
         self.assertEqual(len(pushers), 1)
         self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
+
+        # Reset the attempts.
+        self.email_attempts = []
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 7c47aa7e0a..2a217b1ce0 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -1445,6 +1445,90 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
         self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
 
+    def test_context_as_non_admin(self):
+        """
+        Test that, without being admin, one cannot use the context admin API
+        """
+        # Create a room.
+        user_id = self.register_user("test", "test")
+        user_tok = self.login("test", "test")
+
+        self.register_user("test_2", "test")
+        user_tok_2 = self.login("test_2", "test")
+
+        room_id = self.helper.create_room_as(user_id, tok=user_tok)
+
+        # Populate the room with events.
+        events = []
+        for i in range(30):
+            events.append(
+                self.helper.send_event(
+                    room_id, "com.example.test", content={"index": i}, tok=user_tok
+                )
+            )
+
+        # Now attempt to find the context using the admin API without being admin.
+        midway = (len(events) - 1) // 2
+        for tok in [user_tok, user_tok_2]:
+            channel = self.make_request(
+                "GET",
+                "/_synapse/admin/v1/rooms/%s/context/%s"
+                % (room_id, events[midway]["event_id"]),
+                access_token=tok,
+            )
+            self.assertEquals(
+                403, int(channel.result["code"]), msg=channel.result["body"]
+            )
+            self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+    def test_context_as_admin(self):
+        """
+        Test that, as admin, we can find the context of an event without having joined the room.
+        """
+
+        # Create a room. We're not part of it.
+        user_id = self.register_user("test", "test")
+        user_tok = self.login("test", "test")
+        room_id = self.helper.create_room_as(user_id, tok=user_tok)
+
+        # Populate the room with events.
+        events = []
+        for i in range(30):
+            events.append(
+                self.helper.send_event(
+                    room_id, "com.example.test", content={"index": i}, tok=user_tok
+                )
+            )
+
+        # Now let's fetch the context for this room.
+        midway = (len(events) - 1) // 2
+        channel = self.make_request(
+            "GET",
+            "/_synapse/admin/v1/rooms/%s/context/%s"
+            % (room_id, events[midway]["event_id"]),
+            access_token=self.admin_user_tok,
+        )
+        self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEquals(
+            channel.json_body["event"]["event_id"], events[midway]["event_id"]
+        )
+
+        for i, found_event in enumerate(channel.json_body["events_before"]):
+            for j, posted_event in enumerate(events):
+                if found_event["event_id"] == posted_event["event_id"]:
+                    self.assertTrue(j < midway)
+                    break
+            else:
+                self.fail("Event %s from events_before not found" % j)
+
+        for i, found_event in enumerate(channel.json_body["events_after"]):
+            for j, posted_event in enumerate(events):
+                if found_event["event_id"] == posted_event["event_id"]:
+                    self.assertTrue(j > midway)
+                    break
+            else:
+                self.fail("Event %s from events_after not found" % j)
+
 
 class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
     servlets = [
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index bfcb786af8..49543d9acb 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -15,7 +15,7 @@
 
 import time
 import urllib.parse
-from typing import Any, Dict, Union
+from typing import Any, Dict, List, Union
 from urllib.parse import urlencode
 
 from mock import Mock
@@ -493,13 +493,21 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 200, channel.result)
 
         # parse the form to check it has fields assumed elsewhere in this class
+        html = channel.result["body"].decode("utf-8")
         p = TestHtmlParser()
-        p.feed(channel.result["body"].decode("utf-8"))
+        p.feed(html)
         p.close()
 
-        self.assertCountEqual(p.radios["idp"], ["cas", "oidc", "oidc-idp1", "saml"])
+        # there should be a link for each href
+        returned_idps = []  # type: List[str]
+        for link in p.links:
+            path, query = link.split("?", 1)
+            self.assertEqual(path, "pick_idp")
+            params = urllib.parse.parse_qs(query)
+            self.assertEqual(params["redirectUrl"], [TEST_CLIENT_REDIRECT_URL])
+            returned_idps.append(params["idp"][0])
 
-        self.assertEqual(p.hiddens["redirectUrl"], TEST_CLIENT_REDIRECT_URL)
+        self.assertCountEqual(returned_idps, ["cas", "oidc", "oidc-idp1", "saml"])
 
     def test_multi_sso_redirect_to_cas(self):
         """If CAS is chosen, should redirect to the CAS server"""
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index 38c51525a3..f6f3b9a356 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -18,8 +18,6 @@
 
 from mock import Mock
 
-from twisted.internet import defer
-
 from synapse.rest.client.v1 import room
 from synapse.types import UserID
 
@@ -60,32 +58,6 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
 
         hs.get_datastore().insert_client_ip = _insert_client_ip
 
-        def get_room_members(room_id):
-            if room_id == self.room_id:
-                return defer.succeed([self.user])
-            else:
-                return defer.succeed([])
-
-        @defer.inlineCallbacks
-        def fetch_room_distributions_into(
-            room_id, localusers=None, remotedomains=None, ignore_user=None
-        ):
-            members = yield get_room_members(room_id)
-            for member in members:
-                if ignore_user is not None and member == ignore_user:
-                    continue
-
-                if hs.is_mine(member):
-                    if localusers is not None:
-                        localusers.add(member)
-                else:
-                    if remotedomains is not None:
-                        remotedomains.add(member.domain)
-
-        hs.get_room_member_handler().fetch_room_distributions_into = (
-            fetch_room_distributions_into
-        )
-
         return hs
 
     def prepare(self, reactor, clock, hs):
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index a6c6985173..c279eb49e3 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -30,6 +30,8 @@ from twisted.internet import defer
 from twisted.internet.defer import Deferred
 
 from synapse.logging.context import make_deferred_yieldable
+from synapse.rest import admin
+from synapse.rest.client.v1 import login
 from synapse.rest.media.v1._base import FileInfo
 from synapse.rest.media.v1.filepath import MediaFilePaths
 from synapse.rest.media.v1.media_storage import MediaStorage
@@ -37,6 +39,7 @@ from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
 
 from tests import unittest
 from tests.server import FakeSite, make_request
+from tests.utils import default_config
 
 
 class MediaStorageTests(unittest.HomeserverTestCase):
@@ -398,3 +401,94 @@ class MediaRepoTests(unittest.HomeserverTestCase):
             headers.getRawHeaders(b"X-Robots-Tag"),
             [b"noindex, nofollow, noarchive, noimageindex"],
         )
+
+
+class TestSpamChecker:
+    """A spam checker module that rejects all media that includes the bytes
+    `evil`.
+    """
+
+    def __init__(self, config, api):
+        self.config = config
+        self.api = api
+
+    def parse_config(config):
+        return config
+
+    async def check_event_for_spam(self, foo):
+        return False  # allow all events
+
+    async def user_may_invite(self, inviter_userid, invitee_userid, room_id):
+        return True  # allow all invites
+
+    async def user_may_create_room(self, userid):
+        return True  # allow all room creations
+
+    async def user_may_create_room_alias(self, userid, room_alias):
+        return True  # allow all room aliases
+
+    async def user_may_publish_room(self, userid, room_id):
+        return True  # allow publishing of all rooms
+
+    async def check_media_file_for_spam(self, file_wrapper, file_info) -> bool:
+        buf = BytesIO()
+        await file_wrapper.write_chunks_to(buf.write)
+
+        return b"evil" in buf.getvalue()
+
+
+class SpamCheckerTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        login.register_servlets,
+        admin.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, hs):
+        self.user = self.register_user("user", "pass")
+        self.tok = self.login("user", "pass")
+
+        # Allow for uploading and downloading to/from the media repo
+        self.media_repo = hs.get_media_repository_resource()
+        self.download_resource = self.media_repo.children[b"download"]
+        self.upload_resource = self.media_repo.children[b"upload"]
+
+    def default_config(self):
+        config = default_config("test")
+
+        config.update(
+            {
+                "spam_checker": [
+                    {
+                        "module": TestSpamChecker.__module__ + ".TestSpamChecker",
+                        "config": {},
+                    }
+                ]
+            }
+        )
+
+        return config
+
+    def test_upload_innocent(self):
+        """Attempt to upload some innocent data that should be allowed.
+        """
+
+        image_data = unhexlify(
+            b"89504e470d0a1a0a0000000d4948445200000001000000010806"
+            b"0000001f15c4890000000a49444154789c63000100000500010d"
+            b"0a2db40000000049454e44ae426082"
+        )
+
+        self.helper.upload_media(
+            self.upload_resource, image_data, tok=self.tok, expect_code=200
+        )
+
+    def test_upload_ban(self):
+        """Attempt to upload some data that includes bytes "evil", which should
+        get rejected by the spam checker.
+        """
+
+        data = b"Some evil data"
+
+        self.helper.upload_media(
+            self.upload_resource, data, tok=self.tok, expect_code=400
+        )
diff --git a/tests/test_preview.py b/tests/test_preview.py
index 0c6cbbd921..ea83299918 100644
--- a/tests/test_preview.py
+++ b/tests/test_preview.py
@@ -15,6 +15,7 @@
 
 from synapse.rest.media.v1.preview_url_resource import (
     decode_and_calc_og,
+    get_html_media_encoding,
     summarize_paragraphs,
 )
 
@@ -26,7 +27,7 @@ except ImportError:
     lxml = None
 
 
-class PreviewTestCase(unittest.TestCase):
+class SummarizeTestCase(unittest.TestCase):
     if not lxml:
         skip = "url preview feature requires lxml"
 
@@ -144,12 +145,12 @@ class PreviewTestCase(unittest.TestCase):
         )
 
 
-class PreviewUrlTestCase(unittest.TestCase):
+class CalcOgTestCase(unittest.TestCase):
     if not lxml:
         skip = "url preview feature requires lxml"
 
     def test_simple(self):
-        html = """
+        html = b"""
         <html>
         <head><title>Foo</title></head>
         <body>
@@ -163,7 +164,7 @@ class PreviewUrlTestCase(unittest.TestCase):
         self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
 
     def test_comment(self):
-        html = """
+        html = b"""
         <html>
         <head><title>Foo</title></head>
         <body>
@@ -178,7 +179,7 @@ class PreviewUrlTestCase(unittest.TestCase):
         self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
 
     def test_comment2(self):
-        html = """
+        html = b"""
         <html>
         <head><title>Foo</title></head>
         <body>
@@ -202,7 +203,7 @@ class PreviewUrlTestCase(unittest.TestCase):
         )
 
     def test_script(self):
-        html = """
+        html = b"""
         <html>
         <head><title>Foo</title></head>
         <body>
@@ -217,7 +218,7 @@ class PreviewUrlTestCase(unittest.TestCase):
         self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
 
     def test_missing_title(self):
-        html = """
+        html = b"""
         <html>
         <body>
         Some text.
@@ -230,7 +231,7 @@ class PreviewUrlTestCase(unittest.TestCase):
         self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
 
     def test_h1_as_title(self):
-        html = """
+        html = b"""
         <html>
         <meta property="og:description" content="Some text."/>
         <body>
@@ -244,7 +245,7 @@ class PreviewUrlTestCase(unittest.TestCase):
         self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."})
 
     def test_missing_title_and_broken_h1(self):
-        html = """
+        html = b"""
         <html>
         <body>
         <h1><a href="foo"/></h1>
@@ -258,13 +259,20 @@ class PreviewUrlTestCase(unittest.TestCase):
         self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
 
     def test_empty(self):
-        html = ""
+        """Test a body with no data in it."""
+        html = b""
+        og = decode_and_calc_og(html, "http://example.com/test.html")
+        self.assertEqual(og, {})
+
+    def test_no_tree(self):
+        """A valid body with no tree in it."""
+        html = b"\x00"
         og = decode_and_calc_og(html, "http://example.com/test.html")
         self.assertEqual(og, {})
 
     def test_invalid_encoding(self):
         """An invalid character encoding should be ignored and treated as UTF-8, if possible."""
-        html = """
+        html = b"""
         <html>
         <head><title>Foo</title></head>
         <body>
@@ -290,3 +298,76 @@ class PreviewUrlTestCase(unittest.TestCase):
         """
         og = decode_and_calc_og(html, "http://example.com/test.html")
         self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."})
+
+
+class MediaEncodingTestCase(unittest.TestCase):
+    def test_meta_charset(self):
+        """A character encoding is found via the meta tag."""
+        encoding = get_html_media_encoding(
+            b"""
+        <html>
+        <head><meta charset="ascii">
+        </head>
+        </html>
+        """,
+            "text/html",
+        )
+        self.assertEqual(encoding, "ascii")
+
+        # A less well-formed version.
+        encoding = get_html_media_encoding(
+            b"""
+        <html>
+        <head>< meta charset = ascii>
+        </head>
+        </html>
+        """,
+            "text/html",
+        )
+        self.assertEqual(encoding, "ascii")
+
+    def test_xml_encoding(self):
+        """A character encoding is found via the meta tag."""
+        encoding = get_html_media_encoding(
+            b"""
+        <?xml version="1.0" encoding="ascii"?>
+        <html>
+        </html>
+        """,
+            "text/html",
+        )
+        self.assertEqual(encoding, "ascii")
+
+    def test_meta_xml_encoding(self):
+        """Meta tags take precedence over XML encoding."""
+        encoding = get_html_media_encoding(
+            b"""
+        <?xml version="1.0" encoding="ascii"?>
+        <html>
+        <head><meta charset="UTF-16">
+        </head>
+        </html>
+        """,
+            "text/html",
+        )
+        self.assertEqual(encoding, "UTF-16")
+
+    def test_content_type(self):
+        """A character encoding is found via the Content-Type header."""
+        # Test a few variations of the header.
+        headers = (
+            'text/html; charset="ascii";',
+            "text/html;charset=ascii;",
+            'text/html;  charset="ascii"',
+            "text/html; charset=ascii",
+            'text/html; charset="ascii;',
+            'text/html; charset=ascii";',
+        )
+        for header in headers:
+            encoding = get_html_media_encoding(b"", header)
+            self.assertEqual(encoding, "ascii")
+
+    def test_fallback(self):
+        """A character encoding cannot be found in the body or header."""
+        encoding = get_html_media_encoding(b"", "text/html")
+        self.assertEqual(encoding, "utf-8")