summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2024-05-29 13:39:34 +0100
committerErik Johnston <erik@matrix.org>2024-05-29 13:40:10 +0100
commit484845524fb3e946588cbceaa3e4f3e5785502ab (patch)
treed3a82b5f3ea902158ce2c8353c1b581235c8e775 /synapse
parentMerge branch 'rei/task_scheduler_better_logging' into matrix-org-hotfixes (diff)
parentMove towards using `MultiWriterIdGenerator` everywhere (#17226) (diff)
downloadsynapse-484845524fb3e946588cbceaa3e4f3e5785502ab.tar.xz
Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes
Diffstat (limited to 'synapse')
-rw-r--r--synapse/app/_base.py6
-rw-r--r--synapse/config/_base.pyi2
-rw-r--r--synapse/config/auto_accept_invites.py43
-rw-r--r--synapse/config/experimental.py7
-rw-r--r--synapse/config/homeserver.py2
-rw-r--r--synapse/events/auto_accept_invites.py196
-rw-r--r--synapse/handlers/devicemessage.py7
-rw-r--r--synapse/handlers/e2e_keys.py78
-rw-r--r--synapse/handlers/register.py2
-rw-r--r--synapse/handlers/relations.py2
-rw-r--r--synapse/handlers/sso.py2
-rw-r--r--synapse/handlers/sync.py247
-rw-r--r--synapse/media/media_repository.py11
-rw-r--r--synapse/media/media_storage.py102
-rw-r--r--synapse/media/thumbnailer.py486
-rw-r--r--synapse/media/url_previewer.py4
-rw-r--r--synapse/rest/client/media.py205
-rw-r--r--synapse/rest/client/sync.py171
-rw-r--r--synapse/rest/media/thumbnail_resource.py476
-rw-r--r--synapse/storage/database.py21
-rw-r--r--synapse/storage/databases/main/account_data.py47
-rw-r--r--synapse/storage/databases/main/cache.py18
-rw-r--r--synapse/storage/databases/main/deviceinbox.py46
-rw-r--r--synapse/storage/databases/main/devices.py3
-rw-r--r--synapse/storage/databases/main/events.py7
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py24
-rw-r--r--synapse/storage/databases/main/events_worker.py101
-rw-r--r--synapse/storage/databases/main/presence.py27
-rw-r--r--synapse/storage/databases/main/receipts.py43
-rw-r--r--synapse/storage/databases/main/relations.py2
-rw-r--r--synapse/storage/databases/main/room.py34
-rw-r--r--synapse/storage/engines/postgres.py8
-rw-r--r--synapse/storage/util/id_generators.py49
33 files changed, 1659 insertions, 820 deletions
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 3182608f73..67e0df1459 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -68,6 +68,7 @@ from synapse.config._base import format_config_error
 from synapse.config.homeserver import HomeServerConfig
 from synapse.config.server import ListenerConfig, ManholeConfig, TCPListenerConfig
 from synapse.crypto import context_factory
+from synapse.events.auto_accept_invites import InviteAutoAccepter
 from synapse.events.presence_router import load_legacy_presence_router
 from synapse.handlers.auth import load_legacy_password_auth_providers
 from synapse.http.site import SynapseSite
@@ -582,6 +583,11 @@ async def start(hs: "HomeServer") -> None:
         m = module(config, module_api)
         logger.info("Loaded module %s", m)
 
+    if hs.config.auto_accept_invites.enabled:
+        # Start the local auto_accept_invites module.
+        m = InviteAutoAccepter(hs.config.auto_accept_invites, module_api)
+        logger.info("Loaded local module %s", m)
+
     load_legacy_spam_checkers(hs)
     load_legacy_third_party_event_rules(hs)
     load_legacy_presence_router(hs)
diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index fc51aed234..d9cb0da38b 100644
--- a/synapse/config/_base.pyi
+++ b/synapse/config/_base.pyi
@@ -23,6 +23,7 @@ from synapse.config import (  # noqa: F401
     api,
     appservice,
     auth,
+    auto_accept_invites,
     background_updates,
     cache,
     captcha,
@@ -120,6 +121,7 @@ class RootConfig:
     federation: federation.FederationConfig
     retention: retention.RetentionConfig
     background_updates: background_updates.BackgroundUpdateConfig
+    auto_accept_invites: auto_accept_invites.AutoAcceptInvitesConfig
 
     config_classes: List[Type["Config"]] = ...
     config_files: List[str]
diff --git a/synapse/config/auto_accept_invites.py b/synapse/config/auto_accept_invites.py
new file mode 100644
index 0000000000..d90e13a510
--- /dev/null
+++ b/synapse/config/auto_accept_invites.py
@@ -0,0 +1,43 @@
+#
+# This file is licensed under the Affero General Public License (AGPL) version 3.
+#
+# Copyright (C) 2024 New Vector, Ltd
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# See the GNU Affero General Public License for more details:
+# <https://www.gnu.org/licenses/agpl-3.0.html>.
+#
+# Originally licensed under the Apache License, Version 2.0:
+# <http://www.apache.org/licenses/LICENSE-2.0>.
+#
+# [This file includes modifications made by New Vector Limited]
+#
+#
+from typing import Any
+
+from synapse.types import JsonDict
+
+from ._base import Config
+
+
+class AutoAcceptInvitesConfig(Config):
+    section = "auto_accept_invites"
+
+    def read_config(self, config: JsonDict, **kwargs: Any) -> None:
+        auto_accept_invites_config = config.get("auto_accept_invites") or {}
+
+        self.enabled = auto_accept_invites_config.get("enabled", False)
+
+        self.accept_invites_only_for_direct_messages = auto_accept_invites_config.get(
+            "only_for_direct_messages", False
+        )
+
+        self.accept_invites_only_from_local_users = auto_accept_invites_config.get(
+            "only_from_local_users", False
+        )
+
+        self.worker_to_run_on = auto_accept_invites_config.get("worker_to_run_on")
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index 749452ce93..75fe6d7b24 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -332,6 +332,9 @@ class ExperimentalConfig(Config):
         # MSC3391: Removing account data.
         self.msc3391_enabled = experimental.get("msc3391_enabled", False)
 
+        # MSC3575 (Sliding Sync API endpoints)
+        self.msc3575_enabled: bool = experimental.get("msc3575_enabled", False)
+
         # MSC3773: Thread notifications
         self.msc3773_enabled: bool = experimental.get("msc3773_enabled", False)
 
@@ -436,3 +439,7 @@ class ExperimentalConfig(Config):
         self.msc4115_membership_on_events = experimental.get(
             "msc4115_membership_on_events", False
         )
+
+        self.msc3916_authenticated_media_enabled = experimental.get(
+            "msc3916_authenticated_media_enabled", False
+        )
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index 72e93ed04f..e36c0bd6ae 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -23,6 +23,7 @@ from .account_validity import AccountValidityConfig
 from .api import ApiConfig
 from .appservice import AppServiceConfig
 from .auth import AuthConfig
+from .auto_accept_invites import AutoAcceptInvitesConfig
 from .background_updates import BackgroundUpdateConfig
 from .cache import CacheConfig
 from .captcha import CaptchaConfig
@@ -105,4 +106,5 @@ class HomeServerConfig(RootConfig):
         RedisConfig,
         ExperimentalConfig,
         BackgroundUpdateConfig,
+        AutoAcceptInvitesConfig,
     ]
diff --git a/synapse/events/auto_accept_invites.py b/synapse/events/auto_accept_invites.py
new file mode 100644
index 0000000000..d88ec51d9d
--- /dev/null
+++ b/synapse/events/auto_accept_invites.py
@@ -0,0 +1,196 @@
+#
+# This file is licensed under the Affero General Public License (AGPL) version 3.
+#
+# Copyright 2021 The Matrix.org Foundation C.I.C
+# Copyright (C) 2024 New Vector, Ltd
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# See the GNU Affero General Public License for more details:
+# <https://www.gnu.org/licenses/agpl-3.0.html>.
+#
+# Originally licensed under the Apache License, Version 2.0:
+# <http://www.apache.org/licenses/LICENSE-2.0>.
+#
+# [This file includes modifications made by New Vector Limited]
+#
+#
+import logging
+from http import HTTPStatus
+from typing import Any, Dict, Tuple
+
+from synapse.api.constants import AccountDataTypes, EventTypes, Membership
+from synapse.api.errors import SynapseError
+from synapse.config.auto_accept_invites import AutoAcceptInvitesConfig
+from synapse.module_api import EventBase, ModuleApi, run_as_background_process
+
+logger = logging.getLogger(__name__)
+
+
+class InviteAutoAccepter:
+    def __init__(self, config: AutoAcceptInvitesConfig, api: ModuleApi):
+        # Keep a reference to the Module API.
+        self._api = api
+        self._config = config
+
+        if not self._config.enabled:
+            return
+
+        should_run_on_this_worker = config.worker_to_run_on == self._api.worker_name
+
+        if not should_run_on_this_worker:
+            logger.info(
+                "Not accepting invites on this worker (configured: %r, here: %r)",
+                config.worker_to_run_on,
+                self._api.worker_name,
+            )
+            return
+
+        logger.info(
+            "Accepting invites on this worker (here: %r)", self._api.worker_name
+        )
+
+        # Register the callback.
+        self._api.register_third_party_rules_callbacks(
+            on_new_event=self.on_new_event,
+        )
+
+    async def on_new_event(self, event: EventBase, *args: Any) -> None:
+        """Listens for new events, and if the event is an invite for a local user then
+        automatically accepts it.
+
+        Args:
+            event: The incoming event.
+        """
+        # Check if the event is an invite for a local user.
+        is_invite_for_local_user = (
+            event.type == EventTypes.Member
+            and event.is_state()
+            and event.membership == Membership.INVITE
+            and self._api.is_mine(event.state_key)
+        )
+
+        # Only accept invites for direct messages if the configuration mandates it.
+        is_direct_message = event.content.get("is_direct", False)
+        is_allowed_by_direct_message_rules = (
+            not self._config.accept_invites_only_for_direct_messages
+            or is_direct_message is True
+        )
+
+        # Only accept invites from remote users if the configuration mandates it.
+        is_from_local_user = self._api.is_mine(event.sender)
+        is_allowed_by_local_user_rules = (
+            not self._config.accept_invites_only_from_local_users
+            or is_from_local_user is True
+        )
+
+        if (
+            is_invite_for_local_user
+            and is_allowed_by_direct_message_rules
+            and is_allowed_by_local_user_rules
+        ):
+            # Make the user join the room. We run this as a background process to circumvent a race condition
+            # that occurs when responding to invites over federation (see https://github.com/matrix-org/synapse-auto-accept-invite/issues/12)
+            run_as_background_process(
+                "retry_make_join",
+                self._retry_make_join,
+                event.state_key,
+                event.state_key,
+                event.room_id,
+                "join",
+                bg_start_span=False,
+            )
+
+            if is_direct_message:
+                # Mark this room as a direct message!
+                await self._mark_room_as_direct_message(
+                    event.state_key, event.sender, event.room_id
+                )
+
+    async def _mark_room_as_direct_message(
+        self, user_id: str, dm_user_id: str, room_id: str
+    ) -> None:
+        """
+        Marks a room (`room_id`) as a direct message with the counterparty `dm_user_id`
+        from the perspective of the user `user_id`.
+
+        Args:
+            user_id: the user for whom the membership is changing
+            dm_user_id: the user performing the membership change
+            room_id: room id of the room the user is invited to
+        """
+
+        # This is a dict of User IDs to tuples of Room IDs
+        # (get_global will return a frozendict of tuples as it freezes the data,
+        # but we should accept either frozen or unfrozen variants.)
+        # Be careful: we convert the outer frozendict into a dict here,
+        # but the contents of the dict are still frozen (tuples in lieu of lists,
+        # etc.)
+        dm_map: Dict[str, Tuple[str, ...]] = dict(
+            await self._api.account_data_manager.get_global(
+                user_id, AccountDataTypes.DIRECT
+            )
+            or {}
+        )
+
+        if dm_user_id not in dm_map:
+            dm_map[dm_user_id] = (room_id,)
+        else:
+            dm_rooms_for_user = dm_map[dm_user_id]
+            assert isinstance(dm_rooms_for_user, (tuple, list))
+
+            dm_map[dm_user_id] = tuple(dm_rooms_for_user) + (room_id,)
+
+        await self._api.account_data_manager.put_global(
+            user_id, AccountDataTypes.DIRECT, dm_map
+        )
+
+    async def _retry_make_join(
+        self, sender: str, target: str, room_id: str, new_membership: str
+    ) -> None:
+        """
+        A function to retry sending the `make_join` request with an increasing backoff. This is
+        implemented to work around a race condition when receiving invites over federation.
+
+        Args:
+            sender: the user performing the membership change
+            target: the user for whom the membership is changing
+            room_id: room id of the room to join to
+            new_membership: the type of membership event (in this case will be "join")
+        """
+
+        sleep = 0
+        retries = 0
+        join_event = None
+
+        while retries < 5:
+            try:
+                await self._api.sleep(sleep)
+                join_event = await self._api.update_room_membership(
+                    sender=sender,
+                    target=target,
+                    room_id=room_id,
+                    new_membership=new_membership,
+                )
+            except SynapseError as e:
+                if e.code == HTTPStatus.FORBIDDEN:
+                    logger.debug(
+                        f"Update_room_membership was forbidden. This can sometimes be expected for remote invites. Exception: {e}"
+                    )
+                else:
+                    logger.warn(
+                        f"Update_room_membership raised the following unexpected (SynapseError) exception: {e}"
+                    )
+            except Exception as e:
+                logger.warn(
+                    f"Update_room_membership raised the following unexpected exception: {e}"
+                )
+
+            sleep = 2**retries
+            retries += 1
+
+            if join_event is not None:
+                break
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 79be7c97c8..e56bdb4072 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -236,6 +236,13 @@ class DeviceMessageHandler:
         local_messages = {}
         remote_messages: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
         for user_id, by_device in messages.items():
+            if not UserID.is_valid(user_id):
+                logger.warning(
+                    "Ignoring attempt to send device message to invalid user: %r",
+                    user_id,
+                )
+                continue
+
             # add an opentracing log entry for each message
             for device_id, message_content in by_device.items():
                 log_kv(
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 1ece54ccfc..4f40e9ffd6 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -53,6 +53,9 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
+ONE_TIME_KEY_UPLOAD = "one_time_key_upload_lock"
+
+
 class E2eKeysHandler:
     def __init__(self, hs: "HomeServer"):
         self.config = hs.config
@@ -62,6 +65,7 @@ class E2eKeysHandler:
         self._appservice_handler = hs.get_application_service_handler()
         self.is_mine = hs.is_mine
         self.clock = hs.get_clock()
+        self._worker_lock_handler = hs.get_worker_locks_handler()
 
         federation_registry = hs.get_federation_registry()
 
@@ -855,45 +859,53 @@ class E2eKeysHandler:
     async def _upload_one_time_keys_for_user(
         self, user_id: str, device_id: str, time_now: int, one_time_keys: JsonDict
     ) -> None:
-        logger.info(
-            "Adding one_time_keys %r for device %r for user %r at %d",
-            one_time_keys.keys(),
-            device_id,
-            user_id,
-            time_now,
-        )
+        # We take out a lock so that we don't have to worry about a client
+        # sending duplicate requests.
+        lock_key = f"{user_id}_{device_id}"
+        async with self._worker_lock_handler.acquire_lock(
+            ONE_TIME_KEY_UPLOAD, lock_key
+        ):
+            logger.info(
+                "Adding one_time_keys %r for device %r for user %r at %d",
+                one_time_keys.keys(),
+                device_id,
+                user_id,
+                time_now,
+            )
 
-        # make a list of (alg, id, key) tuples
-        key_list = []
-        for key_id, key_obj in one_time_keys.items():
-            algorithm, key_id = key_id.split(":")
-            key_list.append((algorithm, key_id, key_obj))
+            # make a list of (alg, id, key) tuples
+            key_list = []
+            for key_id, key_obj in one_time_keys.items():
+                algorithm, key_id = key_id.split(":")
+                key_list.append((algorithm, key_id, key_obj))
 
-        # First we check if we have already persisted any of the keys.
-        existing_key_map = await self.store.get_e2e_one_time_keys(
-            user_id, device_id, [k_id for _, k_id, _ in key_list]
-        )
+            # First we check if we have already persisted any of the keys.
+            existing_key_map = await self.store.get_e2e_one_time_keys(
+                user_id, device_id, [k_id for _, k_id, _ in key_list]
+            )
 
-        new_keys = []  # Keys that we need to insert. (alg, id, json) tuples.
-        for algorithm, key_id, key in key_list:
-            ex_json = existing_key_map.get((algorithm, key_id), None)
-            if ex_json:
-                if not _one_time_keys_match(ex_json, key):
-                    raise SynapseError(
-                        400,
-                        (
-                            "One time key %s:%s already exists. "
-                            "Old key: %s; new key: %r"
+            new_keys = []  # Keys that we need to insert. (alg, id, json) tuples.
+            for algorithm, key_id, key in key_list:
+                ex_json = existing_key_map.get((algorithm, key_id), None)
+                if ex_json:
+                    if not _one_time_keys_match(ex_json, key):
+                        raise SynapseError(
+                            400,
+                            (
+                                "One time key %s:%s already exists. "
+                                "Old key: %s; new key: %r"
+                            )
+                            % (algorithm, key_id, ex_json, key),
                         )
-                        % (algorithm, key_id, ex_json, key),
+                else:
+                    new_keys.append(
+                        (algorithm, key_id, encode_canonical_json(key).decode("ascii"))
                     )
-            else:
-                new_keys.append(
-                    (algorithm, key_id, encode_canonical_json(key).decode("ascii"))
-                )
 
-        log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys})
-        await self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys)
+            log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys})
+            await self.store.add_e2e_one_time_keys(
+                user_id, device_id, time_now, new_keys
+            )
 
     async def upload_signing_keys_for_user(
         self, user_id: str, keys: JsonDict
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index e48e70db04..c200e29569 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -590,7 +590,7 @@ class RegistrationHandler:
                 # moving away from bare excepts is a good thing to do.
                 logger.error("Failed to join new user to %r: %r", r, e)
             except Exception as e:
-                logger.error("Failed to join new user to %r: %r", r, e)
+                logger.error("Failed to join new user to %r: %r", r, e, exc_info=True)
 
     async def _auto_join_rooms(self, user_id: str) -> None:
         """Automatically joins users to auto join rooms - creating the room in the first place
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index c5cee8860b..de092f8623 100644
--- a/synapse/handlers/relations.py
+++ b/synapse/handlers/relations.py
@@ -393,9 +393,9 @@ class RelationsHandler:
 
                 # Attempt to find another event to use as the latest event.
                 potential_events, _ = await self._main_store.get_relations_for_event(
+                    room_id,
                     event_id,
                     event,
-                    room_id,
                     RelationTypes.THREAD,
                     direction=Direction.FORWARDS,
                 )
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index f275d4f35a..ee74289b6c 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -817,7 +817,7 @@ class SsoHandler:
                 server_name = profile["avatar_url"].split("/")[-2]
                 media_id = profile["avatar_url"].split("/")[-1]
                 if self._is_mine_server_name(server_name):
-                    media = await self._media_repo.store.get_local_media(media_id)
+                    media = await self._media_repo.store.get_local_media(media_id)  # type: ignore[has-type]
                     if media is not None and upload_name == media.upload_name:
                         logger.info("skipping saving the user avatar")
                         return True
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index fa634d65c7..1d7d9dfdd0 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -28,11 +28,14 @@ from typing import (
     Dict,
     FrozenSet,
     List,
+    Literal,
     Mapping,
     Optional,
     Sequence,
     Set,
     Tuple,
+    Union,
+    overload,
 )
 
 import attr
@@ -128,6 +131,8 @@ class SyncVersion(Enum):
 
     # Traditional `/sync` endpoint
     SYNC_V2 = "sync_v2"
+    # Part of MSC3575 Sliding Sync
+    E2EE_SYNC = "e2ee_sync"
 
 
 @attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -297,6 +302,26 @@ class SyncResult:
         )
 
 
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class E2eeSyncResult:
+    """
+    Attributes:
+        next_batch: Token for the next sync
+        to_device: List of direct messages for the device.
+        device_lists: List of user_ids whose devices have changed
+        device_one_time_keys_count: Dict of algorithm to count for one time keys
+            for this device
+        device_unused_fallback_key_types: List of key types that have an unused fallback
+            key
+    """
+
+    next_batch: StreamToken
+    to_device: List[JsonDict]
+    device_lists: DeviceListUpdates
+    device_one_time_keys_count: JsonMapping
+    device_unused_fallback_key_types: List[str]
+
+
 class SyncHandler:
     def __init__(self, hs: "HomeServer"):
         self.hs_config = hs.config
@@ -339,6 +364,31 @@ class SyncHandler:
 
         self.rooms_to_exclude_globally = hs.config.server.rooms_to_exclude_from_sync
 
+    @overload
+    async def wait_for_sync_for_user(
+        self,
+        requester: Requester,
+        sync_config: SyncConfig,
+        sync_version: Literal[SyncVersion.SYNC_V2],
+        request_key: SyncRequestKey,
+        since_token: Optional[StreamToken] = None,
+        timeout: int = 0,
+        full_state: bool = False,
+    ) -> SyncResult: ...
+
+    @overload
+    async def wait_for_sync_for_user(
+        self,
+        requester: Requester,
+        sync_config: SyncConfig,
+        sync_version: Literal[SyncVersion.E2EE_SYNC],
+        request_key: SyncRequestKey,
+        since_token: Optional[StreamToken] = None,
+        timeout: int = 0,
+        full_state: bool = False,
+    ) -> E2eeSyncResult: ...
+
+    @overload
     async def wait_for_sync_for_user(
         self,
         requester: Requester,
@@ -348,7 +398,18 @@ class SyncHandler:
         since_token: Optional[StreamToken] = None,
         timeout: int = 0,
         full_state: bool = False,
-    ) -> SyncResult:
+    ) -> Union[SyncResult, E2eeSyncResult]: ...
+
+    async def wait_for_sync_for_user(
+        self,
+        requester: Requester,
+        sync_config: SyncConfig,
+        sync_version: SyncVersion,
+        request_key: SyncRequestKey,
+        since_token: Optional[StreamToken] = None,
+        timeout: int = 0,
+        full_state: bool = False,
+    ) -> Union[SyncResult, E2eeSyncResult]:
         """Get the sync for a client if we have new data for it now. Otherwise
         wait for new data to arrive on the server. If the timeout expires, then
         return an empty sync result.
@@ -361,8 +422,10 @@ class SyncHandler:
             since_token: The point in the stream to sync from.
             timeout: How long to wait for new data to arrive before giving up.
             full_state: Whether to return the full state for each room.
+
         Returns:
             When `SyncVersion.SYNC_V2`, returns a full `SyncResult`.
+            When `SyncVersion.E2EE_SYNC`, returns a `E2eeSyncResult`.
         """
         # If the user is not part of the mau group, then check that limits have
         # not been exceeded (if not part of the group by this point, almost certain
@@ -383,6 +446,29 @@ class SyncHandler:
         logger.debug("Returning sync response for %s", user_id)
         return res
 
+    @overload
+    async def _wait_for_sync_for_user(
+        self,
+        sync_config: SyncConfig,
+        sync_version: Literal[SyncVersion.SYNC_V2],
+        since_token: Optional[StreamToken],
+        timeout: int,
+        full_state: bool,
+        cache_context: ResponseCacheContext[SyncRequestKey],
+    ) -> SyncResult: ...
+
+    @overload
+    async def _wait_for_sync_for_user(
+        self,
+        sync_config: SyncConfig,
+        sync_version: Literal[SyncVersion.E2EE_SYNC],
+        since_token: Optional[StreamToken],
+        timeout: int,
+        full_state: bool,
+        cache_context: ResponseCacheContext[SyncRequestKey],
+    ) -> E2eeSyncResult: ...
+
+    @overload
     async def _wait_for_sync_for_user(
         self,
         sync_config: SyncConfig,
@@ -391,7 +477,17 @@ class SyncHandler:
         timeout: int,
         full_state: bool,
         cache_context: ResponseCacheContext[SyncRequestKey],
-    ) -> SyncResult:
+    ) -> Union[SyncResult, E2eeSyncResult]: ...
+
+    async def _wait_for_sync_for_user(
+        self,
+        sync_config: SyncConfig,
+        sync_version: SyncVersion,
+        since_token: Optional[StreamToken],
+        timeout: int,
+        full_state: bool,
+        cache_context: ResponseCacheContext[SyncRequestKey],
+    ) -> Union[SyncResult, E2eeSyncResult]:
         """The start of the machinery that produces a /sync response.
 
         See https://spec.matrix.org/v1.1/client-server-api/#syncing for full details.
@@ -452,14 +548,16 @@ class SyncHandler:
         if timeout == 0 or since_token is None or full_state:
             # we are going to return immediately, so don't bother calling
             # notifier.wait_for_events.
-            result: SyncResult = await self.current_sync_for_user(
-                sync_config, sync_version, since_token, full_state=full_state
+            result: Union[SyncResult, E2eeSyncResult] = (
+                await self.current_sync_for_user(
+                    sync_config, sync_version, since_token, full_state=full_state
+                )
             )
         else:
             # Otherwise, we wait for something to happen and report it to the user.
             async def current_sync_callback(
                 before_token: StreamToken, after_token: StreamToken
-            ) -> SyncResult:
+            ) -> Union[SyncResult, E2eeSyncResult]:
                 return await self.current_sync_for_user(
                     sync_config, sync_version, since_token
                 )
@@ -491,14 +589,43 @@ class SyncHandler:
 
         return result
 
+    @overload
+    async def current_sync_for_user(
+        self,
+        sync_config: SyncConfig,
+        sync_version: Literal[SyncVersion.SYNC_V2],
+        since_token: Optional[StreamToken] = None,
+        full_state: bool = False,
+    ) -> SyncResult: ...
+
+    @overload
+    async def current_sync_for_user(
+        self,
+        sync_config: SyncConfig,
+        sync_version: Literal[SyncVersion.E2EE_SYNC],
+        since_token: Optional[StreamToken] = None,
+        full_state: bool = False,
+    ) -> E2eeSyncResult: ...
+
+    @overload
     async def current_sync_for_user(
         self,
         sync_config: SyncConfig,
         sync_version: SyncVersion,
         since_token: Optional[StreamToken] = None,
         full_state: bool = False,
-    ) -> SyncResult:
-        """Generates the response body of a sync result, represented as a SyncResult.
+    ) -> Union[SyncResult, E2eeSyncResult]: ...
+
+    async def current_sync_for_user(
+        self,
+        sync_config: SyncConfig,
+        sync_version: SyncVersion,
+        since_token: Optional[StreamToken] = None,
+        full_state: bool = False,
+    ) -> Union[SyncResult, E2eeSyncResult]:
+        """
+        Generates the response body of a sync result, represented as a
+        `SyncResult`/`E2eeSyncResult`.
 
         This is a wrapper around `generate_sync_result` which starts an open tracing
         span to track the sync. See `generate_sync_result` for the next part of your
@@ -509,15 +636,25 @@ class SyncHandler:
             sync_version: Determines what kind of sync response to generate.
             since_token: The point in the stream to sync from.p.
             full_state: Whether to return the full state for each room.
+
         Returns:
             When `SyncVersion.SYNC_V2`, returns a full `SyncResult`.
+            When `SyncVersion.E2EE_SYNC`, returns a `E2eeSyncResult`.
         """
         with start_active_span("sync.current_sync_for_user"):
             log_kv({"since_token": since_token})
+
             # Go through the `/sync` v2 path
             if sync_version == SyncVersion.SYNC_V2:
-                sync_result: SyncResult = await self.generate_sync_result(
-                    sync_config, since_token, full_state
+                sync_result: Union[SyncResult, E2eeSyncResult] = (
+                    await self.generate_sync_result(
+                        sync_config, since_token, full_state
+                    )
+                )
+            # Go through the MSC3575 Sliding Sync `/sync/e2ee` path
+            elif sync_version == SyncVersion.E2EE_SYNC:
+                sync_result = await self.generate_e2ee_sync_result(
+                    sync_config, since_token
                 )
             else:
                 raise Exception(
@@ -1726,6 +1863,96 @@ class SyncHandler:
             next_batch=sync_result_builder.now_token,
         )
 
+    async def generate_e2ee_sync_result(
+        self,
+        sync_config: SyncConfig,
+        since_token: Optional[StreamToken] = None,
+    ) -> E2eeSyncResult:
+        """
+        Generates the response body of a MSC3575 Sliding Sync `/sync/e2ee` result.
+
+        This is represented by a `E2eeSyncResult` struct, which is built from small
+        pieces using a `SyncResultBuilder`. The `sync_result_builder` is passed as a
+        mutable ("inout") parameter to various helper functions. These retrieve and
+        process the data which forms the sync body, often writing to the
+        `sync_result_builder` to store their output.
+
+        At the end, we transfer data from the `sync_result_builder` to a new `E2eeSyncResult`
+        instance to signify that the sync calculation is complete.
+        """
+        user_id = sync_config.user.to_string()
+        app_service = self.store.get_app_service_by_user_id(user_id)
+        if app_service:
+            # We no longer support AS users using /sync directly.
+            # See https://github.com/matrix-org/matrix-doc/issues/1144
+            raise NotImplementedError()
+
+        sync_result_builder = await self.get_sync_result_builder(
+            sync_config,
+            since_token,
+            full_state=False,
+        )
+
+        # 1. Calculate `to_device` events
+        await self._generate_sync_entry_for_to_device(sync_result_builder)
+
+        # 2. Calculate `device_lists`
+        # Device list updates are sent if a since token is provided.
+        device_lists = DeviceListUpdates()
+        include_device_list_updates = bool(since_token and since_token.device_list_key)
+        if include_device_list_updates:
+            # Note that _generate_sync_entry_for_rooms sets sync_result_builder.joined, which
+            # is used in calculate_user_changes below.
+            #
+            # TODO: Running `_generate_sync_entry_for_rooms()` is a lot of work just to
+            # figure out the membership changes/derived info needed for
+            # `_generate_sync_entry_for_device_list()`. In the future, we should try to
+            # refactor this away.
+            (
+                newly_joined_rooms,
+                newly_left_rooms,
+            ) = await self._generate_sync_entry_for_rooms(sync_result_builder)
+
+            # This uses the sync_result_builder.joined which is set in
+            # `_generate_sync_entry_for_rooms`, if that didn't find any joined
+            # rooms for some reason it is a no-op.
+            (
+                newly_joined_or_invited_or_knocked_users,
+                newly_left_users,
+            ) = sync_result_builder.calculate_user_changes()
+
+            device_lists = await self._generate_sync_entry_for_device_list(
+                sync_result_builder,
+                newly_joined_rooms=newly_joined_rooms,
+                newly_joined_or_invited_or_knocked_users=newly_joined_or_invited_or_knocked_users,
+                newly_left_rooms=newly_left_rooms,
+                newly_left_users=newly_left_users,
+            )
+
+        # 3. Calculate `device_one_time_keys_count` and `device_unused_fallback_key_types`
+        device_id = sync_config.device_id
+        one_time_keys_count: JsonMapping = {}
+        unused_fallback_key_types: List[str] = []
+        if device_id:
+            # TODO: We should have a way to let clients differentiate between the states of:
+            #   * no change in OTK count since the provided since token
+            #   * the server has zero OTKs left for this device
+            #  Spec issue: https://github.com/matrix-org/matrix-doc/issues/3298
+            one_time_keys_count = await self.store.count_e2e_one_time_keys(
+                user_id, device_id
+            )
+            unused_fallback_key_types = list(
+                await self.store.get_e2e_unused_fallback_key_types(user_id, device_id)
+            )
+
+        return E2eeSyncResult(
+            to_device=sync_result_builder.to_device,
+            device_lists=device_lists,
+            device_one_time_keys_count=one_time_keys_count,
+            device_unused_fallback_key_types=unused_fallback_key_types,
+            next_batch=sync_result_builder.now_token,
+        )
+
     async def get_sync_result_builder(
         self,
         sync_config: SyncConfig,
@@ -1924,7 +2151,7 @@ class SyncHandler:
         users_that_have_changed = (
             await self._device_handler.get_device_changes_in_shared_rooms(
                 user_id,
-                sync_result_builder.joined_room_ids,
+                joined_room_ids,
                 from_token=since_token,
                 now_token=sync_result_builder.now_token,
             )
diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py
index 0e875132f6..9da8495950 100644
--- a/synapse/media/media_repository.py
+++ b/synapse/media/media_repository.py
@@ -650,7 +650,7 @@ class MediaRepository:
 
         file_info = FileInfo(server_name=server_name, file_id=file_id)
 
-        with self.media_storage.store_into_file(file_info) as (f, fname, finish):
+        async with self.media_storage.store_into_file(file_info) as (f, fname):
             try:
                 length, headers = await self.client.download_media(
                     server_name,
@@ -693,8 +693,6 @@ class MediaRepository:
                 )
                 raise SynapseError(502, "Failed to fetch remote media")
 
-            await finish()
-
             if b"Content-Type" in headers:
                 media_type = headers[b"Content-Type"][0].decode("ascii")
             else:
@@ -1045,14 +1043,9 @@ class MediaRepository:
                     ),
                 )
 
-                with self.media_storage.store_into_file(file_info) as (
-                    f,
-                    fname,
-                    finish,
-                ):
+                async with self.media_storage.store_into_file(file_info) as (f, fname):
                     try:
                         await self.media_storage.write_to_file(t_byte_source, f)
-                        await finish()
                     finally:
                         t_byte_source.close()
 
diff --git a/synapse/media/media_storage.py b/synapse/media/media_storage.py
index b45b319f5c..9979c48eac 100644
--- a/synapse/media/media_storage.py
+++ b/synapse/media/media_storage.py
@@ -27,10 +27,9 @@ from typing import (
     IO,
     TYPE_CHECKING,
     Any,
-    Awaitable,
+    AsyncIterator,
     BinaryIO,
     Callable,
-    Generator,
     Optional,
     Sequence,
     Tuple,
@@ -97,11 +96,9 @@ class MediaStorage:
             the file path written to in the primary media store
         """
 
-        with self.store_into_file(file_info) as (f, fname, finish_cb):
+        async with self.store_into_file(file_info) as (f, fname):
             # Write to the main media repository
             await self.write_to_file(source, f)
-            # Write to the other storage providers
-            await finish_cb()
 
         return fname
 
@@ -111,32 +108,27 @@ class MediaStorage:
         await defer_to_thread(self.reactor, _write_file_synchronously, source, output)
 
     @trace_with_opname("MediaStorage.store_into_file")
-    @contextlib.contextmanager
-    def store_into_file(
+    @contextlib.asynccontextmanager
+    async def store_into_file(
         self, file_info: FileInfo
-    ) -> Generator[Tuple[BinaryIO, str, Callable[[], Awaitable[None]]], None, None]:
-        """Context manager used to get a file like object to write into, as
+    ) -> AsyncIterator[Tuple[BinaryIO, str]]:
+        """Async Context manager used to get a file like object to write into, as
         described by file_info.
 
-        Actually yields a 3-tuple (file, fname, finish_cb), where file is a file
-        like object that can be written to, fname is the absolute path of file
-        on disk, and finish_cb is a function that returns an awaitable.
+        Actually yields a 2-tuple (file, fname,), where file is a file
+        like object that can be written to and fname is the absolute path of file
+        on disk.
 
         fname can be used to read the contents from after upload, e.g. to
         generate thumbnails.
 
-        finish_cb must be called and waited on after the file has been successfully been
-        written to. Should not be called if there was an error. Checks for spam and
-        stores the file into the configured storage providers.
-
         Args:
             file_info: Info about the file to store
 
         Example:
 
-            with media_storage.store_into_file(info) as (f, fname, finish_cb):
+            async with media_storage.store_into_file(info) as (f, fname,):
                 # .. write into f ...
-                await finish_cb()
         """
 
         path = self._file_info_to_path(file_info)
@@ -145,62 +137,42 @@ class MediaStorage:
         dirname = os.path.dirname(fname)
         os.makedirs(dirname, exist_ok=True)
 
-        finished_called = [False]
-
         main_media_repo_write_trace_scope = start_active_span(
             "writing to main media repo"
         )
         main_media_repo_write_trace_scope.__enter__()
 
-        try:
-            with open(fname, "wb") as f:
-
-                async def finish() -> None:
-                    # When someone calls finish, we assume they are done writing to the main media repo
-                    main_media_repo_write_trace_scope.__exit__(None, None, None)
-
-                    with start_active_span("writing to other storage providers"):
-                        # Ensure that all writes have been flushed and close the
-                        # file.
-                        f.flush()
-                        f.close()
-
-                        spam_check = await self._spam_checker_module_callbacks.check_media_file_for_spam(
-                            ReadableFileWrapper(self.clock, fname), file_info
-                        )
-                        if spam_check != self._spam_checker_module_callbacks.NOT_SPAM:
-                            logger.info("Blocking media due to spam checker")
-                            # Note that we'll delete the stored media, due to the
-                            # try/except below. The media also won't be stored in
-                            # the DB.
-                            # We currently ignore any additional field returned by
-                            # the spam-check API.
-                            raise SpamMediaException(errcode=spam_check[0])
-
-                        for provider in self.storage_providers:
-                            with start_active_span(str(provider)):
-                                await provider.store_file(path, file_info)
-
-                        finished_called[0] = True
-
-                yield f, fname, finish
-        except Exception as e:
+        with main_media_repo_write_trace_scope:
             try:
-                main_media_repo_write_trace_scope.__exit__(
-                    type(e), None, e.__traceback__
-                )
-                os.remove(fname)
-            except Exception:
-                pass
+                with open(fname, "wb") as f:
+                    yield f, fname
 
-            raise e from None
+            except Exception as e:
+                try:
+                    os.remove(fname)
+                except Exception:
+                    pass
 
-        if not finished_called:
-            exc = Exception("Finished callback not called")
-            main_media_repo_write_trace_scope.__exit__(
-                type(exc), None, exc.__traceback__
+                raise e from None
+
+        with start_active_span("writing to other storage providers"):
+            spam_check = (
+                await self._spam_checker_module_callbacks.check_media_file_for_spam(
+                    ReadableFileWrapper(self.clock, fname), file_info
+                )
             )
-            raise exc
+            if spam_check != self._spam_checker_module_callbacks.NOT_SPAM:
+                logger.info("Blocking media due to spam checker")
+                # Note that we'll delete the stored media, due to the
+                # try/except below. The media also won't be stored in
+                # the DB.
+                # We currently ignore any additional field returned by
+                # the spam-check API.
+                raise SpamMediaException(errcode=spam_check[0])
+
+            for provider in self.storage_providers:
+                with start_active_span(str(provider)):
+                    await provider.store_file(path, file_info)
 
     async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]:
         """Attempts to fetch media described by file_info from the local cache
diff --git a/synapse/media/thumbnailer.py b/synapse/media/thumbnailer.py
index 5538020bec..cc3acf51e1 100644
--- a/synapse/media/thumbnailer.py
+++ b/synapse/media/thumbnailer.py
@@ -22,11 +22,27 @@
 import logging
 from io import BytesIO
 from types import TracebackType
-from typing import Optional, Tuple, Type
+from typing import TYPE_CHECKING, List, Optional, Tuple, Type
 
 from PIL import Image
 
+from synapse.api.errors import Codes, SynapseError, cs_error
+from synapse.config.repository import THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP
+from synapse.http.server import respond_with_json
+from synapse.http.site import SynapseRequest
 from synapse.logging.opentracing import trace
+from synapse.media._base import (
+    FileInfo,
+    ThumbnailInfo,
+    respond_404,
+    respond_with_file,
+    respond_with_responder,
+)
+from synapse.media.media_storage import MediaStorage
+
+if TYPE_CHECKING:
+    from synapse.media.media_repository import MediaRepository
+    from synapse.server import HomeServer
 
 logger = logging.getLogger(__name__)
 
@@ -231,3 +247,471 @@ class Thumbnailer:
     def __del__(self) -> None:
         # Make sure we actually do close the image, rather than leak data.
         self.close()
+
+
+class ThumbnailProvider:
+    def __init__(
+        self,
+        hs: "HomeServer",
+        media_repo: "MediaRepository",
+        media_storage: MediaStorage,
+    ):
+        self.hs = hs
+        self.media_repo = media_repo
+        self.media_storage = media_storage
+        self.store = hs.get_datastores().main
+        self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails
+
+    async def respond_local_thumbnail(
+        self,
+        request: SynapseRequest,
+        media_id: str,
+        width: int,
+        height: int,
+        method: str,
+        m_type: str,
+        max_timeout_ms: int,
+    ) -> None:
+        media_info = await self.media_repo.get_local_media_info(
+            request, media_id, max_timeout_ms
+        )
+        if not media_info:
+            return
+
+        thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
+        await self._select_and_respond_with_thumbnail(
+            request,
+            width,
+            height,
+            method,
+            m_type,
+            thumbnail_infos,
+            media_id,
+            media_id,
+            url_cache=bool(media_info.url_cache),
+            server_name=None,
+        )
+
+    async def select_or_generate_local_thumbnail(
+        self,
+        request: SynapseRequest,
+        media_id: str,
+        desired_width: int,
+        desired_height: int,
+        desired_method: str,
+        desired_type: str,
+        max_timeout_ms: int,
+    ) -> None:
+        media_info = await self.media_repo.get_local_media_info(
+            request, media_id, max_timeout_ms
+        )
+
+        if not media_info:
+            return
+
+        thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
+        for info in thumbnail_infos:
+            t_w = info.width == desired_width
+            t_h = info.height == desired_height
+            t_method = info.method == desired_method
+            t_type = info.type == desired_type
+
+            if t_w and t_h and t_method and t_type:
+                file_info = FileInfo(
+                    server_name=None,
+                    file_id=media_id,
+                    url_cache=bool(media_info.url_cache),
+                    thumbnail=info,
+                )
+
+                responder = await self.media_storage.fetch_media(file_info)
+                if responder:
+                    await respond_with_responder(
+                        request, responder, info.type, info.length
+                    )
+                    return
+
+        logger.debug("We don't have a thumbnail of that size. Generating")
+
+        # Okay, so we generate one.
+        file_path = await self.media_repo.generate_local_exact_thumbnail(
+            media_id,
+            desired_width,
+            desired_height,
+            desired_method,
+            desired_type,
+            url_cache=bool(media_info.url_cache),
+        )
+
+        if file_path:
+            await respond_with_file(request, desired_type, file_path)
+        else:
+            logger.warning("Failed to generate thumbnail")
+            raise SynapseError(400, "Failed to generate thumbnail.")
+
+    async def select_or_generate_remote_thumbnail(
+        self,
+        request: SynapseRequest,
+        server_name: str,
+        media_id: str,
+        desired_width: int,
+        desired_height: int,
+        desired_method: str,
+        desired_type: str,
+        max_timeout_ms: int,
+    ) -> None:
+        media_info = await self.media_repo.get_remote_media_info(
+            server_name, media_id, max_timeout_ms
+        )
+        if not media_info:
+            respond_404(request)
+            return
+
+        thumbnail_infos = await self.store.get_remote_media_thumbnails(
+            server_name, media_id
+        )
+
+        file_id = media_info.filesystem_id
+
+        for info in thumbnail_infos:
+            t_w = info.width == desired_width
+            t_h = info.height == desired_height
+            t_method = info.method == desired_method
+            t_type = info.type == desired_type
+
+            if t_w and t_h and t_method and t_type:
+                file_info = FileInfo(
+                    server_name=server_name,
+                    file_id=file_id,
+                    thumbnail=info,
+                )
+
+                responder = await self.media_storage.fetch_media(file_info)
+                if responder:
+                    await respond_with_responder(
+                        request, responder, info.type, info.length
+                    )
+                    return
+
+        logger.debug("We don't have a thumbnail of that size. Generating")
+
+        # Okay, so we generate one.
+        file_path = await self.media_repo.generate_remote_exact_thumbnail(
+            server_name,
+            file_id,
+            media_id,
+            desired_width,
+            desired_height,
+            desired_method,
+            desired_type,
+        )
+
+        if file_path:
+            await respond_with_file(request, desired_type, file_path)
+        else:
+            logger.warning("Failed to generate thumbnail")
+            raise SynapseError(400, "Failed to generate thumbnail.")
+
+    async def respond_remote_thumbnail(
+        self,
+        request: SynapseRequest,
+        server_name: str,
+        media_id: str,
+        width: int,
+        height: int,
+        method: str,
+        m_type: str,
+        max_timeout_ms: int,
+    ) -> None:
+        # TODO: Don't download the whole remote file
+        # We should proxy the thumbnail from the remote server instead of
+        # downloading the remote file and generating our own thumbnails.
+        media_info = await self.media_repo.get_remote_media_info(
+            server_name, media_id, max_timeout_ms
+        )
+        if not media_info:
+            return
+
+        thumbnail_infos = await self.store.get_remote_media_thumbnails(
+            server_name, media_id
+        )
+        await self._select_and_respond_with_thumbnail(
+            request,
+            width,
+            height,
+            method,
+            m_type,
+            thumbnail_infos,
+            media_id,
+            media_info.filesystem_id,
+            url_cache=False,
+            server_name=server_name,
+        )
+
+    async def _select_and_respond_with_thumbnail(
+        self,
+        request: SynapseRequest,
+        desired_width: int,
+        desired_height: int,
+        desired_method: str,
+        desired_type: str,
+        thumbnail_infos: List[ThumbnailInfo],
+        media_id: str,
+        file_id: str,
+        url_cache: bool,
+        server_name: Optional[str] = None,
+    ) -> None:
+        """
+        Respond to a request with an appropriate thumbnail from the previously generated thumbnails.
+
+        Args:
+            request: The incoming request.
+            desired_width: The desired width, the returned thumbnail may be larger than this.
+            desired_height: The desired height, the returned thumbnail may be larger than this.
+            desired_method: The desired method used to generate the thumbnail.
+            desired_type: The desired content-type of the thumbnail.
+            thumbnail_infos: A list of thumbnail info of candidate thumbnails.
+            file_id: The ID of the media that a thumbnail is being requested for.
+            url_cache: True if this is from a URL cache.
+            server_name: The server name, if this is a remote thumbnail.
+        """
+        logger.debug(
+            "_select_and_respond_with_thumbnail: media_id=%s desired=%sx%s (%s) thumbnail_infos=%s",
+            media_id,
+            desired_width,
+            desired_height,
+            desired_method,
+            thumbnail_infos,
+        )
+
+        # If `dynamic_thumbnails` is enabled, we expect Synapse to go down a
+        # different code path to handle it.
+        assert not self.dynamic_thumbnails
+
+        if thumbnail_infos:
+            file_info = self._select_thumbnail(
+                desired_width,
+                desired_height,
+                desired_method,
+                desired_type,
+                thumbnail_infos,
+                file_id,
+                url_cache,
+                server_name,
+            )
+            if not file_info:
+                logger.info("Couldn't find a thumbnail matching the desired inputs")
+                respond_404(request)
+                return
+
+            # The thumbnail property must exist.
+            assert file_info.thumbnail is not None
+
+            responder = await self.media_storage.fetch_media(file_info)
+            if responder:
+                await respond_with_responder(
+                    request,
+                    responder,
+                    file_info.thumbnail.type,
+                    file_info.thumbnail.length,
+                )
+                return
+
+            # If we can't find the thumbnail we regenerate it. This can happen
+            # if e.g. we've deleted the thumbnails but still have the original
+            # image somewhere.
+            #
+            # Since we have an entry for the thumbnail in the DB we a) know we
+            # have have successfully generated the thumbnail in the past (so we
+            # don't need to worry about repeatedly failing to generate
+            # thumbnails), and b) have already calculated that appropriate
+            # width/height/method so we can just call the "generate exact"
+            # methods.
+
+            # First let's check that we do actually have the original image
+            # still. This will throw a 404 if we don't.
+            # TODO: We should refetch the thumbnails for remote media.
+            await self.media_storage.ensure_media_is_in_local_cache(
+                FileInfo(server_name, file_id, url_cache=url_cache)
+            )
+
+            if server_name:
+                await self.media_repo.generate_remote_exact_thumbnail(
+                    server_name,
+                    file_id=file_id,
+                    media_id=media_id,
+                    t_width=file_info.thumbnail.width,
+                    t_height=file_info.thumbnail.height,
+                    t_method=file_info.thumbnail.method,
+                    t_type=file_info.thumbnail.type,
+                )
+            else:
+                await self.media_repo.generate_local_exact_thumbnail(
+                    media_id=media_id,
+                    t_width=file_info.thumbnail.width,
+                    t_height=file_info.thumbnail.height,
+                    t_method=file_info.thumbnail.method,
+                    t_type=file_info.thumbnail.type,
+                    url_cache=url_cache,
+                )
+
+            responder = await self.media_storage.fetch_media(file_info)
+            await respond_with_responder(
+                request,
+                responder,
+                file_info.thumbnail.type,
+                file_info.thumbnail.length,
+            )
+        else:
+            # This might be because:
+            # 1. We can't create thumbnails for the given media (corrupted or
+            #    unsupported file type), or
+            # 2. The thumbnailing process never ran or errored out initially
+            #    when the media was first uploaded (these bugs should be
+            #    reported and fixed).
+            # Note that we don't attempt to generate a thumbnail now because
+            # `dynamic_thumbnails` is disabled.
+            logger.info("Failed to find any generated thumbnails")
+
+            assert request.path is not None
+            respond_with_json(
+                request,
+                400,
+                cs_error(
+                    "Cannot find any thumbnails for the requested media ('%s'). This might mean the media is not a supported_media_format=(%s) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)"
+                    % (
+                        request.path.decode(),
+                        ", ".join(THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP.keys()),
+                    ),
+                    code=Codes.UNKNOWN,
+                ),
+                send_cors=True,
+            )
+
+    def _select_thumbnail(
+        self,
+        desired_width: int,
+        desired_height: int,
+        desired_method: str,
+        desired_type: str,
+        thumbnail_infos: List[ThumbnailInfo],
+        file_id: str,
+        url_cache: bool,
+        server_name: Optional[str],
+    ) -> Optional[FileInfo]:
+        """
+        Choose an appropriate thumbnail from the previously generated thumbnails.
+
+        Args:
+            desired_width: The desired width, the returned thumbnail may be larger than this.
+            desired_height: The desired height, the returned thumbnail may be larger than this.
+            desired_method: The desired method used to generate the thumbnail.
+            desired_type: The desired content-type of the thumbnail.
+            thumbnail_infos: A list of thumbnail infos of candidate thumbnails.
+            file_id: The ID of the media that a thumbnail is being requested for.
+            url_cache: True if this is from a URL cache.
+            server_name: The server name, if this is a remote thumbnail.
+
+        Returns:
+             The thumbnail which best matches the desired parameters.
+        """
+        desired_method = desired_method.lower()
+
+        # The chosen thumbnail.
+        thumbnail_info = None
+
+        d_w = desired_width
+        d_h = desired_height
+
+        if desired_method == "crop":
+            # Thumbnails that match equal or larger sizes of desired width/height.
+            crop_info_list: List[
+                Tuple[int, int, int, bool, Optional[int], ThumbnailInfo]
+            ] = []
+            # Other thumbnails.
+            crop_info_list2: List[
+                Tuple[int, int, int, bool, Optional[int], ThumbnailInfo]
+            ] = []
+            for info in thumbnail_infos:
+                # Skip thumbnails generated with different methods.
+                if info.method != "crop":
+                    continue
+
+                t_w = info.width
+                t_h = info.height
+                aspect_quality = abs(d_w * t_h - d_h * t_w)
+                min_quality = 0 if d_w <= t_w and d_h <= t_h else 1
+                size_quality = abs((d_w - t_w) * (d_h - t_h))
+                type_quality = desired_type != info.type
+                length_quality = info.length
+                if t_w >= d_w or t_h >= d_h:
+                    crop_info_list.append(
+                        (
+                            aspect_quality,
+                            min_quality,
+                            size_quality,
+                            type_quality,
+                            length_quality,
+                            info,
+                        )
+                    )
+                else:
+                    crop_info_list2.append(
+                        (
+                            aspect_quality,
+                            min_quality,
+                            size_quality,
+                            type_quality,
+                            length_quality,
+                            info,
+                        )
+                    )
+            # Pick the most appropriate thumbnail. Some values of `desired_width` and
+            # `desired_height` may result in a tie, in which case we avoid comparing on
+            # the thumbnail info and pick the thumbnail that appears earlier
+            # in the list of candidates.
+            if crop_info_list:
+                thumbnail_info = min(crop_info_list, key=lambda t: t[:-1])[-1]
+            elif crop_info_list2:
+                thumbnail_info = min(crop_info_list2, key=lambda t: t[:-1])[-1]
+        elif desired_method == "scale":
+            # Thumbnails that match equal or larger sizes of desired width/height.
+            info_list: List[Tuple[int, bool, int, ThumbnailInfo]] = []
+            # Other thumbnails.
+            info_list2: List[Tuple[int, bool, int, ThumbnailInfo]] = []
+
+            for info in thumbnail_infos:
+                # Skip thumbnails generated with different methods.
+                if info.method != "scale":
+                    continue
+
+                t_w = info.width
+                t_h = info.height
+                size_quality = abs((d_w - t_w) * (d_h - t_h))
+                type_quality = desired_type != info.type
+                length_quality = info.length
+                if t_w >= d_w or t_h >= d_h:
+                    info_list.append((size_quality, type_quality, length_quality, info))
+                else:
+                    info_list2.append(
+                        (size_quality, type_quality, length_quality, info)
+                    )
+            # Pick the most appropriate thumbnail. Some values of `desired_width` and
+            # `desired_height` may result in a tie, in which case we avoid comparing on
+            # the thumbnail info and pick the thumbnail that appears earlier
+            # in the list of candidates.
+            if info_list:
+                thumbnail_info = min(info_list, key=lambda t: t[:-1])[-1]
+            elif info_list2:
+                thumbnail_info = min(info_list2, key=lambda t: t[:-1])[-1]
+
+        if thumbnail_info:
+            return FileInfo(
+                file_id=file_id,
+                url_cache=url_cache,
+                server_name=server_name,
+                thumbnail=thumbnail_info,
+            )
+
+        # No matching thumbnail was found.
+        return None
diff --git a/synapse/media/url_previewer.py b/synapse/media/url_previewer.py
index 3897823b35..2e65a04789 100644
--- a/synapse/media/url_previewer.py
+++ b/synapse/media/url_previewer.py
@@ -592,7 +592,7 @@ class UrlPreviewer:
 
         file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True)
 
-        with self.media_storage.store_into_file(file_info) as (f, fname, finish):
+        async with self.media_storage.store_into_file(file_info) as (f, fname):
             if url.startswith("data:"):
                 if not allow_data_urls:
                     raise SynapseError(
@@ -603,8 +603,6 @@ class UrlPreviewer:
             else:
                 download_result = await self._download_url(url, f)
 
-            await finish()
-
         try:
             time_now_ms = self.clock.time_msec()
 
diff --git a/synapse/rest/client/media.py b/synapse/rest/client/media.py
new file mode 100644
index 0000000000..172d240783
--- /dev/null
+++ b/synapse/rest/client/media.py
@@ -0,0 +1,205 @@
+#
+# This file is licensed under the Affero General Public License (AGPL) version 3.
+#
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright (C) 2024 New Vector, Ltd
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# See the GNU Affero General Public License for more details:
+# <https://www.gnu.org/licenses/agpl-3.0.html>.
+#
+# Originally licensed under the Apache License, Version 2.0:
+# <http://www.apache.org/licenses/LICENSE-2.0>.
+#
+# [This file includes modifications made by New Vector Limited]
+#
+#
+
+import logging
+import re
+
+from synapse.http.server import (
+    HttpServer,
+    respond_with_json,
+    respond_with_json_bytes,
+    set_corp_headers,
+    set_cors_headers,
+)
+from synapse.http.servlet import RestServlet, parse_integer, parse_string
+from synapse.http.site import SynapseRequest
+from synapse.media._base import (
+    DEFAULT_MAX_TIMEOUT_MS,
+    MAXIMUM_ALLOWED_MAX_TIMEOUT_MS,
+    respond_404,
+)
+from synapse.media.media_repository import MediaRepository
+from synapse.media.media_storage import MediaStorage
+from synapse.media.thumbnailer import ThumbnailProvider
+from synapse.server import HomeServer
+from synapse.util.stringutils import parse_and_validate_server_name
+
+logger = logging.getLogger(__name__)
+
+
+class UnstablePreviewURLServlet(RestServlet):
+    """
+    Same as `GET /_matrix/media/r0/preview_url`, this endpoint provides a generic preview API
+    for URLs which outputs Open Graph (https://ogp.me/) responses (with some Matrix
+    specific additions).
+
+    This does have trade-offs compared to other designs:
+
+    * Pros:
+      * Simple and flexible; can be used by any clients at any point
+    * Cons:
+      * If each homeserver provides one of these independently, all the homeservers in a
+        room may needlessly DoS the target URI
+      * The URL metadata must be stored somewhere, rather than just using Matrix
+        itself to store the media.
+      * Matrix cannot be used to distribute the metadata between homeservers.
+    """
+
+    PATTERNS = [
+        re.compile(r"^/_matrix/client/unstable/org.matrix.msc3916/media/preview_url$")
+    ]
+
+    def __init__(
+        self,
+        hs: "HomeServer",
+        media_repo: "MediaRepository",
+        media_storage: MediaStorage,
+    ):
+        super().__init__()
+
+        self.auth = hs.get_auth()
+        self.clock = hs.get_clock()
+        self.media_repo = media_repo
+        self.media_storage = media_storage
+        assert self.media_repo.url_previewer is not None
+        self.url_previewer = self.media_repo.url_previewer
+
+    async def on_GET(self, request: SynapseRequest) -> None:
+        requester = await self.auth.get_user_by_req(request)
+        url = parse_string(request, "url", required=True)
+        ts = parse_integer(request, "ts")
+        if ts is None:
+            ts = self.clock.time_msec()
+
+        og = await self.url_previewer.preview(url, requester.user, ts)
+        respond_with_json_bytes(request, 200, og, send_cors=True)
+
+
+class UnstableMediaConfigResource(RestServlet):
+    PATTERNS = [
+        re.compile(r"^/_matrix/client/unstable/org.matrix.msc3916/media/config$")
+    ]
+
+    def __init__(self, hs: "HomeServer"):
+        super().__init__()
+        config = hs.config
+        self.clock = hs.get_clock()
+        self.auth = hs.get_auth()
+        self.limits_dict = {"m.upload.size": config.media.max_upload_size}
+
+    async def on_GET(self, request: SynapseRequest) -> None:
+        await self.auth.get_user_by_req(request)
+        respond_with_json(request, 200, self.limits_dict, send_cors=True)
+
+
+class UnstableThumbnailResource(RestServlet):
+    PATTERNS = [
+        re.compile(
+            "/_matrix/client/unstable/org.matrix.msc3916/media/thumbnail/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)$"
+        )
+    ]
+
+    def __init__(
+        self,
+        hs: "HomeServer",
+        media_repo: "MediaRepository",
+        media_storage: MediaStorage,
+    ):
+        super().__init__()
+
+        self.store = hs.get_datastores().main
+        self.media_repo = media_repo
+        self.media_storage = media_storage
+        self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails
+        self._is_mine_server_name = hs.is_mine_server_name
+        self._server_name = hs.hostname
+        self.prevent_media_downloads_from = hs.config.media.prevent_media_downloads_from
+        self.thumbnailer = ThumbnailProvider(hs, media_repo, media_storage)
+        self.auth = hs.get_auth()
+
+    async def on_GET(
+        self, request: SynapseRequest, server_name: str, media_id: str
+    ) -> None:
+        # Validate the server name, raising if invalid
+        parse_and_validate_server_name(server_name)
+        await self.auth.get_user_by_req(request)
+
+        set_cors_headers(request)
+        set_corp_headers(request)
+        width = parse_integer(request, "width", required=True)
+        height = parse_integer(request, "height", required=True)
+        method = parse_string(request, "method", "scale")
+        # TODO Parse the Accept header to get an prioritised list of thumbnail types.
+        m_type = "image/png"
+        max_timeout_ms = parse_integer(
+            request, "timeout_ms", default=DEFAULT_MAX_TIMEOUT_MS
+        )
+        max_timeout_ms = min(max_timeout_ms, MAXIMUM_ALLOWED_MAX_TIMEOUT_MS)
+
+        if self._is_mine_server_name(server_name):
+            if self.dynamic_thumbnails:
+                await self.thumbnailer.select_or_generate_local_thumbnail(
+                    request, media_id, width, height, method, m_type, max_timeout_ms
+                )
+            else:
+                await self.thumbnailer.respond_local_thumbnail(
+                    request, media_id, width, height, method, m_type, max_timeout_ms
+                )
+            self.media_repo.mark_recently_accessed(None, media_id)
+        else:
+            # Don't let users download media from configured domains, even if it
+            # is already downloaded. This is Trust & Safety tooling to make some
+            # media inaccessible to local users.
+            # See `prevent_media_downloads_from` config docs for more info.
+            if server_name in self.prevent_media_downloads_from:
+                respond_404(request)
+                return
+
+            remote_resp_function = (
+                self.thumbnailer.select_or_generate_remote_thumbnail
+                if self.dynamic_thumbnails
+                else self.thumbnailer.respond_remote_thumbnail
+            )
+            await remote_resp_function(
+                request,
+                server_name,
+                media_id,
+                width,
+                height,
+                method,
+                m_type,
+                max_timeout_ms,
+            )
+            self.media_repo.mark_recently_accessed(server_name, media_id)
+
+
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
+    if hs.config.experimental.msc3916_authenticated_media_enabled:
+        media_repo = hs.get_media_repository()
+        if hs.config.media.url_preview_enabled:
+            UnstablePreviewURLServlet(
+                hs, media_repo, media_repo.media_storage
+            ).register(http_server)
+        UnstableMediaConfigResource(hs).register(http_server)
+        UnstableThumbnailResource(hs, media_repo, media_repo.media_storage).register(
+            http_server
+        )
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index 4a57eaf930..27ea943e31 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -567,5 +567,176 @@ class SyncRestServlet(RestServlet):
         return result
 
 
+class SlidingSyncE2eeRestServlet(RestServlet):
+    """
+    API endpoint for MSC3575 Sliding Sync `/sync/e2ee`. This is being introduced as part
+    of Sliding Sync but doesn't have any sliding window component. It's just a way to
+    get E2EE events without having to sit through a big initial sync (`/sync` v2). And
+    we can avoid encryption events being backed up by the main sync response.
+
+    Having To-Device messages split out to this sync endpoint also helps when clients
+    need to have 2 or more sync streams open at a time, e.g a push notification process
+    and a main process. This can cause the two processes to race to fetch the To-Device
+    events, resulting in the need for complex synchronisation rules to ensure the token
+    is correctly and atomically exchanged between processes.
+
+    GET parameters::
+        timeout(int): How long to wait for new events in milliseconds.
+        since(batch_token): Batch token when asking for incremental deltas.
+
+    Response JSON::
+        {
+            "next_batch": // batch token for the next /sync
+            "to_device": {
+                // list of to-device events
+                "events": [
+                    {
+                        "content: { "algorithm": "m.olm.v1.curve25519-aes-sha2", "ciphertext": { ... }, "org.matrix.msgid": "abcd", "session_id": "abcd" },
+                        "type": "m.room.encrypted",
+                        "sender": "@alice:example.com",
+                    }
+                    // ...
+                ]
+            },
+            "device_lists": {
+                "changed": ["@alice:example.com"],
+                "left": ["@bob:example.com"]
+            },
+            "device_one_time_keys_count": {
+                "signed_curve25519": 50
+            },
+            "device_unused_fallback_key_types": [
+                "signed_curve25519"
+            ]
+        }
+    """
+
+    PATTERNS = client_patterns(
+        "/org.matrix.msc3575/sync/e2ee$", releases=[], v1=False, unstable=True
+    )
+
+    def __init__(self, hs: "HomeServer"):
+        super().__init__()
+        self.hs = hs
+        self.auth = hs.get_auth()
+        self.store = hs.get_datastores().main
+        self.sync_handler = hs.get_sync_handler()
+
+        # Filtering only matters for the `device_lists` because it requires a bunch of
+        # derived information from rooms (see how `_generate_sync_entry_for_rooms()`
+        # prepares a bunch of data for `_generate_sync_entry_for_device_list()`).
+        self.only_member_events_filter_collection = FilterCollection(
+            self.hs,
+            {
+                "room": {
+                    # We only care about membership events for the `device_lists`.
+                    # Membership will tell us whether a user has joined/left a room and
+                    # if there are new devices to encrypt for.
+                    "timeline": {
+                        "types": ["m.room.member"],
+                    },
+                    "state": {
+                        "types": ["m.room.member"],
+                    },
+                    # We don't want any extra account_data generated because it's not
+                    # returned by this endpoint. This helps us avoid work in
+                    # `_generate_sync_entry_for_rooms()`
+                    "account_data": {
+                        "not_types": ["*"],
+                    },
+                    # We don't want any extra ephemeral data generated because it's not
+                    # returned by this endpoint. This helps us avoid work in
+                    # `_generate_sync_entry_for_rooms()`
+                    "ephemeral": {
+                        "not_types": ["*"],
+                    },
+                },
+                # We don't want any extra account_data generated because it's not
+                # returned by this endpoint. (This is just here for good measure)
+                "account_data": {
+                    "not_types": ["*"],
+                },
+                # We don't want any extra presence data generated because it's not
+                # returned by this endpoint. (This is just here for good measure)
+                "presence": {
+                    "not_types": ["*"],
+                },
+            },
+        )
+
+    async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+        requester = await self.auth.get_user_by_req(request, allow_guest=True)
+        user = requester.user
+        device_id = requester.device_id
+
+        timeout = parse_integer(request, "timeout", default=0)
+        since = parse_string(request, "since")
+
+        sync_config = SyncConfig(
+            user=user,
+            filter_collection=self.only_member_events_filter_collection,
+            is_guest=requester.is_guest,
+            device_id=device_id,
+        )
+
+        since_token = None
+        if since is not None:
+            since_token = await StreamToken.from_string(self.store, since)
+
+        # Request cache key
+        request_key = (
+            SyncVersion.E2EE_SYNC,
+            user,
+            timeout,
+            since,
+        )
+
+        # Gather data for the response
+        sync_result = await self.sync_handler.wait_for_sync_for_user(
+            requester,
+            sync_config,
+            SyncVersion.E2EE_SYNC,
+            request_key,
+            since_token=since_token,
+            timeout=timeout,
+            full_state=False,
+        )
+
+        # The client may have disconnected by now; don't bother to serialize the
+        # response if so.
+        if request._disconnected:
+            logger.info("Client has disconnected; not serializing response.")
+            return 200, {}
+
+        response: JsonDict = defaultdict(dict)
+        response["next_batch"] = await sync_result.next_batch.to_string(self.store)
+
+        if sync_result.to_device:
+            response["to_device"] = {"events": sync_result.to_device}
+
+        if sync_result.device_lists.changed:
+            response["device_lists"]["changed"] = list(sync_result.device_lists.changed)
+        if sync_result.device_lists.left:
+            response["device_lists"]["left"] = list(sync_result.device_lists.left)
+
+        # We always include this because https://github.com/vector-im/element-android/issues/3725
+        # The spec isn't terribly clear on when this can be omitted and how a client would tell
+        # the difference between "no keys present" and "nothing changed" in terms of whole field
+        # absent / individual key type entry absent
+        # Corresponding synapse issue: https://github.com/matrix-org/synapse/issues/10456
+        response["device_one_time_keys_count"] = sync_result.device_one_time_keys_count
+
+        # https://github.com/matrix-org/matrix-doc/blob/54255851f642f84a4f1aaf7bc063eebe3d76752b/proposals/2732-olm-fallback-keys.md
+        # states that this field should always be included, as long as the server supports the feature.
+        response["device_unused_fallback_key_types"] = (
+            sync_result.device_unused_fallback_key_types
+        )
+
+        return 200, response
+
+
 def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     SyncRestServlet(hs).register(http_server)
+
+    if hs.config.experimental.msc3575_enabled:
+        SlidingSyncE2eeRestServlet(hs).register(http_server)
diff --git a/synapse/rest/media/thumbnail_resource.py b/synapse/rest/media/thumbnail_resource.py
index 7cb335c7c3..fe8fbb06e4 100644
--- a/synapse/rest/media/thumbnail_resource.py
+++ b/synapse/rest/media/thumbnail_resource.py
@@ -22,23 +22,18 @@
 
 import logging
 import re
-from typing import TYPE_CHECKING, List, Optional, Tuple
+from typing import TYPE_CHECKING
 
-from synapse.api.errors import Codes, SynapseError, cs_error
-from synapse.config.repository import THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP
-from synapse.http.server import respond_with_json, set_corp_headers, set_cors_headers
+from synapse.http.server import set_corp_headers, set_cors_headers
 from synapse.http.servlet import RestServlet, parse_integer, parse_string
 from synapse.http.site import SynapseRequest
 from synapse.media._base import (
     DEFAULT_MAX_TIMEOUT_MS,
     MAXIMUM_ALLOWED_MAX_TIMEOUT_MS,
-    FileInfo,
-    ThumbnailInfo,
     respond_404,
-    respond_with_file,
-    respond_with_responder,
 )
 from synapse.media.media_storage import MediaStorage
+from synapse.media.thumbnailer import ThumbnailProvider
 from synapse.util.stringutils import parse_and_validate_server_name
 
 if TYPE_CHECKING:
@@ -66,10 +61,11 @@ class ThumbnailResource(RestServlet):
         self.store = hs.get_datastores().main
         self.media_repo = media_repo
         self.media_storage = media_storage
-        self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails
         self._is_mine_server_name = hs.is_mine_server_name
         self._server_name = hs.hostname
         self.prevent_media_downloads_from = hs.config.media.prevent_media_downloads_from
+        self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails
+        self.thumbnail_provider = ThumbnailProvider(hs, media_repo, media_storage)
 
     async def on_GET(
         self, request: SynapseRequest, server_name: str, media_id: str
@@ -91,11 +87,11 @@ class ThumbnailResource(RestServlet):
 
         if self._is_mine_server_name(server_name):
             if self.dynamic_thumbnails:
-                await self._select_or_generate_local_thumbnail(
+                await self.thumbnail_provider.select_or_generate_local_thumbnail(
                     request, media_id, width, height, method, m_type, max_timeout_ms
                 )
             else:
-                await self._respond_local_thumbnail(
+                await self.thumbnail_provider.respond_local_thumbnail(
                     request, media_id, width, height, method, m_type, max_timeout_ms
                 )
             self.media_repo.mark_recently_accessed(None, media_id)
@@ -109,9 +105,9 @@ class ThumbnailResource(RestServlet):
                 return
 
             remote_resp_function = (
-                self._select_or_generate_remote_thumbnail
+                self.thumbnail_provider.select_or_generate_remote_thumbnail
                 if self.dynamic_thumbnails
-                else self._respond_remote_thumbnail
+                else self.thumbnail_provider.respond_remote_thumbnail
             )
             await remote_resp_function(
                 request,
@@ -124,457 +120,3 @@ class ThumbnailResource(RestServlet):
                 max_timeout_ms,
             )
             self.media_repo.mark_recently_accessed(server_name, media_id)
-
-    async def _respond_local_thumbnail(
-        self,
-        request: SynapseRequest,
-        media_id: str,
-        width: int,
-        height: int,
-        method: str,
-        m_type: str,
-        max_timeout_ms: int,
-    ) -> None:
-        media_info = await self.media_repo.get_local_media_info(
-            request, media_id, max_timeout_ms
-        )
-        if not media_info:
-            return
-
-        thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
-        await self._select_and_respond_with_thumbnail(
-            request,
-            width,
-            height,
-            method,
-            m_type,
-            thumbnail_infos,
-            media_id,
-            media_id,
-            url_cache=bool(media_info.url_cache),
-            server_name=None,
-        )
-
-    async def _select_or_generate_local_thumbnail(
-        self,
-        request: SynapseRequest,
-        media_id: str,
-        desired_width: int,
-        desired_height: int,
-        desired_method: str,
-        desired_type: str,
-        max_timeout_ms: int,
-    ) -> None:
-        media_info = await self.media_repo.get_local_media_info(
-            request, media_id, max_timeout_ms
-        )
-
-        if not media_info:
-            return
-
-        thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
-        for info in thumbnail_infos:
-            t_w = info.width == desired_width
-            t_h = info.height == desired_height
-            t_method = info.method == desired_method
-            t_type = info.type == desired_type
-
-            if t_w and t_h and t_method and t_type:
-                file_info = FileInfo(
-                    server_name=None,
-                    file_id=media_id,
-                    url_cache=bool(media_info.url_cache),
-                    thumbnail=info,
-                )
-
-                responder = await self.media_storage.fetch_media(file_info)
-                if responder:
-                    await respond_with_responder(
-                        request, responder, info.type, info.length
-                    )
-                    return
-
-        logger.debug("We don't have a thumbnail of that size. Generating")
-
-        # Okay, so we generate one.
-        file_path = await self.media_repo.generate_local_exact_thumbnail(
-            media_id,
-            desired_width,
-            desired_height,
-            desired_method,
-            desired_type,
-            url_cache=bool(media_info.url_cache),
-        )
-
-        if file_path:
-            await respond_with_file(request, desired_type, file_path)
-        else:
-            logger.warning("Failed to generate thumbnail")
-            raise SynapseError(400, "Failed to generate thumbnail.")
-
-    async def _select_or_generate_remote_thumbnail(
-        self,
-        request: SynapseRequest,
-        server_name: str,
-        media_id: str,
-        desired_width: int,
-        desired_height: int,
-        desired_method: str,
-        desired_type: str,
-        max_timeout_ms: int,
-    ) -> None:
-        media_info = await self.media_repo.get_remote_media_info(
-            server_name, media_id, max_timeout_ms
-        )
-        if not media_info:
-            respond_404(request)
-            return
-
-        thumbnail_infos = await self.store.get_remote_media_thumbnails(
-            server_name, media_id
-        )
-
-        file_id = media_info.filesystem_id
-
-        for info in thumbnail_infos:
-            t_w = info.width == desired_width
-            t_h = info.height == desired_height
-            t_method = info.method == desired_method
-            t_type = info.type == desired_type
-
-            if t_w and t_h and t_method and t_type:
-                file_info = FileInfo(
-                    server_name=server_name,
-                    file_id=file_id,
-                    thumbnail=info,
-                )
-
-                responder = await self.media_storage.fetch_media(file_info)
-                if responder:
-                    await respond_with_responder(
-                        request, responder, info.type, info.length
-                    )
-                    return
-
-        logger.debug("We don't have a thumbnail of that size. Generating")
-
-        # Okay, so we generate one.
-        file_path = await self.media_repo.generate_remote_exact_thumbnail(
-            server_name,
-            file_id,
-            media_id,
-            desired_width,
-            desired_height,
-            desired_method,
-            desired_type,
-        )
-
-        if file_path:
-            await respond_with_file(request, desired_type, file_path)
-        else:
-            logger.warning("Failed to generate thumbnail")
-            raise SynapseError(400, "Failed to generate thumbnail.")
-
-    async def _respond_remote_thumbnail(
-        self,
-        request: SynapseRequest,
-        server_name: str,
-        media_id: str,
-        width: int,
-        height: int,
-        method: str,
-        m_type: str,
-        max_timeout_ms: int,
-    ) -> None:
-        # TODO: Don't download the whole remote file
-        # We should proxy the thumbnail from the remote server instead of
-        # downloading the remote file and generating our own thumbnails.
-        media_info = await self.media_repo.get_remote_media_info(
-            server_name, media_id, max_timeout_ms
-        )
-        if not media_info:
-            return
-
-        thumbnail_infos = await self.store.get_remote_media_thumbnails(
-            server_name, media_id
-        )
-        await self._select_and_respond_with_thumbnail(
-            request,
-            width,
-            height,
-            method,
-            m_type,
-            thumbnail_infos,
-            media_id,
-            media_info.filesystem_id,
-            url_cache=False,
-            server_name=server_name,
-        )
-
-    async def _select_and_respond_with_thumbnail(
-        self,
-        request: SynapseRequest,
-        desired_width: int,
-        desired_height: int,
-        desired_method: str,
-        desired_type: str,
-        thumbnail_infos: List[ThumbnailInfo],
-        media_id: str,
-        file_id: str,
-        url_cache: bool,
-        server_name: Optional[str] = None,
-    ) -> None:
-        """
-        Respond to a request with an appropriate thumbnail from the previously generated thumbnails.
-
-        Args:
-            request: The incoming request.
-            desired_width: The desired width, the returned thumbnail may be larger than this.
-            desired_height: The desired height, the returned thumbnail may be larger than this.
-            desired_method: The desired method used to generate the thumbnail.
-            desired_type: The desired content-type of the thumbnail.
-            thumbnail_infos: A list of thumbnail info of candidate thumbnails.
-            file_id: The ID of the media that a thumbnail is being requested for.
-            url_cache: True if this is from a URL cache.
-            server_name: The server name, if this is a remote thumbnail.
-        """
-        logger.debug(
-            "_select_and_respond_with_thumbnail: media_id=%s desired=%sx%s (%s) thumbnail_infos=%s",
-            media_id,
-            desired_width,
-            desired_height,
-            desired_method,
-            thumbnail_infos,
-        )
-
-        # If `dynamic_thumbnails` is enabled, we expect Synapse to go down a
-        # different code path to handle it.
-        assert not self.dynamic_thumbnails
-
-        if thumbnail_infos:
-            file_info = self._select_thumbnail(
-                desired_width,
-                desired_height,
-                desired_method,
-                desired_type,
-                thumbnail_infos,
-                file_id,
-                url_cache,
-                server_name,
-            )
-            if not file_info:
-                logger.info("Couldn't find a thumbnail matching the desired inputs")
-                respond_404(request)
-                return
-
-            # The thumbnail property must exist.
-            assert file_info.thumbnail is not None
-
-            responder = await self.media_storage.fetch_media(file_info)
-            if responder:
-                await respond_with_responder(
-                    request,
-                    responder,
-                    file_info.thumbnail.type,
-                    file_info.thumbnail.length,
-                )
-                return
-
-            # If we can't find the thumbnail we regenerate it. This can happen
-            # if e.g. we've deleted the thumbnails but still have the original
-            # image somewhere.
-            #
-            # Since we have an entry for the thumbnail in the DB we a) know we
-            # have have successfully generated the thumbnail in the past (so we
-            # don't need to worry about repeatedly failing to generate
-            # thumbnails), and b) have already calculated that appropriate
-            # width/height/method so we can just call the "generate exact"
-            # methods.
-
-            # First let's check that we do actually have the original image
-            # still. This will throw a 404 if we don't.
-            # TODO: We should refetch the thumbnails for remote media.
-            await self.media_storage.ensure_media_is_in_local_cache(
-                FileInfo(server_name, file_id, url_cache=url_cache)
-            )
-
-            if server_name:
-                await self.media_repo.generate_remote_exact_thumbnail(
-                    server_name,
-                    file_id=file_id,
-                    media_id=media_id,
-                    t_width=file_info.thumbnail.width,
-                    t_height=file_info.thumbnail.height,
-                    t_method=file_info.thumbnail.method,
-                    t_type=file_info.thumbnail.type,
-                )
-            else:
-                await self.media_repo.generate_local_exact_thumbnail(
-                    media_id=media_id,
-                    t_width=file_info.thumbnail.width,
-                    t_height=file_info.thumbnail.height,
-                    t_method=file_info.thumbnail.method,
-                    t_type=file_info.thumbnail.type,
-                    url_cache=url_cache,
-                )
-
-            responder = await self.media_storage.fetch_media(file_info)
-            await respond_with_responder(
-                request,
-                responder,
-                file_info.thumbnail.type,
-                file_info.thumbnail.length,
-            )
-        else:
-            # This might be because:
-            # 1. We can't create thumbnails for the given media (corrupted or
-            #    unsupported file type), or
-            # 2. The thumbnailing process never ran or errored out initially
-            #    when the media was first uploaded (these bugs should be
-            #    reported and fixed).
-            # Note that we don't attempt to generate a thumbnail now because
-            # `dynamic_thumbnails` is disabled.
-            logger.info("Failed to find any generated thumbnails")
-
-            assert request.path is not None
-            respond_with_json(
-                request,
-                400,
-                cs_error(
-                    "Cannot find any thumbnails for the requested media ('%s'). This might mean the media is not a supported_media_format=(%s) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)"
-                    % (
-                        request.path.decode(),
-                        ", ".join(THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP.keys()),
-                    ),
-                    code=Codes.UNKNOWN,
-                ),
-                send_cors=True,
-            )
-
-    def _select_thumbnail(
-        self,
-        desired_width: int,
-        desired_height: int,
-        desired_method: str,
-        desired_type: str,
-        thumbnail_infos: List[ThumbnailInfo],
-        file_id: str,
-        url_cache: bool,
-        server_name: Optional[str],
-    ) -> Optional[FileInfo]:
-        """
-        Choose an appropriate thumbnail from the previously generated thumbnails.
-
-        Args:
-            desired_width: The desired width, the returned thumbnail may be larger than this.
-            desired_height: The desired height, the returned thumbnail may be larger than this.
-            desired_method: The desired method used to generate the thumbnail.
-            desired_type: The desired content-type of the thumbnail.
-            thumbnail_infos: A list of thumbnail infos of candidate thumbnails.
-            file_id: The ID of the media that a thumbnail is being requested for.
-            url_cache: True if this is from a URL cache.
-            server_name: The server name, if this is a remote thumbnail.
-
-        Returns:
-             The thumbnail which best matches the desired parameters.
-        """
-        desired_method = desired_method.lower()
-
-        # The chosen thumbnail.
-        thumbnail_info = None
-
-        d_w = desired_width
-        d_h = desired_height
-
-        if desired_method == "crop":
-            # Thumbnails that match equal or larger sizes of desired width/height.
-            crop_info_list: List[
-                Tuple[int, int, int, bool, Optional[int], ThumbnailInfo]
-            ] = []
-            # Other thumbnails.
-            crop_info_list2: List[
-                Tuple[int, int, int, bool, Optional[int], ThumbnailInfo]
-            ] = []
-            for info in thumbnail_infos:
-                # Skip thumbnails generated with different methods.
-                if info.method != "crop":
-                    continue
-
-                t_w = info.width
-                t_h = info.height
-                aspect_quality = abs(d_w * t_h - d_h * t_w)
-                min_quality = 0 if d_w <= t_w and d_h <= t_h else 1
-                size_quality = abs((d_w - t_w) * (d_h - t_h))
-                type_quality = desired_type != info.type
-                length_quality = info.length
-                if t_w >= d_w or t_h >= d_h:
-                    crop_info_list.append(
-                        (
-                            aspect_quality,
-                            min_quality,
-                            size_quality,
-                            type_quality,
-                            length_quality,
-                            info,
-                        )
-                    )
-                else:
-                    crop_info_list2.append(
-                        (
-                            aspect_quality,
-                            min_quality,
-                            size_quality,
-                            type_quality,
-                            length_quality,
-                            info,
-                        )
-                    )
-            # Pick the most appropriate thumbnail. Some values of `desired_width` and
-            # `desired_height` may result in a tie, in which case we avoid comparing on
-            # the thumbnail info and pick the thumbnail that appears earlier
-            # in the list of candidates.
-            if crop_info_list:
-                thumbnail_info = min(crop_info_list, key=lambda t: t[:-1])[-1]
-            elif crop_info_list2:
-                thumbnail_info = min(crop_info_list2, key=lambda t: t[:-1])[-1]
-        elif desired_method == "scale":
-            # Thumbnails that match equal or larger sizes of desired width/height.
-            info_list: List[Tuple[int, bool, int, ThumbnailInfo]] = []
-            # Other thumbnails.
-            info_list2: List[Tuple[int, bool, int, ThumbnailInfo]] = []
-
-            for info in thumbnail_infos:
-                # Skip thumbnails generated with different methods.
-                if info.method != "scale":
-                    continue
-
-                t_w = info.width
-                t_h = info.height
-                size_quality = abs((d_w - t_w) * (d_h - t_h))
-                type_quality = desired_type != info.type
-                length_quality = info.length
-                if t_w >= d_w or t_h >= d_h:
-                    info_list.append((size_quality, type_quality, length_quality, info))
-                else:
-                    info_list2.append(
-                        (size_quality, type_quality, length_quality, info)
-                    )
-            # Pick the most appropriate thumbnail. Some values of `desired_width` and
-            # `desired_height` may result in a tie, in which case we avoid comparing on
-            # the thumbnail info and pick the thumbnail that appears earlier
-            # in the list of candidates.
-            if info_list:
-                thumbnail_info = min(info_list, key=lambda t: t[:-1])[-1]
-            elif info_list2:
-                thumbnail_info = min(info_list2, key=lambda t: t[:-1])[-1]
-
-        if thumbnail_info:
-            return FileInfo(
-                file_id=file_id,
-                url_cache=url_cache,
-                server_name=server_name,
-                thumbnail=thumbnail_info,
-            )
-
-        # No matching thumbnail was found.
-        return None
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index d9c85e411e..569f618193 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -2461,7 +2461,11 @@ class DatabasePool:
 
 
 def make_in_list_sql_clause(
-    database_engine: BaseDatabaseEngine, column: str, iterable: Collection[Any]
+    database_engine: BaseDatabaseEngine,
+    column: str,
+    iterable: Collection[Any],
+    *,
+    negative: bool = False,
 ) -> Tuple[str, list]:
     """Returns an SQL clause that checks the given column is in the iterable.
 
@@ -2474,6 +2478,7 @@ def make_in_list_sql_clause(
         database_engine
         column: Name of the column
         iterable: The values to check the column against.
+        negative: Whether we should check for inequality, i.e. `NOT IN`
 
     Returns:
         A tuple of SQL query and the args
@@ -2482,9 +2487,19 @@ def make_in_list_sql_clause(
     if database_engine.supports_using_any_list:
         # This should hopefully be faster, but also makes postgres query
         # stats easier to understand.
-        return "%s = ANY(?)" % (column,), [list(iterable)]
+        if not negative:
+            clause = f"{column} = ANY(?)"
+        else:
+            clause = f"{column} != ALL(?)"
+
+        return clause, [list(iterable)]
     else:
-        return "%s IN (%s)" % (column, ",".join("?" for _ in iterable)), list(iterable)
+        params = ",".join("?" for _ in iterable)
+        if not negative:
+            clause = f"{column} IN ({params})"
+        else:
+            clause = f"{column} NOT IN ({params})"
+        return clause, list(iterable)
 
 
 # These overloads ensure that `columns` and `iterable` values have the same length.
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 563450a97e..9611a84932 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -43,11 +43,9 @@ from synapse.storage.database import (
 )
 from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
 from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
-from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import (
     AbstractStreamIdGenerator,
     MultiWriterIdGenerator,
-    StreamIdGenerator,
 )
 from synapse.types import JsonDict, JsonMapping
 from synapse.util import json_encoder
@@ -75,37 +73,20 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
 
         self._account_data_id_gen: AbstractStreamIdGenerator
 
-        if isinstance(database.engine, PostgresEngine):
-            self._account_data_id_gen = MultiWriterIdGenerator(
-                db_conn=db_conn,
-                db=database,
-                notifier=hs.get_replication_notifier(),
-                stream_name="account_data",
-                instance_name=self._instance_name,
-                tables=[
-                    ("room_account_data", "instance_name", "stream_id"),
-                    ("room_tags_revisions", "instance_name", "stream_id"),
-                    ("account_data", "instance_name", "stream_id"),
-                ],
-                sequence_name="account_data_sequence",
-                writers=hs.config.worker.writers.account_data,
-            )
-        else:
-            # Multiple writers are not supported for SQLite.
-            #
-            # We shouldn't be running in worker mode with SQLite, but its useful
-            # to support it for unit tests.
-            self._account_data_id_gen = StreamIdGenerator(
-                db_conn,
-                hs.get_replication_notifier(),
-                "room_account_data",
-                "stream_id",
-                extra_tables=[
-                    ("account_data", "stream_id"),
-                    ("room_tags_revisions", "stream_id"),
-                ],
-                is_writer=self._instance_name in hs.config.worker.writers.account_data,
-            )
+        self._account_data_id_gen = MultiWriterIdGenerator(
+            db_conn=db_conn,
+            db=database,
+            notifier=hs.get_replication_notifier(),
+            stream_name="account_data",
+            instance_name=self._instance_name,
+            tables=[
+                ("room_account_data", "instance_name", "stream_id"),
+                ("room_tags_revisions", "instance_name", "stream_id"),
+                ("account_data", "instance_name", "stream_id"),
+            ],
+            sequence_name="account_data_sequence",
+            writers=hs.config.worker.writers.account_data,
+        )
 
         account_max = self.get_max_account_data_stream_id()
         self._account_data_stream_cache = StreamChangeCache(
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index bfd492d95d..c6787faea0 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -318,7 +318,13 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
             self._invalidate_local_get_event_cache(redacts)  # type: ignore[attr-defined]
             # Caches which might leak edits must be invalidated for the event being
             # redacted.
-            self._attempt_to_invalidate_cache("get_relations_for_event", (redacts,))
+            self._attempt_to_invalidate_cache(
+                "get_relations_for_event",
+                (
+                    room_id,
+                    redacts,
+                ),
+            )
             self._attempt_to_invalidate_cache("get_applicable_edit", (redacts,))
             self._attempt_to_invalidate_cache("get_thread_id", (redacts,))
             self._attempt_to_invalidate_cache("get_thread_id_for_receipts", (redacts,))
@@ -345,7 +351,13 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
             )
 
         if relates_to:
-            self._attempt_to_invalidate_cache("get_relations_for_event", (relates_to,))
+            self._attempt_to_invalidate_cache(
+                "get_relations_for_event",
+                (
+                    room_id,
+                    relates_to,
+                ),
+            )
             self._attempt_to_invalidate_cache("get_references_for_event", (relates_to,))
             self._attempt_to_invalidate_cache("get_applicable_edit", (relates_to,))
             self._attempt_to_invalidate_cache("get_thread_summary", (relates_to,))
@@ -380,9 +392,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
         self._attempt_to_invalidate_cache(
             "get_unread_event_push_actions_by_room_for_user", (room_id,)
         )
+        self._attempt_to_invalidate_cache("get_relations_for_event", (room_id,))
 
         self._attempt_to_invalidate_cache("_get_membership_from_event_id", None)
-        self._attempt_to_invalidate_cache("get_relations_for_event", None)
         self._attempt_to_invalidate_cache("get_applicable_edit", None)
         self._attempt_to_invalidate_cache("get_thread_id", None)
         self._attempt_to_invalidate_cache("get_thread_id_for_receipts", None)
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 008b925fb2..84c4a5d324 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -50,11 +50,9 @@ from synapse.storage.database import (
     LoggingTransaction,
     make_in_list_sql_clause,
 )
-from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import (
     AbstractStreamIdGenerator,
     MultiWriterIdGenerator,
-    StreamIdGenerator,
 )
 from synapse.types import JsonDict
 from synapse.util import json_encoder
@@ -89,35 +87,23 @@ class DeviceInboxWorkerStore(SQLBaseStore):
             expiry_ms=30 * 60 * 1000,
         )
 
-        if isinstance(database.engine, PostgresEngine):
-            self._can_write_to_device = (
-                self._instance_name in hs.config.worker.writers.to_device
-            )
+        self._can_write_to_device = (
+            self._instance_name in hs.config.worker.writers.to_device
+        )
 
-            self._to_device_msg_id_gen: AbstractStreamIdGenerator = (
-                MultiWriterIdGenerator(
-                    db_conn=db_conn,
-                    db=database,
-                    notifier=hs.get_replication_notifier(),
-                    stream_name="to_device",
-                    instance_name=self._instance_name,
-                    tables=[
-                        ("device_inbox", "instance_name", "stream_id"),
-                        ("device_federation_outbox", "instance_name", "stream_id"),
-                    ],
-                    sequence_name="device_inbox_sequence",
-                    writers=hs.config.worker.writers.to_device,
-                )
-            )
-        else:
-            self._can_write_to_device = True
-            self._to_device_msg_id_gen = StreamIdGenerator(
-                db_conn,
-                hs.get_replication_notifier(),
-                "device_inbox",
-                "stream_id",
-                extra_tables=[("device_federation_outbox", "stream_id")],
-            )
+        self._to_device_msg_id_gen: AbstractStreamIdGenerator = MultiWriterIdGenerator(
+            db_conn=db_conn,
+            db=database,
+            notifier=hs.get_replication_notifier(),
+            stream_name="to_device",
+            instance_name=self._instance_name,
+            tables=[
+                ("device_inbox", "instance_name", "stream_id"),
+                ("device_federation_outbox", "instance_name", "stream_id"),
+            ],
+            sequence_name="device_inbox_sequence",
+            writers=hs.config.worker.writers.to_device,
+        )
 
         max_device_inbox_id = self._to_device_msg_id_gen.get_current_token()
         device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict(
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 4f723d8da1..7f6d1ab45c 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -2176,7 +2176,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         sql = """
             UPDATE device_lists_changes_in_room
             SET converted_to_destinations = true
-            WHERE stream_id > ? AND user_id = ? and device_id = ? AND room_id = ?;
+            WHERE stream_id > ? AND user_id = ? AND device_id = ?
+                AND room_id = ? AND NOT converted_to_destinations
         """
 
         def mark_redundant_device_lists_pokes_txn(txn: LoggingTransaction) -> None:
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 990698aa5c..fd7167904d 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1923,7 +1923,12 @@ class PersistEventsStore:
 
         # Any relation information for the related event must be cleared.
         self.store._invalidate_cache_and_stream(
-            txn, self.store.get_relations_for_event, (redacted_relates_to,)
+            txn,
+            self.store.get_relations_for_event,
+            (
+                room_id,
+                redacted_relates_to,
+            ),
         )
         if rel_type == RelationTypes.REFERENCE:
             self.store._invalidate_cache_and_stream(
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 6c979f9f2c..64d303e330 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -1181,7 +1181,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
             results = list(txn)
             # (event_id, parent_id, rel_type) for each relation
-            relations_to_insert: List[Tuple[str, str, str]] = []
+            relations_to_insert: List[Tuple[str, str, str, str]] = []
             for event_id, event_json_raw in results:
                 try:
                     event_json = db_to_json(event_json_raw)
@@ -1214,7 +1214,8 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
                 if not isinstance(parent_id, str):
                     continue
 
-                relations_to_insert.append((event_id, parent_id, rel_type))
+                room_id = event_json["room_id"]
+                relations_to_insert.append((room_id, event_id, parent_id, rel_type))
 
             # Insert the missing data, note that we upsert here in case the event
             # has already been processed.
@@ -1223,18 +1224,27 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
                     txn=txn,
                     table="event_relations",
                     key_names=("event_id",),
-                    key_values=[(r[0],) for r in relations_to_insert],
+                    key_values=[(r[1],) for r in relations_to_insert],
                     value_names=("relates_to_id", "relation_type"),
-                    value_values=[r[1:] for r in relations_to_insert],
+                    value_values=[r[2:] for r in relations_to_insert],
                 )
 
                 # Iterate the parent IDs and invalidate caches.
-                cache_tuples = {(r[1],) for r in relations_to_insert}
                 self._invalidate_cache_and_stream_bulk(  # type: ignore[attr-defined]
-                    txn, self.get_relations_for_event, cache_tuples  # type: ignore[attr-defined]
+                    txn,
+                    self.get_relations_for_event,  # type: ignore[attr-defined]
+                    {
+                        (
+                            r[0],  # room_id
+                            r[2],  # parent_id
+                        )
+                        for r in relations_to_insert
+                    },
                 )
                 self._invalidate_cache_and_stream_bulk(  # type: ignore[attr-defined]
-                    txn, self.get_thread_summary, cache_tuples  # type: ignore[attr-defined]
+                    txn,
+                    self.get_thread_summary,  # type: ignore[attr-defined]
+                    {(r[1],) for r in relations_to_insert},
                 )
 
             if results:
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index d93c26336b..d074306bf0 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -75,12 +75,10 @@ from synapse.storage.database import (
     LoggingDatabaseConnection,
     LoggingTransaction,
 )
-from synapse.storage.engines import PostgresEngine
 from synapse.storage.types import Cursor
 from synapse.storage.util.id_generators import (
     AbstractStreamIdGenerator,
     MultiWriterIdGenerator,
-    StreamIdGenerator,
 )
 from synapse.storage.util.sequence import build_sequence_generator
 from synapse.types import JsonDict, get_domain_from_id
@@ -195,51 +193,28 @@ class EventsWorkerStore(SQLBaseStore):
 
         self._stream_id_gen: AbstractStreamIdGenerator
         self._backfill_id_gen: AbstractStreamIdGenerator
-        if isinstance(database.engine, PostgresEngine):
-            # If we're using Postgres than we can use `MultiWriterIdGenerator`
-            # regardless of whether this process writes to the streams or not.
-            self._stream_id_gen = MultiWriterIdGenerator(
-                db_conn=db_conn,
-                db=database,
-                notifier=hs.get_replication_notifier(),
-                stream_name="events",
-                instance_name=hs.get_instance_name(),
-                tables=[("events", "instance_name", "stream_ordering")],
-                sequence_name="events_stream_seq",
-                writers=hs.config.worker.writers.events,
-            )
-            self._backfill_id_gen = MultiWriterIdGenerator(
-                db_conn=db_conn,
-                db=database,
-                notifier=hs.get_replication_notifier(),
-                stream_name="backfill",
-                instance_name=hs.get_instance_name(),
-                tables=[("events", "instance_name", "stream_ordering")],
-                sequence_name="events_backfill_stream_seq",
-                positive=False,
-                writers=hs.config.worker.writers.events,
-            )
-        else:
-            # Multiple writers are not supported for SQLite.
-            #
-            # We shouldn't be running in worker mode with SQLite, but its useful
-            # to support it for unit tests.
-            self._stream_id_gen = StreamIdGenerator(
-                db_conn,
-                hs.get_replication_notifier(),
-                "events",
-                "stream_ordering",
-                is_writer=hs.get_instance_name() in hs.config.worker.writers.events,
-            )
-            self._backfill_id_gen = StreamIdGenerator(
-                db_conn,
-                hs.get_replication_notifier(),
-                "events",
-                "stream_ordering",
-                step=-1,
-                extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
-                is_writer=hs.get_instance_name() in hs.config.worker.writers.events,
-            )
+
+        self._stream_id_gen = MultiWriterIdGenerator(
+            db_conn=db_conn,
+            db=database,
+            notifier=hs.get_replication_notifier(),
+            stream_name="events",
+            instance_name=hs.get_instance_name(),
+            tables=[("events", "instance_name", "stream_ordering")],
+            sequence_name="events_stream_seq",
+            writers=hs.config.worker.writers.events,
+        )
+        self._backfill_id_gen = MultiWriterIdGenerator(
+            db_conn=db_conn,
+            db=database,
+            notifier=hs.get_replication_notifier(),
+            stream_name="backfill",
+            instance_name=hs.get_instance_name(),
+            tables=[("events", "instance_name", "stream_ordering")],
+            sequence_name="events_backfill_stream_seq",
+            positive=False,
+            writers=hs.config.worker.writers.events,
+        )
 
         events_max = self._stream_id_gen.get_current_token()
         curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict(
@@ -309,27 +284,17 @@ class EventsWorkerStore(SQLBaseStore):
 
         self._un_partial_stated_events_stream_id_gen: AbstractStreamIdGenerator
 
-        if isinstance(database.engine, PostgresEngine):
-            self._un_partial_stated_events_stream_id_gen = MultiWriterIdGenerator(
-                db_conn=db_conn,
-                db=database,
-                notifier=hs.get_replication_notifier(),
-                stream_name="un_partial_stated_event_stream",
-                instance_name=hs.get_instance_name(),
-                tables=[
-                    ("un_partial_stated_event_stream", "instance_name", "stream_id")
-                ],
-                sequence_name="un_partial_stated_event_stream_sequence",
-                # TODO(faster_joins, multiple writers) Support multiple writers.
-                writers=["master"],
-            )
-        else:
-            self._un_partial_stated_events_stream_id_gen = StreamIdGenerator(
-                db_conn,
-                hs.get_replication_notifier(),
-                "un_partial_stated_event_stream",
-                "stream_id",
-            )
+        self._un_partial_stated_events_stream_id_gen = MultiWriterIdGenerator(
+            db_conn=db_conn,
+            db=database,
+            notifier=hs.get_replication_notifier(),
+            stream_name="un_partial_stated_event_stream",
+            instance_name=hs.get_instance_name(),
+            tables=[("un_partial_stated_event_stream", "instance_name", "stream_id")],
+            sequence_name="un_partial_stated_event_stream_sequence",
+            # TODO(faster_joins, multiple writers) Support multiple writers.
+            writers=["master"],
+        )
 
     def get_un_partial_stated_events_token(self, instance_name: str) -> int:
         return (
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 567c2d30bd..923e764491 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -40,13 +40,11 @@ from synapse.storage.database import (
     LoggingTransaction,
 )
 from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
-from synapse.storage.engines import PostgresEngine
 from synapse.storage.engines._base import IsolationLevel
 from synapse.storage.types import Connection
 from synapse.storage.util.id_generators import (
     AbstractStreamIdGenerator,
     MultiWriterIdGenerator,
-    StreamIdGenerator,
 )
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -91,21 +89,16 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
             self._instance_name in hs.config.worker.writers.presence
         )
 
-        if isinstance(database.engine, PostgresEngine):
-            self._presence_id_gen = MultiWriterIdGenerator(
-                db_conn=db_conn,
-                db=database,
-                notifier=hs.get_replication_notifier(),
-                stream_name="presence_stream",
-                instance_name=self._instance_name,
-                tables=[("presence_stream", "instance_name", "stream_id")],
-                sequence_name="presence_stream_sequence",
-                writers=hs.config.worker.writers.presence,
-            )
-        else:
-            self._presence_id_gen = StreamIdGenerator(
-                db_conn, hs.get_replication_notifier(), "presence_stream", "stream_id"
-            )
+        self._presence_id_gen = MultiWriterIdGenerator(
+            db_conn=db_conn,
+            db=database,
+            notifier=hs.get_replication_notifier(),
+            stream_name="presence_stream",
+            instance_name=self._instance_name,
+            tables=[("presence_stream", "instance_name", "stream_id")],
+            sequence_name="presence_stream_sequence",
+            writers=hs.config.worker.writers.presence,
+        )
 
         self.hs = hs
         self._presence_on_startup = self._get_active_presence(db_conn)
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 13387a3839..8432560a89 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -44,12 +44,10 @@ from synapse.storage.database import (
     LoggingDatabaseConnection,
     LoggingTransaction,
 )
-from synapse.storage.engines import PostgresEngine
 from synapse.storage.engines._base import IsolationLevel
 from synapse.storage.util.id_generators import (
     AbstractStreamIdGenerator,
     MultiWriterIdGenerator,
-    StreamIdGenerator,
 )
 from synapse.types import (
     JsonDict,
@@ -80,35 +78,20 @@ class ReceiptsWorkerStore(SQLBaseStore):
         # class below that is used on the main process.
         self._receipts_id_gen: AbstractStreamIdGenerator
 
-        if isinstance(database.engine, PostgresEngine):
-            self._can_write_to_receipts = (
-                self._instance_name in hs.config.worker.writers.receipts
-            )
+        self._can_write_to_receipts = (
+            self._instance_name in hs.config.worker.writers.receipts
+        )
 
-            self._receipts_id_gen = MultiWriterIdGenerator(
-                db_conn=db_conn,
-                db=database,
-                notifier=hs.get_replication_notifier(),
-                stream_name="receipts",
-                instance_name=self._instance_name,
-                tables=[("receipts_linearized", "instance_name", "stream_id")],
-                sequence_name="receipts_sequence",
-                writers=hs.config.worker.writers.receipts,
-            )
-        else:
-            self._can_write_to_receipts = True
-
-            # Multiple writers are not supported for SQLite.
-            #
-            # We shouldn't be running in worker mode with SQLite, but its useful
-            # to support it for unit tests.
-            self._receipts_id_gen = StreamIdGenerator(
-                db_conn,
-                hs.get_replication_notifier(),
-                "receipts_linearized",
-                "stream_id",
-                is_writer=hs.get_instance_name() in hs.config.worker.writers.receipts,
-            )
+        self._receipts_id_gen = MultiWriterIdGenerator(
+            db_conn=db_conn,
+            db=database,
+            notifier=hs.get_replication_notifier(),
+            stream_name="receipts",
+            instance_name=self._instance_name,
+            tables=[("receipts_linearized", "instance_name", "stream_id")],
+            sequence_name="receipts_sequence",
+            writers=hs.config.worker.writers.receipts,
+        )
 
         super().__init__(database, db_conn, hs)
 
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 77f3641525..29a001ff92 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -169,9 +169,9 @@ class RelationsWorkerStore(SQLBaseStore):
     @cached(uncached_args=("event",), tree=True)
     async def get_relations_for_event(
         self,
+        room_id: str,
         event_id: str,
         event: EventBase,
-        room_id: str,
         relation_type: Optional[str] = None,
         event_type: Optional[str] = None,
         limit: int = 5,
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 8205109548..616c941687 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -58,13 +58,11 @@ from synapse.storage.database import (
     LoggingTransaction,
 )
 from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
-from synapse.storage.engines import PostgresEngine
 from synapse.storage.types import Cursor
 from synapse.storage.util.id_generators import (
     AbstractStreamIdGenerator,
     IdGenerator,
     MultiWriterIdGenerator,
-    StreamIdGenerator,
 )
 from synapse.types import JsonDict, RetentionPolicy, StrCollection, ThirdPartyInstanceID
 from synapse.util import json_encoder
@@ -155,27 +153,17 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
 
         self._un_partial_stated_rooms_stream_id_gen: AbstractStreamIdGenerator
 
-        if isinstance(database.engine, PostgresEngine):
-            self._un_partial_stated_rooms_stream_id_gen = MultiWriterIdGenerator(
-                db_conn=db_conn,
-                db=database,
-                notifier=hs.get_replication_notifier(),
-                stream_name="un_partial_stated_room_stream",
-                instance_name=self._instance_name,
-                tables=[
-                    ("un_partial_stated_room_stream", "instance_name", "stream_id")
-                ],
-                sequence_name="un_partial_stated_room_stream_sequence",
-                # TODO(faster_joins, multiple writers) Support multiple writers.
-                writers=["master"],
-            )
-        else:
-            self._un_partial_stated_rooms_stream_id_gen = StreamIdGenerator(
-                db_conn,
-                hs.get_replication_notifier(),
-                "un_partial_stated_room_stream",
-                "stream_id",
-            )
+        self._un_partial_stated_rooms_stream_id_gen = MultiWriterIdGenerator(
+            db_conn=db_conn,
+            db=database,
+            notifier=hs.get_replication_notifier(),
+            stream_name="un_partial_stated_room_stream",
+            instance_name=self._instance_name,
+            tables=[("un_partial_stated_room_stream", "instance_name", "stream_id")],
+            sequence_name="un_partial_stated_room_stream_sequence",
+            # TODO(faster_joins, multiple writers) Support multiple writers.
+            writers=["master"],
+        )
 
     def process_replication_position(
         self, stream_name: str, instance_name: str, token: int
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index b9168ee074..90641d5a18 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -142,6 +142,10 @@ class PostgresEngine(
         apply stricter checks on new databases versus existing database.
         """
 
+        allow_unsafe_locale = self.config.get("allow_unsafe_locale", False)
+        if allow_unsafe_locale:
+            return
+
         collation, ctype = self.get_db_locale(txn)
 
         errors = []
@@ -155,7 +159,9 @@ class PostgresEngine(
         if errors:
             raise IncorrectDatabaseSetup(
                 "Database is incorrectly configured:\n\n%s\n\n"
-                "See docs/postgres.md for more information." % ("\n".join(errors))
+                "See docs/postgres.md for more information. You can override this check by"
+                "setting 'allow_unsafe_locale' to true in the database config.",
+                "\n".join(errors),
             )
 
     def convert_param_style(self, sql: str) -> str:
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index fadc75cc80..0cf5851ad7 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -53,9 +53,11 @@ from synapse.storage.database import (
     DatabasePool,
     LoggingDatabaseConnection,
     LoggingTransaction,
+    make_in_list_sql_clause,
 )
+from synapse.storage.engines import PostgresEngine
 from synapse.storage.types import Cursor
-from synapse.storage.util.sequence import PostgresSequenceGenerator
+from synapse.storage.util.sequence import build_sequence_generator
 
 if TYPE_CHECKING:
     from synapse.notifier import ReplicationNotifier
@@ -432,7 +434,22 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
         # no active writes in progress.
         self._max_position_of_local_instance = self._max_seen_allocated_stream_id
 
-        self._sequence_gen = PostgresSequenceGenerator(sequence_name)
+        # This goes and fills out the above state from the database.
+        self._load_current_ids(db_conn, tables)
+
+        self._sequence_gen = build_sequence_generator(
+            db_conn=db_conn,
+            database_engine=db.engine,
+            get_first_callback=lambda _: self._persisted_upto_position,
+            sequence_name=sequence_name,
+            # We only need to set the below if we want it to call
+            # `check_consistency`, but we do that ourselves below so we can
+            # leave them blank.
+            table=None,
+            id_column=None,
+            stream_name=None,
+            positive=positive,
+        )
 
         # We check that the table and sequence haven't diverged.
         for table, _, id_column in tables:
@@ -444,9 +461,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
                 positive=positive,
             )
 
-        # This goes and fills out the above state from the database.
-        self._load_current_ids(db_conn, tables)
-
         self._max_seen_allocated_stream_id = max(
             self._current_positions.values(), default=1
         )
@@ -480,13 +494,17 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
             # important if we add back a writer after a long time; we want to
             # consider that a "new" writer, rather than using the old stale
             # entry here.
-            sql = """
+            clause, args = make_in_list_sql_clause(
+                self._db.engine, "instance_name", self._writers, negative=True
+            )
+
+            sql = f"""
                 DELETE FROM stream_positions
                 WHERE
                     stream_name = ?
-                    AND instance_name != ALL(?)
+                    AND {clause}
             """
-            cur.execute(sql, (self._stream_name, self._writers))
+            cur.execute(sql, [self._stream_name] + args)
 
             sql = """
                 SELECT instance_name, stream_id FROM stream_positions
@@ -508,12 +526,16 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
             # We add a GREATEST here to ensure that the result is always
             # positive. (This can be a problem for e.g. backfill streams where
             # the server has never backfilled).
+            greatest_func = (
+                "GREATEST" if isinstance(self._db.engine, PostgresEngine) else "MAX"
+            )
             max_stream_id = 1
             for table, _, id_column in tables:
                 sql = """
-                    SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1)
+                    SELECT %(greatest_func)s(COALESCE(%(agg)s(%(id)s), 1), 1)
                     FROM %(table)s
                 """ % {
+                    "greatest_func": greatest_func,
                     "id": id_column,
                     "table": table,
                     "agg": "MAX" if self._positive else "-MIN",
@@ -913,6 +935,11 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
 
         # We upsert the value, ensuring on conflict that we always increase the
         # value (or decrease if stream goes backwards).
+        if isinstance(self._db.engine, PostgresEngine):
+            agg = "GREATEST" if self._positive else "LEAST"
+        else:
+            agg = "MAX" if self._positive else "MIN"
+
         sql = """
             INSERT INTO stream_positions (stream_name, instance_name, stream_id)
             VALUES (?, ?, ?)
@@ -920,10 +947,10 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
             DO UPDATE SET
                 stream_id = %(agg)s(stream_positions.stream_id, EXCLUDED.stream_id)
         """ % {
-            "agg": "GREATEST" if self._positive else "LEAST",
+            "agg": agg,
         }
 
-        pos = (self.get_current_token_for_writer(self._instance_name),)
+        pos = self.get_current_token_for_writer(self._instance_name)
         txn.execute(sql, (self._stream_name, self._instance_name, pos))