summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/handlers/test_federation.py64
-rw-r--r--tests/handlers/test_password_providers.py223
-rw-r--r--tests/handlers/test_user_directory.py192
-rw-r--r--tests/rest/admin/test_user.py401
-rw-r--r--tests/rest/client/test_relations.py55
-rw-r--r--tests/rest/client/test_third_party_rules.py56
-rw-r--r--tests/rest/media/v1/test_filepath.py238
-rw-r--r--tests/rest/media/v1/test_oembed.py51
-rw-r--r--tests/server.py54
-rw-r--r--tests/storage/test_user_directory.py77
-rw-r--r--tests/test_event_auth.py138
-rw-r--r--tests/test_preview.py93
-rw-r--r--tests/unittest.py4
13 files changed, 1155 insertions, 491 deletions
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index 936ebf3dde..e1557566e4 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -23,6 +23,7 @@ from synapse.federation.federation_base import event_from_pdu_json
 from synapse.logging.context import LoggingContext, run_in_background
 from synapse.rest import admin
 from synapse.rest.client import login, room
+from synapse.types import create_requester
 from synapse.util.stringutils import random_string
 
 from tests import unittest
@@ -30,6 +31,10 @@ from tests import unittest
 logger = logging.getLogger(__name__)
 
 
+def generate_fake_event_id() -> str:
+    return "$fake_" + random_string(43)
+
+
 class FederationTestCase(unittest.HomeserverTestCase):
     servlets = [
         admin.register_servlets,
@@ -198,6 +203,65 @@ class FederationTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(sg, sg2)
 
+    def test_backfill_with_many_backward_extremities(self):
+        """
+        Check that we can backfill with many backward extremities.
+        The goal is to make sure that when we only use a portion
+        of backwards extremities(the magic number is more than 5),
+        no errors are thrown.
+
+        Regression test, see #11027
+        """
+        # create the room
+        user_id = self.register_user("kermit", "test")
+        tok = self.login("kermit", "test")
+        requester = create_requester(user_id)
+
+        room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+
+        ev1 = self.helper.send(room_id, "first message", tok=tok)
+
+        # Create "many" backward extremities. The magic number we're trying to
+        # create more than is 5 which corresponds to the number of backward
+        # extremities we slice off in `_maybe_backfill_inner`
+        for _ in range(0, 8):
+            event_handler = self.hs.get_event_creation_handler()
+            event, context = self.get_success(
+                event_handler.create_event(
+                    requester,
+                    {
+                        "type": "m.room.message",
+                        "content": {
+                            "msgtype": "m.text",
+                            "body": "message connected to fake event",
+                        },
+                        "room_id": room_id,
+                        "sender": user_id,
+                    },
+                    prev_event_ids=[
+                        ev1["event_id"],
+                        # We're creating an backward extremity each time thanks
+                        # to this fake event
+                        generate_fake_event_id(),
+                    ],
+                )
+            )
+            self.get_success(
+                event_handler.handle_new_client_event(requester, event, context)
+            )
+
+        current_depth = 1
+        limit = 100
+        with LoggingContext("receive_pdu"):
+            # Make sure backfill still works
+            d = run_in_background(
+                self.hs.get_federation_handler().maybe_backfill,
+                room_id,
+                current_depth,
+                limit,
+            )
+        self.get_success(d)
+
     def test_backfill_floating_outlier_membership_auth(self):
         """
         As the local homeserver, check that we can properly process a federated
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index 38e6d9f536..7dd4a5a367 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -20,6 +20,8 @@ from unittest.mock import Mock
 from twisted.internet import defer
 
 import synapse
+from synapse.handlers.auth import load_legacy_password_auth_providers
+from synapse.module_api import ModuleApi
 from synapse.rest.client import devices, login
 from synapse.types import JsonDict
 
@@ -36,8 +38,8 @@ ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_servi
 mock_password_provider = Mock()
 
 
-class PasswordOnlyAuthProvider:
-    """A password_provider which only implements `check_password`."""
+class LegacyPasswordOnlyAuthProvider:
+    """A legacy password_provider which only implements `check_password`."""
 
     @staticmethod
     def parse_config(self):
@@ -50,8 +52,8 @@ class PasswordOnlyAuthProvider:
         return mock_password_provider.check_password(*args)
 
 
-class CustomAuthProvider:
-    """A password_provider which implements a custom login type."""
+class LegacyCustomAuthProvider:
+    """A legacy password_provider which implements a custom login type."""
 
     @staticmethod
     def parse_config(self):
@@ -67,7 +69,23 @@ class CustomAuthProvider:
         return mock_password_provider.check_auth(*args)
 
 
-class PasswordCustomAuthProvider:
+class CustomAuthProvider:
+    """A module which registers password_auth_provider callbacks for a custom login type."""
+
+    @staticmethod
+    def parse_config(self):
+        pass
+
+    def __init__(self, config, api: ModuleApi):
+        api.register_password_auth_provider_callbacks(
+            auth_checkers={("test.login_type", ("test_field",)): self.check_auth},
+        )
+
+    def check_auth(self, *args):
+        return mock_password_provider.check_auth(*args)
+
+
+class LegacyPasswordCustomAuthProvider:
     """A password_provider which implements password login via `check_auth`, as well
     as a custom type."""
 
@@ -85,8 +103,32 @@ class PasswordCustomAuthProvider:
         return mock_password_provider.check_auth(*args)
 
 
-def providers_config(*providers: Type[Any]) -> dict:
-    """Returns a config dict that will enable the given password auth providers"""
+class PasswordCustomAuthProvider:
+    """A module which registers password_auth_provider callbacks for a custom login type.
+    as well as a password login"""
+
+    @staticmethod
+    def parse_config(self):
+        pass
+
+    def __init__(self, config, api: ModuleApi):
+        api.register_password_auth_provider_callbacks(
+            auth_checkers={
+                ("test.login_type", ("test_field",)): self.check_auth,
+                ("m.login.password", ("password",)): self.check_auth,
+            },
+        )
+        pass
+
+    def check_auth(self, *args):
+        return mock_password_provider.check_auth(*args)
+
+    def check_pass(self, *args):
+        return mock_password_provider.check_password(*args)
+
+
+def legacy_providers_config(*providers: Type[Any]) -> dict:
+    """Returns a config dict that will enable the given legacy password auth providers"""
     return {
         "password_providers": [
             {"module": "%s.%s" % (__name__, provider.__qualname__), "config": {}}
@@ -95,6 +137,16 @@ def providers_config(*providers: Type[Any]) -> dict:
     }
 
 
+def providers_config(*providers: Type[Any]) -> dict:
+    """Returns a config dict that will enable the given modules"""
+    return {
+        "modules": [
+            {"module": "%s.%s" % (__name__, provider.__qualname__), "config": {}}
+            for provider in providers
+        ]
+    }
+
+
 class PasswordAuthProviderTests(unittest.HomeserverTestCase):
     servlets = [
         synapse.rest.admin.register_servlets,
@@ -107,8 +159,21 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         mock_password_provider.reset_mock()
         super().setUp()
 
-    @override_config(providers_config(PasswordOnlyAuthProvider))
-    def test_password_only_auth_provider_login(self):
+    def make_homeserver(self, reactor, clock):
+        hs = self.setup_test_homeserver()
+        # Load the modules into the homeserver
+        module_api = hs.get_module_api()
+        for module, config in hs.config.modules.loaded_modules:
+            module(config=config, api=module_api)
+        load_legacy_password_auth_providers(hs)
+
+        return hs
+
+    @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
+    def test_password_only_auth_progiver_login_legacy(self):
+        self.password_only_auth_provider_login_test_body()
+
+    def password_only_auth_provider_login_test_body(self):
         # login flows should only have m.login.password
         flows = self._get_login_flows()
         self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS)
@@ -138,8 +203,11 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
             "@ USER🙂NAME :test", " pASS😢word "
         )
 
-    @override_config(providers_config(PasswordOnlyAuthProvider))
-    def test_password_only_auth_provider_ui_auth(self):
+    @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
+    def test_password_only_auth_provider_ui_auth_legacy(self):
+        self.password_only_auth_provider_ui_auth_test_body()
+
+    def password_only_auth_provider_ui_auth_test_body(self):
         """UI Auth should delegate correctly to the password provider"""
 
         # create the user, otherwise access doesn't work
@@ -172,8 +240,11 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 200)
         mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
 
-    @override_config(providers_config(PasswordOnlyAuthProvider))
-    def test_local_user_fallback_login(self):
+    @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
+    def test_local_user_fallback_login_legacy(self):
+        self.local_user_fallback_login_test_body()
+
+    def local_user_fallback_login_test_body(self):
         """rejected login should fall back to local db"""
         self.register_user("localuser", "localpass")
 
@@ -186,8 +257,11 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 200, channel.result)
         self.assertEqual("@localuser:test", channel.json_body["user_id"])
 
-    @override_config(providers_config(PasswordOnlyAuthProvider))
-    def test_local_user_fallback_ui_auth(self):
+    @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
+    def test_local_user_fallback_ui_auth_legacy(self):
+        self.local_user_fallback_ui_auth_test_body()
+
+    def local_user_fallback_ui_auth_test_body(self):
         """rejected login should fall back to local db"""
         self.register_user("localuser", "localpass")
 
@@ -223,11 +297,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
 
     @override_config(
         {
-            **providers_config(PasswordOnlyAuthProvider),
+            **legacy_providers_config(LegacyPasswordOnlyAuthProvider),
             "password_config": {"localdb_enabled": False},
         }
     )
-    def test_no_local_user_fallback_login(self):
+    def test_no_local_user_fallback_login_legacy(self):
+        self.no_local_user_fallback_login_test_body()
+
+    def no_local_user_fallback_login_test_body(self):
         """localdb_enabled can block login with the local password"""
         self.register_user("localuser", "localpass")
 
@@ -242,11 +319,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
 
     @override_config(
         {
-            **providers_config(PasswordOnlyAuthProvider),
+            **legacy_providers_config(LegacyPasswordOnlyAuthProvider),
             "password_config": {"localdb_enabled": False},
         }
     )
-    def test_no_local_user_fallback_ui_auth(self):
+    def test_no_local_user_fallback_ui_auth_legacy(self):
+        self.no_local_user_fallback_ui_auth_test_body()
+
+    def no_local_user_fallback_ui_auth_test_body(self):
         """localdb_enabled can block ui auth with the local password"""
         self.register_user("localuser", "localpass")
 
@@ -280,11 +360,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
 
     @override_config(
         {
-            **providers_config(PasswordOnlyAuthProvider),
+            **legacy_providers_config(LegacyPasswordOnlyAuthProvider),
             "password_config": {"enabled": False},
         }
     )
-    def test_password_auth_disabled(self):
+    def test_password_auth_disabled_legacy(self):
+        self.password_auth_disabled_test_body()
+
+    def password_auth_disabled_test_body(self):
         """password auth doesn't work if it's disabled across the board"""
         # login flows should be empty
         flows = self._get_login_flows()
@@ -295,8 +378,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 400, channel.result)
         mock_password_provider.check_password.assert_not_called()
 
+    @override_config(legacy_providers_config(LegacyCustomAuthProvider))
+    def test_custom_auth_provider_login_legacy(self):
+        self.custom_auth_provider_login_test_body()
+
     @override_config(providers_config(CustomAuthProvider))
     def test_custom_auth_provider_login(self):
+        self.custom_auth_provider_login_test_body()
+
+    def custom_auth_provider_login_test_body(self):
         # login flows should have the custom flow and m.login.password, since we
         # haven't disabled local password lookup.
         # (password must come first, because reasons)
@@ -312,7 +402,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 400, channel.result)
         mock_password_provider.check_auth.assert_not_called()
 
-        mock_password_provider.check_auth.return_value = defer.succeed("@user:bz")
+        mock_password_provider.check_auth.return_value = defer.succeed(
+            ("@user:bz", None)
+        )
         channel = self._send_login("test.login_type", "u", test_field="y")
         self.assertEqual(channel.code, 200, channel.result)
         self.assertEqual("@user:bz", channel.json_body["user_id"])
@@ -325,7 +417,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         # in these cases, but at least we can guard against the API changing
         # unexpectedly
         mock_password_provider.check_auth.return_value = defer.succeed(
-            "@ MALFORMED! :bz"
+            ("@ MALFORMED! :bz", None)
         )
         channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ")
         self.assertEqual(channel.code, 200, channel.result)
@@ -334,8 +426,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
             " USER🙂NAME ", "test.login_type", {"test_field": " abc "}
         )
 
+    @override_config(legacy_providers_config(LegacyCustomAuthProvider))
+    def test_custom_auth_provider_ui_auth_legacy(self):
+        self.custom_auth_provider_ui_auth_test_body()
+
     @override_config(providers_config(CustomAuthProvider))
     def test_custom_auth_provider_ui_auth(self):
+        self.custom_auth_provider_ui_auth_test_body()
+
+    def custom_auth_provider_ui_auth_test_body(self):
         # register the user and log in twice, to get two devices
         self.register_user("localuser", "localpass")
         tok1 = self.login("localuser", "localpass")
@@ -367,7 +466,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         mock_password_provider.reset_mock()
 
         # right params, but authing as the wrong user
-        mock_password_provider.check_auth.return_value = defer.succeed("@user:bz")
+        mock_password_provider.check_auth.return_value = defer.succeed(
+            ("@user:bz", None)
+        )
         body["auth"]["test_field"] = "foo"
         channel = self._delete_device(tok1, "dev2", body)
         self.assertEqual(channel.code, 403)
@@ -379,7 +480,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
 
         # and finally, succeed
         mock_password_provider.check_auth.return_value = defer.succeed(
-            "@localuser:test"
+            ("@localuser:test", None)
         )
         channel = self._delete_device(tok1, "dev2", body)
         self.assertEqual(channel.code, 200)
@@ -387,8 +488,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
             "localuser", "test.login_type", {"test_field": "foo"}
         )
 
+    @override_config(legacy_providers_config(LegacyCustomAuthProvider))
+    def test_custom_auth_provider_callback_legacy(self):
+        self.custom_auth_provider_callback_test_body()
+
     @override_config(providers_config(CustomAuthProvider))
     def test_custom_auth_provider_callback(self):
+        self.custom_auth_provider_callback_test_body()
+
+    def custom_auth_provider_callback_test_body(self):
         callback = Mock(return_value=defer.succeed(None))
 
         mock_password_provider.check_auth.return_value = defer.succeed(
@@ -411,9 +519,21 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
             self.assertIn(p, call_args[0])
 
     @override_config(
+        {
+            **legacy_providers_config(LegacyCustomAuthProvider),
+            "password_config": {"enabled": False},
+        }
+    )
+    def test_custom_auth_password_disabled_legacy(self):
+        self.custom_auth_password_disabled_test_body()
+
+    @override_config(
         {**providers_config(CustomAuthProvider), "password_config": {"enabled": False}}
     )
     def test_custom_auth_password_disabled(self):
+        self.custom_auth_password_disabled_test_body()
+
+    def custom_auth_password_disabled_test_body(self):
         """Test login with a custom auth provider where password login is disabled"""
         self.register_user("localuser", "localpass")
 
@@ -427,11 +547,23 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
 
     @override_config(
         {
+            **legacy_providers_config(LegacyCustomAuthProvider),
+            "password_config": {"enabled": False, "localdb_enabled": False},
+        }
+    )
+    def test_custom_auth_password_disabled_localdb_enabled_legacy(self):
+        self.custom_auth_password_disabled_localdb_enabled_test_body()
+
+    @override_config(
+        {
             **providers_config(CustomAuthProvider),
             "password_config": {"enabled": False, "localdb_enabled": False},
         }
     )
     def test_custom_auth_password_disabled_localdb_enabled(self):
+        self.custom_auth_password_disabled_localdb_enabled_test_body()
+
+    def custom_auth_password_disabled_localdb_enabled_test_body(self):
         """Check the localdb_enabled == enabled == False
 
         Regression test for https://github.com/matrix-org/synapse/issues/8914: check
@@ -450,11 +582,23 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
 
     @override_config(
         {
+            **legacy_providers_config(LegacyPasswordCustomAuthProvider),
+            "password_config": {"enabled": False},
+        }
+    )
+    def test_password_custom_auth_password_disabled_login_legacy(self):
+        self.password_custom_auth_password_disabled_login_test_body()
+
+    @override_config(
+        {
             **providers_config(PasswordCustomAuthProvider),
             "password_config": {"enabled": False},
         }
     )
     def test_password_custom_auth_password_disabled_login(self):
+        self.password_custom_auth_password_disabled_login_test_body()
+
+    def password_custom_auth_password_disabled_login_test_body(self):
         """log in with a custom auth provider which implements password, but password
         login is disabled"""
         self.register_user("localuser", "localpass")
@@ -466,6 +610,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         channel = self._send_password_login("localuser", "localpass")
         self.assertEqual(channel.code, 400, channel.result)
         mock_password_provider.check_auth.assert_not_called()
+        mock_password_provider.check_password.assert_not_called()
+
+    @override_config(
+        {
+            **legacy_providers_config(LegacyPasswordCustomAuthProvider),
+            "password_config": {"enabled": False},
+        }
+    )
+    def test_password_custom_auth_password_disabled_ui_auth_legacy(self):
+        self.password_custom_auth_password_disabled_ui_auth_test_body()
 
     @override_config(
         {
@@ -474,12 +628,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         }
     )
     def test_password_custom_auth_password_disabled_ui_auth(self):
+        self.password_custom_auth_password_disabled_ui_auth_test_body()
+
+    def password_custom_auth_password_disabled_ui_auth_test_body(self):
         """UI Auth with a custom auth provider which implements password, but password
         login is disabled"""
         # register the user and log in twice via the test login type to get two devices,
         self.register_user("localuser", "localpass")
         mock_password_provider.check_auth.return_value = defer.succeed(
-            "@localuser:test"
+            ("@localuser:test", None)
         )
         channel = self._send_login("test.login_type", "localuser", test_field="")
         self.assertEqual(channel.code, 200, channel.result)
@@ -516,6 +673,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
             "Password login has been disabled.", channel.json_body["error"]
         )
         mock_password_provider.check_auth.assert_not_called()
