summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/15503.feature1
-rw-r--r--docs/usage/configuration/config_documentation.md34
-rw-r--r--synapse/api/errors.py2
-rw-r--r--synapse/config/ratelimiting.py7
-rw-r--r--synapse/config/repository.py6
-rw-r--r--synapse/media/_base.py6
-rw-r--r--synapse/media/media_repository.py220
-rw-r--r--synapse/rest/media/create_resource.py83
-rw-r--r--synapse/rest/media/download_resource.py22
-rw-r--r--synapse/rest/media/media_repository_resource.py8
-rw-r--r--synapse/rest/media/thumbnail_resource.py69
-rw-r--r--synapse/rest/media/upload_resource.py75
-rw-r--r--synapse/storage/databases/main/media_repository.py90
-rw-r--r--tests/media/test_media_storage.py4
14 files changed, 568 insertions, 59 deletions
diff --git a/changelog.d/15503.feature b/changelog.d/15503.feature
new file mode 100644
index 0000000000..b6ca97a2cf
--- /dev/null
+++ b/changelog.d/15503.feature
@@ -0,0 +1 @@
+Add support for asynchronous uploads as defined by [MSC2246](https://github.com/matrix-org/matrix-spec-proposals/pull/2246). Contributed by @sumnerevans at @beeper.
diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md
index 4200e70c83..7c4e742cd5 100644
--- a/docs/usage/configuration/config_documentation.md
+++ b/docs/usage/configuration/config_documentation.md
@@ -1753,6 +1753,19 @@ rc_third_party_invite:
   burst_count: 10
 ```
 ---
+### `rc_media_create`
+
+This option ratelimits creation of MXC URIs via the `/_matrix/media/v1/create`
+endpoint based on the account that's creating the media. Defaults to
+`per_second: 10`, `burst_count: 50`.
+
+Example configuration:
+```yaml
+rc_media_create:
+  per_second: 10
+  burst_count: 50
+```
+---
 ### `rc_federation`
 
 Defines limits on federation requests.
@@ -1814,6 +1827,27 @@ Example configuration:
 media_store_path: "DATADIR/media_store"
 ```
 ---
+### `max_pending_media_uploads`
+
+How many *pending media uploads* can a given user have? A pending media upload
+is a created MXC URI that (a) is not expired (the `unused_expires_at` timestamp
+has not passed) and (b) the media has not yet been uploaded for. Defaults to 5.
+
+Example configuration:
+```yaml
+max_pending_media_uploads: 5
+```
+---
+### `unused_expiration_time`
+
+How long to wait in milliseconds before expiring created media IDs. Defaults to
+"24h"
+
+Example configuration:
+```yaml
+unused_expiration_time: "1h"
+```
+---
 ### `media_storage_providers`
 
 Media storage providers allow media to be stored in different
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index fdb2955be8..fbd8b16ec3 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -83,6 +83,8 @@ class Codes(str, Enum):
     USER_DEACTIVATED = "M_USER_DEACTIVATED"
     # USER_LOCKED = "M_USER_LOCKED"
     USER_LOCKED = "ORG_MATRIX_MSC3939_USER_LOCKED"
+    NOT_YET_UPLOADED = "M_NOT_YET_UPLOADED"
+    CANNOT_OVERWRITE_MEDIA = "M_CANNOT_OVERWRITE_MEDIA"
 
     # Part of MSC3848
     # https://github.com/matrix-org/matrix-spec-proposals/pull/3848
diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index 4efbaeac0d..b1fcaf71a3 100644
--- a/synapse/config/ratelimiting.py
+++ b/synapse/config/ratelimiting.py
@@ -204,3 +204,10 @@ class RatelimitConfig(Config):
             "rc_third_party_invite",
             defaults={"per_second": 0.0025, "burst_count": 5},
         )
+
+        # Ratelimit create media requests:
+        self.rc_media_create = RatelimitSettings.parse(
+            config,
+            "rc_media_create",
+            defaults={"per_second": 10, "burst_count": 50},
+        )
diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index f6cfdd3e04..839c026d70 100644
--- a/synapse/config/repository.py
+++ b/synapse/config/repository.py
@@ -141,6 +141,12 @@ class ContentRepositoryConfig(Config):
             "prevent_media_downloads_from", []
         )
 
+        self.unused_expiration_time = self.parse_duration(
+            config.get("unused_expiration_time", "24h")
+        )
+
+        self.max_pending_media_uploads = config.get("max_pending_media_uploads", 5)
+
         self.media_store_path = self.ensure_directory(
             config.get("media_store_path", "media_store")
         )
diff --git a/synapse/media/_base.py b/synapse/media/_base.py
index 860e5ddca2..9d88a711cf 100644
--- a/synapse/media/_base.py
+++ b/synapse/media/_base.py
@@ -83,6 +83,12 @@ INLINE_CONTENT_TYPES = [
     "audio/x-flac",
 ]
 
