summary refs log tree commit diff
path: root/synapse/media/media_repository.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/media/media_repository.py')
-rw-r--r--synapse/media/media_repository.py70
1 files changed, 40 insertions, 30 deletions
diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py
index 72b0f1c5de..1957426c6a 100644
--- a/synapse/media/media_repository.py
+++ b/synapse/media/media_repository.py
@@ -19,6 +19,7 @@ import shutil
 from io import BytesIO
 from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple
 
+import attr
 from matrix_common.types.mxc_uri import MXCUri
 
 import twisted.internet.error
@@ -50,6 +51,7 @@ from synapse.media.storage_provider import StorageProviderWrapper
 from synapse.media.thumbnailer import Thumbnailer, ThumbnailError
 from synapse.media.url_previewer import UrlPreviewer
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.databases.main.media_repository import RemoteMedia
 from synapse.types import UserID
 from synapse.util.async_helpers import Linearizer
 from synapse.util.retryutils import NotRetryingDestination
@@ -245,18 +247,18 @@ class MediaRepository:
             Resolves once a response has successfully been written to request
         """
         media_info = await self.store.get_local_media(media_id)
-        if not media_info or media_info["quarantined_by"]:
+        if not media_info or media_info.quarantined_by:
             respond_404(request)
             return
 
         self.mark_recently_accessed(None, media_id)
 
-        media_type = media_info["media_type"]
+        media_type = media_info.media_type
         if not media_type:
             media_type = "application/octet-stream"
-        media_length = media_info["media_length"]
-        upload_name = name if name else media_info["upload_name"]
-        url_cache = media_info["url_cache"]
+        media_length = media_info.media_length
+        upload_name = name if name else media_info.upload_name
+        url_cache = media_info.url_cache
 
         file_info = FileInfo(None, media_id, url_cache=bool(url_cache))
 
@@ -310,16 +312,20 @@ class MediaRepository:
 
         # We deliberately stream the file outside the lock
         if responder:
-            media_type = media_info["media_type"]
-            media_length = media_info["media_length"]
-            upload_name = name if name else media_info["upload_name"]
+            upload_name = name if name else media_info.upload_name
             await respond_with_responder(
-                request, responder, media_type, media_length, upload_name
+                request,
+                responder,
+                media_info.media_type,
+                media_info.media_length,
+                upload_name,
             )
         else:
             respond_404(request)
 
-    async def get_remote_media_info(self, server_name: str, media_id: str) -> dict:
+    async def get_remote_media_info(
+        self, server_name: str, media_id: str
+    ) -> RemoteMedia:
         """Gets the media info associated with the remote file, downloading
         if necessary.
 
@@ -353,7 +359,7 @@ class MediaRepository:
 
     async def _get_remote_media_impl(
         self, server_name: str, media_id: str
-    ) -> Tuple[Optional[Responder], dict]:
+    ) -> Tuple[Optional[Responder], RemoteMedia]:
         """Looks for media in local cache, if not there then attempt to
         download from remote server.
 
@@ -373,15 +379,17 @@ class MediaRepository:
 
         # If we have an entry in the DB, try and look for it
         if media_info:
-            file_id = media_info["filesystem_id"]
+            file_id = media_info.filesystem_id
             file_info = FileInfo(server_name, file_id)
 
-            if media_info["quarantined_by"]:
+            if media_info.quarantined_by:
                 logger.info("Media is quarantined")
                 raise NotFoundError()
 
-            if not media_info["media_type"]:
-                media_info["media_type"] = "application/octet-stream"
+            if not media_info.media_type:
+                media_info = attr.evolve(
+                    media_info, media_type="application/octet-stream"
+                )
 
             responder = await self.media_storage.fetch_media(file_info)
             if responder:
@@ -403,9 +411,9 @@ class MediaRepository:
             if not media_info:
                 raise e
 
-        file_id = media_info["filesystem_id"]
-        if not media_info["media_type"]:
-            media_info["media_type"] = "application/octet-stream"
+        file_id = media_info.filesystem_id
+        if not media_info.media_type:
+            media_info = attr.evolve(media_info, media_type="application/octet-stream")
         file_info = FileInfo(server_name, file_id)
 
         # We generate thumbnails even if another process downloaded the media
@@ -415,7 +423,7 @@ class MediaRepository:
         # otherwise they'll request thumbnails and get a 404 if they're not
         # ready yet.
         await self._generate_thumbnails(
-            server_name, media_id, file_id, media_info["media_type"]
+            server_name, media_id, file_id, media_info.media_type
         )
 
         responder = await self.media_storage.fetch_media(file_info)
@@ -425,7 +433,7 @@ class MediaRepository:
         self,
         server_name: str,
         media_id: str,
-    ) -> dict:
+    ) -> RemoteMedia:
         """Attempt to download the remote file from the given server name,
         using the given file_id as the local id.
 
@@ -518,7 +526,7 @@ class MediaRepository:
                 origin=server_name,
                 media_id=media_id,
                 media_type=media_type,
-                time_now_ms=self.clock.time_msec(),
+                time_now_ms=time_now_ms,
                 upload_name=upload_name,
                 media_length=length,
                 filesystem_id=file_id,
@@ -526,15 +534,17 @@ class MediaRepository:
 
         logger.info("Stored remote media in file %r", fname)
 
-        media_info = {
-            "media_type": media_type,
-            "media_length": length,
-            "upload_name": upload_name,
-            "created_ts": time_now_ms,
-            "filesystem_id": file_id,
-        }
-
-        return media_info
+        return RemoteMedia(
+            media_origin=server_name,
+            media_id=media_id,
+            media_type=media_type,
+            media_length=length,
+            upload_name=upload_name,
+            created_ts=time_now_ms,
+            filesystem_id=file_id,
+            last_access_ts=time_now_ms,
+            quarantined_by=None,
+        )
 
     def _get_thumbnail_requirements(
         self, media_type: str