+        mock_password_provider.check_password.assert_not_called()
         mock_password_provider.reset_mock()
 
         # successful auth
@@ -526,6 +684,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         mock_password_provider.check_auth.assert_called_once_with(
             "localuser", "test.login_type", {"test_field": "x"}
         )
+        mock_password_provider.check_password.assert_not_called()
+
+    @override_config(
+        {
+            **legacy_providers_config(LegacyCustomAuthProvider),
+            "password_config": {"localdb_enabled": False},
+        }
+    )
+    def test_custom_auth_no_local_user_fallback_legacy(self):
+        self.custom_auth_no_local_user_fallback_test_body()
 
     @override_config(
         {
@@ -534,6 +702,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         }
     )
     def test_custom_auth_no_local_user_fallback(self):
+        self.custom_auth_no_local_user_fallback_test_body()
+
+    def custom_auth_no_local_user_fallback_test_body(self):
         """Test login with a custom auth provider where the local db is disabled"""
         self.register_user("localuser", "localpass")
 
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 0120b4688b..b9ad92b977 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -109,18 +109,14 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
             tok=alice_token,
         )
 
-        users = self.get_success(self.user_dir_helper.get_users_in_user_directory())
-        in_public = self.get_success(self.user_dir_helper.get_users_in_public_rooms())
-        in_private = self.get_success(
-            self.user_dir_helper.get_users_who_share_private_rooms()
+        # The user directory should reflect the room memberships above.
+        users, in_public, in_private = self.get_success(
+            self.user_dir_helper.get_tables()
         )
-
         self.assertEqual(users, {alice, bob})
+        self.assertEqual(in_public, {(alice, public), (bob, public), (alice, public2)})
         self.assertEqual(
-            set(in_public), {(alice, public), (bob, public), (alice, public2)}
-        )
-        self.assertEqual(
-            self.user_dir_helper._compress_shared(in_private),
+            in_private,
             {(alice, bob, private), (bob, alice, private)},
         )
 
@@ -209,6 +205,88 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         in_public = self.get_success(self.user_dir_helper.get_users_in_public_rooms())
         self.assertEqual(set(in_public), {(user1, room), (user2, room)})
 
+    def test_excludes_users_when_making_room_public(self) -> None:
+        # Create a regular user and a support user.
+        alice = self.register_user("alice", "pass")
+        alice_token = self.login(alice, "pass")
+        support = "@support1:test"
+        self.get_success(
+            self.store.register_user(
+                user_id=support, password_hash=None, user_type=UserTypes.SUPPORT
+            )
+        )
+
+        # Make a public and private room containing Alice and the support user
+        public, initially_private = self._create_rooms_and_inject_memberships(
+            alice, alice_token, support
+        )
+        self._check_only_one_user_in_directory(alice, public)
+
+        # Alice makes the private room public.
+        self.helper.send_state(
+            initially_private,
+            "m.room.join_rules",
+            {"join_rule": "public"},
+            tok=alice_token,
+        )
+
+        users, in_public, in_private = self.get_success(
+            self.user_dir_helper.get_tables()
+        )
+        self.assertEqual(users, {alice})
+        self.assertEqual(in_public, {(alice, public), (alice, initially_private)})
+        self.assertEqual(in_private, set())
+
+    def test_switching_from_private_to_public_to_private(self) -> None:
+        """Check we update the room sharing tables when switching a room
+        from private to public, then back again to private."""
+        # Alice and Bob share a private room.
+        alice = self.register_user("alice", "pass")
+        alice_token = self.login(alice, "pass")
+        bob = self.register_user("bob", "pass")
+        bob_token = self.login(bob, "pass")
+        room = self.helper.create_room_as(alice, is_public=False, tok=alice_token)
+        self.helper.invite(room, alice, bob, tok=alice_token)
+        self.helper.join(room, bob, tok=bob_token)
+
+        # The user directory should reflect this.
+        def check_user_dir_for_private_room() -> None:
+            users, in_public, in_private = self.get_success(
+                self.user_dir_helper.get_tables()
+            )
+            self.assertEqual(users, {alice, bob})
+            self.assertEqual(in_public, set())
+            self.assertEqual(in_private, {(alice, bob, room), (bob, alice, room)})
+
+        check_user_dir_for_private_room()
+
+        # Alice makes the room public.
+        self.helper.send_state(
+            room,
+            "m.room.join_rules",
+            {"join_rule": "public"},
+            tok=alice_token,
+        )
+
+        # The user directory should be updated accordingly
+        users, in_public, in_private = self.get_success(
+            self.user_dir_helper.get_tables()
+        )
+        self.assertEqual(users, {alice, bob})
+        self.assertEqual(in_public, {(alice, room), (bob, room)})
+        self.assertEqual(in_private, set())
+
+        # Alice makes the room private.
+        self.helper.send_state(
+            room,
+            "m.room.join_rules",
+            {"join_rule": "invite"},
+            tok=alice_token,
+        )
+
+        # The user directory should be updated accordingly
+        check_user_dir_for_private_room()
+
     def _create_rooms_and_inject_memberships(
         self, creator: str, token: str, joiner: str
     ) -> Tuple[str, str]:
@@ -232,15 +310,18 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         return public_room, private_room
 
     def _check_only_one_user_in_directory(self, user: str, public: str) -> None:
-        users = self.get_success(self.user_dir_helper.get_users_in_user_directory())
-        in_public = self.get_success(self.user_dir_helper.get_users_in_public_rooms())
-        in_private = self.get_success(
-            self.user_dir_helper.get_users_who_share_private_rooms()
-        )
+        """Check that the user directory DB tables show that:
 
+        - only one user is in the user directory
+        - they belong to exactly one public room
+        - they don't share a private room with anyone.
+        """
+        users, in_public, in_private = self.get_success(
+            self.user_dir_helper.get_tables()
+        )
         self.assertEqual(users, {user})
-        self.assertEqual(set(in_public), {(user, public)})
-        self.assertEqual(in_private, [])
+        self.assertEqual(in_public, {(user, public)})
+        self.assertEqual(in_private, set())
 
     def test_handle_local_profile_change_with_support_user(self) -> None:
         support_user_id = "@support:test"
@@ -581,11 +662,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
             self.user_dir_helper.get_users_in_public_rooms()
         )
 