+# Default timeout_ms for download and thumbnail requests
+DEFAULT_MAX_TIMEOUT_MS = 20_000
+
+# Maximum allowed timeout_ms for download and thumbnail requests
+MAXIMUM_ALLOWED_MAX_TIMEOUT_MS = 60_000
+
 
 def respond_404(request: SynapseRequest) -> None:
     assert request.path is not None
diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py
index 1957426c6a..bf976b9e7c 100644
--- a/synapse/media/media_repository.py
+++ b/synapse/media/media_repository.py
@@ -27,13 +27,16 @@ import twisted.web.http
 from twisted.internet.defer import Deferred
 
 from synapse.api.errors import (
+    Codes,
     FederationDeniedError,
     HttpResponseException,
     NotFoundError,
     RequestSendFailed,
     SynapseError,
+    cs_error,
 )
 from synapse.config.repository import ThumbnailRequirement
+from synapse.http.server import respond_with_json
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import defer_to_thread
 from synapse.logging.opentracing import trace
@@ -51,7 +54,7 @@ from synapse.media.storage_provider import StorageProviderWrapper
 from synapse.media.thumbnailer import Thumbnailer, ThumbnailError
 from synapse.media.url_previewer import UrlPreviewer
 from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage.databases.main.media_repository import RemoteMedia
+from synapse.storage.databases.main.media_repository import LocalMedia, RemoteMedia
 from synapse.types import UserID
 from synapse.util.async_helpers import Linearizer
 from synapse.util.retryutils import NotRetryingDestination
@@ -80,6 +83,8 @@ class MediaRepository:
         self.store = hs.get_datastores().main
         self.max_upload_size = hs.config.media.max_upload_size
         self.max_image_pixels = hs.config.media.max_image_pixels
+        self.unused_expiration_time = hs.config.media.unused_expiration_time
+        self.max_pending_media_uploads = hs.config.media.max_pending_media_uploads
 
         Thumbnailer.set_limits(self.max_image_pixels)
 
@@ -186,6 +191,117 @@ class MediaRepository:
             self.recently_accessed_locals.add(media_id)
 
     @trace
