diff --git a/synapse/media/_base.py b/synapse/media/_base.py
index 860e5ddca2..9d88a711cf 100644
--- a/synapse/media/_base.py
+++ b/synapse/media/_base.py
@@ -83,6 +83,12 @@ INLINE_CONTENT_TYPES = [
"audio/x-flac",
]
+# Default timeout_ms for download and thumbnail requests
+DEFAULT_MAX_TIMEOUT_MS = 20_000
+
+# Maximum allowed timeout_ms for download and thumbnail requests
+MAXIMUM_ALLOWED_MAX_TIMEOUT_MS = 60_000
+
def respond_404(request: SynapseRequest) -> None:
assert request.path is not None
diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py
index 1957426c6a..bf976b9e7c 100644
--- a/synapse/media/media_repository.py
+++ b/synapse/media/media_repository.py
@@ -27,13 +27,16 @@ import twisted.web.http
from twisted.internet.defer import Deferred
from synapse.api.errors import (
+ Codes,
FederationDeniedError,
HttpResponseException,
NotFoundError,
RequestSendFailed,
SynapseError,
+ cs_error,
)
from synapse.config.repository import ThumbnailRequirement
+from synapse.http.server import respond_with_json
from synapse.http.site import SynapseRequest
from synapse.logging.context import defer_to_thread
from synapse.logging.opentracing import trace
@@ -51,7 +54,7 @@ from synapse.media.storage_provider import StorageProviderWrapper
from synapse.media.thumbnailer import Thumbnailer, ThumbnailError
from synapse.media.url_previewer import UrlPreviewer
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage.databases.main.media_repository import RemoteMedia
+from synapse.storage.databases.main.media_repository import LocalMedia, RemoteMedia
from synapse.types import UserID
from synapse.util.async_helpers import Linearizer
from synapse.util.retryutils import NotRetryingDestination
@@ -80,6 +83,8 @@ class MediaRepository:
self.store = hs.get_datastores().main
self.max_upload_size = hs.config.media.max_upload_size
self.max_image_pixels = hs.config.media.max_image_pixels
+ self.unused_expiration_time = hs.config.media.unused_expiration_time
+ self.max_pending_media_uploads = hs.config.media.max_pending_media_uploads
Thumbnailer.set_limits(self.max_image_pixels)
@@ -186,6 +191,117 @@ class MediaRepository:
self.recently_accessed_locals.add(media_id)
@trace
+ async def create_media_id(self, auth_user: UserID) -> Tuple[str, int]:
+ """Create and store a media ID for a local user and return the MXC URI and its
+ expiration.
+
+ Args:
+ auth_user: The user_id of the uploader
+
+ Returns:
+ A tuple containing the MXC URI of the stored content and the timestamp at
+ which the MXC URI expires.
+ """
+ media_id = random_string(24)
+ now = self.clock.time_msec()
+ await self.store.store_local_media_id(
+ media_id=media_id,
+ time_now_ms=now,
+ user_id=auth_user,
+ )
+ return f"mxc://{self.server_name}/{media_id}", now + self.unused_expiration_time
+
+ @trace
+ async def reached_pending_media_limit(self, auth_user: UserID) -> Tuple[bool, int]:
+ """Check if the user is over the limit for pending media uploads.
+
+ Args:
+ auth_user: The user_id of the uploader
+
+ Returns:
+ A tuple with a boolean and an integer indicating whether the user has too
+ many pending media uploads and the timestamp at which the first pending
+ media will expire, respectively.
+ """
+ pending, first_expiration_ts = await self.store.count_pending_media(
+ user_id=auth_user
+ )
+ return pending >= self.max_pending_media_uploads, first_expiration_ts
+
+ @trace
+ async def verify_can_upload(self, media_id: str, auth_user: UserID) -> None:
+ """Verify that the media ID can be uploaded to by the given user. This
+ function checks that:
+
+ * the media ID exists
+ * the media ID does not already have content
+ * the user uploading is the same as the one who created the media ID
+ * the media ID has not expired
+
+ Args:
+ media_id: The media ID to verify
+ auth_user: The user_id of the uploader
+ """
+ media = await self.store.get_local_media(media_id)
+ if media is None:
+ raise SynapseError(404, "Unknow media ID", errcode=Codes.NOT_FOUND)
+
+ if media.user_id != auth_user.to_string():
+ raise SynapseError(
+ 403,
+ "Only the creator of the media ID can upload to it",
+ errcode=Codes.FORBIDDEN,
+ )
+
+ if media.media_length is not None:
+ raise SynapseError(
+ 409,
+ "Media ID already has content",
+ errcode=Codes.CANNOT_OVERWRITE_MEDIA,
+ )
+
+ expired_time_ms = self.clock.time_msec() - self.unused_expiration_time
+ if media.created_ts < expired_time_ms:
+ raise NotFoundError("Media ID has expired")
+
+ @trace
+ async def update_content(
+ self,
+ media_id: str,
+ media_type: str,
+ upload_name: Optional[str],
+ content: IO,
+ content_length: int,
+ auth_user: UserID,
+ ) -> None:
+ """Update the content of the given media ID.
+
+ Args:
+ media_id: The media ID to replace.
+ media_type: The content type of the file.
+ upload_name: The name of the file, if provided.
+ content: A file like object that is the content to store
+ content_length: The length of the content
+ auth_user: The user_id of the uploader
+ """
+ file_info = FileInfo(server_name=None, file_id=media_id)
+ fname = await self.media_storage.store_file(content, file_info)
+ logger.info("Stored local media in file %r", fname)
+
+ await self.store.update_local_media(
+ media_id=media_id,
+ media_type=media_type,
+ upload_name=upload_name,
+ media_length=content_length,
+ user_id=auth_user,
+ )
+
+ try:
+ await self._generate_thumbnails(None, media_id, media_id, media_type)
+ except Exception as e:
+ logger.info("Failed to generate thumbnails: %s", e)
+
+ @trace
async def create_content(
self,
media_type: str,
@@ -231,8 +347,74 @@ class MediaRepository:
return MXCUri(self.server_name, media_id)
+ def respond_not_yet_uploaded(self, request: SynapseRequest) -> None:
+ respond_with_json(
+ request,
+ 504,
+ cs_error("Media has not been uploaded yet", code=Codes.NOT_YET_UPLOADED),
+ send_cors=True,
+ )
+
+ async def get_local_media_info(
+ self, request: SynapseRequest, media_id: str, max_timeout_ms: int
+ ) -> Optional[LocalMedia]:
+ """Gets the info dictionary for given local media ID. If the media has
+ not been uploaded yet, this function will wait up to ``max_timeout_ms``
+ milliseconds for the media to be uploaded.
+
+ Args:
+ request: The incoming request.
+ media_id: The media ID of the content. (This is the same as
+ the file_id for local content.)
+ max_timeout_ms: the maximum number of milliseconds to wait for the
+ media to be uploaded.
+
+ Returns:
+ Either the info dictionary for the given local media ID or
+ ``None``. If ``None``, then no further processing is necessary as
+ this function will send the necessary JSON response.
+ """
+ wait_until = self.clock.time_msec() + max_timeout_ms
+ while True:
+ # Get the info for the media
+ media_info = await self.store.get_local_media(media_id)
+ if not media_info:
+ logger.info("Media %s is unknown", media_id)
+ respond_404(request)
+ return None
+
+ if media_info.quarantined_by:
+ logger.info("Media %s is quarantined", media_id)
+ respond_404(request)
+ return None
+
+ # The file has been uploaded, so stop looping
+ if media_info.media_length is not None:
+ return media_info
+
+ # Check if the media ID has expired and still hasn't been uploaded to.
+ now = self.clock.time_msec()
+ expired_time_ms = now - self.unused_expiration_time
+ if media_info.created_ts < expired_time_ms:
+ logger.info("Media %s has expired without being uploaded", media_id)
+ respond_404(request)
+ return None
+
+ if now >= wait_until:
+ break
+
+ await self.clock.sleep(0.5)
+
+ logger.info("Media %s has not yet been uploaded", media_id)
+ self.respond_not_yet_uploaded(request)
+ return None
+
async def get_local_media(
- self, request: SynapseRequest, media_id: str, name: Optional[str]
+ self,
+ request: SynapseRequest,
+ media_id: str,
+ name: Optional[str],
+ max_timeout_ms: int,
) -> None:
"""Responds to requests for local media, if exists, or returns 404.
@@ -242,13 +424,14 @@ class MediaRepository:
the file_id for local content.)
name: Optional name that, if specified, will be used as
the filename in the Content-Disposition header of the response.
+ max_timeout_ms: the maximum number of milliseconds to wait for the
+ media to be uploaded.
Returns:
Resolves once a response has successfully been written to request
"""
- media_info = await self.store.get_local_media(media_id)
- if not media_info or media_info.quarantined_by:
- respond_404(request)
+ media_info = await self.get_local_media_info(request, media_id, max_timeout_ms)
+ if not media_info:
return
self.mark_recently_accessed(None, media_id)
@@ -273,6 +456,7 @@ class MediaRepository:
server_name: str,
media_id: str,
name: Optional[str],
+ max_timeout_ms: int,
) -> None:
"""Respond to requests for remote media.
@@ -282,6 +466,8 @@ class MediaRepository:
media_id: The media ID of the content (as defined by the remote server).
name: Optional name that, if specified, will be used as
the filename in the Content-Disposition header of the response.
+ max_timeout_ms: the maximum number of milliseconds to wait for the
+ media to be uploaded.
Returns:
Resolves once a response has successfully been written to request
@@ -307,11 +493,11 @@ class MediaRepository:
key = (server_name, media_id)
async with self.remote_media_linearizer.queue(key):
responder, media_info = await self._get_remote_media_impl(
- server_name, media_id
+ server_name, media_id, max_timeout_ms
)
# We deliberately stream the file outside the lock
- if responder:
+ if responder and media_info:
upload_name = name if name else media_info.upload_name
await respond_with_responder(
request,
@@ -324,7 +510,7 @@ class MediaRepository:
respond_404(request)
async def get_remote_media_info(
- self, server_name: str, media_id: str
+ self, server_name: str, media_id: str, max_timeout_ms: int
) -> RemoteMedia:
"""Gets the media info associated with the remote file, downloading
if necessary.
@@ -332,6 +518,8 @@ class MediaRepository:
Args:
server_name: Remote server_name where the media originated.
media_id: The media ID of the content (as defined by the remote server).
+ max_timeout_ms: the maximum number of milliseconds to wait for the
+ media to be uploaded.
Returns:
The media info of the file
@@ -347,7 +535,7 @@ class MediaRepository:
key = (server_name, media_id)
async with self.remote_media_linearizer.queue(key):
responder, media_info = await self._get_remote_media_impl(
- server_name, media_id
+ server_name, media_id, max_timeout_ms
)
# Ensure we actually use the responder so that it releases resources
@@ -358,7 +546,7 @@ class MediaRepository:
return media_info
async def _get_remote_media_impl(
- self, server_name: str, media_id: str
+ self, server_name: str, media_id: str, max_timeout_ms: int
) -> Tuple[Optional[Responder], RemoteMedia]:
"""Looks for media in local cache, if not there then attempt to
download from remote server.
@@ -367,6 +555,8 @@ class MediaRepository:
server_name: Remote server_name where the media originated.
media_id: The media ID of the content (as defined by the
remote server).
+ max_timeout_ms: the maximum number of milliseconds to wait for the
+ media to be uploaded.
Returns:
A tuple of responder and the media info of the file.
@@ -399,8 +589,7 @@ class MediaRepository:
try:
media_info = await self._download_remote_file(
- server_name,
- media_id,
+ server_name, media_id, max_timeout_ms
)
except SynapseError:
raise
@@ -433,6 +622,7 @@ class MediaRepository:
self,
server_name: str,
media_id: str,
+ max_timeout_ms: int,
) -> RemoteMedia:
"""Attempt to download the remote file from the given server name,
using the given file_id as the local id.
@@ -442,7 +632,8 @@ class MediaRepository:
media_id: The media ID of the content (as defined by the
remote server). This is different than the file_id, which is
locally generated.
- file_id: Local file ID
+ max_timeout_ms: the maximum number of milliseconds to wait for the
+ media to be uploaded.
Returns:
The media info of the file.
@@ -466,7 +657,8 @@ class MediaRepository:
# tell the remote server to 404 if it doesn't
# recognise the server_name, to make sure we don't
# end up with a routing loop.
- "allow_remote": "false"
+ "allow_remote": "false",
+ "timeout_ms": str(max_timeout_ms),
},
)
except RequestSendFailed as e:
|