-        self.assertEqual(
-            self.user_dir_helper._compress_shared(shares_private),
-            {(u1, u2, room), (u2, u1, room)},
-        )
-        self.assertEqual(public_users, [])
+        self.assertEqual(shares_private, {(u1, u2, room), (u2, u1, room)})
+        self.assertEqual(public_users, set())
 
         # We get one search result when searching for user2 by user1.
         s = self.get_success(self.handler.search_users(u1, "user2", 10))
@@ -610,8 +688,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
             self.user_dir_helper.get_users_in_public_rooms()
         )
 
-        self.assertEqual(self.user_dir_helper._compress_shared(shares_private), set())
-        self.assertEqual(public_users, [])
+        self.assertEqual(shares_private, set())
+        self.assertEqual(public_users, set())
 
         # User1 now gets no search results for any of the other users.
         s = self.get_success(self.handler.search_users(u1, "user2", 10))
@@ -645,11 +723,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
             self.user_dir_helper.get_users_in_public_rooms()
         )
 
-        self.assertEqual(
-            self.user_dir_helper._compress_shared(shares_private),
-            {(u1, u2, room), (u2, u1, room)},
-        )
-        self.assertEqual(public_users, [])
+        self.assertEqual(shares_private, {(u1, u2, room), (u2, u1, room)})
+        self.assertEqual(public_users, set())
 
         # We get one search result when searching for user2 by user1.
         s = self.get_success(self.handler.search_users(u1, "user2", 10))
@@ -704,11 +779,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
             self.user_dir_helper.get_users_in_public_rooms()
         )
 
-        self.assertEqual(
-            self.user_dir_helper._compress_shared(shares_private),
-            {(u1, u2, room), (u2, u1, room)},
-        )
-        self.assertEqual(public_users, [])
+        self.assertEqual(shares_private, {(u1, u2, room), (u2, u1, room)})
+        self.assertEqual(public_users, set())
 
         # Configure a spam checker.
         spam_checker = self.hs.get_spam_checker()
@@ -740,8 +812,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         )
 
         # No users share rooms
-        self.assertEqual(public_users, [])
-        self.assertEqual(self.user_dir_helper._compress_shared(shares_private), set())
+        self.assertEqual(public_users, set())
+        self.assertEqual(shares_private, set())
 
         # Despite not sharing a room, search_all_users means we get a search
         # result.
@@ -842,6 +914,56 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
             self.hs.get_storage().persistence.persist_event(event, context)
         )
 
+    def test_local_user_leaving_room_remains_in_user_directory(self) -> None:
+        """We've chosen to simplify the user directory's implementation by
+        always including local users. Ensure this invariant is maintained when
+        a local user
+        - leaves a room, and
+        - leaves the last room they're in which is visible to this server.
+
+        This is user-visible if the "search_all_users" config option is on: the
+        local user who left a room would no longer be searchable if this test fails!
+        """
+        alice = self.register_user("alice", "pass")
+        alice_token = self.login(alice, "pass")
+        bob = self.register_user("bob", "pass")
+        bob_token = self.login(bob, "pass")
+
+        # Alice makes two public rooms, which Bob joins.
+        room1 = self.helper.create_room_as(alice, is_public=True, tok=alice_token)
+        room2 = self.helper.create_room_as(alice, is_public=True, tok=alice_token)
+        self.helper.join(room1, bob, tok=bob_token)
+        self.helper.join(room2, bob, tok=bob_token)
+
+        # The user directory tables are updated.
+        users, in_public, in_private = self.get_success(
+            self.user_dir_helper.get_tables()
+        )
+        self.assertEqual(users, {alice, bob})
+        self.assertEqual(
+            in_public, {(alice, room1), (alice, room2), (bob, room1), (bob, room2)}
+        )
+        self.assertEqual(in_private, set())
+
+        # Alice leaves one room. She should still be in the directory.
+        self.helper.leave(room1, alice, tok=alice_token)
+        users, in_public, in_private = self.get_success(
+            self.user_dir_helper.get_tables()
+        )
+        self.assertEqual(users, {alice, bob})
+        self.assertEqual(in_public, {(alice, room2), (bob, room1), (bob, room2)})
+        self.assertEqual(in_private, set())
+
+        # Alice leaves the other. She should still be in the directory.
+        self.helper.leave(room2, alice, tok=alice_token)
+        self.wait_for_background_updates()
+        users, in_public, in_private = self.get_success(
+            self.user_dir_helper.get_tables()
+        )
+        self.assertEqual(users, {alice, bob})
+        self.assertEqual(in_public, {(bob, room1), (bob, room2)})
+        self.assertEqual(in_private, set())
+
 
 class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
     servlets = [
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 6ed9e42173..c9e2754b09 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -14,14 +14,13 @@
 
 import hashlib
 import hmac
-import json
 import os
 import urllib.parse
 from binascii import unhexlify
 from typing import List, Optional
 from unittest.mock import Mock, patch
 
-from parameterized import parameterized
+from parameterized import parameterized, parameterized_class
 
 import synapse.rest.admin
 from synapse.api.constants import UserTypes
@@ -104,8 +103,8 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
         # 59 seconds
         self.reactor.advance(59)
 
-        body = json.dumps({"nonce": nonce})
-        channel = self.make_request("POST", self.url, body.encode("utf8"))
+        body = {"nonce": nonce}
+        channel = self.make_request("POST", self.url, body)
 
         self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual("username must be specified", channel.json_body["error"])
@@ -113,7 +112,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
         # 61 seconds
         self.reactor.advance(2)
 
-        channel = self.make_request("POST", self.url, body.encode("utf8"))
+        channel = self.make_request("POST", self.url, body)
 
         self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual("unrecognised nonce", channel.json_body["error"])
@@ -129,18 +128,16 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
         want_mac.update(b"notthenonce\x00bob\x00abc123\x00admin")
         want_mac = want_mac.hexdigest()
 
-        body = json.dumps(
-            {
-                "nonce": nonce,
-                "username": "bob",
-                "password": "abc123",
-                "admin": True,
-                "mac": want_mac,
-            }
-        )
-        channel = self.make_request("POST", self.url, body.encode("utf8"))
+        body = {
+            "nonce": nonce,
+            "username": "bob",
+            "password": "abc123",
+            "admin": True,
+            "mac": want_mac,
+        }
+        channel = self.make_request("POST", self.url, body)
 
-        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(403, channel.code, msg=channel.json_body)
         self.assertEqual("HMAC incorrect", channel.json_body["error"])
 
     def test_register_correct_nonce(self):
@@ -157,17 +154,15 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
         )
         want_mac = want_mac.hexdigest()
 
-        body = json.dumps(
-            {
-                "nonce": nonce,
-                "username": "bob",
-                "password": "abc123",
-                "admin": True,
-                "user_type": UserTypes.SUPPORT,
-                "mac": want_mac,
-            }
-        )
-        channel = self.make_request("POST", self.url, body.encode("utf8"))
+        body = {
+            "nonce": nonce,
+            "username": "bob",
+            "password": "abc123",
+            "admin": True,
+            "user_type": UserTypes.SUPPORT,
+            "mac": want_mac,
+        }
+        channel = self.make_request("POST", self.url, body)
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual("@bob:test", channel.json_body["user_id"])
@@ -183,22 +178,20 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
         want_mac.update(nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin")
         want_mac = want_mac.hexdigest()
 
-        body = json.dumps(
-            {
-                "nonce": nonce,
-                "username": "bob",
-                "password": "abc123",
-                "admin": True,
-                "mac": want_mac,
-            }
-        )
-        channel = self.make_request("POST", self.url, body.encode("utf8"))
+        body = {
+            "nonce": nonce,
+            "username": "bob",
+            "password": "abc123",
+            "admin": True,
+            "mac": want_mac,
+        }
+        channel = self.make_request("POST", self.url, body)
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual("@bob:test", channel.json_body["user_id"])
 
         # Now, try and reuse it
-        channel = self.make_request("POST", self.url, body.encode("utf8"))
+        channel = self.make_request("POST", self.url, body)
 
         self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual("unrecognised nonce", channel.json_body["error"])
@@ -218,9 +211,8 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
         # Nonce check
         #
 
-        # Must be present
-        body = json.dumps({})
-        channel = self.make_request("POST", self.url, body.encode("utf8"))
+        # Must be an empty body present
+        channel = self.make_request("POST", self.url, {})
 
         self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual("nonce must be specified", channel.json_body["error"])
@@ -230,29 +222,28 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
         #
 
         # Must be present
-        body = json.dumps({"nonce": nonce()})
-        channel = self.make_request("POST", self.url, body.encode("utf8"))
+        channel = self.make_request("POST", self.url, {"nonce": nonce()})
 
         self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual("username must be specified", channel.json_body["error"])
 
         # Must be a string
-        body = json.dumps({"nonce": nonce(), "username": 1234})
-        channel = self.make_request("POST", self.url, body.encode("utf8"))
+        body = {"nonce": nonce(), "username": 1234}
+        channel = self.make_request("POST", self.url, body)
 
         self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual("Invalid username", channel.json_body["error"])
 
         # Must not have null bytes
-        body = json.dumps({"nonce": nonce(), "username": "abcd\u0000"})
-        channel = self.make_request("POST", self.url, body.encode("utf8"))
+        body = {"nonce": nonce(), "username": "abcd\u0000"}
+        channel = self.make_request("POST", self.url, body)
 
         self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual("Invalid username", channel.json_body["error"])
 
         # Must not have null bytes
-        body = json.dumps({"nonce": nonce(), "username": "a" * 1000})
-        channel = self.make_request("POST", self.url, body.encode("utf8"))
+        body = {"nonce": nonce(), "username": "a" * 1000}
+        channel = self.make_request("POST", self.url, body)
 
         self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual("Invalid username", channel.json_body["error"])
@@ -262,29 +253,29 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
         #
 
         # Must be present
-        body = json.dumps({"nonce": nonce(), "username": "a"})
-        channel = self.make_request("POST", self.url, body.encode("utf8"))
+        body = {"nonce": nonce(), "username": "a"}
+        channel = self.make_request("POST", self.url, body)
 
         self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual("password must be specified", channel.json_body["error"])
 
         # Must be a string
-        body = json.dumps({"nonce": nonce(), "username": "a", "password": 1234})
-        channel = self.make_request("POST", self.url, body.encode("utf8"))
+        body = {"nonce": nonce(), "username": "a", "password": 1234}
+        channel = self.make_request("POST", self.url, body)
 
         self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual("Invalid password", channel.json_body["error"])
 
         # Must not have null bytes
-        body = json.dumps({"nonce": nonce(), "username": "a", "password": "abcd\u0000"})
-        channel = self.make_request("POST", self.url, body.encode("utf8"))
+        body = {"nonce": nonce(), "username": "a", "password": "abcd\u0000"}
+        channel = self.make_request("POST", self.url, body)
 
         self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual("Invalid password", channel.json_body["error"])
 
         # Super long
-        body = json.dumps({"nonce": nonce(), "username": "a", "password": "A" * 1000})
-        channel = self.make_request("POST", self.url, body.encode("utf8"))
+        body = {"nonce": nonce(), "username": "a", "password": "A" * 1000}
+        channel = self.make_request("POST", self.url, body)
 
         self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual("Invalid password", channel.json_body["error"])
@@ -294,15 +285,13 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
         #
 
         # Invalid user_type
-        body = json.dumps(
-            {
-                "nonce": nonce(),
-                "username": "a",
-                "password": "1234",
-                "user_type": "invalid",
-            }
-        )
-        channel = self.make_request("POST", self.url, body.encode("utf8"))
+        body = {
+            "nonce": nonce(),
+            "username": "a",
+            "password": "1234",
+            "user_type": "invalid",
+        }
+        channel = self.make_request("POST", self.url, body)
 
         self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual("Invalid user type", channel.json_body["error"])
@@ -320,10 +309,14 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
         want_mac.update(nonce.encode("ascii") + b"\x00bob1\x00abc123\x00notadmin")
         want_mac = want_mac.hexdigest()
 
-        body = json.dumps(
-            {"nonce": nonce, "username": "bob1", "password": "abc123", "mac": want_mac}
-        )
-        channel = self.make_request("POST", self.url, body.encode("utf8"))
+        body = {
+            "nonce": nonce,
+            "username": "bob1",
+            "password": "abc123",
+            "mac": want_mac,
+        }
+
+        channel = self.make_request("POST", self.url, body)
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual("@bob1:test", channel.json_body["user_id"])
@@ -340,16 +333,14 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
         want_mac.update(nonce.encode("ascii") + b"\x00bob2\x00abc123\x00notadmin")
         want_mac = want_mac.hexdigest()
 
-        body = json.dumps(
-            {
-                "nonce": nonce,
-                "username": "bob2",
-                "displayname": None,
-                "password": "abc123",
-                "mac": want_mac,
-            }
-        )
-        channel = self.make_request("POST", self.url, body.encode("utf8"))
+        body = {
+            "nonce": nonce,
+            "username": "bob2",
+            "displayname": None,
+            "password": "abc123",
+            "mac": want_mac,
+        }
+        channel = self.make_request("POST", self.url, body)
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual("@bob2:test", channel.json_body["user_id"])
@@ -366,22 +357,20 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
         want_mac.update(nonce.encode("ascii") + b"\x00bob3\x00abc123\x00notadmin")
         want_mac = want_mac.hexdigest()
 
-        body = json.dumps(
-            {
-                "nonce": nonce,
-                "username": "bob3",
-                "displayname": "",
-                "password": "abc123",
-                "mac": want_mac,
-            }
-        )
-        channel = self.make_request("POST", self.url, body.encode("utf8"))
+        body = {
+            "nonce": nonce,
+            "username": "bob3",
+            "displayname": "",
+            "password": "abc123",
+            "mac": want_mac,
+        }
+        channel = self.make_request("POST", self.url, body)
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual("@bob3:test", channel.json_body["user_id"])
 
         channel = self.make_request("GET", "/profile/@bob3:test/displayname")
-        self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(404, channel.code, msg=channel.json_body)
 
         # set displayname
         channel = self.make_request("GET", self.url)
@@ -391,16 +380,14 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
         want_mac.update(nonce.encode("ascii") + b"\x00bob4\x00abc123\x00notadmin")
         want_mac = want_mac.hexdigest()
 
-        body = json.dumps(
-            {
-                "nonce": nonce,
-                "username": "bob4",
-                "displayname": "Bob's Name",
-                "password": "abc123",
-                "mac": want_mac,
-            }
-        )
-        channel = self.make_request("POST", self.url, body.encode("utf8"))
+        body = {
+            "nonce": nonce,
+            "username": "bob4",
+            "displayname": "Bob's Name",
+            "password": "abc123",
+            "mac": want_mac,
+        }
+        channel = self.make_request("POST", self.url, body)
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual("@bob4:test", channel.json_body["user_id"])
@@ -440,17 +427,15 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
         )
         want_mac = want_mac.hexdigest()
 