+    async def create_media_id(self, auth_user: UserID) -> Tuple[str, int]:
+        """Create and store a media ID for a local user and return the MXC URI and its
+        expiration.
+
+        Args:
+            auth_user: The user_id of the uploader
+
+        Returns:
+            A tuple containing the MXC URI of the stored content and the timestamp at
+            which the MXC URI expires.
+        """
+        media_id = random_string(24)
+        now = self.clock.time_msec()
+        await self.store.store_local_media_id(
+            media_id=media_id,
+            time_now_ms=now,
+            user_id=auth_user,
+        )
+        return f"mxc://{self.server_name}/{media_id}", now + self.unused_expiration_time
+
+    @trace
+    async def reached_pending_media_limit(self, auth_user: UserID) -> Tuple[bool, int]:
+        """Check if the user is over the limit for pending media uploads.
+
+        Args:
+            auth_user: The user_id of the uploader
+
+        Returns:
+            A tuple with a boolean and an integer indicating whether the user has too
+            many pending media uploads and the timestamp at which the first pending
+            media will expire, respectively.
+        """
+        pending, first_expiration_ts = await self.store.count_pending_media(
+            user_id=auth_user
+        )
+        return pending >= self.max_pending_media_uploads, first_expiration_ts
+
+    @trace
+    async def verify_can_upload(self, media_id: str, auth_user: UserID) -> None:
+        """Verify that the media ID can be uploaded to by the given user. This
+        function checks that:
+
+        * the media ID exists
+        * the media ID does not already have content
+        * the user uploading is the same as the one who created the media ID
+        * the media ID has not expired
+
+        Args:
+            media_id: The media ID to verify
+            auth_user: The user_id of the uploader
+        """
+        media = await self.store.get_local_media(media_id)
+        if media is None:
+            raise SynapseError(404, "Unknow media ID", errcode=Codes.NOT_FOUND)
+
+        if media.user_id != auth_user.to_string():
+            raise SynapseError(
+                403,
+                "Only the creator of the media ID can upload to it",
+                errcode=Codes.FORBIDDEN,
+            )
+
+        if media.media_length is not None:
+            raise SynapseError(
+                409,
+                "Media ID already has content",
+                errcode=Codes.CANNOT_OVERWRITE_MEDIA,
+            )
+
+        expired_time_ms = self.clock.time_msec() - self.unused_expiration_time
+        if media.created_ts < expired_time_ms:
+            raise NotFoundError("Media ID has expired")
+
+    @trace
+    async def update_content(
+        self,
+        media_id: str,
+        media_type: str,
+        upload_name: Optional[str],
+        content: IO,
+        content_length: int,
+        auth_user: UserID,
+    ) -> None:
+        """Update the content of the given media ID.
+
+        Args:
+            media_id: The media ID to replace.
+            media_type: The content type of the file.
+            upload_name: The name of the file, if provided.
+            content: A file like object that is the content to store
+            content_length: The length of the content
+            auth_user: The user_id of the uploader
+        """
+        file_info = FileInfo(server_name=None, file_id=media_id)
+        fname = await self.media_storage.store_file(content, file_info)
+        logger.info("Stored local media in file %r", fname)
+
+        await self.store.update_local_media(
+            media_id=media_id,
+            media_type=media_type,
+            upload_name=upload_name,
+            media_length=content_length,
+            user_id=auth_user,
+        )
+
+        try:
+            await self._generate_thumbnails(None, media_id, media_id, media_type)
+        except Exception as e:
+            logger.info("Failed to generate thumbnails: %s", e)
+
+    @trace
     async def create_content(
         self,
         media_type: str,
@@ -231,8 +347,74 @@ class MediaRepository:
 
         return MXCUri(self.server_name, media_id)
 
+    def respond_not_yet_uploaded(self, request: SynapseRequest) -> None:
+        respond_with_json(
+            request,
+            504,
+            cs_error("Media has not been uploaded yet", code=Codes.NOT_YET_UPLOADED),
+            send_cors=True,
+        )
+
+    async def get_local_media_info(
+        self, request: SynapseRequest, media_id: str, max_timeout_ms: int
+    ) -> Optional[LocalMedia]:
+        """Gets the info dictionary for given local media ID. If the media has
+        not been uploaded yet, this function will wait up to ``max_timeout_ms``
+        milliseconds for the media to be uploaded.
+
+        Args:
+            request: The incoming request.
+            media_id: The media ID of the content. (This is the same as
+                the file_id for local content.)
+            max_timeout_ms: the maximum number of milliseconds to wait for the
+                media to be uploaded.
+
+        Returns:
+            Either the info dictionary for the given local media ID or
+            ``None``. If ``None``, then no further processing is necessary as
+            this function will send the necessary JSON response.
+        """
+        wait_until = self.clock.time_msec() + max_timeout_ms
+        while True:
+            # Get the info for the media
+            media_info = await self.store.get_local_media(media_id)
+            if not media_info:
+                logger.info("Media %s is unknown", media_id)
+                respond_404(request)
+                return None
+
+            if media_info.quarantined_by:
+                logger.info("Media %s is quarantined", media_id)
+                respond_404(request)
+                return None
+
+            # The file has been uploaded, so stop looping
+            if media_info.media_length is not None:
+                return media_info
+
+            # Check if the media ID has expired and still hasn't been uploaded to.
+            now = self.clock.time_msec()
+            expired_time_ms = now - self.unused_expiration_time
+            if media_info.created_ts < expired_time_ms:
+                logger.info("Media %s has expired without being uploaded", media_id)
+                respond_404(request)
+                return None
+
+            if now >= wait_until:
+                break
+
+            await self.clock.sleep(0.5)
+
+        logger.info("Media %s has not yet been uploaded", media_id)
+        self.respond_not_yet_uploaded(request)
+        return None
+
     async def get_local_media(
-        self, request: SynapseRequest, media_id: str, name: Optional[str]
+        self,
+        request: SynapseRequest,
+        media_id: str,
+        name: Optional[str],
+        max_timeout_ms: int,
     ) -> None:
         """Responds to requests for local media, if exists, or returns 404.
 
@@ -242,13 +424,14 @@ class MediaRepository:
                 the file_id for local content.)
             name: Optional name that, if specified, will be used as
                 the filename in the Content-Disposition header of the response.
+            max_timeout_ms: the maximum number of milliseconds to wait for the
+                media to be uploaded.
 
         Returns:
             Resolves once a response has successfully been written to request
         """
-        media_info = await self.store.get_local_media(media_id)
-        if not media_info or media_info.quarantined_by:
-            respond_404(request)
+        media_info = await self.get_local_media_info(request, media_id, max_timeout_ms)
+        if not media_info:
             return
 
         self.mark_recently_accessed(None, media_id)
@@ -273,6 +456,7 @@ class MediaRepository:
         server_name: str,
         media_id: str,
         name: Optional[str],
+        max_timeout_ms: int,
     ) -> None:
         """Respond to requests for remote media.
 
@@ -282,6 +466,8 @@ class MediaRepository:
             media_id: The media ID of the content (as defined by the remote server).
             name: Optional name that, if specified, will be used as
                 the filename in the Content-Disposition header of the response.
+            max_timeout_ms: the maximum number of milliseconds to wait for the
+                media to be uploaded.
 
         Returns:
             Resolves once a response has successfully been written to request
@@ -307,11 +493,11 @@ class MediaRepository:
         key = (server_name, media_id)
         async with self.remote_media_linearizer.queue(key):
             responder, media_info = await self._get_remote_media_impl(
-                server_name, media_id
+                server_name, media_id, max_timeout_ms
             )
 
         # We deliberately stream the file outside the lock
-        if responder:
+        if responder and media_info:
             upload_name = name if name else media_info.upload_name
             await respond_with_responder(
                 request,
@@ -324,7 +510,7 @@ class MediaRepository:
             respond_404(request)
 
     async def get_remote_media_info(
-        self, server_name: str, media_id: str
+        self, server_name: str, media_id: str, max_timeout_ms: int
     ) -> RemoteMedia:
         """Gets the media info associated with the remote file, downloading
         if necessary.
