From ff716b483b07b21de72d999250fdf9397003a914 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 9 Nov 2023 11:00:30 -0500 Subject: Return attrs for more media repo APIs. (#16611) --- synapse/media/media_repository.py | 70 ++++++++++++++++++++++----------------- synapse/media/url_previewer.py | 11 +++--- 2 files changed, 45 insertions(+), 36 deletions(-) (limited to 'synapse/media') diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py index 72b0f1c5de..1957426c6a 100644 --- a/synapse/media/media_repository.py +++ b/synapse/media/media_repository.py @@ -19,6 +19,7 @@ import shutil from io import BytesIO from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple +import attr from matrix_common.types.mxc_uri import MXCUri import twisted.internet.error @@ -50,6 +51,7 @@ from synapse.media.storage_provider import StorageProviderWrapper from synapse.media.thumbnailer import Thumbnailer, ThumbnailError from synapse.media.url_previewer import UrlPreviewer from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage.databases.main.media_repository import RemoteMedia from synapse.types import UserID from synapse.util.async_helpers import Linearizer from synapse.util.retryutils import NotRetryingDestination @@ -245,18 +247,18 @@ class MediaRepository: Resolves once a response has successfully been written to request """ media_info = await self.store.get_local_media(media_id) - if not media_info or media_info["quarantined_by"]: + if not media_info or media_info.quarantined_by: respond_404(request) return self.mark_recently_accessed(None, media_id) - media_type = media_info["media_type"] + media_type = media_info.media_type if not media_type: media_type = "application/octet-stream" - media_length = media_info["media_length"] - upload_name = name if name else media_info["upload_name"] - url_cache = media_info["url_cache"] + media_length = media_info.media_length + upload_name = name if name else media_info.upload_name + url_cache = media_info.url_cache file_info = FileInfo(None, media_id, url_cache=bool(url_cache)) @@ -310,16 +312,20 @@ class MediaRepository: # We deliberately stream the file outside the lock if responder: - media_type = media_info["media_type"] - media_length = media_info["media_length"] - upload_name = name if name else media_info["upload_name"] + upload_name = name if name else media_info.upload_name await respond_with_responder( - request, responder, media_type, media_length, upload_name + request, + responder, + media_info.media_type, + media_info.media_length, + upload_name, ) else: respond_404(request) - async def get_remote_media_info(self, server_name: str, media_id: str) -> dict: + async def get_remote_media_info( + self, server_name: str, media_id: str + ) -> RemoteMedia: """Gets the media info associated with the remote file, downloading if necessary. @@ -353,7 +359,7 @@ class MediaRepository: async def _get_remote_media_impl( self, server_name: str, media_id: str - ) -> Tuple[Optional[Responder], dict]: + ) -> Tuple[Optional[Responder], RemoteMedia]: """Looks for media in local cache, if not there then attempt to download from remote server. @@ -373,15 +379,17 @@ class MediaRepository: # If we have an entry in the DB, try and look for it if media_info: - file_id = media_info["filesystem_id"] + file_id = media_info.filesystem_id file_info = FileInfo(server_name, file_id) - if media_info["quarantined_by"]: + if media_info.quarantined_by: logger.info("Media is quarantined") raise NotFoundError() - if not media_info["media_type"]: - media_info["media_type"] = "application/octet-stream" + if not media_info.media_type: + media_info = attr.evolve( + media_info, media_type="application/octet-stream" + ) responder = await self.media_storage.fetch_media(file_info) if responder: @@ -403,9 +411,9 @@ class MediaRepository: if not media_info: raise e - file_id = media_info["filesystem_id"] - if not media_info["media_type"]: - media_info["media_type"] = "application/octet-stream" + file_id = media_info.filesystem_id + if not media_info.media_type: + media_info = attr.evolve(media_info, media_type="application/octet-stream") file_info = FileInfo(server_name, file_id) # We generate thumbnails even if another process downloaded the media @@ -415,7 +423,7 @@ class MediaRepository: # otherwise they'll request thumbnails and get a 404 if they're not # ready yet. await self._generate_thumbnails( - server_name, media_id, file_id, media_info["media_type"] + server_name, media_id, file_id, media_info.media_type ) responder = await self.media_storage.fetch_media(file_info) @@ -425,7 +433,7 @@ class MediaRepository: self, server_name: str, media_id: str, - ) -> dict: + ) -> RemoteMedia: """Attempt to download the remote file from the given server name, using the given file_id as the local id. @@ -518,7 +526,7 @@ class MediaRepository: origin=server_name, media_id=media_id, media_type=media_type, - time_now_ms=self.clock.time_msec(), + time_now_ms=time_now_ms, upload_name=upload_name, media_length=length, filesystem_id=file_id, @@ -526,15 +534,17 @@ class MediaRepository: logger.info("Stored remote media in file %r", fname) - media_info = { - "media_type": media_type, - "media_length": length, - "upload_name": upload_name, - "created_ts": time_now_ms, - "filesystem_id": file_id, - } - - return media_info + return RemoteMedia( + media_origin=server_name, + media_id=media_id, + media_type=media_type, + media_length=length, + upload_name=upload_name, + created_ts=time_now_ms, + filesystem_id=file_id, + last_access_ts=time_now_ms, + quarantined_by=None, + ) def _get_thumbnail_requirements( self, media_type: str diff --git a/synapse/media/url_previewer.py b/synapse/media/url_previewer.py index 9b5a3dd5f4..44aac21de6 100644 --- a/synapse/media/url_previewer.py +++ b/synapse/media/url_previewer.py @@ -240,15 +240,14 @@ class UrlPreviewer: cache_result = await self.store.get_url_cache(url, ts) if ( cache_result - and cache_result["expires_ts"] > ts - and cache_result["response_code"] / 100 == 2 + and cache_result.expires_ts > ts + and cache_result.response_code // 100 == 2 ): # It may be stored as text in the database, not as bytes (such as # PostgreSQL). If so, encode it back before handing it on. - og = cache_result["og"] - if isinstance(og, str): - og = og.encode("utf8") - return og + if isinstance(cache_result.og, str): + return cache_result.og.encode("utf8") + return cache_result.og # If this URL can be accessed via an allowed oEmbed, use that instead. url_to_download = url -- cgit 1.4.1