-        body = json.dumps(
-            {
-                "nonce": nonce,
-                "username": "bob",
-                "password": "abc123",
-                "admin": True,
-                "user_type": UserTypes.SUPPORT,
-                "mac": want_mac,
-            }
-        )
-        channel = self.make_request("POST", self.url, body.encode("utf8"))
+        body = {
+            "nonce": nonce,
+            "username": "bob",
+            "password": "abc123",
+            "admin": True,
+            "user_type": UserTypes.SUPPORT,
+            "mac": want_mac,
+        }
+        channel = self.make_request("POST", self.url, body)
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual("@bob:test", channel.json_body["user_id"])
@@ -993,12 +978,11 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
         """
         If parameter `erase` is not boolean, return an error
         """
-        body = json.dumps({"erase": "False"})
 
         channel = self.make_request(
             "POST",
             self.url,
-            content=body.encode(encoding="utf_8"),
+            content={"erase": "False"},
             access_token=self.admin_user_tok,
         )
 
@@ -2201,7 +2185,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
         """
         channel = self.make_request("GET", self.url, b"{}")
 
-        self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(401, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
 
     def test_requester_is_no_admin(self):
@@ -2216,7 +2200,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
             access_token=other_user_token,
         )
 
-        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(403, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
 
     def test_user_does_not_exist(self):
@@ -2359,7 +2343,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
         """
         channel = self.make_request("GET", self.url, b"{}")
 
-        self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(401, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
 
     def test_requester_is_no_admin(self):
@@ -2374,7 +2358,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
             access_token=other_user_token,
         )
 
-        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(403, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
 
     def test_user_does_not_exist(self):
@@ -3073,7 +3057,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
         """Try to login as a user without authentication."""
         channel = self.make_request("POST", self.url, b"{}")
 
-        self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(401, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
 
     def test_not_admin(self):
@@ -3082,7 +3066,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
             "POST", self.url, b"{}", access_token=self.other_user_tok
         )
 
-        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(403, channel.code, msg=channel.json_body)
 
     def test_send_event(self):
         """Test that sending event as a user works."""
@@ -3127,7 +3111,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
 
         # The puppet token should no longer work
         channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
-        self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(401, channel.code, msg=channel.json_body)
 
         # .. but the real user's tokens should still work
         channel = self.make_request(
@@ -3160,7 +3144,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             "GET", "devices", b"{}", access_token=self.other_user_tok
         )
-        self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(401, channel.code, msg=channel.json_body)
 
     def test_admin_logout_all(self):
         """Tests that the admin user calling `/logout/all` does expire the
@@ -3181,7 +3165,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
 
         # The puppet token should no longer work
         channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
-        self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(401, channel.code, msg=channel.json_body)
 
         # .. but the real user's tokens should still work
         channel = self.make_request(
@@ -3242,6 +3226,13 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
         self.helper.join(room_id, user=self.other_user, tok=puppet_token)
 
 
+@parameterized_class(
+    ("url_prefix",),
+    [
+        ("/_synapse/admin/v1/whois/%s",),
+        ("/_matrix/client/r0/admin/whois/%s",),
+    ],
+)
 class WhoisRestTestCase(unittest.HomeserverTestCase):
 
     servlets = [
@@ -3254,21 +3245,14 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
         self.admin_user_tok = self.login("admin", "pass")
 
         self.other_user = self.register_user("user", "pass")
-        self.url1 = "/_synapse/admin/v1/whois/%s" % urllib.parse.quote(self.other_user)
-        self.url2 = "/_matrix/client/r0/admin/whois/%s" % urllib.parse.quote(
-            self.other_user
-        )
+        self.url = self.url_prefix % self.other_user
 
     def test_no_auth(self):
         """
         Try to get information of an user without authentication.
         """
-        channel = self.make_request("GET", self.url1, b"{}")
-        self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
-        self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
-
-        channel = self.make_request("GET", self.url2, b"{}")
-        self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+        channel = self.make_request("GET", self.url, b"{}")
+        self.assertEqual(401, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
 
     def test_requester_is_not_admin(self):
@@ -3280,38 +3264,21 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "GET",
-            self.url1,
-            access_token=other_user2_token,
-        )
-        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
-        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
-
-        channel = self.make_request(
-            "GET",
-            self.url2,
+            self.url,
             access_token=other_user2_token,
         )
-        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(403, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
 
     def test_user_is_not_local(self):
         """
         Tests that a lookup for a user that is not a local returns a 400
         """
-        url1 = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain"
-        url2 = "/_matrix/client/r0/admin/whois/@unknown_person:unknown_domain"
-
-        channel = self.make_request(
-            "GET",
-            url1,
-            access_token=self.admin_user_tok,
-        )
-        self.assertEqual(400, channel.code, msg=channel.json_body)
-        self.assertEqual("Can only whois a local user", channel.json_body["error"])
+        url = self.url_prefix % "@unknown_person:unknown_domain"
 
         channel = self.make_request(
             "GET",
-            url2,
+            url,
             access_token=self.admin_user_tok,
         )
         self.assertEqual(400, channel.code, msg=channel.json_body)
@@ -3323,16 +3290,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
         """
         channel = self.make_request(
             "GET",
-            self.url1,
-            access_token=self.admin_user_tok,
-        )
-        self.assertEqual(200, channel.code, msg=channel.json_body)
-        self.assertEqual(self.other_user, channel.json_body["user_id"])
-        self.assertIn("devices", channel.json_body)
-
-        channel = self.make_request(
-            "GET",
-            self.url2,
+            self.url,
             access_token=self.admin_user_tok,
         )
         self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -3347,16 +3305,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "GET",
-            self.url1,
-            access_token=other_user_token,
-        )
-        self.assertEqual(200, channel.code, msg=channel.json_body)
-        self.assertEqual(self.other_user, channel.json_body["user_id"])
-        self.assertIn("devices", channel.json_body)
-
-        channel = self.make_request(
-            "GET",
-            self.url2,
+            self.url,
             access_token=other_user_token,
         )
         self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -3388,7 +3337,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
         Try to get information of an user without authentication.
         """
         channel = self.make_request("POST", self.url)
-        self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(401, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
 
     def test_requester_is_not_admin(self):
@@ -3398,7 +3347,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
         other_user_token = self.login("user", "pass")
 
         channel = self.make_request("POST", self.url, access_token=other_user_token)
-        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(403, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
 
     def test_user_is_not_local(self):
@@ -3447,84 +3396,41 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
             % urllib.parse.quote(self.other_user)
         )
 
-    def test_no_auth(self):
+    @parameterized.expand(["GET", "POST", "DELETE"])
+    def test_no_auth(self, method: str):
         """
         Try to get information of a user without authentication.
         """
-        channel = self.make_request("GET", self.url, b"{}")
-
-        self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
-        self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
-
-        channel = self.make_request("POST", self.url, b"{}")
+        channel = self.make_request(method, self.url, b"{}")
 
-        self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
-        self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
-
-        channel = self.make_request("DELETE", self.url, b"{}")
-
-        self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(401, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
 
-    def test_requester_is_no_admin(self):
+    @parameterized.expand(["GET", "POST", "DELETE"])
+    def test_requester_is_no_admin(self, method: str):
         """
         If the user is not a server admin, an error is returned.
         """
         other_user_token = self.login("user", "pass")
 
         channel = self.make_request(
-            "GET",
-            self.url,
-            access_token=other_user_token,
-        )
-
-        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
-        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
-
-        channel = self.make_request(
-            "POST",
-            self.url,
-            access_token=other_user_token,
-        )
-
-        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
-        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
-
-        channel = self.make_request(
-            "DELETE",
+            method,
             self.url,
             access_token=other_user_token,
         )
 
-        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(403, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
 
-    def test_user_does_not_exist(self):
+    @parameterized.expand(["GET", "POST", "DELETE"])
+    def test_user_does_not_exist(self, method: str):
         """
         Tests that a lookup for a user that does not exist returns a 404
         """
         url = "/_synapse/admin/v1/users/@unknown_person:test/override_ratelimit"
 
         channel = self.make_request(
-            "GET",
-            url,
-            access_token=self.admin_user_tok,
-        )
-
-        self.assertEqual(404, channel.code, msg=channel.json_body)
-        self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
-
-        channel = self.make_request(
-            "POST",
-            url,
-            access_token=self.admin_user_tok,
-        )
-
-        self.assertEqual(404, channel.code, msg=channel.json_body)
-        self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
-
-        channel = self.make_request(
-            "DELETE",
+            method,
             url,
             access_token=self.admin_user_tok,
         )
@@ -3532,7 +3438,14 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
         self.assertEqual(404, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
 
-    def test_user_is_not_local(self):
+    @parameterized.expand(
+        [
+            ("GET", "Can only look up local users"),
+            ("POST", "Only local users can be ratelimited"),
+            ("DELETE", "Only local users can be ratelimited"),
+        ]
+    )
+    def test_user_is_not_local(self, method: str, error_msg: str):
         """
         Tests that a lookup for a user that is not a local returns a 400
         """
@@ -3541,35 +3454,13 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
         )
 
         channel = self.make_request(
-            "GET",
-            url,
-            access_token=self.admin_user_tok,
-        )
-
-        self.assertEqual(400, channel.code, msg=channel.json_body)
-        self.assertEqual("Can only look up local users", channel.json_body["error"])
-
-        channel = self.make_request(
-            "POST",
-            url,
-            access_token=self.admin_user_tok,
-        )
-
-        self.assertEqual(400, channel.code, msg=channel.json_body)
-        self.assertEqual(
-            "Only local users can be ratelimited", channel.json_body["error"]
-        )
-
-        channel = self.make_request(
-            "DELETE",
+            method,
             url,
             access_token=self.admin_user_tok,
         )
 
         self.assertEqual(400, channel.code, msg=channel.json_body)
-        self.assertEqual(
-            "Only local users can be ratelimited", channel.json_body["error"]
-        )
+        self.assertEqual(error_msg, channel.json_body["error"])
 
     def test_invalid_parameter(self):
         """
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index 02b5e9a8d0..3c7d49f0b4 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -13,15 +13,15 @@
 # limitations under the License.
 
 import itertools