@@ -332,6 +518,8 @@ class MediaRepository:
         Args:
             server_name: Remote server_name where the media originated.
             media_id: The media ID of the content (as defined by the remote server).
+            max_timeout_ms: the maximum number of milliseconds to wait for the
+                media to be uploaded.
 
         Returns:
             The media info of the file
@@ -347,7 +535,7 @@ class MediaRepository:
         key = (server_name, media_id)
         async with self.remote_media_linearizer.queue(key):
             responder, media_info = await self._get_remote_media_impl(
-                server_name, media_id
+                server_name, media_id, max_timeout_ms
             )
 
         # Ensure we actually use the responder so that it releases resources
@@ -358,7 +546,7 @@ class MediaRepository:
         return media_info
 
     async def _get_remote_media_impl(
-        self, server_name: str, media_id: str
+        self, server_name: str, media_id: str, max_timeout_ms: int
     ) -> Tuple[Optional[Responder], RemoteMedia]:
         """Looks for media in local cache, if not there then attempt to
         download from remote server.
@@ -367,6 +555,8 @@ class MediaRepository:
             server_name: Remote server_name where the media originated.
             media_id: The media ID of the content (as defined by the
                 remote server).
+            max_timeout_ms: the maximum number of milliseconds to wait for the
+                media to be uploaded.
 
         Returns:
             A tuple of responder and the media info of the file.
@@ -399,8 +589,7 @@ class MediaRepository:
 
         try:
             media_info = await self._download_remote_file(
-                server_name,
-                media_id,
+                server_name, media_id, max_timeout_ms
             )
         except SynapseError:
             raise
@@ -433,6 +622,7 @@ class MediaRepository:
         self,
         server_name: str,
         media_id: str,
+        max_timeout_ms: int,
     ) -> RemoteMedia:
         """Attempt to download the remote file from the given server name,
         using the given file_id as the local id.
@@ -442,7 +632,8 @@ class MediaRepository:
             media_id: The media ID of the content (as defined by the
                 remote server). This is different than the file_id, which is
                 locally generated.
-            file_id: Local file ID
+            max_timeout_ms: the maximum number of milliseconds to wait for the
+                media to be uploaded.
 
         Returns:
             The media info of the file.
@@ -466,7 +657,8 @@ class MediaRepository:
                         # tell the remote server to 404 if it doesn't
                         # recognise the server_name, to make sure we don't
                         # end up with a routing loop.
-                        "allow_remote": "false"
+                        "allow_remote": "false",
+                        "timeout_ms": str(max_timeout_ms),
                     },
                 )
             except RequestSendFailed as e:
