summary refs log tree commit diff
path: root/synapse/media
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/media')
-rw-r--r--synapse/media/_base.py479
-rw-r--r--synapse/media/filepath.py410
-rw-r--r--synapse/media/media_repository.py1038
-rw-r--r--synapse/media/media_storage.py374
-rw-r--r--synapse/media/oembed.py265
-rw-r--r--synapse/media/preview_html.py501
-rw-r--r--synapse/media/storage_provider.py181
-rw-r--r--synapse/media/thumbnailer.py221
8 files changed, 3469 insertions, 0 deletions
diff --git a/synapse/media/_base.py b/synapse/media/_base.py
new file mode 100644
index 0000000000..ef8334ae25
--- /dev/null
+++ b/synapse/media/_base.py
@@ -0,0 +1,479 @@
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import urllib
+from abc import ABC, abstractmethod
+from types import TracebackType
+from typing import Awaitable, Dict, Generator, List, Optional, Tuple, Type
+
+import attr
+
+from twisted.internet.interfaces import IConsumer
+from twisted.protocols.basic import FileSender
+from twisted.web.server import Request
+
+from synapse.api.errors import Codes, SynapseError, cs_error
+from synapse.http.server import finish_request, respond_with_json
+from synapse.http.site import SynapseRequest
+from synapse.logging.context import make_deferred_yieldable
+from synapse.util.stringutils import is_ascii, parse_and_validate_server_name
+
+logger = logging.getLogger(__name__)
+
+# list all text content types that will have the charset default to UTF-8 when
+# none is given
+TEXT_CONTENT_TYPES = [
+    "text/css",
+    "text/csv",
+    "text/html",
+    "text/calendar",
+    "text/plain",
+    "text/javascript",
+    "application/json",
+    "application/ld+json",
+    "application/rtf",
+    "image/svg+xml",
+    "text/xml",
+]
+
+
+def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]:
+    """Parses the server name, media ID and optional file name from the request URI
+
+    Also performs some rough validation on the server name.
+
+    Args:
+        request: The `Request`.
+
+    Returns:
+        A tuple containing the parsed server name, media ID and optional file name.
+
+    Raises:
+        SynapseError(404): if parsing or validation fail for any reason
+    """
+    try:
+        # The type on postpath seems incorrect in Twisted 21.2.0.
+        postpath: List[bytes] = request.postpath  # type: ignore
+        assert postpath
+
+        # This allows users to append e.g. /test.png to the URL. Useful for
+        # clients that parse the URL to see content type.
+        server_name_bytes, media_id_bytes = postpath[:2]
+        server_name = server_name_bytes.decode("utf-8")
+        media_id = media_id_bytes.decode("utf8")
+
+        # Validate the server name, raising if invalid
+        parse_and_validate_server_name(server_name)
+
+        file_name = None
+        if len(postpath) > 2:
+            try:
+                file_name = urllib.parse.unquote(postpath[-1].decode("utf-8"))
+            except UnicodeDecodeError:
+                pass
+        return server_name, media_id, file_name
+    except Exception:
+        raise SynapseError(
+            404, "Invalid media id token %r" % (request.postpath,), Codes.UNKNOWN
+        )
+
+
+def respond_404(request: SynapseRequest) -> None:
+    respond_with_json(
+        request,
+        404,
+        cs_error("Not found %r" % (request.postpath,), code=Codes.NOT_FOUND),
+        send_cors=True,
+    )
+
+
+async def respond_with_file(
+    request: SynapseRequest,
+    media_type: str,
+    file_path: str,
+    file_size: Optional[int] = None,
+    upload_name: Optional[str] = None,
+) -> None:
+    logger.debug("Responding with %r", file_path)
+
+    if os.path.isfile(file_path):
+        if file_size is None:
+            stat = os.stat(file_path)
+            file_size = stat.st_size
+
+        add_file_headers(request, media_type, file_size, upload_name)
+
+        with open(file_path, "rb") as f:
+            await make_deferred_yieldable(FileSender().beginFileTransfer(f, request))
+
+        finish_request(request)
+    else:
+        respond_404(request)
+
+
+def add_file_headers(
+    request: Request,
+    media_type: str,
+    file_size: Optional[int],
+    upload_name: Optional[str],
+) -> None:
+    """Adds the correct response headers in preparation for responding with the
+    media.
+
+    Args:
+        request
+        media_type: The media/content type.
+        file_size: Size in bytes of the media, if known.
+        upload_name: The name of the requested file, if any.
+    """
+
+    def _quote(x: str) -> str:
+        return urllib.parse.quote(x.encode("utf-8"))
+
+    # Default to a UTF-8 charset for text content types.
+    # ex, uses UTF-8 for 'text/css' but not 'text/css; charset=UTF-16'
+    if media_type.lower() in TEXT_CONTENT_TYPES:
+        content_type = media_type + "; charset=UTF-8"
+    else:
+        content_type = media_type
+
+    request.setHeader(b"Content-Type", content_type.encode("UTF-8"))
+    if upload_name:
+        # RFC6266 section 4.1 [1] defines both `filename` and `filename*`.
+        #
+        # `filename` is defined to be a `value`, which is defined by RFC2616
+        # section 3.6 [2] to be a `token` or a `quoted-string`, where a `token`
+        # is (essentially) a single US-ASCII word, and a `quoted-string` is a
+        # US-ASCII string surrounded by double-quotes, using backslash as an
+        # escape character. Note that %-encoding is *not* permitted.
+        #
+        # `filename*` is defined to be an `ext-value`, which is defined in
+        # RFC5987 section 3.2.1 [3] to be `charset "'" [ language ] "'" value-chars`,
+        # where `value-chars` is essentially a %-encoded string in the given charset.
+        #
+        # [1]: https://tools.ietf.org/html/rfc6266#section-4.1
+        # [2]: https://tools.ietf.org/html/rfc2616#section-3.6
+        # [3]: https://tools.ietf.org/html/rfc5987#section-3.2.1
+
+        # We avoid the quoted-string version of `filename`, because (a) synapse didn't
+        # correctly interpret those as of 0.99.2 and (b) they are a bit of a pain and we
+        # may as well just do the filename* version.
+        if _can_encode_filename_as_token(upload_name):
+            disposition = "inline; filename=%s" % (upload_name,)
+        else:
+            disposition = "inline; filename*=utf-8''%s" % (_quote(upload_name),)
+
+        request.setHeader(b"Content-Disposition", disposition.encode("ascii"))
+
+    # cache for at least a day.
+    # XXX: we might want to turn this off for data we don't want to
+    # recommend caching as it's sensitive or private - or at least
+    # select private. don't bother setting Expires as all our
+    # clients are smart enough to be happy with Cache-Control
+    request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400")
+    if file_size is not None:
+        request.setHeader(b"Content-Length", b"%d" % (file_size,))
+
+    # Tell web crawlers to not index, archive, or follow links in media. This
+    # should help to prevent things in the media repo from showing up in web
+    # search results.
+    request.setHeader(b"X-Robots-Tag", "noindex, nofollow, noarchive, noimageindex")
+
+
+# separators as defined in RFC2616. SP and HT are handled separately.
+# see _can_encode_filename_as_token.
+_FILENAME_SEPARATOR_CHARS = {
+    "(",
+    ")",
+    "<",
+    ">",
+    "@",
+    ",",
+    ";",
+    ":",
+    "\\",
+    '"',
+    "/",
+    "[",
+    "]",
+    "?",
+    "=",
+    "{",
+    "}",
+}
+
+
+def _can_encode_filename_as_token(x: str) -> bool:
+    for c in x:
+        # from RFC2616:
+        #
+        #        token          = 1*<any CHAR except CTLs or separators>
+        #
+        #        separators     = "(" | ")" | "<" | ">" | "@"
+        #                       | "," | ";" | ":" | "\" | <">
+        #                       | "/" | "[" | "]" | "?" | "="
+        #                       | "{" | "}" | SP | HT
+        #
+        #        CHAR           = <any US-ASCII character (octets 0 - 127)>
+        #
+        #        CTL            = <any US-ASCII control character
+        #                         (octets 0 - 31) and DEL (127)>
+        #
+        if ord(c) >= 127 or ord(c) <= 32 or c in _FILENAME_SEPARATOR_CHARS:
+            return False
+    return True
+
+
+async def respond_with_responder(
+    request: SynapseRequest,
+    responder: "Optional[Responder]",
+    media_type: str,
+    file_size: Optional[int],
+    upload_name: Optional[str] = None,
+) -> None:
+    """Responds to the request with given responder. If responder is None then
+    returns 404.
+
+    Args:
+        request
+        responder
+        media_type: The media/content type.
+        file_size: Size in bytes of the media. If not known it should be None
+        upload_name: The name of the requested file, if any.
+    """
+    if not responder:
+        respond_404(request)
+        return
+
+    # If we have a responder we *must* use it as a context manager.
+    with responder:
+        if request._disconnected:
+            logger.warning(
+                "Not sending response to request %s, already disconnected.", request
+            )
+            return
+
+        logger.debug("Responding to media request with responder %s", responder)
+        add_file_headers(request, media_type, file_size, upload_name)
+        try:
+            await responder.write_to_consumer(request)
+        except Exception as e:
+            # The majority of the time this will be due to the client having gone
+            # away. Unfortunately, Twisted simply throws a generic exception at us
+            # in that case.
+            logger.warning("Failed to write to consumer: %s %s", type(e), e)
+
+            # Unregister the producer, if it has one, so Twisted doesn't complain
+            if request.producer:
+                request.unregisterProducer()
+
+    finish_request(request)
+
+
+class Responder(ABC):
+    """Represents a response that can be streamed to the requester.
+
+    Responder is a context manager which *must* be used, so that any resources
+    held can be cleaned up.
+    """
+
+    @abstractmethod
+    def write_to_consumer(self, consumer: IConsumer) -> Awaitable:
+        """Stream response into consumer
+
+        Args:
+            consumer: The consumer to stream into.
+
+        Returns:
+            Resolves once the response has finished being written
+        """
+        raise NotImplementedError()
+
+    def __enter__(self) -> None:  # noqa: B027
+        pass
+
+    def __exit__(  # noqa: B027
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc_val: Optional[BaseException],
+        exc_tb: Optional[TracebackType],
+    ) -> None:
+        pass
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class ThumbnailInfo:
+    """Details about a generated thumbnail."""
+
+    width: int
+    height: int
+    method: str
+    # Content type of thumbnail, e.g. image/png
+    type: str
+    # The size of the media file, in bytes.
+    length: Optional[int] = None
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class FileInfo:
+    """Details about a requested/uploaded file."""
+
+    # The server name where the media originated from, or None if local.
+    server_name: Optional[str]
+    # The local ID of the file. For local files this is the same as the media_id
+    file_id: str
+    # If the file is for the url preview cache
+    url_cache: bool = False
+    # Whether the file is a thumbnail or not.
+    thumbnail: Optional[ThumbnailInfo] = None
+
+    # The below properties exist to maintain compatibility with third-party modules.
+    @property
+    def thumbnail_width(self) -> Optional[int]:
+        if not self.thumbnail:
+            return None
+        return self.thumbnail.width
+
+    @property
+    def thumbnail_height(self) -> Optional[int]:
+        if not self.thumbnail:
+            return None
+        return self.thumbnail.height
+
+    @property
+    def thumbnail_method(self) -> Optional[str]:
+        if not self.thumbnail:
+            return None
+        return self.thumbnail.method
+
+    @property
+    def thumbnail_type(self) -> Optional[str]:
+        if not self.thumbnail:
+            return None
+        return self.thumbnail.type
+
+    @property
+    def thumbnail_length(self) -> Optional[int]:
+        if not self.thumbnail:
+            return None
+        return self.thumbnail.length
+
+
+def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]:
+    """
+    Get the filename of the downloaded file by inspecting the
+    Content-Disposition HTTP header.
+
+    Args:
+        headers: The HTTP request headers.
+
+    Returns:
+        The filename, or None.
+    """
+    content_disposition = headers.get(b"Content-Disposition", [b""])
+
+    # No header, bail out.
+    if not content_disposition[0]:
+        return None
+
+    _, params = _parse_header(content_disposition[0])
+
+    upload_name = None
+
+    # First check if there is a valid UTF-8 filename
+    upload_name_utf8 = params.get(b"filename*", None)
+    if upload_name_utf8:
+        if upload_name_utf8.lower().startswith(b"utf-8''"):
+            upload_name_utf8 = upload_name_utf8[7:]
+            # We have a filename*= section. This MUST be ASCII, and any UTF-8
+            # bytes are %-quoted.
+            try:
+                # Once it is decoded, we can then unquote the %-encoded
+                # parts strictly into a unicode string.
+                upload_name = urllib.parse.unquote(
+                    upload_name_utf8.decode("ascii"), errors="strict"
+                )
+            except UnicodeDecodeError:
+                # Incorrect UTF-8.
+                pass
+
+    # If there isn't check for an ascii name.
+    if not upload_name:
+        upload_name_ascii = params.get(b"filename", None)
+        if upload_name_ascii and is_ascii(upload_name_ascii):
+            upload_name = upload_name_ascii.decode("ascii")
+
+    # This may be None here, indicating we did not find a matching name.
+    return upload_name
+
+
+def _parse_header(line: bytes) -> Tuple[bytes, Dict[bytes, bytes]]:
+    """Parse a Content-type like header.
+
+    Cargo-culted from `cgi`, but works on bytes rather than strings.
+
+    Args:
+        line: header to be parsed
+
+    Returns:
+        The main content-type, followed by the parameter dictionary
+    """
+    parts = _parseparam(b";" + line)
+    key = next(parts)
+    pdict = {}
+    for p in parts:
+        i = p.find(b"=")
+        if i >= 0:
+            name = p[:i].strip().lower()
+            value = p[i + 1 :].strip()
+
+            # strip double-quotes
+            if len(value) >= 2 and value[0:1] == value[-1:] == b'"':
+                value = value[1:-1]
+                value = value.replace(b"\\\\", b"\\").replace(b'\\"', b'"')
+            pdict[name] = value
+
+    return key, pdict
+
+
+def _parseparam(s: bytes) -> Generator[bytes, None, None]:
+    """Generator which splits the input on ;, respecting double-quoted sequences
+
+    Cargo-culted from `cgi`, but works on bytes rather than strings.
+
+    Args:
+        s: header to be parsed
+
+    Returns:
+        The split input
+    """
+    while s[:1] == b";":
+        s = s[1:]
+
+        # look for the next ;
+        end = s.find(b";")
+
+        # if there is an odd number of " marks between here and the next ;, skip to the
+        # next ; instead
+        while end > 0 and (s.count(b'"', 0, end) - s.count(b'\\"', 0, end)) % 2:
+            end = s.find(b";", end + 1)
+
+        if end < 0:
+            end = len(s)
+        f = s[:end]
+        yield f.strip()
+        s = s[end:]
diff --git a/synapse/media/filepath.py b/synapse/media/filepath.py
new file mode 100644
index 0000000000..1f6441c412
--- /dev/null
+++ b/synapse/media/filepath.py
@@ -0,0 +1,410 @@
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import functools
+import os
+import re
+import string
+from typing import Any, Callable, List, TypeVar, Union, cast
+
+NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d")
+
+
+F = TypeVar("F", bound=Callable[..., str])
+
+
+def _wrap_in_base_path(func: F) -> F:
+    """Takes a function that returns a relative path and turns it into an
+    absolute path based on the location of the primary media store
+    """
+
+    @functools.wraps(func)
+    def _wrapped(self: "MediaFilePaths", *args: Any, **kwargs: Any) -> str:
+        path = func(self, *args, **kwargs)
+        return os.path.join(self.base_path, path)
+
+    return cast(F, _wrapped)
+
+
+GetPathMethod = TypeVar(
+    "GetPathMethod", bound=Union[Callable[..., str], Callable[..., List[str]]]
+)
+
+
+def _wrap_with_jail_check(relative: bool) -> Callable[[GetPathMethod], GetPathMethod]:
+    """Wraps a path-returning method to check that the returned path(s) do not escape
+    the media store directory.
+
+    The path-returning method may return either a single path, or a list of paths.
+
+    The check is not expected to ever fail, unless `func` is missing a call to
+    `_validate_path_component`, or `_validate_path_component` is buggy.
+
+    Args:
+        relative: A boolean indicating whether the wrapped method returns paths relative
+            to the media store directory.
+
+    Returns:
+        A method which will wrap a path-returning method, adding a check to ensure that
+        the returned path(s) lie within the media store directory. The check will raise
+        a `ValueError` if it fails.
+    """
+
+    def _wrap_with_jail_check_inner(func: GetPathMethod) -> GetPathMethod:
+        @functools.wraps(func)
+        def _wrapped(
+            self: "MediaFilePaths", *args: Any, **kwargs: Any
+        ) -> Union[str, List[str]]:
+            path_or_paths = func(self, *args, **kwargs)
+
+            if isinstance(path_or_paths, list):
+                paths_to_check = path_or_paths
+            else:
+                paths_to_check = [path_or_paths]
+
+            for path in paths_to_check:
+                # Construct the path that will ultimately be used.
+                # We cannot guess whether `path` is relative to the media store
+                # directory, since the media store directory may itself be a relative
+                # path.
+                if relative:
+                    path = os.path.join(self.base_path, path)
+                normalized_path = os.path.normpath(path)
+
+                # Now that `normpath` has eliminated `../`s and `./`s from the path,
+                # `os.path.commonpath` can be used to check whether it lies within the
+                # media store directory.
+                if (
+                    os.path.commonpath([normalized_path, self.normalized_base_path])
+                    != self.normalized_base_path
+                ):
+                    # The path resolves to outside the media store directory,
+                    # or `self.base_path` is `.`, which is an unlikely configuration.
+                    raise ValueError(f"Invalid media store path: {path!r}")
+
+                # Note that `os.path.normpath`/`abspath` has a subtle caveat:
+                # `a/b/c/../c` will normalize to `a/b/c`, but the former refers to a
+                # different path if `a/b/c` is a symlink. That is, the check above is
+                # not perfect and may allow a certain restricted subset of untrustworthy
+                # paths through. Since the check above is secondary to the main
+                # `_validate_path_component` checks, it's less important for it to be
+                # perfect.
+                #
+                # As an alternative, `os.path.realpath` will resolve symlinks, but
+                # proves problematic if there are symlinks inside the media store.
+                # eg. if `url_store/` is symlinked to elsewhere, its canonical path
+                # won't match that of the main media store directory.
+
+            return path_or_paths
+
+        return cast(GetPathMethod, _wrapped)
+
+    return _wrap_with_jail_check_inner
+
+
+ALLOWED_CHARACTERS = set(
+    string.ascii_letters
+    + string.digits
+    + "_-"
+    + ".[]:"  # Domain names, IPv6 addresses and ports in server names
+)
+FORBIDDEN_NAMES = {
+    "",
+    os.path.curdir,  # "." for the current platform
+    os.path.pardir,  # ".." for the current platform
+}
+
+
+def _validate_path_component(name: str) -> str:
+    """Checks that the given string can be safely used as a path component
+
+    Args:
+        name: The path component to check.
+
+    Returns:
+        The path component if valid.
+
+    Raises:
+        ValueError: If `name` cannot be safely used as a path component.
+    """
+    if not ALLOWED_CHARACTERS.issuperset(name) or name in FORBIDDEN_NAMES:
+        raise ValueError(f"Invalid path component: {name!r}")
+
+    return name
+
+
+class MediaFilePaths:
+    """Describes where files are stored on disk.
+
+    Most of the functions have a `*_rel` variant which returns a file path that
+    is relative to the base media store path. This is mainly used when we want
+    to write to the backup media store (when one is configured)
+    """
+
+    def __init__(self, primary_base_path: str):
+        self.base_path = primary_base_path
+        self.normalized_base_path = os.path.normpath(self.base_path)
+
+        # Refuse to initialize if paths cannot be validated correctly for the current
+        # platform.
+        assert os.path.sep not in ALLOWED_CHARACTERS
+        assert os.path.altsep not in ALLOWED_CHARACTERS
+        # On Windows, paths have all sorts of weirdness which `_validate_path_component`
+        # does not consider. In any case, the remote media store can't work correctly
+        # for certain homeservers there, since ":"s aren't allowed in paths.
+        assert os.name == "posix"
+
+    @_wrap_with_jail_check(relative=True)
+    def local_media_filepath_rel(self, media_id: str) -> str:
+        return os.path.join(
+            "local_content",
+            _validate_path_component(media_id[0:2]),
+            _validate_path_component(media_id[2:4]),
+            _validate_path_component(media_id[4:]),
+        )
+
+    local_media_filepath = _wrap_in_base_path(local_media_filepath_rel)
+
+    @_wrap_with_jail_check(relative=True)
+    def local_media_thumbnail_rel(
+        self, media_id: str, width: int, height: int, content_type: str, method: str
+    ) -> str:
+        top_level_type, sub_type = content_type.split("/")
+        file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
+        return os.path.join(
+            "local_thumbnails",
+            _validate_path_component(media_id[0:2]),
+            _validate_path_component(media_id[2:4]),
+            _validate_path_component(media_id[4:]),
+            _validate_path_component(file_name),
+        )
+
+    local_media_thumbnail = _wrap_in_base_path(local_media_thumbnail_rel)
+
+    @_wrap_with_jail_check(relative=False)
+    def local_media_thumbnail_dir(self, media_id: str) -> str:
+        """
+        Retrieve the local store path of thumbnails of a given media_id
+
+        Args:
+            media_id: The media ID to query.
+        Returns:
+            Path of local_thumbnails from media_id
+        """
+        return os.path.join(
+            self.base_path,
+            "local_thumbnails",
+            _validate_path_component(media_id[0:2]),
+            _validate_path_component(media_id[2:4]),
+            _validate_path_component(media_id[4:]),
+        )
+
+    @_wrap_with_jail_check(relative=True)
+    def remote_media_filepath_rel(self, server_name: str, file_id: str) -> str:
+        return os.path.join(
+            "remote_content",
+            _validate_path_component(server_name),
+            _validate_path_component(file_id[0:2]),
+            _validate_path_component(file_id[2:4]),
+            _validate_path_component(file_id[4:]),
+        )
+
+    remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel)
+
+    @_wrap_with_jail_check(relative=True)
+    def remote_media_thumbnail_rel(
+        self,
+        server_name: str,
+        file_id: str,
+        width: int,
+        height: int,
+        content_type: str,
+        method: str,
+    ) -> str:
+        top_level_type, sub_type = content_type.split("/")
+        file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
+        return os.path.join(
+            "remote_thumbnail",
+            _validate_path_component(server_name),
+            _validate_path_component(file_id[0:2]),
+            _validate_path_component(file_id[2:4]),
+            _validate_path_component(file_id[4:]),
+            _validate_path_component(file_name),
+        )
+
+    remote_media_thumbnail = _wrap_in_base_path(remote_media_thumbnail_rel)
+
+    # Legacy path that was used to store thumbnails previously.
+    # Should be removed after some time, when most of the thumbnails are stored
+    # using the new path.
+    @_wrap_with_jail_check(relative=True)
+    def remote_media_thumbnail_rel_legacy(
+        self, server_name: str, file_id: str, width: int, height: int, content_type: str
+    ) -> str:
+        top_level_type, sub_type = content_type.split("/")
+        file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
+        return os.path.join(
+            "remote_thumbnail",
+            _validate_path_component(server_name),
+            _validate_path_component(file_id[0:2]),
+            _validate_path_component(file_id[2:4]),
+            _validate_path_component(file_id[4:]),
+            _validate_path_component(file_name),
+        )
+
+    @_wrap_with_jail_check(relative=False)
+    def remote_media_thumbnail_dir(self, server_name: str, file_id: str) -> str:
+        return os.path.join(
+            self.base_path,
+            "remote_thumbnail",
+            _validate_path_component(server_name),
+            _validate_path_component(file_id[0:2]),
+            _validate_path_component(file_id[2:4]),
+            _validate_path_component(file_id[4:]),
+        )
+
+    @_wrap_with_jail_check(relative=True)
+    def url_cache_filepath_rel(self, media_id: str) -> str:
+        if NEW_FORMAT_ID_RE.match(media_id):
+            # Media id is of the form <DATE><RANDOM_STRING>
+            # E.g.: 2017-09-28-fsdRDt24DS234dsf
+            return os.path.join(
+                "url_cache",
+                _validate_path_component(media_id[:10]),
+                _validate_path_component(media_id[11:]),
+            )
+        else:
+            return os.path.join(
+                "url_cache",
+                _validate_path_component(media_id[0:2]),
+                _validate_path_component(media_id[2:4]),
+                _validate_path_component(media_id[4:]),
+            )
+
+    url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel)
+
+    @_wrap_with_jail_check(relative=False)
+    def url_cache_filepath_dirs_to_delete(self, media_id: str) -> List[str]:
+        "The dirs to try and remove if we delete the media_id file"
+        if NEW_FORMAT_ID_RE.match(media_id):
+            return [
+                os.path.join(
+                    self.base_path, "url_cache", _validate_path_component(media_id[:10])
+                )
+            ]
+        else:
+            return [
+                os.path.join(
+                    self.base_path,
+                    "url_cache",
+                    _validate_path_component(media_id[0:2]),
+                    _validate_path_component(media_id[2:4]),
+                ),
+                os.path.join(
+                    self.base_path, "url_cache", _validate_path_component(media_id[0:2])
+                ),
+            ]
+
+    @_wrap_with_jail_check(relative=True)
+    def url_cache_thumbnail_rel(
+        self, media_id: str, width: int, height: int, content_type: str, method: str
+    ) -> str:
+        # Media id is of the form <DATE><RANDOM_STRING>
+        # E.g.: 2017-09-28-fsdRDt24DS234dsf
+
+        top_level_type, sub_type = content_type.split("/")
+        file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
+
+        if NEW_FORMAT_ID_RE.match(media_id):
+            return os.path.join(
+                "url_cache_thumbnails",
+                _validate_path_component(media_id[:10]),
+                _validate_path_component(media_id[11:]),
+                _validate_path_component(file_name),
+            )
+        else:
+            return os.path.join(
+                "url_cache_thumbnails",
+                _validate_path_component(media_id[0:2]),
+                _validate_path_component(media_id[2:4]),
+                _validate_path_component(media_id[4:]),
+                _validate_path_component(file_name),
+            )
+
+    url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel)
+
+    @_wrap_with_jail_check(relative=True)
+    def url_cache_thumbnail_directory_rel(self, media_id: str) -> str:
+        # Media id is of the form <DATE><RANDOM_STRING>
+        # E.g.: 2017-09-28-fsdRDt24DS234dsf
+
+        if NEW_FORMAT_ID_RE.match(media_id):
+            return os.path.join(
+                "url_cache_thumbnails",
+                _validate_path_component(media_id[:10]),
+                _validate_path_component(media_id[11:]),
+            )
+        else:
+            return os.path.join(
+                "url_cache_thumbnails",
+                _validate_path_component(media_id[0:2]),
+                _validate_path_component(media_id[2:4]),
+                _validate_path_component(media_id[4:]),
+            )
+
+    url_cache_thumbnail_directory = _wrap_in_base_path(
+        url_cache_thumbnail_directory_rel
+    )
+
+    @_wrap_with_jail_check(relative=False)
+    def url_cache_thumbnail_dirs_to_delete(self, media_id: str) -> List[str]:
+        "The dirs to try and remove if we delete the media_id thumbnails"
+        # Media id is of the form <DATE><RANDOM_STRING>
+        # E.g.: 2017-09-28-fsdRDt24DS234dsf
+        if NEW_FORMAT_ID_RE.match(media_id):
+            return [
+                os.path.join(
+                    self.base_path,
+                    "url_cache_thumbnails",
+                    _validate_path_component(media_id[:10]),
+                    _validate_path_component(media_id[11:]),
+                ),
+                os.path.join(
+                    self.base_path,
+                    "url_cache_thumbnails",
+                    _validate_path_component(media_id[:10]),
+                ),
+            ]
+        else:
+            return [
+                os.path.join(
+                    self.base_path,
+                    "url_cache_thumbnails",
+                    _validate_path_component(media_id[0:2]),
+                    _validate_path_component(media_id[2:4]),
+                    _validate_path_component(media_id[4:]),
+                ),
+                os.path.join(
+                    self.base_path,
+                    "url_cache_thumbnails",
+                    _validate_path_component(media_id[0:2]),
+                    _validate_path_component(media_id[2:4]),
+                ),
+                os.path.join(
+                    self.base_path,
+                    "url_cache_thumbnails",
+                    _validate_path_component(media_id[0:2]),
+                ),
+            ]
diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py
new file mode 100644
index 0000000000..b81e3c2b0c
--- /dev/null
+++ b/synapse/media/media_repository.py
@@ -0,0 +1,1038 @@
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import errno
+import logging
+import os
+import shutil
+from io import BytesIO
+from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple
+
+from matrix_common.types.mxc_uri import MXCUri
+
+import twisted.internet.error
+import twisted.web.http
+from twisted.internet.defer import Deferred
+
+from synapse.api.errors import (
+    FederationDeniedError,
+    HttpResponseException,
+    NotFoundError,
+    RequestSendFailed,
+    SynapseError,
+)
+from synapse.config.repository import ThumbnailRequirement
+from synapse.http.site import SynapseRequest
+from synapse.logging.context import defer_to_thread
+from synapse.media._base import (
+    FileInfo,
+    Responder,
+    ThumbnailInfo,
+    get_filename_from_headers,
+    respond_404,
+    respond_with_responder,
+)
+from synapse.media.filepath import MediaFilePaths
+from synapse.media.media_storage import MediaStorage
+from synapse.media.storage_provider import StorageProviderWrapper
+from synapse.media.thumbnailer import Thumbnailer, ThumbnailError
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.types import UserID
+from synapse.util.async_helpers import Linearizer
+from synapse.util.retryutils import NotRetryingDestination
+from synapse.util.stringutils import random_string
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+# How often to run the background job to update the "recently accessed"
+# attribute of local and remote media.
+UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000  # 1 minute
+# How often to run the background job to check for local and remote media
+# that should be purged according to the configured media retention settings.
+MEDIA_RETENTION_CHECK_PERIOD_MS = 60 * 60 * 1000  # 1 hour
+
+
+class MediaRepository:
+    def __init__(self, hs: "HomeServer"):
+        self.hs = hs
+        self.auth = hs.get_auth()
+        self.client = hs.get_federation_http_client()
+        self.clock = hs.get_clock()
+        self.server_name = hs.hostname
+        self.store = hs.get_datastores().main
+        self.max_upload_size = hs.config.media.max_upload_size
+        self.max_image_pixels = hs.config.media.max_image_pixels
+
+        Thumbnailer.set_limits(self.max_image_pixels)
+
+        self.primary_base_path: str = hs.config.media.media_store_path
+        self.filepaths: MediaFilePaths = MediaFilePaths(self.primary_base_path)
+
+        self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails
+        self.thumbnail_requirements = hs.config.media.thumbnail_requirements
+
+        self.remote_media_linearizer = Linearizer(name="media_remote")
+
+        self.recently_accessed_remotes: Set[Tuple[str, str]] = set()
+        self.recently_accessed_locals: Set[str] = set()
+
+        self.federation_domain_whitelist = (
+            hs.config.federation.federation_domain_whitelist
+        )
+
+        # List of StorageProviders where we should search for media and
+        # potentially upload to.
+        storage_providers = []
+
+        for (
+            clz,
+            provider_config,
+            wrapper_config,
+        ) in hs.config.media.media_storage_providers:
+            backend = clz(hs, provider_config)
+            provider = StorageProviderWrapper(
+                backend,
+                store_local=wrapper_config.store_local,
+                store_remote=wrapper_config.store_remote,
+                store_synchronous=wrapper_config.store_synchronous,
+            )
+            storage_providers.append(provider)
+
+        self.media_storage = MediaStorage(
+            self.hs, self.primary_base_path, self.filepaths, storage_providers
+        )
+
+        self.clock.looping_call(
+            self._start_update_recently_accessed, UPDATE_RECENTLY_ACCESSED_TS
+        )
+
+        # Media retention configuration options
+        self._media_retention_local_media_lifetime_ms = (
+            hs.config.media.media_retention_local_media_lifetime_ms
+        )
+        self._media_retention_remote_media_lifetime_ms = (
+            hs.config.media.media_retention_remote_media_lifetime_ms
+        )
+
+        # Check whether local or remote media retention is configured
+        if (
+            hs.config.media.media_retention_local_media_lifetime_ms is not None
+            or hs.config.media.media_retention_remote_media_lifetime_ms is not None
+        ):
+            # Run the background job to apply media retention rules routinely,
+            # with the duration between runs dictated by the homeserver config.
+            self.clock.looping_call(
+                self._start_apply_media_retention_rules,
+                MEDIA_RETENTION_CHECK_PERIOD_MS,
+            )
+
+    def _start_update_recently_accessed(self) -> Deferred:
+        return run_as_background_process(
+            "update_recently_accessed_media", self._update_recently_accessed
+        )
+
+    def _start_apply_media_retention_rules(self) -> Deferred:
+        return run_as_background_process(
+            "apply_media_retention_rules", self._apply_media_retention_rules
+        )
+
+    async def _update_recently_accessed(self) -> None:
+        remote_media = self.recently_accessed_remotes
+        self.recently_accessed_remotes = set()
+
+        local_media = self.recently_accessed_locals
+        self.recently_accessed_locals = set()
+
+        await self.store.update_cached_last_access_time(
+            local_media, remote_media, self.clock.time_msec()
+        )
+
+    def mark_recently_accessed(self, server_name: Optional[str], media_id: str) -> None:
+        """Mark the given media as recently accessed.
+
+        Args:
+            server_name: Origin server of media, or None if local
+            media_id: The media ID of the content
+        """
+        if server_name:
+            self.recently_accessed_remotes.add((server_name, media_id))
+        else:
+            self.recently_accessed_locals.add(media_id)
+
+    async def create_content(
+        self,
+        media_type: str,
+        upload_name: Optional[str],
+        content: IO,
+        content_length: int,
+        auth_user: UserID,
+    ) -> MXCUri:
+        """Store uploaded content for a local user and return the mxc URL
+
+        Args:
+            media_type: The content type of the file.
+            upload_name: The name of the file, if provided.
+            content: A file like object that is the content to store
+            content_length: The length of the content
+            auth_user: The user_id of the uploader
+
+        Returns:
+            The mxc url of the stored content
+        """
+
+        media_id = random_string(24)
+
+        file_info = FileInfo(server_name=None, file_id=media_id)
+
+        fname = await self.media_storage.store_file(content, file_info)
+
+        logger.info("Stored local media in file %r", fname)
+
+        await self.store.store_local_media(
+            media_id=media_id,
+            media_type=media_type,
+            time_now_ms=self.clock.time_msec(),
+            upload_name=upload_name,
+            media_length=content_length,
+            user_id=auth_user,
+        )
+
+        await self._generate_thumbnails(None, media_id, media_id, media_type)
+
+        return MXCUri(self.server_name, media_id)
+
+    async def get_local_media(
+        self, request: SynapseRequest, media_id: str, name: Optional[str]
+    ) -> None:
+        """Responds to requests for local media, if exists, or returns 404.
+
+        Args:
+            request: The incoming request.
+            media_id: The media ID of the content. (This is the same as
+                the file_id for local content.)
+            name: Optional name that, if specified, will be used as
+                the filename in the Content-Disposition header of the response.
+
+        Returns:
+            Resolves once a response has successfully been written to request
+        """
+        media_info = await self.store.get_local_media(media_id)
+        if not media_info or media_info["quarantined_by"]:
+            respond_404(request)
+            return
+
+        self.mark_recently_accessed(None, media_id)
+
+        media_type = media_info["media_type"]
+        if not media_type:
+            media_type = "application/octet-stream"
+        media_length = media_info["media_length"]
+        upload_name = name if name else media_info["upload_name"]
+        url_cache = media_info["url_cache"]
+
+        file_info = FileInfo(None, media_id, url_cache=bool(url_cache))
+
+        responder = await self.media_storage.fetch_media(file_info)
+        await respond_with_responder(
+            request, responder, media_type, media_length, upload_name
+        )
+
+    async def get_remote_media(
+        self,
+        request: SynapseRequest,
+        server_name: str,
+        media_id: str,
+        name: Optional[str],
+    ) -> None:
+        """Respond to requests for remote media.
+
+        Args:
+            request: The incoming request.
+            server_name: Remote server_name where the media originated.
+            media_id: The media ID of the content (as defined by the remote server).
+            name: Optional name that, if specified, will be used as
+                the filename in the Content-Disposition header of the response.
+
+        Returns:
+            Resolves once a response has successfully been written to request
+        """
+        if (
+            self.federation_domain_whitelist is not None
+            and server_name not in self.federation_domain_whitelist
+        ):
+            raise FederationDeniedError(server_name)
+
+        self.mark_recently_accessed(server_name, media_id)
+
+        # We linearize here to ensure that we don't try and download remote
+        # media multiple times concurrently
+        key = (server_name, media_id)
+        async with self.remote_media_linearizer.queue(key):
+            responder, media_info = await self._get_remote_media_impl(
+                server_name, media_id
+            )
+
+        # We deliberately stream the file outside the lock
+        if responder:
+            media_type = media_info["media_type"]
+            media_length = media_info["media_length"]
+            upload_name = name if name else media_info["upload_name"]
+            await respond_with_responder(
+                request, responder, media_type, media_length, upload_name
+            )
+        else:
+            respond_404(request)
+
+    async def get_remote_media_info(self, server_name: str, media_id: str) -> dict:
+        """Gets the media info associated with the remote file, downloading
+        if necessary.
+
+        Args:
+            server_name: Remote server_name where the media originated.
+            media_id: The media ID of the content (as defined by the remote server).
+
+        Returns:
+            The media info of the file
+        """
+        if (
+            self.federation_domain_whitelist is not None
+            and server_name not in self.federation_domain_whitelist
+        ):
+            raise FederationDeniedError(server_name)
+
+        # We linearize here to ensure that we don't try and download remote
+        # media multiple times concurrently
+        key = (server_name, media_id)
+        async with self.remote_media_linearizer.queue(key):
+            responder, media_info = await self._get_remote_media_impl(
+                server_name, media_id
+            )
+
+        # Ensure we actually use the responder so that it releases resources
+        if responder:
+            with responder:
+                pass
+
+        return media_info
+
+    async def _get_remote_media_impl(
+        self, server_name: str, media_id: str
+    ) -> Tuple[Optional[Responder], dict]:
+        """Looks for media in local cache, if not there then attempt to
+        download from remote server.
+
+        Args:
+            server_name: Remote server_name where the media originated.
+            media_id: The media ID of the content (as defined by the
+                remote server).
+
+        Returns:
+            A tuple of responder and the media info of the file.
+        """
+        media_info = await self.store.get_cached_remote_media(server_name, media_id)
+
+        # file_id is the ID we use to track the file locally. If we've already
+        # seen the file then reuse the existing ID, otherwise generate a new
+        # one.
+
+        # If we have an entry in the DB, try and look for it
+        if media_info:
+            file_id = media_info["filesystem_id"]
+            file_info = FileInfo(server_name, file_id)
+
+            if media_info["quarantined_by"]:
+                logger.info("Media is quarantined")
+                raise NotFoundError()
+
+            if not media_info["media_type"]:
+                media_info["media_type"] = "application/octet-stream"
+
+            responder = await self.media_storage.fetch_media(file_info)
+            if responder:
+                return responder, media_info
+
+        # Failed to find the file anywhere, lets download it.
+
+        try:
+            media_info = await self._download_remote_file(
+                server_name,
+                media_id,
+            )
+        except SynapseError:
+            raise
+        except Exception as e:
+            # An exception may be because we downloaded media in another
+            # process, so let's check if we magically have the media.
+            media_info = await self.store.get_cached_remote_media(server_name, media_id)
+            if not media_info:
+                raise e
+
+        file_id = media_info["filesystem_id"]
+        if not media_info["media_type"]:
+            media_info["media_type"] = "application/octet-stream"
+        file_info = FileInfo(server_name, file_id)
+
+        # We generate thumbnails even if another process downloaded the media
+        # as a) it's conceivable that the other download request dies before it
+        # generates thumbnails, but mainly b) we want to be sure the thumbnails
+        # have finished being generated before responding to the client,
+        # otherwise they'll request thumbnails and get a 404 if they're not
+        # ready yet.
+        await self._generate_thumbnails(
+            server_name, media_id, file_id, media_info["media_type"]
+        )
+
+        responder = await self.media_storage.fetch_media(file_info)
+        return responder, media_info
+
+    async def _download_remote_file(
+        self,
+        server_name: str,
+        media_id: str,
+    ) -> dict:
+        """Attempt to download the remote file from the given server name,
+        using the given file_id as the local id.
+
+        Args:
+            server_name: Originating server
+            media_id: The media ID of the content (as defined by the
+                remote server). This is different than the file_id, which is
+                locally generated.
+            file_id: Local file ID
+
+        Returns:
+            The media info of the file.
+        """
+
+        file_id = random_string(24)
+
+        file_info = FileInfo(server_name=server_name, file_id=file_id)
+
+        with self.media_storage.store_into_file(file_info) as (f, fname, finish):
+            request_path = "/".join(
+                ("/_matrix/media/r0/download", server_name, media_id)
+            )
+            try:
+                length, headers = await self.client.get_file(
+                    server_name,
+                    request_path,
+                    output_stream=f,
+                    max_size=self.max_upload_size,
+                    args={
+                        # tell the remote server to 404 if it doesn't
+                        # recognise the server_name, to make sure we don't
+                        # end up with a routing loop.
+                        "allow_remote": "false"
+                    },
+                )
+            except RequestSendFailed as e:
+                logger.warning(
+                    "Request failed fetching remote media %s/%s: %r",
+                    server_name,
+                    media_id,
+                    e,
+                )
+                raise SynapseError(502, "Failed to fetch remote media")
+
+            except HttpResponseException as e:
+                logger.warning(
+                    "HTTP error fetching remote media %s/%s: %s",
+                    server_name,
+                    media_id,
+                    e.response,
+                )
+                if e.code == twisted.web.http.NOT_FOUND:
+                    raise e.to_synapse_error()
+                raise SynapseError(502, "Failed to fetch remote media")
+
+            except SynapseError:
+                logger.warning(
+                    "Failed to fetch remote media %s/%s", server_name, media_id
+                )
+                raise
+            except NotRetryingDestination:
+                logger.warning("Not retrying destination %r", server_name)
+                raise SynapseError(502, "Failed to fetch remote media")
+            except Exception:
+                logger.exception(
+                    "Failed to fetch remote media %s/%s", server_name, media_id
+                )
+                raise SynapseError(502, "Failed to fetch remote media")
+
+            await finish()
+
+            if b"Content-Type" in headers:
+                media_type = headers[b"Content-Type"][0].decode("ascii")
+            else:
+                media_type = "application/octet-stream"
+            upload_name = get_filename_from_headers(headers)
+            time_now_ms = self.clock.time_msec()
+
+            # Multiple remote media download requests can race (when using
+            # multiple media repos), so this may throw a violation constraint
+            # exception. If it does we'll delete the newly downloaded file from
+            # disk (as we're in the ctx manager).
+            #
+            # However: we've already called `finish()` so we may have also
+            # written to the storage providers. This is preferable to the
+            # alternative where we call `finish()` *after* this, where we could
+            # end up having an entry in the DB but fail to write the files to
+            # the storage providers.
+            await self.store.store_cached_remote_media(
+                origin=server_name,
+                media_id=media_id,
+                media_type=media_type,
+                time_now_ms=self.clock.time_msec(),
+                upload_name=upload_name,
+                media_length=length,
+                filesystem_id=file_id,
+            )
+
+        logger.info("Stored remote media in file %r", fname)
+
+        media_info = {
+            "media_type": media_type,
+            "media_length": length,
+            "upload_name": upload_name,
+            "created_ts": time_now_ms,
+            "filesystem_id": file_id,
+        }
+
+        return media_info
+
+    def _get_thumbnail_requirements(
+        self, media_type: str
+    ) -> Tuple[ThumbnailRequirement, ...]:
+        scpos = media_type.find(";")
+        if scpos > 0:
+            media_type = media_type[:scpos]
+        return self.thumbnail_requirements.get(media_type, ())
+
+    def _generate_thumbnail(
+        self,
+        thumbnailer: Thumbnailer,
+        t_width: int,
+        t_height: int,
+        t_method: str,
+        t_type: str,
+    ) -> Optional[BytesIO]:
+        m_width = thumbnailer.width
+        m_height = thumbnailer.height
+
+        if m_width * m_height >= self.max_image_pixels:
+            logger.info(
+                "Image too large to thumbnail %r x %r > %r",
+                m_width,
+                m_height,
+                self.max_image_pixels,
+            )
+            return None
+
+        if thumbnailer.transpose_method is not None:
+            m_width, m_height = thumbnailer.transpose()
+
+        if t_method == "crop":
+            return thumbnailer.crop(t_width, t_height, t_type)
+        elif t_method == "scale":
+            t_width, t_height = thumbnailer.aspect(t_width, t_height)
+            t_width = min(m_width, t_width)
+            t_height = min(m_height, t_height)
+            return thumbnailer.scale(t_width, t_height, t_type)
+
+        return None
+
+    async def generate_local_exact_thumbnail(
+        self,
+        media_id: str,
+        t_width: int,
+        t_height: int,
+        t_method: str,
+        t_type: str,
+        url_cache: bool,
+    ) -> Optional[str]:
+        input_path = await self.media_storage.ensure_media_is_in_local_cache(
+            FileInfo(None, media_id, url_cache=url_cache)
+        )
+
+        try:
+            thumbnailer = Thumbnailer(input_path)
+        except ThumbnailError as e:
+            logger.warning(
+                "Unable to generate a thumbnail for local media %s using a method of %s and type of %s: %s",
+                media_id,
+                t_method,
+                t_type,
+                e,
+            )
+            return None
+
+        with thumbnailer:
+            t_byte_source = await defer_to_thread(
+                self.hs.get_reactor(),
+                self._generate_thumbnail,
+                thumbnailer,
+                t_width,
+                t_height,
+                t_method,
+                t_type,
+            )
+
+        if t_byte_source:
+            try:
+                file_info = FileInfo(
+                    server_name=None,
+                    file_id=media_id,
+                    url_cache=url_cache,
+                    thumbnail=ThumbnailInfo(
+                        width=t_width,
+                        height=t_height,
+                        method=t_method,
+                        type=t_type,
+                    ),
+                )
+
+                output_path = await self.media_storage.store_file(
+                    t_byte_source, file_info
+                )
+            finally:
+                t_byte_source.close()
+
+            logger.info("Stored thumbnail in file %r", output_path)
+
+            t_len = os.path.getsize(output_path)
+
+            await self.store.store_local_thumbnail(
+                media_id, t_width, t_height, t_type, t_method, t_len
+            )
+
+            return output_path
+
+        # Could not generate thumbnail.
+        return None
+
+    async def generate_remote_exact_thumbnail(
+        self,
+        server_name: str,
+        file_id: str,
+        media_id: str,
+        t_width: int,
+        t_height: int,
+        t_method: str,
+        t_type: str,
+    ) -> Optional[str]:
+        input_path = await self.media_storage.ensure_media_is_in_local_cache(
+            FileInfo(server_name, file_id)
+        )
+
+        try:
+            thumbnailer = Thumbnailer(input_path)
+        except ThumbnailError as e:
+            logger.warning(
+                "Unable to generate a thumbnail for remote media %s from %s using a method of %s and type of %s: %s",
+                media_id,
+                server_name,
+                t_method,
+                t_type,
+                e,
+            )
+            return None
+
+        with thumbnailer:
+            t_byte_source = await defer_to_thread(
+                self.hs.get_reactor(),
+                self._generate_thumbnail,
+                thumbnailer,
+                t_width,
+                t_height,
+                t_method,
+                t_type,
+            )
+
+        if t_byte_source:
+            try:
+                file_info = FileInfo(
+                    server_name=server_name,
+                    file_id=file_id,
+                    thumbnail=ThumbnailInfo(
+                        width=t_width,
+                        height=t_height,
+                        method=t_method,
+                        type=t_type,
+                    ),
+                )
+
+                output_path = await self.media_storage.store_file(
+                    t_byte_source, file_info
+                )
+            finally:
+                t_byte_source.close()
+
+            logger.info("Stored thumbnail in file %r", output_path)
+
+            t_len = os.path.getsize(output_path)
+
+            await self.store.store_remote_media_thumbnail(
+                server_name,
+                media_id,
+                file_id,
+                t_width,
+                t_height,
+                t_type,
+                t_method,
+                t_len,
+            )
+
+            return output_path
+
+        # Could not generate thumbnail.
+        return None
+
+    async def _generate_thumbnails(
+        self,
+        server_name: Optional[str],
+        media_id: str,
+        file_id: str,
+        media_type: str,
+        url_cache: bool = False,
+    ) -> Optional[dict]:
+        """Generate and store thumbnails for an image.
+
+        Args:
+            server_name: The server name if remote media, else None if local
+            media_id: The media ID of the content. (This is the same as
+                the file_id for local content)
+            file_id: Local file ID
+            media_type: The content type of the file
+            url_cache: If we are thumbnailing images downloaded for the URL cache,
+                used exclusively by the url previewer
+
+        Returns:
+            Dict with "width" and "height" keys of original image or None if the
+            media cannot be thumbnailed.
+        """
+        requirements = self._get_thumbnail_requirements(media_type)
+        if not requirements:
+            return None
+
+        input_path = await self.media_storage.ensure_media_is_in_local_cache(
+            FileInfo(server_name, file_id, url_cache=url_cache)
+        )
+
+        try:
+            thumbnailer = Thumbnailer(input_path)
+        except ThumbnailError as e:
+            logger.warning(
+                "Unable to generate thumbnails for remote media %s from %s of type %s: %s",
+                media_id,
+                server_name,
+                media_type,
+                e,
+            )
+            return None
+
+        with thumbnailer:
+            m_width = thumbnailer.width
+            m_height = thumbnailer.height
+
+            if m_width * m_height >= self.max_image_pixels:
+                logger.info(
+                    "Image too large to thumbnail %r x %r > %r",
+                    m_width,
+                    m_height,
+                    self.max_image_pixels,
+                )
+                return None
+
+            if thumbnailer.transpose_method is not None:
+                m_width, m_height = await defer_to_thread(
+                    self.hs.get_reactor(), thumbnailer.transpose
+                )
+
+            # We deduplicate the thumbnail sizes by ignoring the cropped versions if
+            # they have the same dimensions of a scaled one.
+            thumbnails: Dict[Tuple[int, int, str], str] = {}
+            for requirement in requirements:
+                if requirement.method == "crop":
+                    thumbnails.setdefault(
+                        (requirement.width, requirement.height, requirement.media_type),
+                        requirement.method,
+                    )
+                elif requirement.method == "scale":
+                    t_width, t_height = thumbnailer.aspect(
+                        requirement.width, requirement.height
+                    )
+                    t_width = min(m_width, t_width)
+                    t_height = min(m_height, t_height)
+                    thumbnails[
+                        (t_width, t_height, requirement.media_type)
+                    ] = requirement.method
+
+            # Now we generate the thumbnails for each dimension, store it
+            for (t_width, t_height, t_type), t_method in thumbnails.items():
+                # Generate the thumbnail
+                if t_method == "crop":
+                    t_byte_source = await defer_to_thread(
+                        self.hs.get_reactor(),
+                        thumbnailer.crop,
+                        t_width,
+                        t_height,
+                        t_type,
+                    )
+                elif t_method == "scale":
+                    t_byte_source = await defer_to_thread(
+                        self.hs.get_reactor(),
+                        thumbnailer.scale,
+                        t_width,
+                        t_height,
+                        t_type,
+                    )
+                else:
+                    logger.error("Unrecognized method: %r", t_method)
+                    continue
+
+                if not t_byte_source:
+                    continue
+
+                file_info = FileInfo(
+                    server_name=server_name,
+                    file_id=file_id,
+                    url_cache=url_cache,
+                    thumbnail=ThumbnailInfo(
+                        width=t_width,
+                        height=t_height,
+                        method=t_method,
+                        type=t_type,
+                    ),
+                )
+
+                with self.media_storage.store_into_file(file_info) as (
+                    f,
+                    fname,
+                    finish,
+                ):
+                    try:
+                        await self.media_storage.write_to_file(t_byte_source, f)
+                        await finish()
+                    finally:
+                        t_byte_source.close()
+
+                    t_len = os.path.getsize(fname)
+
+                    # Write to database
+                    if server_name:
+                        # Multiple remote media download requests can race (when
+                        # using multiple media repos), so this may throw a violation
+                        # constraint exception. If it does we'll delete the newly
+                        # generated thumbnail from disk (as we're in the ctx
+                        # manager).
+                        #
+                        # However: we've already called `finish()` so we may have
+                        # also written to the storage providers. This is preferable
+                        # to the alternative where we call `finish()` *after* this,
+                        # where we could end up having an entry in the DB but fail
+                        # to write the files to the storage providers.
+                        try:
+                            await self.store.store_remote_media_thumbnail(
+                                server_name,
+                                media_id,
+                                file_id,
+                                t_width,
+                                t_height,
+                                t_type,
+                                t_method,
+                                t_len,
+                            )
+                        except Exception as e:
+                            thumbnail_exists = (
+                                await self.store.get_remote_media_thumbnail(
+                                    server_name,
+                                    media_id,
+                                    t_width,
+                                    t_height,
+                                    t_type,
+                                )
+                            )
+                            if not thumbnail_exists:
+                                raise e
+                    else:
+                        await self.store.store_local_thumbnail(
+                            media_id, t_width, t_height, t_type, t_method, t_len
+                        )
+
+        return {"width": m_width, "height": m_height}
+
+    async def _apply_media_retention_rules(self) -> None:
+        """
+        Purge old local and remote media according to the media retention rules
+        defined in the homeserver config.
+        """
+        # Purge remote media
+        if self._media_retention_remote_media_lifetime_ms is not None:
+            # Calculate a threshold timestamp derived from the configured lifetime. Any
+            # media that has not been accessed since this timestamp will be removed.
+            remote_media_threshold_timestamp_ms = (
+                self.clock.time_msec() - self._media_retention_remote_media_lifetime_ms
+            )
+
+            logger.info(
+                "Purging remote media last accessed before"
+                f" {remote_media_threshold_timestamp_ms}"
+            )
+
+            await self.delete_old_remote_media(
+                before_ts=remote_media_threshold_timestamp_ms
+            )
+
+        # And now do the same for local media
+        if self._media_retention_local_media_lifetime_ms is not None:
+            # This works the same as the remote media threshold
+            local_media_threshold_timestamp_ms = (
+                self.clock.time_msec() - self._media_retention_local_media_lifetime_ms
+            )
+
+            logger.info(
+                "Purging local media last accessed before"
+                f" {local_media_threshold_timestamp_ms}"
+            )
+
+            await self.delete_old_local_media(
+                before_ts=local_media_threshold_timestamp_ms,
+                keep_profiles=True,
+                delete_quarantined_media=False,
+                delete_protected_media=False,
+            )
+
+    async def delete_old_remote_media(self, before_ts: int) -> Dict[str, int]:
+        old_media = await self.store.get_remote_media_ids(
+            before_ts, include_quarantined_media=False
+        )
+
+        deleted = 0
+
+        for media in old_media:
+            origin = media["media_origin"]
+            media_id = media["media_id"]
+            file_id = media["filesystem_id"]
+            key = (origin, media_id)
+
+            logger.info("Deleting: %r", key)
+
+            # TODO: Should we delete from the backup store
+
+            async with self.remote_media_linearizer.queue(key):
+                full_path = self.filepaths.remote_media_filepath(origin, file_id)
+                try:
+                    os.remove(full_path)
+                except OSError as e:
+                    logger.warning("Failed to remove file: %r", full_path)
+                    if e.errno == errno.ENOENT:
+                        pass
+                    else:
+                        continue
+
+                thumbnail_dir = self.filepaths.remote_media_thumbnail_dir(
+                    origin, file_id
+                )
+                shutil.rmtree(thumbnail_dir, ignore_errors=True)
+
+                await self.store.delete_remote_media(origin, media_id)
+                deleted += 1
+
+        return {"deleted": deleted}
+
+    async def delete_local_media_ids(
+        self, media_ids: List[str]
+    ) -> Tuple[List[str], int]:
+        """
+        Delete the given local or remote media ID from this server
+
+        Args:
+            media_id: The media ID to delete.
+        Returns:
+            A tuple of (list of deleted media IDs, total deleted media IDs).
+        """
+        return await self._remove_local_media_from_disk(media_ids)
+
+    async def delete_old_local_media(
+        self,
+        before_ts: int,
+        size_gt: int = 0,
+        keep_profiles: bool = True,
+        delete_quarantined_media: bool = False,
+        delete_protected_media: bool = False,
+    ) -> Tuple[List[str], int]:
+        """
+        Delete local or remote media from this server by size and timestamp. Removes
+        media files, any thumbnails and cached URLs.
+
+        Args:
+            before_ts: Unix timestamp in ms.
+                Files that were last used before this timestamp will be deleted.
+            size_gt: Size of the media in bytes. Files that are larger will be deleted.
+            keep_profiles: Switch to delete also files that are still used in image data
+                (e.g user profile, room avatar). If false these files will be deleted.
+            delete_quarantined_media: If True, media marked as quarantined will be deleted.
+            delete_protected_media: If True, media marked as protected will be deleted.
+
+        Returns:
+            A tuple of (list of deleted media IDs, total deleted media IDs).
+        """
+        old_media = await self.store.get_local_media_ids(
+            before_ts,
+            size_gt,
+            keep_profiles,
+            include_quarantined_media=delete_quarantined_media,
+            include_protected_media=delete_protected_media,
+        )
+        return await self._remove_local_media_from_disk(old_media)
+
+    async def _remove_local_media_from_disk(
+        self, media_ids: List[str]
+    ) -> Tuple[List[str], int]:
+        """
+        Delete local or remote media from this server. Removes media files,
+        any thumbnails and cached URLs.
+
+        Args:
+            media_ids: List of media_id to delete
+        Returns:
+            A tuple of (list of deleted media IDs, total deleted media IDs).
+        """
+        removed_media = []
+        for media_id in media_ids:
+            logger.info("Deleting media with ID '%s'", media_id)
+            full_path = self.filepaths.local_media_filepath(media_id)
+            try:
+                os.remove(full_path)
+            except OSError as e:
+                logger.warning("Failed to remove file: %r: %s", full_path, e)
+                if e.errno == errno.ENOENT:
+                    pass
+                else:
+                    continue
+
+            thumbnail_dir = self.filepaths.local_media_thumbnail_dir(media_id)
+            shutil.rmtree(thumbnail_dir, ignore_errors=True)
+
+            await self.store.delete_remote_media(self.server_name, media_id)
+
+            await self.store.delete_url_cache((media_id,))
+            await self.store.delete_url_cache_media((media_id,))
+
+            removed_media.append(media_id)
+
+        return removed_media, len(removed_media)
diff --git a/synapse/media/media_storage.py b/synapse/media/media_storage.py
new file mode 100644
index 0000000000..a7e22a91e1
--- /dev/null
+++ b/synapse/media/media_storage.py
@@ -0,0 +1,374 @@
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import contextlib
+import logging
+import os
+import shutil
+from types import TracebackType
+from typing import (
+    IO,
+    TYPE_CHECKING,
+    Any,
+    Awaitable,
+    BinaryIO,
+    Callable,
+    Generator,
+    Optional,
+    Sequence,
+    Tuple,
+    Type,
+)
+
+import attr
+
+from twisted.internet.defer import Deferred
+from twisted.internet.interfaces import IConsumer
+from twisted.protocols.basic import FileSender
+
+import synapse
+from synapse.api.errors import NotFoundError
+from synapse.logging.context import defer_to_thread, make_deferred_yieldable
+from synapse.util import Clock
+from synapse.util.file_consumer import BackgroundFileConsumer
+
+from ._base import FileInfo, Responder
+from .filepath import MediaFilePaths
+
+if TYPE_CHECKING:
+    from synapse.media.storage_provider import StorageProvider
+    from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class MediaStorage:
+    """Responsible for storing/fetching files from local sources.
+
+    Args:
+        hs
+        local_media_directory: Base path where we store media on disk
+        filepaths
+        storage_providers: List of StorageProvider that are used to fetch and store files.
+    """
+
+    def __init__(
+        self,
+        hs: "HomeServer",
+        local_media_directory: str,
+        filepaths: MediaFilePaths,
+        storage_providers: Sequence["StorageProvider"],
+    ):
+        self.hs = hs
+        self.reactor = hs.get_reactor()
+        self.local_media_directory = local_media_directory
+        self.filepaths = filepaths
+        self.storage_providers = storage_providers
+        self.spam_checker = hs.get_spam_checker()
+        self.clock = hs.get_clock()
+
+    async def store_file(self, source: IO, file_info: FileInfo) -> str:
+        """Write `source` to the on disk media store, and also any other
+        configured storage providers
+
+        Args:
+            source: A file like object that should be written
+            file_info: Info about the file to store
+
+        Returns:
+            the file path written to in the primary media store
+        """
+
+        with self.store_into_file(file_info) as (f, fname, finish_cb):
+            # Write to the main repository
+            await self.write_to_file(source, f)
+            await finish_cb()
+
+        return fname
+
+    async def write_to_file(self, source: IO, output: IO) -> None:
+        """Asynchronously write the `source` to `output`."""
+        await defer_to_thread(self.reactor, _write_file_synchronously, source, output)
+
+    @contextlib.contextmanager
+    def store_into_file(
+        self, file_info: FileInfo
+    ) -> Generator[Tuple[BinaryIO, str, Callable[[], Awaitable[None]]], None, None]:
+        """Context manager used to get a file like object to write into, as
+        described by file_info.
+
+        Actually yields a 3-tuple (file, fname, finish_cb), where file is a file
+        like object that can be written to, fname is the absolute path of file
+        on disk, and finish_cb is a function that returns an awaitable.
+
+        fname can be used to read the contents from after upload, e.g. to
+        generate thumbnails.
+
+        finish_cb must be called and waited on after the file has been
+        successfully been written to. Should not be called if there was an
+        error.
+
+        Args:
+            file_info: Info about the file to store
+
+        Example:
+
+            with media_storage.store_into_file(info) as (f, fname, finish_cb):
+                # .. write into f ...
+                await finish_cb()
+        """
+
+        path = self._file_info_to_path(file_info)
+        fname = os.path.join(self.local_media_directory, path)
+
+        dirname = os.path.dirname(fname)
+        os.makedirs(dirname, exist_ok=True)
+
+        finished_called = [False]
+
+        try:
+            with open(fname, "wb") as f:
+
+                async def finish() -> None:
+                    # Ensure that all writes have been flushed and close the
+                    # file.
+                    f.flush()
+                    f.close()
+
+                    spam_check = await self.spam_checker.check_media_file_for_spam(
+                        ReadableFileWrapper(self.clock, fname), file_info
+                    )
+                    if spam_check != synapse.module_api.NOT_SPAM:
+                        logger.info("Blocking media due to spam checker")
+                        # Note that we'll delete the stored media, due to the
+                        # try/except below. The media also won't be stored in
+                        # the DB.
+                        # We currently ignore any additional field returned by
+                        # the spam-check API.
+                        raise SpamMediaException(errcode=spam_check[0])
+
+                    for provider in self.storage_providers:
+                        await provider.store_file(path, file_info)
+
+                    finished_called[0] = True
+
+                yield f, fname, finish
+        except Exception as e:
+            try:
+                os.remove(fname)
+            except Exception:
+                pass
+
+            raise e from None
+
+        if not finished_called:
+            raise Exception("Finished callback not called")
+
+    async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]:
+        """Attempts to fetch media described by file_info from the local cache
+        and configured storage providers.
+
+        Args:
+            file_info
+
+        Returns:
+            Returns a Responder if the file was found, otherwise None.
+        """
+        paths = [self._file_info_to_path(file_info)]
+
+        # fallback for remote thumbnails with no method in the filename
+        if file_info.thumbnail and file_info.server_name:
+            paths.append(
+                self.filepaths.remote_media_thumbnail_rel_legacy(
+                    server_name=file_info.server_name,
+                    file_id=file_info.file_id,
+                    width=file_info.thumbnail.width,
+                    height=file_info.thumbnail.height,
+                    content_type=file_info.thumbnail.type,
+                )
+            )
+
+        for path in paths:
+            local_path = os.path.join(self.local_media_directory, path)
+            if os.path.exists(local_path):
+                logger.debug("responding with local file %s", local_path)
+                return FileResponder(open(local_path, "rb"))
+            logger.debug("local file %s did not exist", local_path)
+
+        for provider in self.storage_providers:
+            for path in paths:
+                res: Any = await provider.fetch(path, file_info)
+                if res:
+                    logger.debug("Streaming %s from %s", path, provider)
+                    return res
+                logger.debug("%s not found on %s", path, provider)
+
+        return None
+
+    async def ensure_media_is_in_local_cache(self, file_info: FileInfo) -> str:
+        """Ensures that the given file is in the local cache. Attempts to
+        download it from storage providers if it isn't.
+
+        Args:
+            file_info
+
+        Returns:
+            Full path to local file
+        """
+        path = self._file_info_to_path(file_info)
+        local_path = os.path.join(self.local_media_directory, path)
+        if os.path.exists(local_path):
+            return local_path
+
+        # Fallback for paths without method names
+        # Should be removed in the future
+        if file_info.thumbnail and file_info.server_name:
+            legacy_path = self.filepaths.remote_media_thumbnail_rel_legacy(
+                server_name=file_info.server_name,
+                file_id=file_info.file_id,
+                width=file_info.thumbnail.width,
+                height=file_info.thumbnail.height,
+                content_type=file_info.thumbnail.type,
+            )
+            legacy_local_path = os.path.join(self.local_media_directory, legacy_path)
+            if os.path.exists(legacy_local_path):
+                return legacy_local_path
+
+        dirname = os.path.dirname(local_path)
+        os.makedirs(dirname, exist_ok=True)
+
+        for provider in self.storage_providers:
+            res: Any = await provider.fetch(path, file_info)
+            if res:
+                with res:
+                    consumer = BackgroundFileConsumer(
+                        open(local_path, "wb"), self.reactor
+                    )
+                    await res.write_to_consumer(consumer)
+                    await consumer.wait()
+                return local_path
+
+        raise NotFoundError()
+
+    def _file_info_to_path(self, file_info: FileInfo) -> str:
+        """Converts file_info into a relative path.
+
+        The path is suitable for storing files under a directory, e.g. used to
+        store files on local FS under the base media repository directory.
+        """
+        if file_info.url_cache:
+            if file_info.thumbnail:
+                return self.filepaths.url_cache_thumbnail_rel(
+                    media_id=file_info.file_id,
+                    width=file_info.thumbnail.width,
+                    height=file_info.thumbnail.height,
+                    content_type=file_info.thumbnail.type,
+                    method=file_info.thumbnail.method,
+                )
+            return self.filepaths.url_cache_filepath_rel(file_info.file_id)
+
+        if file_info.server_name:
+            if file_info.thumbnail:
+                return self.filepaths.remote_media_thumbnail_rel(
+                    server_name=file_info.server_name,
+                    file_id=file_info.file_id,
+                    width=file_info.thumbnail.width,
+                    height=file_info.thumbnail.height,
+                    content_type=file_info.thumbnail.type,
+                    method=file_info.thumbnail.method,
+                )
+            return self.filepaths.remote_media_filepath_rel(
+                file_info.server_name, file_info.file_id
+            )
+
+        if file_info.thumbnail:
+            return self.filepaths.local_media_thumbnail_rel(
+                media_id=file_info.file_id,
+                width=file_info.thumbnail.width,
+                height=file_info.thumbnail.height,
+                content_type=file_info.thumbnail.type,
+                method=file_info.thumbnail.method,
+            )
+        return self.filepaths.local_media_filepath_rel(file_info.file_id)
+
+
+def _write_file_synchronously(source: IO, dest: IO) -> None:
+    """Write `source` to the file like `dest` synchronously. Should be called
+    from a thread.
+
+    Args:
+        source: A file like object that's to be written
+        dest: A file like object to be written to
+    """
+    source.seek(0)  # Ensure we read from the start of the file
+    shutil.copyfileobj(source, dest)
+
+
+class FileResponder(Responder):
+    """Wraps an open file that can be sent to a request.
+
+    Args:
+        open_file: A file like object to be streamed ot the client,
+            is closed when finished streaming.
+    """
+
+    def __init__(self, open_file: IO):
+        self.open_file = open_file
+
+    def write_to_consumer(self, consumer: IConsumer) -> Deferred:
+        return make_deferred_yieldable(
+            FileSender().beginFileTransfer(self.open_file, consumer)
+        )
+
+    def __exit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc_val: Optional[BaseException],
+        exc_tb: Optional[TracebackType],
+    ) -> None:
+        self.open_file.close()
+
+
+class SpamMediaException(NotFoundError):
+    """The media was blocked by a spam checker, so we simply 404 the request (in
+    the same way as if it was quarantined).
+    """
+
+
+@attr.s(slots=True, auto_attribs=True)
+class ReadableFileWrapper:
+    """Wrapper that allows reading a file in chunks, yielding to the reactor,
+    and writing to a callback.
+
+    This is simplified `FileSender` that takes an IO object rather than an
+    `IConsumer`.
+    """
+
+    CHUNK_SIZE = 2**14
+
+    clock: Clock
+    path: str
+
+    async def write_chunks_to(self, callback: Callable[[bytes], object]) -> None:
+        """Reads the file in chunks and calls the callback with each chunk."""
+
+        with open(self.path, "rb") as file:
+            while True:
+                chunk = file.read(self.CHUNK_SIZE)
+                if not chunk:
+                    break
+
+                callback(chunk)
+
+                # We yield to the reactor by sleeping for 0 seconds.
+                await self.clock.sleep(0)
diff --git a/synapse/media/oembed.py b/synapse/media/oembed.py
new file mode 100644
index 0000000000..c0eaf04be5
--- /dev/null
+++ b/synapse/media/oembed.py
@@ -0,0 +1,265 @@
+#  Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+#  Licensed under the Apache License, Version 2.0 (the "License");
+#  you may not use this file except in compliance with the License.
+#  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+#  limitations under the License.
+import html
+import logging
+import urllib.parse
+from typing import TYPE_CHECKING, List, Optional
+
+import attr
+
+from synapse.media.preview_html import parse_html_description
+from synapse.types import JsonDict
+from synapse.util import json_decoder
+
+if TYPE_CHECKING:
+    from lxml import etree
+
+    from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class OEmbedResult:
+    # The Open Graph result (converted from the oEmbed result).
+    open_graph_result: JsonDict
+    # The author_name of the oEmbed result
+    author_name: Optional[str]
+    # Number of milliseconds to cache the content, according to the oEmbed response.
+    #
+    # This will be None if no cache-age is provided in the oEmbed response (or
+    # if the oEmbed response cannot be turned into an Open Graph response).
+    cache_age: Optional[int]
+
+
+class OEmbedProvider:
+    """
+    A helper for accessing oEmbed content.
+
+    It can be used to check if a URL should be accessed via oEmbed and for
+    requesting/parsing oEmbed content.
+    """
+
+    def __init__(self, hs: "HomeServer"):
+        self._oembed_patterns = {}
+        for oembed_endpoint in hs.config.oembed.oembed_patterns:
+            api_endpoint = oembed_endpoint.api_endpoint
+
+            # Only JSON is supported at the moment. This could be declared in
+            # the formats field. Otherwise, if the endpoint ends in .xml assume
+            # it doesn't support JSON.
+            if (
+                oembed_endpoint.formats is not None
+                and "json" not in oembed_endpoint.formats
+            ) or api_endpoint.endswith(".xml"):
+                logger.info(
+                    "Ignoring oEmbed endpoint due to not supporting JSON: %s",
+                    api_endpoint,
+                )
+                continue
+
+            # Iterate through each URL pattern and point it to the endpoint.
+            for pattern in oembed_endpoint.url_patterns:
+                self._oembed_patterns[pattern] = api_endpoint
+
+    def get_oembed_url(self, url: str) -> Optional[str]:
+        """
+        Check whether the URL should be downloaded as oEmbed content instead.
+
+        Args:
+            url: The URL to check.
+
+        Returns:
+            A URL to use instead or None if the original URL should be used.
+        """
+        for url_pattern, endpoint in self._oembed_patterns.items():
+            if url_pattern.fullmatch(url):
+                # TODO Specify max height / width.
+
+                # Note that only the JSON format is supported, some endpoints want
+                # this in the URL, others want it as an argument.
+                endpoint = endpoint.replace("{format}", "json")
+
+                args = {"url": url, "format": "json"}
+                query_str = urllib.parse.urlencode(args, True)
+                return f"{endpoint}?{query_str}"
+
+        # No match.
+        return None
+
+    def autodiscover_from_html(self, tree: "etree.Element") -> Optional[str]:
+        """
+        Search an HTML document for oEmbed autodiscovery information.
+
+        Args:
+            tree: The parsed HTML body.
+
+        Returns:
+            The URL to use for oEmbed information, or None if no URL was found.
+        """
+        # Search for link elements with the proper rel and type attributes.
+        for tag in tree.xpath(
+            "//link[@rel='alternate'][@type='application/json+oembed']"
+        ):
+            if "href" in tag.attrib:
+                return tag.attrib["href"]
+
+        # Some providers (e.g. Flickr) use alternative instead of alternate.
+        for tag in tree.xpath(
+            "//link[@rel='alternative'][@type='application/json+oembed']"
+        ):
+            if "href" in tag.attrib:
+                return tag.attrib["href"]
+
+        return None
+
+    def parse_oembed_response(self, url: str, raw_body: bytes) -> OEmbedResult:
+        """
+        Parse the oEmbed response into an Open Graph response.
+
+        Args:
+            url: The URL which is being previewed (not the one which was
+                requested).
+            raw_body: The oEmbed response as JSON encoded as bytes.
+
+        Returns:
+            json-encoded Open Graph data
+        """
+
+        try:
+            # oEmbed responses *must* be UTF-8 according to the spec.
+            oembed = json_decoder.decode(raw_body.decode("utf-8"))
+        except ValueError:
+            return OEmbedResult({}, None, None)
+
+        # The version is a required string field, but not always provided,
+        # or sometimes provided as a float. Be lenient.
+        oembed_version = oembed.get("version", "1.0")
+        if oembed_version != "1.0" and oembed_version != 1:
+            return OEmbedResult({}, None, None)
+
+        # Attempt to parse the cache age, if possible.
+        try:
+            cache_age = int(oembed.get("cache_age")) * 1000
+        except (TypeError, ValueError):
+            # If the cache age cannot be parsed (e.g. wrong type or invalid
+            # string), ignore it.
+            cache_age = None
+
+        # The oEmbed response converted to Open Graph.
+        open_graph_response: JsonDict = {"og:url": url}
+
+        title = oembed.get("title")
+        if title and isinstance(title, str):
+            # A common WordPress plug-in seems to incorrectly escape entities
+            # in the oEmbed response.
+            open_graph_response["og:title"] = html.unescape(title)
+
+        author_name = oembed.get("author_name")
+        if not isinstance(author_name, str):
+            author_name = None
+
+        # Use the provider name and as the site.
+        provider_name = oembed.get("provider_name")
+        if provider_name and isinstance(provider_name, str):
+            open_graph_response["og:site_name"] = provider_name
+
+        # If a thumbnail exists, use it. Note that dimensions will be calculated later.
+        thumbnail_url = oembed.get("thumbnail_url")
+        if thumbnail_url and isinstance(thumbnail_url, str):
+            open_graph_response["og:image"] = thumbnail_url
+
+        # Process each type separately.
+        oembed_type = oembed.get("type")
+        if oembed_type == "rich":
+            html_str = oembed.get("html")
+            if isinstance(html_str, str):
+                calc_description_and_urls(open_graph_response, html_str)
+
+        elif oembed_type == "photo":
+            # If this is a photo, use the full image, not the thumbnail.
+            url = oembed.get("url")
+            if url and isinstance(url, str):
+                open_graph_response["og:image"] = url
+
+        elif oembed_type == "video":
+            open_graph_response["og:type"] = "video.other"
+            html_str = oembed.get("html")
+            if html_str and isinstance(html_str, str):
+                calc_description_and_urls(open_graph_response, oembed["html"])
+            for size in ("width", "height"):
+                val = oembed.get(size)
+                if type(val) is int:
+                    open_graph_response[f"og:video:{size}"] = val
+
+        elif oembed_type == "link":
+            open_graph_response["og:type"] = "website"
+
+        else:
+            logger.warning("Unknown oEmbed type: %s", oembed_type)
+
+        return OEmbedResult(open_graph_response, author_name, cache_age)
+
+
+def _fetch_urls(tree: "etree.Element", tag_name: str) -> List[str]:
+    results = []
+    for tag in tree.xpath("//*/" + tag_name):
+        if "src" in tag.attrib:
+            results.append(tag.attrib["src"])
+    return results
+
+
+def calc_description_and_urls(open_graph_response: JsonDict, html_body: str) -> None:
+    """
+    Calculate description for an HTML document.
+
+    This uses lxml to convert the HTML document into plaintext. If errors
+    occur during processing of the document, an empty response is returned.
+
+    Args:
+        open_graph_response: The current Open Graph summary. This is updated with additional fields.
+        html_body: The HTML document, as bytes.
+
+    Returns:
+        The summary
+    """
+    # If there's no body, nothing useful is going to be found.
+    if not html_body:
+        return
+
+    from lxml import etree
+
+    # Create an HTML parser. If this fails, log and return no metadata.
+    parser = etree.HTMLParser(recover=True, encoding="utf-8")
+
+    # Attempt to parse the body. If this fails, log and return no metadata.
+    tree = etree.fromstring(html_body, parser)
+
+    # The data was successfully parsed, but no tree was found.
+    if tree is None:
+        return
+
+    # Attempt to find interesting URLs (images, videos, embeds).
+    if "og:image" not in open_graph_response:
+        image_urls = _fetch_urls(tree, "img")
+        if image_urls:
+            open_graph_response["og:image"] = image_urls[0]
+
+    video_urls = _fetch_urls(tree, "video") + _fetch_urls(tree, "embed")
+    if video_urls:
+        open_graph_response["og:video"] = video_urls[0]
+
+    description = parse_html_description(tree)
+    if description:
+        open_graph_response["og:description"] = description
diff --git a/synapse/media/preview_html.py b/synapse/media/preview_html.py
new file mode 100644
index 0000000000..516d0434f0
--- /dev/null
+++ b/synapse/media/preview_html.py
@@ -0,0 +1,501 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import codecs
+import logging
+import re
+from typing import (
+    TYPE_CHECKING,
+    Callable,
+    Dict,
+    Generator,
+    Iterable,
+    List,
+    Optional,
+    Set,
+    Union,
+)
+
+if TYPE_CHECKING:
+    from lxml import etree
+
+logger = logging.getLogger(__name__)
+
+_charset_match = re.compile(
+    rb'<\s*meta[^>]*charset\s*=\s*"?([a-z0-9_-]+)"?', flags=re.I
+)
+_xml_encoding_match = re.compile(
+    rb'\s*<\s*\?\s*xml[^>]*encoding="([a-z0-9_-]+)"', flags=re.I
+)
+_content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I)
+
+# Certain elements aren't meant for display.
+ARIA_ROLES_TO_IGNORE = {"directory", "menu", "menubar", "toolbar"}
+
+
+def _normalise_encoding(encoding: str) -> Optional[str]:
+    """Use the Python codec's name as the normalised entry."""
+    try:
+        return codecs.lookup(encoding).name
+    except LookupError:
+        return None
+
+
+def _get_html_media_encodings(
+    body: bytes, content_type: Optional[str]
+) -> Iterable[str]:
+    """
+    Get potential encoding of the body based on the (presumably) HTML body or the content-type header.
+
+    The precedence used for finding a character encoding is:
+
+    1. <meta> tag with a charset declared.
+    2. The XML document's character encoding attribute.
+    3. The Content-Type header.
+    4. Fallback to utf-8.
+    5. Fallback to windows-1252.
+
+    This roughly follows the algorithm used by BeautifulSoup's bs4.dammit.EncodingDetector.
+
+    Args:
+        body: The HTML document, as bytes.
+        content_type: The Content-Type header.
+
+    Returns:
+        The character encoding of the body, as a string.
+    """
+    # There's no point in returning an encoding more than once.
+    attempted_encodings: Set[str] = set()
+
+    # Limit searches to the first 1kb, since it ought to be at the top.
+    body_start = body[:1024]
+
+    # Check if it has an encoding set in a meta tag.
+    match = _charset_match.search(body_start)
+    if match:
+        encoding = _normalise_encoding(match.group(1).decode("ascii"))
+        if encoding:
+            attempted_encodings.add(encoding)
+            yield encoding
+
+    # TODO Support <meta http-equiv="Content-Type" content="text/html; charset=utf-8"/>
+
+    # Check if it has an XML document with an encoding.
+    match = _xml_encoding_match.match(body_start)
+    if match:
+        encoding = _normalise_encoding(match.group(1).decode("ascii"))
+        if encoding and encoding not in attempted_encodings:
+            attempted_encodings.add(encoding)
+            yield encoding
+
+    # Check the HTTP Content-Type header for a character set.
+    if content_type:
+        content_match = _content_type_match.match(content_type)
+        if content_match:
+            encoding = _normalise_encoding(content_match.group(1))
+            if encoding and encoding not in attempted_encodings:
+                attempted_encodings.add(encoding)
+                yield encoding
+
+    # Finally, fallback to UTF-8, then windows-1252.
+    for fallback in ("utf-8", "cp1252"):
+        if fallback not in attempted_encodings:
+            yield fallback
+
+
+def decode_body(
+    body: bytes, uri: str, content_type: Optional[str] = None
+) -> Optional["etree.Element"]:
+    """
+    This uses lxml to parse the HTML document.
+
+    Args:
+        body: The HTML document, as bytes.
+        uri: The URI used to download the body.
+        content_type: The Content-Type header.
+
+    Returns:
+        The parsed HTML body, or None if an error occurred during processed.
+    """
+    # If there's no body, nothing useful is going to be found.
+    if not body:
+        return None
+
+    # The idea here is that multiple encodings are tried until one works.
+    # Unfortunately the result is never used and then LXML will decode the string
+    # again with the found encoding.
+    for encoding in _get_html_media_encodings(body, content_type):
+        try:
+            body.decode(encoding)
+        except Exception:
+            pass
+        else:
+            break
+    else:
+        logger.warning("Unable to decode HTML body for %s", uri)
+        return None
+
+    from lxml import etree
+
+    # Create an HTML parser.
+    parser = etree.HTMLParser(recover=True, encoding=encoding)
+
+    # Attempt to parse the body. Returns None if the body was successfully
+    # parsed, but no tree was found.
+    return etree.fromstring(body, parser)
+
+
+def _get_meta_tags(
+    tree: "etree.Element",
+    property: str,
+    prefix: str,
+    property_mapper: Optional[Callable[[str], Optional[str]]] = None,
+) -> Dict[str, Optional[str]]:
+    """
+    Search for meta tags prefixed with a particular string.
+
+    Args:
+        tree: The parsed HTML document.
+        property: The name of the property which contains the tag name, e.g.
+            "property" for Open Graph.
+        prefix: The prefix on the property to search for, e.g. "og" for Open Graph.
+        property_mapper: An optional callable to map the property to the Open Graph
+            form. Can return None for a key to ignore that key.
+
+    Returns:
+        A map of tag name to value.
+    """
+    results: Dict[str, Optional[str]] = {}
+    for tag in tree.xpath(
+        f"//*/meta[starts-with(@{property}, '{prefix}:')][@content][not(@content='')]"
+    ):
+        # if we've got more than 50 tags, someone is taking the piss
+        if len(results) >= 50:
+            logger.warning(
+                "Skipping parsing of Open Graph for page with too many '%s:' tags",
+                prefix,
+            )
+            return {}
+
+        key = tag.attrib[property]
+        if property_mapper:
+            key = property_mapper(key)
+            # None is a special value used to ignore a value.
+            if key is None:
+                continue
+
+        results[key] = tag.attrib["content"]
+
+    return results
+
+
+def _map_twitter_to_open_graph(key: str) -> Optional[str]:
+    """
+    Map a Twitter card property to the analogous Open Graph property.
+
+    Args:
+        key: The Twitter card property (starts with "twitter:").
+
+    Returns:
+        The Open Graph property (starts with "og:") or None to have this property
+        be ignored.
+    """
+    # Twitter card properties with no analogous Open Graph property.
+    if key == "twitter:card" or key == "twitter:creator":
+        return None
+    if key == "twitter:site":
+        return "og:site_name"
+    # Otherwise, swap twitter to og.
+    return "og" + key[7:]
+
+
+def parse_html_to_open_graph(tree: "etree.Element") -> Dict[str, Optional[str]]:
+    """
+    Parse the HTML document into an Open Graph response.
+
+    This uses lxml to search the HTML document for Open Graph data (or
+    synthesizes it from the document).
+
+    Args:
+        tree: The parsed HTML document.
+
+    Returns:
+        The Open Graph response as a dictionary.
+    """
+
+    # Search for Open Graph (og:) meta tags, e.g.:
+    #
+    # "og:type"         : "video",
+    # "og:url"          : "https://www.youtube.com/watch?v=LXDBoHyjmtw",
+    # "og:site_name"    : "YouTube",
+    # "og:video:type"   : "application/x-shockwave-flash",
+    # "og:description"  : "Fun stuff happening here",
+    # "og:title"        : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon",
+    # "og:image"        : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg",
+    # "og:video:url"    : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1",
+    # "og:video:width"  : "1280"
+    # "og:video:height" : "720",
+    # "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3",
+
+    og = _get_meta_tags(tree, "property", "og")
+
+    # TODO: Search for properties specific to the different Open Graph types,
+    # such as article: meta tags, e.g.:
+    #
+    # "article:publisher" : "https://www.facebook.com/thethudonline" />
+    # "article:author" content="https://www.facebook.com/thethudonline" />
+    # "article:tag" content="baby" />
+    # "article:section" content="Breaking News" />
+    # "article:published_time" content="2016-03-31T19:58:24+00:00" />
+    # "article:modified_time" content="2016-04-01T18:31:53+00:00" />
+
+    # Search for Twitter Card (twitter:) meta tags, e.g.:
+    #
+    # "twitter:site"    : "@matrixdotorg"
+    # "twitter:creator" : "@matrixdotorg"
+    #
+    # Twitter cards tags also duplicate Open Graph tags.
+    #
+    # See https://developer.twitter.com/en/docs/twitter-for-websites/cards/guides/getting-started
+    twitter = _get_meta_tags(tree, "name", "twitter", _map_twitter_to_open_graph)
+    # Merge the Twitter values with the Open Graph values, but do not overwrite
+    # information from Open Graph tags.
+    for key, value in twitter.items():
+        if key not in og:
+            og[key] = value
+
+    if "og:title" not in og:
+        # Attempt to find a title from the title tag, or the biggest header on the page.
+        title = tree.xpath("((//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1])/text()")
+        if title:
+            og["og:title"] = title[0].strip()
+        else:
+            og["og:title"] = None
+
+    if "og:image" not in og:
+        meta_image = tree.xpath(
+            "//*/meta[translate(@itemprop, 'IMAGE', 'image')='image'][not(@content='')]/@content[1]"
+        )
+        # If a meta image is found, use it.
+        if meta_image:
+            og["og:image"] = meta_image[0]
+        else:
+            # Try to find images which are larger than 10px by 10px.
+            #
+            # TODO: consider inlined CSS styles as well as width & height attribs
+            images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
+            images = sorted(
+                images,
+                key=lambda i: (
+                    -1 * float(i.attrib["width"]) * float(i.attrib["height"])
+                ),
+            )
+            # If no images were found, try to find *any* images.
+            if not images:
+                images = tree.xpath("//img[@src][1]")
+            if images:
+                og["og:image"] = images[0].attrib["src"]
+
+            # Finally, fallback to the favicon if nothing else.
+            else:
+                favicons = tree.xpath("//link[@href][contains(@rel, 'icon')]/@href[1]")
+                if favicons:
+                    og["og:image"] = favicons[0]
+
+    if "og:description" not in og:
+        # Check the first meta description tag for content.
+        meta_description = tree.xpath(
+            "//*/meta[translate(@name, 'DESCRIPTION', 'description')='description'][not(@content='')]/@content[1]"
+        )
+        # If a meta description is found with content, use it.
+        if meta_description:
+            og["og:description"] = meta_description[0]
+        else:
+            og["og:description"] = parse_html_description(tree)
+    elif og["og:description"]:
+        # This must be a non-empty string at this point.
+        assert isinstance(og["og:description"], str)
+        og["og:description"] = summarize_paragraphs([og["og:description"]])
+
+    # TODO: delete the url downloads to stop diskfilling,
+    # as we only ever cared about its OG
+    return og
+
+
+def parse_html_description(tree: "etree.Element") -> Optional[str]:
+    """
+    Calculate a text description based on an HTML document.
+
+    Grabs any text nodes which are inside the <body/> tag, unless they are within
+    an HTML5 semantic markup tag (<header/>, <nav/>, <aside/>, <footer/>), or
+    if they are within a <script/>, <svg/> or <style/> tag, or if they are within
+    a tag whose content is usually only shown to old browsers
+    (<iframe/>, <video/>, <canvas/>, <picture/>).
+
+    This is a very very very coarse approximation to a plain text render of the page.
+
+    Args:
+        tree: The parsed HTML document.
+
+    Returns:
+        The plain text description, or None if one cannot be generated.
+    """
+    # We don't just use XPATH here as that is slow on some machines.
+
+    from lxml import etree
+
+    TAGS_TO_REMOVE = {
+        "header",
+        "nav",
+        "aside",
+        "footer",
+        "script",
+        "noscript",
+        "style",
+        "svg",
+        "iframe",
+        "video",
+        "canvas",
+        "img",
+        "picture",
+        etree.Comment,
+    }
+
+    # Split all the text nodes into paragraphs (by splitting on new
+    # lines)
+    text_nodes = (
+        re.sub(r"\s+", "\n", el).strip()
+        for el in _iterate_over_text(tree.find("body"), TAGS_TO_REMOVE)
+    )
+    return summarize_paragraphs(text_nodes)
+
+
+def _iterate_over_text(
+    tree: Optional["etree.Element"],
+    tags_to_ignore: Set[Union[str, "etree.Comment"]],
+    stack_limit: int = 1024,
+) -> Generator[str, None, None]:
+    """Iterate over the tree returning text nodes in a depth first fashion,
+    skipping text nodes inside certain tags.
+
+    Args:
+        tree: The parent element to iterate. Can be None if there isn't one.
+        tags_to_ignore: Set of tags to ignore
+        stack_limit: Maximum stack size limit for depth-first traversal.
+            Nodes will be dropped if this limit is hit, which may truncate the
+            textual result.
+            Intended to limit the maximum working memory when generating a preview.
+    """
+
+    if tree is None:
+        return
+
+    # This is a stack whose items are elements to iterate over *or* strings
+    # to be returned.
+    elements: List[Union[str, "etree.Element"]] = [tree]
+    while elements:
+        el = elements.pop()
+
+        if isinstance(el, str):
+            yield el
+        elif el.tag not in tags_to_ignore:
+            # If the element isn't meant for display, ignore it.
+            if el.get("role") in ARIA_ROLES_TO_IGNORE:
+                continue
+
+            # el.text is the text before the first child, so we can immediately
+            # return it if the text exists.
+            if el.text:
+                yield el.text
+
+            # We add to the stack all the element's children, interspersed with
+            # each child's tail text (if it exists).
+            #
+            # We iterate in reverse order so that earlier pieces of text appear
+            # closer to the top of the stack.
+            for child in el.iterchildren(reversed=True):
+                if len(elements) > stack_limit:
+                    # We've hit our limit for working memory
+                    break
+
+                if child.tail:
+                    # The tail text of a node is text that comes *after* the node,
+                    # so we always include it even if we ignore the child node.
+                    elements.append(child.tail)
+
+                elements.append(child)
+
+
+def summarize_paragraphs(
+    text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500
+) -> Optional[str]:
+    """
+    Try to get a summary respecting first paragraph and then word boundaries.
+
+    Args:
+        text_nodes: The paragraphs to summarize.
+        min_size: The minimum number of words to include.
+        max_size: The maximum number of words to include.
+
+    Returns:
+        A summary of the text nodes, or None if that was not possible.
+    """
+
+    # TODO: Respect sentences?
+
+    description = ""
+
+    # Keep adding paragraphs until we get to the MIN_SIZE.
+    for text_node in text_nodes:
+        if len(description) < min_size:
+            text_node = re.sub(r"[\t \r\n]+", " ", text_node)
+            description += text_node + "\n\n"
+        else:
+            break
+
+    description = description.strip()
+    description = re.sub(r"[\t ]+", " ", description)
+    description = re.sub(r"[\t \r\n]*[\r\n]+", "\n\n", description)
+
+    # If the concatenation of paragraphs to get above MIN_SIZE
+    # took us over MAX_SIZE, then we need to truncate mid paragraph
+    if len(description) > max_size:
+        new_desc = ""
+
+        # This splits the paragraph into words, but keeping the
+        # (preceding) whitespace intact so we can easily concat
+        # words back together.
+        for match in re.finditer(r"\s*\S+", description):
+            word = match.group()
+
+            # Keep adding words while the total length is less than
+            # MAX_SIZE.
+            if len(word) + len(new_desc) < max_size:
+                new_desc += word
+            else:
+                # At this point the next word *will* take us over
+                # MAX_SIZE, but we also want to ensure that its not
+                # a huge word. If it is add it anyway and we'll
+                # truncate later.
+                if len(new_desc) < min_size:
+                    new_desc += word
+                break
+
+        # Double check that we're not over the limit
+        if len(new_desc) > max_size:
+            new_desc = new_desc[:max_size]
+
+        # We always add an ellipsis because at the very least
+        # we chopped mid paragraph.
+        description = new_desc.strip() + "…"
+    return description if description else None
diff --git a/synapse/media/storage_provider.py b/synapse/media/storage_provider.py
new file mode 100644
index 0000000000..1c9b71d69c
--- /dev/null
+++ b/synapse/media/storage_provider.py
@@ -0,0 +1,181 @@
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import abc
+import logging
+import os
+import shutil
+from typing import TYPE_CHECKING, Callable, Optional
+
+from synapse.config._base import Config
+from synapse.logging.context import defer_to_thread, run_in_background
+from synapse.util.async_helpers import maybe_awaitable
+
+from ._base import FileInfo, Responder
+from .media_storage import FileResponder
+
+logger = logging.getLogger(__name__)
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+
+class StorageProvider(metaclass=abc.ABCMeta):
+    """A storage provider is a service that can store uploaded media and
+    retrieve them.
+    """
+
+    @abc.abstractmethod
+    async def store_file(self, path: str, file_info: FileInfo) -> None:
+        """Store the file described by file_info. The actual contents can be
+        retrieved by reading the file in file_info.upload_path.
+
+        Args:
+            path: Relative path of file in local cache
+            file_info: The metadata of the file.
+        """
+
+    @abc.abstractmethod
+    async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
+        """Attempt to fetch the file described by file_info and stream it
+        into writer.
+
+        Args:
+            path: Relative path of file in local cache
+            file_info: The metadata of the file.
+
+        Returns:
+            Returns a Responder if the provider has the file, otherwise returns None.
+        """
+
+
+class StorageProviderWrapper(StorageProvider):
+    """Wraps a storage provider and provides various config options
+
+    Args:
+        backend: The storage provider to wrap.
+        store_local: Whether to store new local files or not.
+        store_synchronous: Whether to wait for file to be successfully
+            uploaded, or todo the upload in the background.
+        store_remote: Whether remote media should be uploaded
+    """
+
+    def __init__(
+        self,
+        backend: StorageProvider,
+        store_local: bool,
+        store_synchronous: bool,
+        store_remote: bool,
+    ):
+        self.backend = backend
+        self.store_local = store_local
+        self.store_synchronous = store_synchronous
+        self.store_remote = store_remote
+
+    def __str__(self) -> str:
+        return "StorageProviderWrapper[%s]" % (self.backend,)
+
+    async def store_file(self, path: str, file_info: FileInfo) -> None:
+        if not file_info.server_name and not self.store_local:
+            return None
+
+        if file_info.server_name and not self.store_remote:
+            return None
+
+        if file_info.url_cache:
+            # The URL preview cache is short lived and not worth offloading or
+            # backing up.
+            return None
+
+        if self.store_synchronous:
+            # store_file is supposed to return an Awaitable, but guard
+            # against improper implementations.
+            await maybe_awaitable(self.backend.store_file(path, file_info))  # type: ignore
+        else:
+            # TODO: Handle errors.
+            async def store() -> None:
+                try:
+                    return await maybe_awaitable(
+                        self.backend.store_file(path, file_info)
+                    )
+                except Exception:
+                    logger.exception("Error storing file")
+
+            run_in_background(store)
+
+    async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
+        if file_info.url_cache:
+            # Files in the URL preview cache definitely aren't stored here,
+            # so avoid any potentially slow I/O or network access.
+            return None
+
+        # store_file is supposed to return an Awaitable, but guard
+        # against improper implementations.
+        return await maybe_awaitable(self.backend.fetch(path, file_info))
+
+
+class FileStorageProviderBackend(StorageProvider):
+    """A storage provider that stores files in a directory on a filesystem.
+
+    Args:
+        hs
+        config: The config returned by `parse_config`.
+    """
+
+    def __init__(self, hs: "HomeServer", config: str):
+        self.hs = hs
+        self.cache_directory = hs.config.media.media_store_path
+        self.base_directory = config
+
+    def __str__(self) -> str:
+        return "FileStorageProviderBackend[%s]" % (self.base_directory,)
+
+    async def store_file(self, path: str, file_info: FileInfo) -> None:
+        """See StorageProvider.store_file"""
+
+        primary_fname = os.path.join(self.cache_directory, path)
+        backup_fname = os.path.join(self.base_directory, path)
+
+        dirname = os.path.dirname(backup_fname)
+        os.makedirs(dirname, exist_ok=True)
+
+        # mypy needs help inferring the type of the second parameter, which is generic
+        shutil_copyfile: Callable[[str, str], str] = shutil.copyfile
+        await defer_to_thread(
+            self.hs.get_reactor(),
+            shutil_copyfile,
+            primary_fname,
+            backup_fname,
+        )
+
+    async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
+        """See StorageProvider.fetch"""
+
+        backup_fname = os.path.join(self.base_directory, path)
+        if os.path.isfile(backup_fname):
+            return FileResponder(open(backup_fname, "rb"))
+
+        return None
+
+    @staticmethod
+    def parse_config(config: dict) -> str:
+        """Called on startup to parse config supplied. This should parse
+        the config and raise if there is a problem.
+
+        The returned value is passed into the constructor.
+
+        In this case we only care about a single param, the directory, so let's
+        just pull that out.
+        """
+        return Config.ensure_directory(config["directory"])
diff --git a/synapse/media/thumbnailer.py b/synapse/media/thumbnailer.py
new file mode 100644
index 0000000000..f909a4fb9a
--- /dev/null
+++ b/synapse/media/thumbnailer.py
@@ -0,0 +1,221 @@
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from io import BytesIO
+from types import TracebackType
+from typing import Optional, Tuple, Type
+
+from PIL import Image
+
+logger = logging.getLogger(__name__)
+
+EXIF_ORIENTATION_TAG = 0x0112
+EXIF_TRANSPOSE_MAPPINGS = {
+    2: Image.FLIP_LEFT_RIGHT,
+    3: Image.ROTATE_180,
+    4: Image.FLIP_TOP_BOTTOM,
+    5: Image.TRANSPOSE,
+    6: Image.ROTATE_270,
+    7: Image.TRANSVERSE,
+    8: Image.ROTATE_90,
+}
+
+
+class ThumbnailError(Exception):
+    """An error occurred generating a thumbnail."""
+
+
+class Thumbnailer:
+    FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"}
+
+    @staticmethod
+    def set_limits(max_image_pixels: int) -> None:
+        Image.MAX_IMAGE_PIXELS = max_image_pixels
+
+    def __init__(self, input_path: str):
+        # Have we closed the image?
+        self._closed = False
+
+        try:
+            self.image = Image.open(input_path)
+        except OSError as e:
+            # If an error occurs opening the image, a thumbnail won't be able to
+            # be generated.
+            raise ThumbnailError from e
+        except Image.DecompressionBombError as e:
+            # If an image decompression bomb error occurs opening the image,
+            # then the image exceeds the pixel limit and a thumbnail won't
+            # be able to be generated.
+            raise ThumbnailError from e
+
+        self.width, self.height = self.image.size
+        self.transpose_method = None
+        try:
+            # We don't use ImageOps.exif_transpose since it crashes with big EXIF
+            #
+            # Ignore safety: Pillow seems to acknowledge that this method is
+            # "private, experimental, but generally widely used". Pillow 6
+            # includes a public getexif() method (no underscore) that we might
+            # consider using instead when we can bump that dependency.
+            #
+            # At the time of writing, Debian buster (currently oldstable)
+            # provides version 5.4.1. It's expected to EOL in mid-2022, see
+            # https://wiki.debian.org/DebianReleases#Production_Releases
+            image_exif = self.image._getexif()  # type: ignore
+            if image_exif is not None:
+                image_orientation = image_exif.get(EXIF_ORIENTATION_TAG)
+                assert type(image_orientation) is int
+                self.transpose_method = EXIF_TRANSPOSE_MAPPINGS.get(image_orientation)
+        except Exception as e:
+            # A lot of parsing errors can happen when parsing EXIF
+            logger.info("Error parsing image EXIF information: %s", e)
+
+    def transpose(self) -> Tuple[int, int]:
+        """Transpose the image using its EXIF Orientation tag
+
+        Returns:
+            A tuple containing the new image size in pixels as (width, height).
+        """
+        if self.transpose_method is not None:
+            # Safety: `transpose` takes an int rather than e.g. an IntEnum.
+            # self.transpose_method is set above to be a value in
+            # EXIF_TRANSPOSE_MAPPINGS, and that only contains correct values.
+            with self.image:
+                self.image = self.image.transpose(self.transpose_method)  # type: ignore[arg-type]
+            self.width, self.height = self.image.size
+            self.transpose_method = None
+            # We don't need EXIF any more
+            self.image.info["exif"] = None
+        return self.image.size
+
+    def aspect(self, max_width: int, max_height: int) -> Tuple[int, int]:
+        """Calculate the largest size that preserves aspect ratio which
+        fits within the given rectangle::
+
+            (w_in / h_in) = (w_out / h_out)
+            w_out = max(min(w_max, h_max * (w_in / h_in)), 1)
+            h_out = max(min(h_max, w_max * (h_in / w_in)), 1)
+
+        Args:
+            max_width: The largest possible width.
+            max_height: The largest possible height.
+        """
+
+        if max_width * self.height < max_height * self.width:
+            return max_width, max((max_width * self.height) // self.width, 1)
+        else:
+            return max((max_height * self.width) // self.height, 1), max_height
+
+    def _resize(self, width: int, height: int) -> Image.Image:
+        # 1-bit or 8-bit color palette images need converting to RGB
+        # otherwise they will be scaled using nearest neighbour which
+        # looks awful.
+        #
+        # If the image has transparency, use RGBA instead.
+        if self.image.mode in ["1", "L", "P"]:
+            if self.image.info.get("transparency", None) is not None:
+                with self.image:
+                    self.image = self.image.convert("RGBA")
+            else:
+                with self.image:
+                    self.image = self.image.convert("RGB")
+        return self.image.resize((width, height), Image.ANTIALIAS)
+
+    def scale(self, width: int, height: int, output_type: str) -> BytesIO:
+        """Rescales the image to the given dimensions.
+
+        Returns:
+            The bytes of the encoded image ready to be written to disk
+        """
+        with self._resize(width, height) as scaled:
+            return self._encode_image(scaled, output_type)
+
+    def crop(self, width: int, height: int, output_type: str) -> BytesIO:
+        """Rescales and crops the image to the given dimensions preserving
+        aspect::
+            (w_in / h_in) = (w_scaled / h_scaled)
+            w_scaled = max(w_out, h_out * (w_in / h_in))
+            h_scaled = max(h_out, w_out * (h_in / w_in))
+
+        Args:
+            max_width: The largest possible width.
+            max_height: The largest possible height.
+
+        Returns:
+            The bytes of the encoded image ready to be written to disk
+        """
+        if width * self.height > height * self.width:
+            scaled_width = width
+            scaled_height = (width * self.height) // self.width
+            crop_top = (scaled_height - height) // 2
+            crop_bottom = height + crop_top
+            crop = (0, crop_top, width, crop_bottom)
+        else:
+            scaled_width = (height * self.width) // self.height
+            scaled_height = height
+            crop_left = (scaled_width - width) // 2
+            crop_right = width + crop_left
+            crop = (crop_left, 0, crop_right, height)
+
+        with self._resize(scaled_width, scaled_height) as scaled_image:
+            with scaled_image.crop(crop) as cropped:
+                return self._encode_image(cropped, output_type)
+
+    def _encode_image(self, output_image: Image.Image, output_type: str) -> BytesIO:
+        output_bytes_io = BytesIO()
+        fmt = self.FORMATS[output_type]
+        if fmt == "JPEG":
+            output_image = output_image.convert("RGB")
+        output_image.save(output_bytes_io, fmt, quality=80)
+        return output_bytes_io
+
+    def close(self) -> None:
+        """Closes the underlying image file.
+
+        Once closed no other functions can be called.
+
+        Can be called multiple times.
+        """
+
+        if self._closed:
+            return
+
+        self._closed = True
+
+        # Since we run this on the finalizer then we need to handle `__init__`
+        # raising an exception before it can define `self.image`.
+        image = getattr(self, "image", None)
+        if image is None:
+            return
+
+        image.close()
+
+    def __enter__(self) -> "Thumbnailer":
+        """Make `Thumbnailer` a context manager that calls `close` on
+        `__exit__`.
+        """
+        return self
+
+    def __exit__(
+        self,
+        type: Optional[Type[BaseException]],
+        value: Optional[BaseException],
+        traceback: Optional[TracebackType],
+    ) -> None:
+        self.close()
+
+    def __del__(self) -> None:
+        # Make sure we actually do close the image, rather than leak data.
+        self.close()