-import json
-import urllib
-from typing import Optional
+import urllib.parse
+from typing import Dict, List, Optional, Tuple
 
 from synapse.api.constants import EventTypes, RelationTypes
 from synapse.rest import admin
 from synapse.rest.client import login, register, relations, room
 
 from tests import unittest
+from tests.server import FakeChannel
 
 
 class RelationsTestCase(unittest.HomeserverTestCase):
@@ -34,16 +34,16 @@ class RelationsTestCase(unittest.HomeserverTestCase):
     ]
     hijack_auth = False
 
-    def make_homeserver(self, reactor, clock):
+    def default_config(self) -> dict:
         # We need to enable msc1849 support for aggregations
-        config = self.default_config()
+        config = super().default_config()
         config["experimental_msc1849_support_enabled"] = True
 
         # We enable frozen dicts as relations/edits change event contents, so we
         # want to test that we don't modify the events in the caches.
         config["use_frozen_dicts"] = True
 
-        return self.setup_test_homeserver(config=config)
+        return config
 
     def prepare(self, reactor, clock, hs):
         self.user_id, self.user_token = self._create_user("alice")
@@ -146,8 +146,8 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             self.assertEquals(200, channel.code, channel.json_body)
             expected_event_ids.append(channel.json_body["event_id"])
 
-        prev_token = None
-        found_event_ids = []
+        prev_token: Optional[str] = None
+        found_event_ids: List[str] = []
         for _ in range(20):
             from_token = ""
             if prev_token:
@@ -203,8 +203,8 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             idx += 1
             idx %= len(access_tokens)
 
-        prev_token = None
-        found_groups = {}
+        prev_token: Optional[str] = None
+        found_groups: Dict[str, int] = {}
         for _ in range(20):
             from_token = ""
             if prev_token:
@@ -270,8 +270,8 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
         self.assertEquals(200, channel.code, channel.json_body)
 
-        prev_token = None
-        found_event_ids = []
+        prev_token: Optional[str] = None
+        found_event_ids: List[str] = []
         encoded_key = urllib.parse.quote_plus("👍".encode())
         for _ in range(20):
             from_token = ""