diff --git a/synapse/rest/media/create_resource.py b/synapse/rest/media/create_resource.py
new file mode 100644
index 0000000000..994afdf13c
--- /dev/null
+++ b/synapse/rest/media/create_resource.py
@@ -0,0 +1,83 @@
+# Copyright 2023 Beeper Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import re
+from typing import TYPE_CHECKING
+
+from synapse.api.errors import LimitExceededError
+from synapse.api.ratelimiting import Ratelimiter
+from synapse.http.server import respond_with_json
+from synapse.http.servlet import RestServlet
+from synapse.http.site import SynapseRequest
+
+if TYPE_CHECKING:
+    from synapse.media.media_repository import MediaRepository
+    from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class CreateResource(RestServlet):
+    PATTERNS = [re.compile("/_matrix/media/v1/create")]
+
+    def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
+        super().__init__()
+
+        self.media_repo = media_repo
+        self.clock = hs.get_clock()
+        self.auth = hs.get_auth()
+        self.max_pending_media_uploads = hs.config.media.max_pending_media_uploads
+
+        # A rate limiter for creating new media IDs.
+        self._create_media_rate_limiter = Ratelimiter(
+            store=hs.get_datastores().main,
+            clock=self.clock,
+            cfg=hs.config.ratelimiting.rc_media_create,
+        )
+
+    async def on_POST(self, request: SynapseRequest) -> None:
+        requester = await self.auth.get_user_by_req(request)
+
+        # If the create media requests for the user are over the limit, drop them.
+        await self._create_media_rate_limiter.ratelimit(requester)
+
+        (
+            reached_pending_limit,
+            first_expiration_ts,
+        ) = await self.media_repo.reached_pending_media_limit(requester.user)
+        if reached_pending_limit:
+            raise LimitExceededError(
+                limiter_name="max_pending_media_uploads",
+                retry_after_ms=first_expiration_ts - self.clock.time_msec(),
+            )
+
+        content_uri, unused_expires_at = await self.media_repo.create_media_id(
+            requester.user
+        )
+
+        logger.info(
+            "Created Media URI %r that if unused will expire at %d",
+            content_uri,
+            unused_expires_at,
+        )
+        respond_with_json(
+            request,
+            200,
+            {
+                "content_uri": content_uri,
+                "unused_expires_at": unused_expires_at,
+            },
+            send_cors=True,
+        )
diff --git a/synapse/rest/media/download_resource.py b/synapse/rest/media/download_resource.py
index 65b9ff52fa..60cd87548c 100644
--- a/synapse/rest/media/download_resource.py
+++ b/synapse/rest/media/download_resource.py
@@ -17,9 +17,13 @@ import re
 from typing import TYPE_CHECKING, Optional
 
 from synapse.http.server import set_corp_headers, set_cors_headers
-from synapse.http.servlet import RestServlet, parse_boolean
+from synapse.http.servlet import RestServlet, parse_boolean, parse_integer
 from synapse.http.site import SynapseRequest
-from synapse.media._base import respond_404
+from synapse.media._base import (
+    DEFAULT_MAX_TIMEOUT_MS,
+    MAXIMUM_ALLOWED_MAX_TIMEOUT_MS,
+    respond_404,
+)
 from synapse.util.stringutils import parse_and_validate_server_name
 
 if TYPE_CHECKING:
@@ -65,12 +69,16 @@ class DownloadResource(RestServlet):
         )
         # Limited non-standard form of CSP for IE11
         request.setHeader(b"X-Content-Security-Policy", b"sandbox;")
