diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py
index 72b0f1c5de..1957426c6a 100644
--- a/synapse/media/media_repository.py
+++ b/synapse/media/media_repository.py
@@ -19,6 +19,7 @@ import shutil
from io import BytesIO
from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple
+import attr
from matrix_common.types.mxc_uri import MXCUri
import twisted.internet.error
@@ -50,6 +51,7 @@ from synapse.media.storage_provider import StorageProviderWrapper
from synapse.media.thumbnailer import Thumbnailer, ThumbnailError
from synapse.media.url_previewer import UrlPreviewer
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.databases.main.media_repository import RemoteMedia
from synapse.types import UserID
from synapse.util.async_helpers import Linearizer
from synapse.util.retryutils import NotRetryingDestination
@@ -245,18 +247,18 @@ class MediaRepository:
Resolves once a response has successfully been written to request
"""
media_info = await self.store.get_local_media(media_id)
- if not media_info or media_info["quarantined_by"]:
+ if not media_info or media_info.quarantined_by:
respond_404(request)
return
self.mark_recently_accessed(None, media_id)
- media_type = media_info["media_type"]
+ media_type = media_info.media_type
if not media_type:
media_type = "application/octet-stream"
- media_length = media_info["media_length"]
- upload_name = name if name else media_info["upload_name"]
- url_cache = media_info["url_cache"]
+ media_length = media_info.media_length
+ upload_name = name if name else media_info.upload_name
+ url_cache = media_info.url_cache
file_info = FileInfo(None, media_id, url_cache=bool(url_cache))
@@ -310,16 +312,20 @@ class MediaRepository:
# We deliberately stream the file outside the lock
if responder:
- media_type = media_info["media_type"]
- media_length = media_info["media_length"]
- upload_name = name if name else media_info["upload_name"]
+ upload_name = name if name else media_info.upload_name
await respond_with_responder(
- request, responder, media_type, media_length, upload_name
+ request,
+ responder,
+ media_info.media_type,
+ media_info.media_length,
+ upload_name,
)
else:
respond_404(request)
- async def get_remote_media_info(self, server_name: str, media_id: str) -> dict:
+ async def get_remote_media_info(
+ self, server_name: str, media_id: str
+ ) -> RemoteMedia:
"""Gets the media info associated with the remote file, downloading
if necessary.
@@ -353,7 +359,7 @@ class MediaRepository:
async def _get_remote_media_impl(
self, server_name: str, media_id: str
- ) -> Tuple[Optional[Responder], dict]:
+ ) -> Tuple[Optional[Responder], RemoteMedia]:
"""Looks for media in local cache, if not there then attempt to
download from remote server.
@@ -373,15 +379,17 @@ class MediaRepository:
# If we have an entry in the DB, try and look for it
if media_info:
- file_id = media_info["filesystem_id"]
+ file_id = media_info.filesystem_id
file_info = FileInfo(server_name, file_id)
- if media_info["quarantined_by"]:
+ if media_info.quarantined_by:
logger.info("Media is quarantined")
raise NotFoundError()
- if not media_info["media_type"]:
- media_info["media_type"] = "application/octet-stream"
+ if not media_info.media_type:
+ media_info = attr.evolve(
+ media_info, media_type="application/octet-stream"
+ )
responder = await self.media_storage.fetch_media(file_info)
if responder:
@@ -403,9 +411,9 @@ class MediaRepository:
if not media_info:
raise e
- file_id = media_info["filesystem_id"]
- if not media_info["media_type"]:
- media_info["media_type"] = "application/octet-stream"
+ file_id = media_info.filesystem_id
+ if not media_info.media_type:
+ media_info = attr.evolve(media_info, media_type="application/octet-stream")
file_info = FileInfo(server_name, file_id)
# We generate thumbnails even if another process downloaded the media
@@ -415,7 +423,7 @@ class MediaRepository:
# otherwise they'll request thumbnails and get a 404 if they're not
# ready yet.
await self._generate_thumbnails(
- server_name, media_id, file_id, media_info["media_type"]
+ server_name, media_id, file_id, media_info.media_type
)
responder = await self.media_storage.fetch_media(file_info)
@@ -425,7 +433,7 @@ class MediaRepository:
self,
server_name: str,
media_id: str,
- ) -> dict:
+ ) -> RemoteMedia:
"""Attempt to download the remote file from the given server name,
using the given file_id as the local id.
@@ -518,7 +526,7 @@ class MediaRepository:
origin=server_name,
media_id=media_id,
media_type=media_type,
- time_now_ms=self.clock.time_msec(),
+ time_now_ms=time_now_ms,
upload_name=upload_name,
media_length=length,
filesystem_id=file_id,
@@ -526,15 +534,17 @@ class MediaRepository:
logger.info("Stored remote media in file %r", fname)
- media_info = {
- "media_type": media_type,
- "media_length": length,
- "upload_name": upload_name,
- "created_ts": time_now_ms,
- "filesystem_id": file_id,
- }
-
- return media_info
+ return RemoteMedia(
+ media_origin=server_name,
+ media_id=media_id,
+ media_type=media_type,
+ media_length=length,
+ upload_name=upload_name,
+ created_ts=time_now_ms,
+ filesystem_id=file_id,
+ last_access_ts=time_now_ms,
+ quarantined_by=None,
+ )
def _get_thumbnail_requirements(
self, media_type: str
|