@@ -677,24 +677,23 @@ class RelationsTestCase(unittest.HomeserverTestCase):
 
     def _send_relation(
         self,
-        relation_type,
-        event_type,
-        key=None,
+        relation_type: str,
+        event_type: str,
+        key: Optional[str] = None,
         content: Optional[dict] = None,
-        access_token=None,
-        parent_id=None,
-    ):
+        access_token: Optional[str] = None,
+        parent_id: Optional[str] = None,
+    ) -> FakeChannel:
         """Helper function to send a relation pointing at `self.parent_id`
 
         Args:
-            relation_type (str): One of `RelationTypes`
-            event_type (str): The type of the event to create
-            parent_id (str): The event_id this relation relates to. If None, then self.parent_id
-            key (str|None): The aggregation key used for m.annotation relation
-                type.
-            content(dict|None): The content of the created event.
-            access_token (str|None): The access token used to send the relation,
-                defaults to `self.user_token`
+            relation_type: One of `RelationTypes`
+            event_type: The type of the event to create
+            key: The aggregation key used for m.annotation relation type.
+            content: The content of the created event.
+            access_token: The access token used to send the relation, defaults
+                to `self.user_token`
+            parent_id: The event_id this relation relates to. If None, then self.parent_id
 
         Returns:
             FakeChannel
@@ -712,12 +711,12 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             "POST",
             "/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s"
             % (self.room, original_id, relation_type, event_type, query),
-            json.dumps(content or {}).encode("utf-8"),
+            content or {},
             access_token=access_token,
         )
         return channel
 
-    def _create_user(self, localpart):
+    def _create_user(self, localpart: str) -> Tuple[str, str]:
         user_id = self.register_user(localpart, "abc123")
         access_token = self.login(localpart, "abc123")
 
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index 38ac9be113..531f09c48b 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -12,25 +12,28 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import threading
-from typing import Dict
+from typing import TYPE_CHECKING, Dict, Optional, Tuple
 from unittest.mock import Mock
 
 from synapse.api.constants import EventTypes
+from synapse.api.errors import SynapseError
 from synapse.events import EventBase
 from synapse.events.third_party_rules import load_legacy_third_party_event_rules
-from synapse.module_api import ModuleApi
 from synapse.rest import admin
 from synapse.rest.client import login, room
-from synapse.types import Requester, StateMap
+from synapse.types import JsonDict, Requester, StateMap
 from synapse.util.frozenutils import unfreeze
 
 from tests import unittest
 
+if TYPE_CHECKING:
+    from synapse.module_api import ModuleApi
+
 thread_local = threading.local()
 
 
 class LegacyThirdPartyRulesTestModule:
-    def __init__(self, config: Dict, module_api: ModuleApi):
+    def __init__(self, config: Dict, module_api: "ModuleApi"):
         # keep a record of the "current" rules module, so that the test can patch
         # it if desired.
         thread_local.rules_module = self
@@ -50,7 +53,7 @@ class LegacyThirdPartyRulesTestModule:
 
 
 class LegacyDenyNewRooms(LegacyThirdPartyRulesTestModule):
-    def __init__(self, config: Dict, module_api: ModuleApi):
+    def __init__(self, config: Dict, module_api: "ModuleApi"):
         super().__init__(config, module_api)
 
     def on_create_room(
@@ -60,7 +63,7 @@ class LegacyDenyNewRooms(LegacyThirdPartyRulesTestModule):
 
 
 class LegacyChangeEvents(LegacyThirdPartyRulesTestModule):
-    def __init__(self, config: Dict, module_api: ModuleApi):
+    def __init__(self, config: Dict, module_api: "ModuleApi"):
         super().__init__(config, module_api)
 
     async def check_event_allowed(self, event: EventBase, state: StateMap[EventBase]):
@@ -136,6 +139,47 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
         )
         self.assertEquals(channel.result["code"], b"403", channel.result)
 
+    def test_third_party_rules_workaround_synapse_errors_pass_through(self):
+        """
+        Tests that the workaround introduced by https://github.com/matrix-org/synapse/pull/11042
+        is functional: that SynapseErrors are passed through from check_event_allowed
+        and bubble up to the web resource.
+
+        NEW MODULES SHOULD NOT MAKE USE OF THIS WORKAROUND!
+        This is a temporary workaround!
+        """
+
+        class NastyHackException(SynapseError):
+            def error_dict(self):
+                """
+                This overrides SynapseError's `error_dict` to nastily inject
+                JSON into the error response.
+                """
+                result = super().error_dict()
+                result["nasty"] = "very"
+                return result
+
+        # add a callback that will raise our hacky exception
+        async def check(ev, state) -> Tuple[bool, Optional[JsonDict]]:
+            raise NastyHackException(429, "message")
+
+        self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check]
+
+        # Make a request
+        channel = self.make_request(
+            "PUT",
+            "/_matrix/client/r0/rooms/%s/send/foo.bar.forbidden/2" % self.room_id,
+            {},
+            access_token=self.tok,
+        )
+        # Check the error code
+        self.assertEquals(channel.result["code"], b"429", channel.result)
+        # Check the JSON body has had the `nasty` key injected
+        self.assertEqual(
+            channel.json_body,
+            {"errcode": "M_UNKNOWN", "error": "message", "nasty": "very"},
+        )
+
     def test_cannot_modify_event(self):
         """cannot accidentally modify an event before it is persisted"""
 
diff --git a/tests/rest/media/v1/test_filepath.py b/tests/rest/media/v1/test_filepath.py
new file mode 100644
index 0000000000..09504a485f
--- /dev/null
+++ b/tests/rest/media/v1/test_filepath.py
@@ -0,0 +1,238 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from synapse.rest.media.v1.filepath import MediaFilePaths
+
+from tests import unittest
+
+
+class MediaFilePathsTestCase(unittest.TestCase):
+    def setUp(self):
+        super().setUp()
+
+        self.filepaths = MediaFilePaths("/media_store")
+
+    def test_local_media_filepath(self):
+        """Test local media paths"""
+        self.assertEqual(
+            self.filepaths.local_media_filepath_rel("GerZNDnDZVjsOtardLuwfIBg"),
+            "local_content/Ge/rZ/NDnDZVjsOtardLuwfIBg",
+        )
+        self.assertEqual(
+            self.filepaths.local_media_filepath("GerZNDnDZVjsOtardLuwfIBg"),
+            "/media_store/local_content/Ge/rZ/NDnDZVjsOtardLuwfIBg",
+        )
+
+    def test_local_media_thumbnail(self):
+        """Test local media thumbnail paths"""
+        self.assertEqual(
+            self.filepaths.local_media_thumbnail_rel(
+                "GerZNDnDZVjsOtardLuwfIBg", 800, 600, "image/jpeg", "scale"
+            ),
+            "local_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale",
+        )
+        self.assertEqual(
+            self.filepaths.local_media_thumbnail(
+                "GerZNDnDZVjsOtardLuwfIBg", 800, 600, "image/jpeg", "scale"
+            ),
+            "/media_store/local_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale",
+        )
+
+    def test_local_media_thumbnail_dir(self):
+        """Test local media thumbnail directory paths"""
+        self.assertEqual(
+            self.filepaths.local_media_thumbnail_dir("GerZNDnDZVjsOtardLuwfIBg"),
+            "/media_store/local_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg",
+        )
+
+    def test_remote_media_filepath(self):
+        """Test remote media paths"""
+        self.assertEqual(
+            self.filepaths.remote_media_filepath_rel(
+                "example.com", "GerZNDnDZVjsOtardLuwfIBg"
+            ),
+            "remote_content/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg",
+        )
+        self.assertEqual(
+            self.filepaths.remote_media_filepath(
+                "example.com", "GerZNDnDZVjsOtardLuwfIBg"
+            ),
+            "/media_store/remote_content/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg",
+        )
+
+    def test_remote_media_thumbnail(self):
+        """Test remote media thumbnail paths"""
+        self.assertEqual(
+            self.filepaths.remote_media_thumbnail_rel(
+                "example.com",
+                "GerZNDnDZVjsOtardLuwfIBg",
+                800,
+                600,
+                "image/jpeg",
+                "scale",
+            ),
+            "remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale",
+        )
+        self.assertEqual(
+            self.filepaths.remote_media_thumbnail(
+                "example.com",
+                "GerZNDnDZVjsOtardLuwfIBg",
+                800,
+                600,
+                "image/jpeg",
+                "scale",
+            ),
+            "/media_store/remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale",
+        )
+
+    def test_remote_media_thumbnail_legacy(self):
+        """Test old-style remote media thumbnail paths"""
+        self.assertEqual(
+            self.filepaths.remote_media_thumbnail_rel_legacy(
+                "example.com", "GerZNDnDZVjsOtardLuwfIBg", 800, 600, "image/jpeg"
+            ),
+            "remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg",
+        )
+
+    def test_remote_media_thumbnail_dir(self):
+        """Test remote media thumbnail directory paths"""
+        self.assertEqual(
+            self.filepaths.remote_media_thumbnail_dir(
+                "example.com", "GerZNDnDZVjsOtardLuwfIBg"
+            ),
+            "/media_store/remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg",
+        )
+
+    def test_url_cache_filepath(self):
+        """Test URL cache paths"""
+        self.assertEqual(
+            self.filepaths.url_cache_filepath_rel("2020-01-02_GerZNDnDZVjsOtar"),
+            "url_cache/2020-01-02/GerZNDnDZVjsOtar",
+        )
+        self.assertEqual(
+            self.filepaths.url_cache_filepath("2020-01-02_GerZNDnDZVjsOtar"),
+            "/media_store/url_cache/2020-01-02/GerZNDnDZVjsOtar",
+        )
+
+    def test_url_cache_filepath_legacy(self):
+        """Test old-style URL cache paths"""
+        self.assertEqual(
+            self.filepaths.url_cache_filepath_rel("GerZNDnDZVjsOtardLuwfIBg"),
+            "url_cache/Ge/rZ/NDnDZVjsOtardLuwfIBg",
+        )
+        self.assertEqual(
+            self.filepaths.url_cache_filepath("GerZNDnDZVjsOtardLuwfIBg"),
+            "/media_store/url_cache/Ge/rZ/NDnDZVjsOtardLuwfIBg",
+        )
+
+    def test_url_cache_filepath_dirs_to_delete(self):
+        """Test URL cache cleanup paths"""
+        self.assertEqual(
+            self.filepaths.url_cache_filepath_dirs_to_delete(
+                "2020-01-02_GerZNDnDZVjsOtar"
+            ),
+            ["/media_store/url_cache/2020-01-02"],
+        )
+
+    def test_url_cache_filepath_dirs_to_delete_legacy(self):
+        """Test old-style URL cache cleanup paths"""
+        self.assertEqual(
+            self.filepaths.url_cache_filepath_dirs_to_delete(
+                "GerZNDnDZVjsOtardLuwfIBg"
+            ),
+            [
+                "/media_store/url_cache/Ge/rZ",
+                "/media_store/url_cache/Ge",
+            ],
+        )
+
+    def test_url_cache_thumbnail(self):
+        """Test URL cache thumbnail paths"""
+        self.assertEqual(
+            self.filepaths.url_cache_thumbnail_rel(
+                "2020-01-02_GerZNDnDZVjsOtar", 800, 600, "image/jpeg", "scale"
+            ),
+            "url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar/800-600-image-jpeg-scale",
+        )
+        self.assertEqual(
+            self.filepaths.url_cache_thumbnail(
+                "2020-01-02_GerZNDnDZVjsOtar", 800, 600, "image/jpeg", "scale"
+            ),
+            "/media_store/url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar/800-600-image-jpeg-scale",
+        )
+
+    def test_url_cache_thumbnail_legacy(self):
+        """Test old-style URL cache thumbnail paths"""
+        self.assertEqual(
+            self.filepaths.url_cache_thumbnail_rel(
+                "GerZNDnDZVjsOtardLuwfIBg", 800, 600, "image/jpeg", "scale"
+            ),
+            "url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale",
+        )
+        self.assertEqual(
+            self.filepaths.url_cache_thumbnail(
+                "GerZNDnDZVjsOtardLuwfIBg", 800, 600, "image/jpeg", "scale"
+            ),
+            "/media_store/url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale",
+        )
+
+    def test_url_cache_thumbnail_directory(self):
+        """Test URL cache thumbnail directory paths"""
+        self.assertEqual(
+            self.filepaths.url_cache_thumbnail_directory_rel(
+                "2020-01-02_GerZNDnDZVjsOtar"
+            ),
+            "url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar",
+        )
+        self.assertEqual(
+            self.filepaths.url_cache_thumbnail_directory("2020-01-02_GerZNDnDZVjsOtar"),
+            "/media_store/url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar",
+        )
+
+    def test_url_cache_thumbnail_directory_legacy(self):
+        """Test old-style URL cache thumbnail directory paths"""
+        self.assertEqual(
+            self.filepaths.url_cache_thumbnail_directory_rel(
+                "GerZNDnDZVjsOtardLuwfIBg"
+            ),
+            "url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg",
+        )
+        self.assertEqual(
+            self.filepaths.url_cache_thumbnail_directory("GerZNDnDZVjsOtardLuwfIBg"),
+            "/media_store/url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg",
+        )
+
+    def test_url_cache_thumbnail_dirs_to_delete(self):
+        """Test URL cache thumbnail cleanup paths"""
+        self.assertEqual(
+            self.filepaths.url_cache_thumbnail_dirs_to_delete(
+                "2020-01-02_GerZNDnDZVjsOtar"
+            ),
+            [
+                "/media_store/url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar",
+                "/media_store/url_cache_thumbnails/2020-01-02",
+            ],
+        )
+
+    def test_url_cache_thumbnail_dirs_to_delete_legacy(self):
+        """Test old-style URL cache thumbnail cleanup paths"""
+        self.assertEqual(
+            self.filepaths.url_cache_thumbnail_dirs_to_delete(
+                "GerZNDnDZVjsOtardLuwfIBg"
+            ),
+            [
+                "/media_store/url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg",
+                "/media_store/url_cache_thumbnails/Ge/rZ",
+                "/media_store/url_cache_thumbnails/Ge",
+            ],
+        )
diff --git a/tests/rest/media/v1/test_oembed.py b/tests/rest/media/v1/test_oembed.py
new file mode 100644
index 0000000000..048d0ca44a
--- /dev/null
+++ b/tests/rest/media/v1/test_oembed.py
@@ -0,0 +1,51 @@
+#  Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+#  Licensed under the Apache License, Version 2.0 (the "License");
+#  you may not use this file except in compliance with the License.
+#  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+#  limitations under the License.
+
+import json
+
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.rest.media.v1.oembed import OEmbedProvider
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
+
+from tests.unittest import HomeserverTestCase
+
+
+class OEmbedTests(HomeserverTestCase):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer):
+        self.oembed = OEmbedProvider(homeserver)
+
+    def parse_response(self, response: JsonDict):
+        return self.oembed.parse_oembed_response(
+            "https://test", json.dumps(response).encode("utf-8")
+        )
+
+    def test_version(self):
+        """Accept versions that are similar to 1.0 as a string or int (or missing)."""
+        for version in ("1.0", 1.0, 1):
+            result = self.parse_response({"version": version, "type": "link"})
+            # An empty Open Graph response is an error, ensure the URL is included.
+            self.assertIn("og:url", result.open_graph_result)
+
+        # A missing version should be treated as 1.0.
+        result = self.parse_response({"type": "link"})
+        self.assertIn("og:url", result.open_graph_result)
+
+        # Invalid versions should be rejected.
+        for version in ("2.0", "1", 1.1, 0, None, {}, []):
+            result = self.parse_response({"version": version, "type": "link"})
+            # An empty Open Graph response is an error, ensure the URL is included.
+            self.assertEqual({}, result.open_graph_result)
diff --git a/tests/server.py b/tests/server.py
index 64645651ce..103351b487 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -1,3 +1,17 @@
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
 import json
 import logging
 from collections import deque
@@ -27,9 +41,10 @@ from twisted.python.failure import Failure
 from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
 from twisted.web.http_headers import Headers
 from twisted.web.resource import IResource
-from twisted.web.server import Site
+from twisted.web.server import Request, Site
 
 from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
 from synapse.util import Clock
 
 from tests.utils import setup_test_homeserver as _sth
@@ -198,14 +213,14 @@ class FakeSite:
 def make_request(
     reactor,
     site: Union[Site, FakeSite],
-    method,
-    path,
-    content=b"",
-    access_token=None,
-    request=SynapseRequest,
-    shorthand=True,
-    federation_auth_origin=None,
-    content_is_form=False,
+    method: Union[bytes, str],
+    path: Union[bytes, str],
+    content: Union[bytes, str, JsonDict] = b"",
+    access_token: Optional[str] = None,
+    request: Request = SynapseRequest,
+    shorthand: bool = True,
+    federation_auth_origin: Optional[bytes] = None,
+    content_is_form: bool = False,
     await_result: bool = True,
     custom_headers: Optional[
         Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
@@ -218,26 +233,23 @@ def make_request(
     Returns the fake Channel object which records the response to the request.
 
     Args:
+        reactor:
         site: The twisted Site to use to render the request
-
-        method (bytes/unicode): The HTTP request method ("verb").
-        path (bytes/unicode): The HTTP path, suitably URL encoded (e.g.
-        escaped UTF-8 & spaces and such).
-        content (bytes or dict): The body of the request. JSON-encoded, if
-        a dict.
+        method: The HTTP request method ("verb").
+        path: The HTTP path, suitably URL encoded (e.g. escaped UTF-8 & spaces and such).
+        content: The body of the request. JSON-encoded, if a str of bytes.
+        access_token: The access token to add as authorization for the request.
+        request: The request class to create.
         shorthand: Whether to try and be helpful and prefix the given URL
-        with the usual REST API path, if it doesn't contain it.
-        federation_auth_origin (bytes|None): if set to not-None, we will add a fake
+            with the usual REST API path, if it doesn't contain it.
+        federation_auth_origin: if set to not-None, we will add a fake
             Authorization header pretenting to be the given server name.
         content_is_form: Whether the content is URL encoded form data. Adds the
             'Content-Type': 'application/x-www-form-urlencoded' header.
-
-        custom_headers: (name, value) pairs to add as request headers
-
         await_result: whether to wait for the request to complete rendering. If true,
              will pump the reactor until the the renderer tells the channel the request
              is finished.
-
+        custom_headers: (name, value) pairs to add as request headers
         client_ip: The IP to use as the requesting IP. Useful for testing
             ratelimiting.
 
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index be3ed64f5e..37cf7bb232 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -11,7 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import Any, Dict, List, Set, Tuple
+from typing import Any, Dict, Set, Tuple
 from unittest import mock
 from unittest.mock import Mock, patch
 
@@ -42,18 +42,7 @@ class GetUserDirectoryTables:
     def __init__(self, store: DataStore):
         self.store = store
 
-    def _compress_shared(
-        self, shared: List[Dict[str, str]]
-    ) -> Set[Tuple[str, str, str]]:
-        """
-        Compress a list of users who share rooms dicts to a list of tuples.
-        """
-        r = set()
-        for i in shared:
-            r.add((i["user_id"], i["other_user_id"], i["room_id"]))
-        return r
-
-    async def get_users_in_public_rooms(self) -> List[Tuple[str, str]]:
+    async def get_users_in_public_rooms(self) -> Set[Tuple[str, str]]:
         """Fetch the entire `users_in_public_rooms` table.
 
         Returns a list of tuples (user_id, room_id) where room_id is public and
@@ -63,24 +52,27 @@ class GetUserDirectoryTables:
             "users_in_public_rooms", None, ("user_id", "room_id")
         )
 
-        retval = []
+        retval = set()
         for i in r:
-            retval.append((i["user_id"], i["room_id"]))
+            retval.add((i["user_id"], i["room_id"]))
         return retval
 
-    async def get_users_who_share_private_rooms(self) -> List[Dict[str, str]]:
+    async def get_users_who_share_private_rooms(self) -> Set[Tuple[str, str, str]]:
         """Fetch the entire `users_who_share_private_rooms` table.
 
-        Returns a dict containing "user_id", "other_user_id" and "room_id" keys.
-        The dicts can be flattened to Tuples with the `_compress_shared` method.
-        (This seems a little awkward---maybe we could clean this up.)
+        Returns a set of tuples (user_id, other_user_id, room_id) corresponding
+        to the rows of `users_who_share_private_rooms`.
         """
 