-        request.setHeader(
-            b"Referrer-Policy",
-            b"no-referrer",
+        request.setHeader(b"Referrer-Policy", b"no-referrer")
+        max_timeout_ms = parse_integer(
+            request, "timeout_ms", default=DEFAULT_MAX_TIMEOUT_MS
         )
+        max_timeout_ms = min(max_timeout_ms, MAXIMUM_ALLOWED_MAX_TIMEOUT_MS)
+
         if self._is_mine_server_name(server_name):
-            await self.media_repo.get_local_media(request, media_id, file_name)
+            await self.media_repo.get_local_media(
+                request, media_id, file_name, max_timeout_ms
+            )
         else:
             allow_remote = parse_boolean(request, "allow_remote", default=True)
             if not allow_remote:
@@ -83,5 +91,5 @@ class DownloadResource(RestServlet):
                 return
 
             await self.media_repo.get_remote_media(
-                request, server_name, media_id, file_name
+                request, server_name, media_id, file_name, max_timeout_ms
             )
diff --git a/synapse/rest/media/media_repository_resource.py b/synapse/rest/media/media_repository_resource.py
index 2089bb1029..ca65116b84 100644
--- a/synapse/rest/media/media_repository_resource.py
+++ b/synapse/rest/media/media_repository_resource.py
@@ -18,10 +18,11 @@ from synapse.config._base import ConfigError
 from synapse.http.server import HttpServer, JsonResource
 
 from .config_resource import MediaConfigResource
+from .create_resource import CreateResource
 from .download_resource import DownloadResource
 from .preview_url_resource import PreviewUrlResource
 from .thumbnail_resource import ThumbnailResource
-from .upload_resource import UploadResource
+from .upload_resource import AsyncUploadServlet, UploadServlet
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -91,8 +92,9 @@ class MediaRepositoryResource(JsonResource):
 
         # Note that many of these should not exist as v1 endpoints, but empirically
         # a lot of traffic still goes to them.
-
-        UploadResource(hs, media_repo).register(http_server)
+        CreateResource(hs, media_repo).register(http_server)
+        UploadServlet(hs, media_repo).register(http_server)
+        AsyncUploadServlet(hs, media_repo).register(http_server)
         DownloadResource(hs, media_repo).register(http_server)
         ThumbnailResource(hs, media_repo, media_repo.media_storage).register(
             http_server
diff --git a/synapse/rest/media/thumbnail_resource.py b/synapse/rest/media/thumbnail_resource.py
index efda8b4ab4..681f2a5a27 100644
--- a/synapse/rest/media/thumbnail_resource.py
+++ b/synapse/rest/media/thumbnail_resource.py
@@ -23,6 +23,8 @@ from synapse.http.server import respond_with_json, set_corp_headers, set_cors_he
 from synapse.http.servlet import RestServlet, parse_integer, parse_string
 from synapse.http.site import SynapseRequest
 from synapse.media._base import (
+    DEFAULT_MAX_TIMEOUT_MS,
+    MAXIMUM_ALLOWED_MAX_TIMEOUT_MS,
     FileInfo,
     ThumbnailInfo,
     respond_404,
@@ -75,15 +77,19 @@ class ThumbnailResource(RestServlet):
         method = parse_string(request, "method", "scale")
         # TODO Parse the Accept header to get an prioritised list of thumbnail types.
         m_type = "image/png"
+        max_timeout_ms = parse_integer(
+            request, "timeout_ms", default=DEFAULT_MAX_TIMEOUT_MS
+        )
+        max_timeout_ms = min(max_timeout_ms, MAXIMUM_ALLOWED_MAX_TIMEOUT_MS)
 
         if self._is_mine_server_name(server_name):
             if self.dynamic_thumbnails:
                 await self._select_or_generate_local_thumbnail(
-                    request, media_id, width, height, method, m_type
+                    request, media_id, width, height, method, m_type, max_timeout_ms
                 )
             else:
                 await self._respond_local_thumbnail(
-                    request, media_id, width, height, method, m_type
+                    request, media_id, width, height, method, m_type, max_timeout_ms
                 )
             self.media_repo.mark_recently_accessed(None, media_id)
         else:
@@ -95,14 +101,21 @@ class ThumbnailResource(RestServlet):
                 respond_404(request)
                 return
 
-            if self.dynamic_thumbnails:
-                await self._select_or_generate_remote_thumbnail(
-                    request, server_name, media_id, width, height, method, m_type
-                )
-            else:
-                await self._respond_remote_thumbnail(
-                    request, server_name, media_id, width, height, method, m_type
-                )
+            remote_resp_function = (
+                self._select_or_generate_remote_thumbnail
+                if self.dynamic_thumbnails
+                else self._respond_remote_thumbnail
+            )
+            await remote_resp_function(
+                request,
+                server_name,
+                media_id,
+                width,
+                height,
+                method,
+                m_type,
+                max_timeout_ms,
+            )
             self.media_repo.mark_recently_accessed(server_name, media_id)
 
     async def _respond_local_thumbnail(
@@ -113,15 +126,12 @@ class ThumbnailResource(RestServlet):
         height: int,
         method: str,
         m_type: str,
+        max_timeout_ms: int,
     ) -> None:
-        media_info = await self.store.get_local_media(media_id)
-
+        media_info = await self.media_repo.get_local_media_info(
+            request, media_id, max_timeout_ms
+        )
         if not media_info:
-            respond_404(request)
-            return
-        if media_info.quarantined_by:
-            logger.info("Media is quarantined")
-            respond_404(request)
             return
 
         thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
@@ -146,15 +156,13 @@ class ThumbnailResource(RestServlet):
         desired_height: int,
         desired_method: str,
         desired_type: str,
+        max_timeout_ms: int,
     ) -> None:
-        media_info = await self.store.get_local_media(media_id)
+        media_info = await self.media_repo.get_local_media_info(
+            request, media_id, max_timeout_ms
+        )
 
         if not media_info:
-            respond_404(request)
-            return
-        if media_info.quarantined_by:
-            logger.info("Media is quarantined")
-            respond_404(request)
             return
 
         thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
@@ -206,8 +214,14 @@ class ThumbnailResource(RestServlet):
         desired_height: int,
         desired_method: str,
         desired_type: str,
+        max_timeout_ms: int,
     ) -> None:
-        media_info = await self.media_repo.get_remote_media_info(server_name, media_id)
+        media_info = await self.media_repo.get_remote_media_info(
+            server_name, media_id, max_timeout_ms
+        )
+        if not media_info:
+            respond_404(request)
+            return
 
         thumbnail_infos = await self.store.get_remote_media_thumbnails(
             server_name, media_id
@@ -263,11 +277,16 @@ class ThumbnailResource(RestServlet):
         height: int,
         method: str,
         m_type: str,
+        max_timeout_ms: int,
     ) -> None:
         # TODO: Don't download the whole remote file
         # We should proxy the thumbnail from the remote server instead of
         # downloading the remote file and generating our own thumbnails.
-        media_info = await self.media_repo.get_remote_media_info(server_name, media_id)
+        media_info = await self.media_repo.get_remote_media_info(
+            server_name, media_id, max_timeout_ms
+        )
+        if not media_info:
+            return
 
         thumbnail_infos = await self.store.get_remote_media_thumbnails(
             server_name, media_id
diff --git a/synapse/rest/media/upload_resource.py b/synapse/rest/media/upload_resource.py
index 949326d85d..62d3e228a8 100644
--- a/synapse/rest/media/upload_resource.py
+++ b/synapse/rest/media/upload_resource.py
@@ -15,7 +15,7 @@
 
 import logging
 import re
-from typing import IO, TYPE_CHECKING, Dict, List, Optional
+from typing import IO, TYPE_CHECKING, Dict, List, Optional, Tuple
 
 from synapse.api.errors import Codes, SynapseError
 from synapse.http.server import respond_with_json
@@ -29,23 +29,24 @@ if TYPE_CHECKING:
 
 logger = logging.getLogger(__name__)
 
+# The name of the lock to use when uploading media.
+_UPLOAD_MEDIA_LOCK_NAME = "upload_media"
 
-class UploadResource(RestServlet):
-    PATTERNS = [re.compile("/_matrix/media/(r0|v3|v1)/upload")]
 
+class BaseUploadServlet(RestServlet):
     def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
         super().__init__()
 
         self.media_repo = media_repo
         self.filepaths = media_repo.filepaths
         self.store = hs.get_datastores().main
-        self.clock = hs.get_clock()
+        self.server_name = hs.hostname
         self.auth = hs.get_auth()
         self.max_upload_size = hs.config.media.max_upload_size
-        self.clock = hs.get_clock()
 
-    async def on_POST(self, request: SynapseRequest) -> None:
-        requester = await self.auth.get_user_by_req(request)
+    def _get_file_metadata(
+        self, request: SynapseRequest
+    ) -> Tuple[int, Optional[str], str]:
         raw_content_length = request.getHeader("Content-Length")
         if raw_content_length is None:
             raise SynapseError(msg="Request must specify a Content-Length", code=400)
@@ -88,6 +89,16 @@ class UploadResource(RestServlet):
         #     disposition = headers.getRawHeaders(b"Content-Disposition")[0]
         # TODO(markjh): parse content-dispostion
 
+        return content_length, upload_name, media_type
+
+
+class UploadServlet(BaseUploadServlet):
+    PATTERNS = [re.compile("/_matrix/media/(r0|v3|v1)/upload$")]
+
+    async def on_POST(self, request: SynapseRequest) -> None:
+        requester = await self.auth.get_user_by_req(request)
+        content_length, upload_name, media_type = self._get_file_metadata(request)
+
         try:
             content: IO = request.content  # type: ignore
             content_uri = await self.media_repo.create_content(
@@ -103,3 +114,53 @@ class UploadResource(RestServlet):
         respond_with_json(
             request, 200, {"content_uri": str(content_uri)}, send_cors=True
         )
+
+
+class AsyncUploadServlet(BaseUploadServlet):
+    PATTERNS = [
+        re.compile(
+            "/_matrix/media/v3/upload/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)$"
+        )
+    ]
+
+    async def on_PUT(
+        self, request: SynapseRequest, server_name: str, media_id: str
+    ) -> None:
+        requester = await self.auth.get_user_by_req(request)
+
+        if server_name != self.server_name:
+            raise SynapseError(
+                404,
+                "Non-local server name specified",
+                errcode=Codes.NOT_FOUND,
+            )
+
+        lock = await self.store.try_acquire_lock(_UPLOAD_MEDIA_LOCK_NAME, media_id)
+        if not lock:
+            raise SynapseError(
+                409,
+                "Media ID cannot be overwritten",
+                errcode=Codes.CANNOT_OVERWRITE_MEDIA,
+            )
+
+        async with lock:
+            await self.media_repo.verify_can_upload(media_id, requester.user)
+            content_length, upload_name, media_type = self._get_file_metadata(request)
+
+            try:
+                content: IO = request.content  # type: ignore
+                await self.media_repo.update_content(
+                    media_id,
+                    media_type,
+                    upload_name,
+                    content,
+                    content_length,
+                    requester.user,
+                )
+            except SpamMediaException:
+                # For uploading of media we want to respond with a 400, instead of
+                # the default 404, as that would just be confusing.
+                raise SynapseError(400, "Bad content")
+
+            logger.info("Uploaded content for media ID %r", media_id)
+            respond_with_json(request, 200, {}, send_cors=True)
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 3f80a64dc5..149135b8b5 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -49,13 +49,14 @@ BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2 = (
 class LocalMedia:
     media_id: str
     media_type: str
-    media_length: int
+    media_length: Optional[int]
     upload_name: str
     created_ts: int
     url_cache: Optional[str]
     last_access_ts: int
     quarantined_by: Optional[str]
     safe_from_quarantine: bool
+    user_id: Optional[str]
 
 
 @attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -149,6 +150,13 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
             self._drop_media_index_without_method,
         )
 
+        if hs.config.media.can_load_media_repo:
+            self.unused_expiration_time: Optional[
+                int
+            ] = hs.config.media.unused_expiration_time
+        else:
+            self.unused_expiration_time = None
+
     async def _drop_media_index_without_method(
         self, progress: JsonDict, batch_size: int
     ) -> int:
@@ -202,6 +210,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
                 "url_cache",
                 "last_access_ts",
                 "safe_from_quarantine",
+                "user_id",
             ),
             allow_none=True,
             desc="get_local_media",
@@ -218,6 +227,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             url_cache=row[5],
             last_access_ts=row[6],
             safe_from_quarantine=row[7],
+            user_id=row[8],
         )
 
     async def get_local_media_by_user_paginate(
@@ -272,7 +282,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
                     url_cache,
                     last_access_ts,
                     quarantined_by,
-                    safe_from_quarantine
+                    safe_from_quarantine,
+                    user_id
                 FROM local_media_repository
                 WHERE user_id = ?
                 ORDER BY {order_by_column} {order}, media_id ASC
@@ -295,6 +306,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
                     last_access_ts=row[6],
                     quarantined_by=row[7],
                     safe_from_quarantine=bool(row[8]),
+                    user_id=row[9],
                 )
                 for row in txn
             ]
@@ -392,6 +404,23 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         )
 
     @trace
+    async def store_local_media_id(
+        self,
+        media_id: str,
+        time_now_ms: int,
+        user_id: UserID,
+    ) -> None:
+        await self.db_pool.simple_insert(
+            "local_media_repository",
+            {
+                "media_id": media_id,
+                "created_ts": time_now_ms,
+                "user_id": user_id.to_string(),
+            },
+            desc="store_local_media_id",
+        )
+
+    @trace
     async def store_local_media(
         self,
         media_id: str,
@@ -416,6 +445,30 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             desc="store_local_media",
         )
 
+    async def update_local_media(
+        self,
+        media_id: str,
+        media_type: str,
+        upload_name: Optional[str],
+        media_length: int,
+        user_id: UserID,
+        url_cache: Optional[str] = None,
+    ) -> None:
+        await self.db_pool.simple_update_one(
+            "local_media_repository",
+            keyvalues={
+                "user_id": user_id.to_string(),
+                "media_id": media_id,
+            },
+            updatevalues={
+                "media_type": media_type,
+                "upload_name": upload_name,
+                "media_length": media_length,
+                "url_cache": url_cache,
+            },
+            desc="update_local_media",
+        )
+
     async def mark_local_media_as_safe(self, media_id: str, safe: bool = True) -> None:
         """Mark a local media as safe or unsafe from quarantining."""
         await self.db_pool.simple_update_one(
@@ -425,6 +478,39 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             desc="mark_local_media_as_safe",
         )
 