-        return await self.store.db_pool.simple_select_list(
+        rows = await self.store.db_pool.simple_select_list(
             "users_who_share_private_rooms",
             None,
             ["user_id", "other_user_id", "room_id"],
         )
+        rv = set()
+        for row in rows:
+            rv.add((row["user_id"], row["other_user_id"], row["room_id"]))
+        return rv
 
     async def get_users_in_user_directory(self) -> Set[str]:
         """Fetch the set of users in the `user_directory` table.
@@ -113,6 +105,16 @@ class GetUserDirectoryTables:
             for row in rows
         }
 
+    async def get_tables(
+        self,
+    ) -> Tuple[Set[str], Set[Tuple[str, str]], Set[Tuple[str, str, str]]]:
+        """Multiple tests want to inspect these tables, so expose them together."""
+        return (
+            await self.get_users_in_user_directory(),
+            await self.get_users_in_public_rooms(),
+            await self.get_users_who_share_private_rooms(),
+        )
+
 
 class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
     """Ensure that rebuilding the directory writes the correct data to the DB.
@@ -166,8 +168,8 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
         )
 
         # Nothing updated yet
-        self.assertEqual(shares_private, [])
-        self.assertEqual(public_users, [])
+        self.assertEqual(shares_private, set())
+        self.assertEqual(public_users, set())
 
         # Ugh, have to reset this flag
         self.store.db_pool.updates._all_done = False
@@ -236,24 +238,15 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
         # Do the initial population of the user directory via the background update
         self._purge_and_rebuild_user_dir()
 
-        shares_private = self.get_success(
-            self.user_dir_helper.get_users_who_share_private_rooms()
-        )
-        public_users = self.get_success(
-            self.user_dir_helper.get_users_in_public_rooms()
+        users, in_public, in_private = self.get_success(
+            self.user_dir_helper.get_tables()
         )
 
         # User 1 and User 2 are in the same public room
-        self.assertEqual(set(public_users), {(u1, room), (u2, room)})
-
+        self.assertEqual(in_public, {(u1, room), (u2, room)})
         # User 1 and User 3 share private rooms
-        self.assertEqual(
-            self.user_dir_helper._compress_shared(shares_private),
-            {(u1, u3, private_room), (u3, u1, private_room)},
-        )
-
+        self.assertEqual(in_private, {(u1, u3, private_room), (u3, u1, private_room)})
         # All three should have entries in the directory
-        users = self.get_success(self.user_dir_helper.get_users_in_user_directory())
         self.assertEqual(users, {u1, u2, u3})
 
     # The next four tests (test_population_excludes_*) all set up
@@ -289,16 +282,12 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
         self, normal_user: str, public_room: str, private_room: str
     ) -> None:
         # After rebuilding the directory, we should only see the normal user.
-        users = self.get_success(self.user_dir_helper.get_users_in_user_directory())
-        self.assertEqual(users, {normal_user})
-        in_public_rooms = self.get_success(
-            self.user_dir_helper.get_users_in_public_rooms()
+        users, in_public, in_private = self.get_success(
+            self.user_dir_helper.get_tables()
         )
-        self.assertEqual(set(in_public_rooms), {(normal_user, public_room)})
-        in_private_rooms = self.get_success(
-            self.user_dir_helper.get_users_who_share_private_rooms()
-        )
-        self.assertEqual(in_private_rooms, [])
+        self.assertEqual(users, {normal_user})
+        self.assertEqual(in_public, {(normal_user, public_room)})
+        self.assertEqual(in_private, set())
 
     def test_population_excludes_support_user(self) -> None:
         # Create a normal and support user.
diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py
index cf407c51cf..e2c506e5a4 100644
--- a/tests/test_event_auth.py
+++ b/tests/test_event_auth.py
@@ -24,6 +24,47 @@ from synapse.types import JsonDict, get_domain_from_id
 
 
 class EventAuthTestCase(unittest.TestCase):
+    def test_rejected_auth_events(self):
+        """
+        Events that refer to rejected events in their auth events are rejected
+        """
+        creator = "@creator:example.com"
+        auth_events = [
+            _create_event(creator),
+            _join_event(creator),
+        ]
+
+        # creator should be able to send state
+        event_auth.check_auth_rules_for_event(
+            RoomVersions.V9,
+            _random_state_event(creator),
+            auth_events,
+        )
+
+        # ... but a rejected join_rules event should cause it to be rejected
+        rejected_join_rules = _join_rules_event(creator, "public")
+        rejected_join_rules.rejected_reason = "stinky"
+        auth_events.append(rejected_join_rules)
+
+        self.assertRaises(
+            AuthError,
+            event_auth.check_auth_rules_for_event,
+            RoomVersions.V9,
+            _random_state_event(creator),
+            auth_events,
+        )
+
+        # ... even if there is *also* a good join rules
+        auth_events.append(_join_rules_event(creator, "public"))
+
+        self.assertRaises(
+            AuthError,
+            event_auth.check_auth_rules_for_event,
+            RoomVersions.V9,
+            _random_state_event(creator),
+            auth_events,
+        )
+
     def test_random_users_cannot_send_state_before_first_pl(self):
         """
         Check that, before the first PL lands, the creator is the only user
@@ -31,11 +72,11 @@ class EventAuthTestCase(unittest.TestCase):
         """
         creator = "@creator:example.com"
         joiner = "@joiner:example.com"
-        auth_events = {
-            ("m.room.create", ""): _create_event(creator),
-            ("m.room.member", creator): _join_event(creator),
-            ("m.room.member", joiner): _join_event(joiner),
-        }
+        auth_events = [
+            _create_event(creator),
+            _join_event(creator),
+            _join_event(joiner),
+        ]
 
         # creator should be able to send state
         event_auth.check_auth_rules_for_event(
@@ -62,15 +103,15 @@ class EventAuthTestCase(unittest.TestCase):
         pleb = "@joiner:example.com"
         king = "@joiner2:example.com"
 
-        auth_events = {
-            ("m.room.create", ""): _create_event(creator),
-            ("m.room.member", creator): _join_event(creator),
-            ("m.room.power_levels", ""): _power_levels_event(
+        auth_events = [
+            _create_event(creator),
+            _join_event(creator),
+            _power_levels_event(
                 creator, {"state_default": "30", "users": {pleb: "29", king: "30"}}
             ),
-            ("m.room.member", pleb): _join_event(pleb),
-            ("m.room.member", king): _join_event(king),
-        }
+            _join_event(pleb),
+            _join_event(king),
+        ]
 
         # pleb should not be able to send state
         self.assertRaises(
@@ -92,10 +133,10 @@ class EventAuthTestCase(unittest.TestCase):
         """Alias events have special behavior up through room version 6."""
         creator = "@creator:example.com"
         other = "@other:example.com"
-        auth_events = {
-            ("m.room.create", ""): _create_event(creator),
-            ("m.room.member", creator): _join_event(creator),
-        }
+        auth_events = [
+            _create_event(creator),
+            _join_event(creator),
+        ]
 
         # creator should be able to send aliases
         event_auth.check_auth_rules_for_event(
@@ -131,10 +172,10 @@ class EventAuthTestCase(unittest.TestCase):
         """After MSC2432, alias events have no special behavior."""
         creator = "@creator:example.com"
         other = "@other:example.com"
-        auth_events = {
-            ("m.room.create", ""): _create_event(creator),
-            ("m.room.member", creator): _join_event(creator),
-        }
+        auth_events = [
+            _create_event(creator),
+            _join_event(creator),
+        ]
 
         # creator should be able to send aliases
         event_auth.check_auth_rules_for_event(
@@ -170,14 +211,14 @@ class EventAuthTestCase(unittest.TestCase):
         creator = "@creator:example.com"
         pleb = "@joiner:example.com"
 
-        auth_events = {
-            ("m.room.create", ""): _create_event(creator),
-            ("m.room.member", creator): _join_event(creator),
-            ("m.room.power_levels", ""): _power_levels_event(
+        auth_events = [
+            _create_event(creator),
+            _join_event(creator),
+            _power_levels_event(
                 creator, {"state_default": "30", "users": {pleb: "30"}}
             ),
-            ("m.room.member", pleb): _join_event(pleb),
-        }
+            _join_event(pleb),
+        ]
 
         # pleb should be able to modify the notifications power level.
         event_auth.check_auth_rules_for_event(
@@ -211,7 +252,7 @@ class EventAuthTestCase(unittest.TestCase):
         event_auth.check_auth_rules_for_event(
             RoomVersions.V6,
             _join_event(pleb),
-            auth_events,
+            auth_events.values(),
         )
 
         # A user cannot be force-joined to a room.
@@ -219,7 +260,7 @@ class EventAuthTestCase(unittest.TestCase):
             event_auth.check_auth_rules_for_event(
                 RoomVersions.V6,
                 _member_event(pleb, "join", sender=creator),
-                auth_events,
+                auth_events.values(),
             )
 
         # Banned should be rejected.
@@ -228,7 +269,7 @@ class EventAuthTestCase(unittest.TestCase):
             event_auth.check_auth_rules_for_event(
                 RoomVersions.V6,
                 _join_event(pleb),
-                auth_events,
+                auth_events.values(),
             )
 
         # A user who left can re-join.
@@ -236,7 +277,7 @@ class EventAuthTestCase(unittest.TestCase):
         event_auth.check_auth_rules_for_event(
             RoomVersions.V6,
             _join_event(pleb),
-            auth_events,
+            auth_events.values(),
         )
 
         # A user can send a join if they're in the room.
@@ -244,7 +285,7 @@ class EventAuthTestCase(unittest.TestCase):
         event_auth.check_auth_rules_for_event(
             RoomVersions.V6,
             _join_event(pleb),
-            auth_events,
+            auth_events.values(),
         )
 
         # A user can accept an invite.
@@ -254,7 +295,7 @@ class EventAuthTestCase(unittest.TestCase):
         event_auth.check_auth_rules_for_event(
             RoomVersions.V6,
             _join_event(pleb),
-            auth_events,
+            auth_events.values(),
         )
 
     def test_join_rules_invite(self):
@@ -275,7 +316,7 @@ class EventAuthTestCase(unittest.TestCase):
             event_auth.check_auth_rules_for_event(
                 RoomVersions.V6,
                 _join_event(pleb),
-                auth_events,
+                auth_events.values(),
             )
 
         # A user cannot be force-joined to a room.
@@ -283,7 +324,7 @@ class EventAuthTestCase(unittest.TestCase):
             event_auth.check_auth_rules_for_event(
                 RoomVersions.V6,
                 _member_event(pleb, "join", sender=creator),
-                auth_events,
+                auth_events.values(),
             )
 
         # Banned should be rejected.
@@ -292,7 +333,7 @@ class EventAuthTestCase(unittest.TestCase):
             event_auth.check_auth_rules_for_event(
                 RoomVersions.V6,
                 _join_event(pleb),
-                auth_events,
+                auth_events.values(),
             )
 
         # A user who left cannot re-join.
@@ -301,7 +342,7 @@ class EventAuthTestCase(unittest.TestCase):
             event_auth.check_auth_rules_for_event(
                 RoomVersions.V6,
                 _join_event(pleb),
-                auth_events,
+                auth_events.values(),
             )
 
         # A user can send a join if they're in the room.
@@ -309,7 +350,7 @@ class EventAuthTestCase(unittest.TestCase):
         event_auth.check_auth_rules_for_event(
             RoomVersions.V6,
             _join_event(pleb),
-            auth_events,
+            auth_events.values(),
         )
 
         # A user can accept an invite.
@@ -319,7 +360,7 @@ class EventAuthTestCase(unittest.TestCase):
         event_auth.check_auth_rules_for_event(
             RoomVersions.V6,
             _join_event(pleb),
-            auth_events,
+            auth_events.values(),
         )
 
     def test_join_rules_msc3083_restricted(self):
@@ -347,7 +388,7 @@ class EventAuthTestCase(unittest.TestCase):
             event_auth.check_auth_rules_for_event(
                 RoomVersions.V6,
                 _join_event(pleb),
-                auth_events,
+                auth_events.values(),
             )
 
         # A properly formatted join event should work.
@@ -360,7 +401,7 @@ class EventAuthTestCase(unittest.TestCase):
         event_auth.check_auth_rules_for_event(
             RoomVersions.V8,
             authorised_join_event,
-            auth_events,
+            auth_events.values(),
         )
 
         # A join issued by a specific user works (i.e. the power level checks
@@ -380,7 +421,7 @@ class EventAuthTestCase(unittest.TestCase):
                     EventContentFields.AUTHORISING_USER: "@inviter:foo.test"
                 },
             ),
-            pl_auth_events,
+            pl_auth_events.values(),
         )
 
         # A join which is missing an authorised server is rejected.
@@ -388,7 +429,7 @@ class EventAuthTestCase(unittest.TestCase):
             event_auth.check_auth_rules_for_event(
                 RoomVersions.V8,
                 _join_event(pleb),
-                auth_events,
+                auth_events.values(),
             )
 
         # An join authorised by a user who is not in the room is rejected.
@@ -405,7 +446,7 @@ class EventAuthTestCase(unittest.TestCase):
                         EventContentFields.AUTHORISING_USER: "@other:example.com"
                     },
                 ),
-                auth_events,
+                auth_events.values(),
             )
 
         # A user cannot be force-joined to a room. (This uses an event which
@@ -421,7 +462,7 @@ class EventAuthTestCase(unittest.TestCase):
                         EventContentFields.AUTHORISING_USER: "@inviter:foo.test"
                     },
                 ),
-                auth_events,
+                auth_events.values(),
             )
 
         # Banned should be rejected.
@@ -430,7 +471,7 @@ class EventAuthTestCase(unittest.TestCase):
             event_auth.check_auth_rules_for_event(
                 RoomVersions.V8,
                 authorised_join_event,
-                auth_events,
+                auth_events.values(),
             )
 
         # A user who left can re-join.
@@ -438,7 +479,7 @@ class EventAuthTestCase(unittest.TestCase):
         event_auth.check_auth_rules_for_event(
             RoomVersions.V8,
             authorised_join_event,
-            auth_events,
+            auth_events.values(),
         )
 
         # A user can send a join if they're in the room. (This doesn't need to
@@ -447,7 +488,7 @@ class EventAuthTestCase(unittest.TestCase):
         event_auth.check_auth_rules_for_event(
             RoomVersions.V8,
             _join_event(pleb),
-            auth_events,
+            auth_events.values(),
         )
 
         # A user can accept an invite. (This doesn't need to be authorised since
@@ -458,7 +499,7 @@ class EventAuthTestCase(unittest.TestCase):
         event_auth.check_auth_rules_for_event(
             RoomVersions.V8,
             _join_event(pleb),
-            auth_events,
+            auth_events.values(),
         )
 
 
@@ -473,6 +514,7 @@ def _create_event(user_id: str) -> EventBase:
             "room_id": TEST_ROOM_ID,
             "event_id": _get_event_id(),
             "type": "m.room.create",
+            "state_key": "",
             "sender": user_id,
             "content": {"creator": user_id},
         }
diff --git a/tests/test_preview.py b/tests/test_preview.py
index 09e017b4d9..9a576f9a4e 100644
--- a/tests/test_preview.py
+++ b/tests/test_preview.py
@@ -15,7 +15,7 @@
 from synapse.rest.media.v1.preview_url_resource import (
     _calc_og,
     decode_body,
-    get_html_media_encoding,
+    get_html_media_encodings,
     summarize_paragraphs,
 )
 
@@ -159,7 +159,7 @@ class CalcOgTestCase(unittest.TestCase):
         </html>
         """
 
-        tree = decode_body(html)
+        tree = decode_body(html, "http://example.com/test.html")
         og = _calc_og(tree, "http://example.com/test.html")
 
         self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
@@ -175,7 +175,7 @@ class CalcOgTestCase(unittest.TestCase):
         </html>
         """
 
-        tree = decode_body(html)
+        tree = decode_body(html, "http://example.com/test.html")
         og = _calc_og(tree, "http://example.com/test.html")
 
         self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
@@ -194,7 +194,7 @@ class CalcOgTestCase(unittest.TestCase):
         </html>
         """
 
-        tree = decode_body(html)
+        tree = decode_body(html, "http://example.com/test.html")
         og = _calc_og(tree, "http://example.com/test.html")
 
         self.assertEqual(
@@ -216,7 +216,7 @@ class CalcOgTestCase(unittest.TestCase):
         </html>
         """
 
-        tree = decode_body(html)
+        tree = decode_body(html, "http://example.com/test.html")
         og = _calc_og(tree, "http://example.com/test.html")
 
         self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
@@ -230,7 +230,7 @@ class CalcOgTestCase(unittest.TestCase):
         </html>
         """
 
-        tree = decode_body(html)
+        tree = decode_body(html, "http://example.com/test.html")
         og = _calc_og(tree, "http://example.com/test.html")
 
         self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
@@ -245,7 +245,7 @@ class CalcOgTestCase(unittest.TestCase):
         </html>
         """
 
-        tree = decode_body(html)
+        tree = decode_body(html, "http://example.com/test.html")
         og = _calc_og(tree, "http://example.com/test.html")
 
         self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."})
@@ -260,7 +260,7 @@ class CalcOgTestCase(unittest.TestCase):
         </html>
         """
 
-        tree = decode_body(html)
+        tree = decode_body(html, "http://example.com/test.html")
         og = _calc_og(tree, "http://example.com/test.html")
 
         self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
@@ -268,13 +268,13 @@ class CalcOgTestCase(unittest.TestCase):
     def test_empty(self):
         """Test a body with no data in it."""
         html = b""
-        tree = decode_body(html)
+        tree = decode_body(html, "http://example.com/test.html")
         self.assertIsNone(tree)
 
     def test_no_tree(self):
         """A valid body with no tree in it."""
         html = b"\x00"
-        tree = decode_body(html)
+        tree = decode_body(html, "http://example.com/test.html")
         self.assertIsNone(tree)
 
     def test_invalid_encoding(self):
@@ -287,7 +287,7 @@ class CalcOgTestCase(unittest.TestCase):
         </body>
         </html>
         """
-        tree = decode_body(html, "invalid-encoding")
+        tree = decode_body(html, "http://example.com/test.html", "invalid-encoding")
         og = _calc_og(tree, "http://example.com/test.html")
         self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
 
@@ -302,15 +302,29 @@ class CalcOgTestCase(unittest.TestCase):
         </body>
         </html>
         """
-        tree = decode_body(html)
+        tree = decode_body(html, "http://example.com/test.html")
         og = _calc_og(tree, "http://example.com/test.html")
         self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."})
 
+    def test_windows_1252(self):
+        """A body which uses cp1252, but doesn't declare that."""
+        html = b"""
+        <html>
+        <head><title>\xf3</title></head>
+        <body>
+        Some text.
+        </body>
+        </html>
+        """
+        tree = decode_body(html, "http://example.com/test.html")
+        og = _calc_og(tree, "http://example.com/test.html")
+        self.assertEqual(og, {"og:title": "ó", "og:description": "Some text."})
+
 
 class MediaEncodingTestCase(unittest.TestCase):
     def test_meta_charset(self):
         """A character encoding is found via the meta tag."""
-        encoding = get_html_media_encoding(
+        encodings = get_html_media_encodings(
             b"""
         <html>
         <head><meta charset="ascii">
@@ -319,10 +333,10 @@ class MediaEncodingTestCase(unittest.TestCase):
         """,
             "text/html",
         )
-        self.assertEqual(encoding, "ascii")
+        self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"])
 
         # A less well-formed version.
-        encoding = get_html_media_encoding(
+        encodings = get_html_media_encodings(
             b"""
         <html>
         <head>< meta charset = ascii>
@@ -331,11 +345,11 @@ class MediaEncodingTestCase(unittest.TestCase):
         """,
             "text/html",
         )
-        self.assertEqual(encoding, "ascii")
+        self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"])
 
     def test_meta_charset_underscores(self):
         """A character encoding contains underscore."""
-        encoding = get_html_media_encoding(
+        encodings = get_html_media_encodings(
             b"""
         <html>
         <head><meta charset="Shift_JIS">
@@ -344,11 +358,11 @@ class MediaEncodingTestCase(unittest.TestCase):
         """,
             "text/html",
         )
-        self.assertEqual(encoding, "Shift_JIS")
+        self.assertEqual(list(encodings), ["shift_jis", "utf-8", "cp1252"])
 
     def test_xml_encoding(self):
         """A character encoding is found via the meta tag."""
-        encoding = get_html_media_encoding(
+        encodings = get_html_media_encodings(
             b"""
         <?xml version="1.0" encoding="ascii"?>
         <html>
@@ -356,11 +370,11 @@ class MediaEncodingTestCase(unittest.TestCase):
         """,
             "text/html",
         )
-        self.assertEqual(encoding, "ascii")
+        self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"])
 
     def test_meta_xml_encoding(self):
         """Meta tags take precedence over XML encoding."""
-        encoding = get_html_media_encoding(
+        encodings = get_html_media_encodings(
             b"""
         <?xml version="1.0" encoding="ascii"?>
         <html>
@@ -370,7 +384,7 @@ class MediaEncodingTestCase(unittest.TestCase):
         """,
             "text/html",
         )
-        self.assertEqual(encoding, "UTF-16")
+        self.assertEqual(list(encodings), ["utf-16", "ascii", "utf-8", "cp1252"])
 
     def test_content_type(self):
         """A character encoding is found via the Content-Type header."""
@@ -384,10 +398,37 @@ class MediaEncodingTestCase(unittest.TestCase):
             'text/html; charset=ascii";',
         )
         for header in headers:
-            encoding = get_html_media_encoding(b"", header)
-            self.assertEqual(encoding, "ascii")
+            encodings = get_html_media_encodings(b"", header)
+            self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"])
 
     def test_fallback(self):
         """A character encoding cannot be found in the body or header."""
-        encoding = get_html_media_encoding(b"", "text/html")
-        self.assertEqual(encoding, "utf-8")
+        encodings = get_html_media_encodings(b"", "text/html")
+        self.assertEqual(list(encodings), ["utf-8", "cp1252"])
+
+    def test_duplicates(self):
+        """Ensure each encoding is only attempted once."""
+        encodings = get_html_media_encodings(
+            b"""
+        <?xml version="1.0" encoding="utf8"?>
+        <html>
+        <head><meta charset="UTF-8">
+        </head>
+        </html>
+        """,
+            'text/html; charset="UTF_8"',
+        )
+        self.assertEqual(list(encodings), ["utf-8", "cp1252"])
+
+    def test_unknown_invalid(self):
+        """A character encoding should be ignored if it is unknown or invalid."""
+        encodings = get_html_media_encodings(
+            b"""
+        <html>
+        <head><meta charset="invalid">
+        </head>
+        </html>
+        """,
+            'text/html; charset="invalid"',
+        )
+        self.assertEqual(list(encodings), ["utf-8", "cp1252"])
diff --git a/tests/unittest.py b/tests/unittest.py
index 81c1a9e9d2..a9b60b7eeb 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -46,7 +46,7 @@ from synapse.logging.context import (
     set_current_context,
 )
 from synapse.server import HomeServer
-from synapse.types import UserID, create_requester
+from synapse.types import JsonDict, UserID, create_requester
 from synapse.util import Clock
 from synapse.util.httpresourcetree import create_resource_tree
 from synapse.util.ratelimitutils import FederationRateLimiter
@@ -401,7 +401,7 @@ class HomeserverTestCase(TestCase):
         self,
         method: Union[bytes, str],
         path: Union[bytes, str],
-        content: Union[bytes, dict] = b"",
+        content: Union[bytes, str, JsonDict] = b"",
         access_token: Optional[str] = None,
         request: Type[T] = SynapseRequest,
         shorthand: bool = True,