+    async def count_pending_media(self, user_id: UserID) -> Tuple[int, int]:
+        """Count the number of pending media for a user.
+
+        Returns:
+            A tuple of two integers: the total pending media requests and the earliest
+            expiration timestamp.
+        """
+
+        def get_pending_media_txn(txn: LoggingTransaction) -> Tuple[int, int]:
+            sql = """
+            SELECT COUNT(*), MIN(created_ts)
+            FROM local_media_repository
+            WHERE user_id = ?
+                AND created_ts > ?
+                AND media_length IS NULL
+            """
+            assert self.unused_expiration_time is not None
+            txn.execute(
+                sql,
+                (
+                    user_id.to_string(),
+                    self._clock.time_msec() - self.unused_expiration_time,
+                ),
+            )
+            row = txn.fetchone()
+            if not row:
+                return 0, 0
+            return row[0], (row[1] + self.unused_expiration_time if row[1] else 0)
+
+        return await self.db_pool.runInteraction(
+            "get_pending_media", get_pending_media_txn
+        )
+
     async def get_url_cache(self, url: str, ts: int) -> Optional[UrlCache]:
         """Get the media_id and ts for a cached URL as of the given timestamp
         Returns:
diff --git a/tests/media/test_media_storage.py b/tests/media/test_media_storage.py
index a8e7a76b29..f262304c3d 100644
--- a/tests/media/test_media_storage.py
+++ b/tests/media/test_media_storage.py
@@ -318,7 +318,9 @@ class MediaRepoTests(unittest.HomeserverTestCase):
         self.assertEqual(
             self.fetches[0][2], "/_matrix/media/r0/download/" + self.media_id
         )
-        self.assertEqual(self.fetches[0][3], {"allow_remote": "false"})
+        self.assertEqual(
+            self.fetches[0][3], {"allow_remote": "false", "timeout_ms": "20000"}
+        )
 
         headers = {
             b"Content-Length": [b"%d" % (len(self.test_image.data))],