summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/databases/main/media_repository.py105
1 files changed, 65 insertions, 40 deletions
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index c8d7c9fd32..7f99c64f1b 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -15,9 +15,7 @@
 from enum import Enum
 from typing import (
     TYPE_CHECKING,
-    Any,
     Collection,
-    Dict,
     Iterable,
     List,
     Optional,
@@ -54,11 +52,32 @@ class LocalMedia:
     media_length: int
     upload_name: str
     created_ts: int
+    url_cache: Optional[str]
     last_access_ts: int
     quarantined_by: Optional[str]
     safe_from_quarantine: bool
 
 
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class RemoteMedia:
+    media_origin: str
+    media_id: str
+    media_type: str
+    media_length: int
+    upload_name: Optional[str]
+    filesystem_id: str
+    created_ts: int
+    last_access_ts: int
+    quarantined_by: Optional[str]
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class UrlCache:
+    response_code: int
+    expires_ts: int
+    og: Union[str, bytes]
+
+
 class MediaSortOrder(Enum):
     """
     Enum to define the sorting method used when returning media with
@@ -165,13 +184,13 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         super().__init__(database, db_conn, hs)
         self.server_name: str = hs.hostname
 
-    async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]:
+    async def get_local_media(self, media_id: str) -> Optional[LocalMedia]:
         """Get the metadata for a local piece of media
 
         Returns:
             None if the media_id doesn't exist.
         """
-        return await self.db_pool.simple_select_one(
+        row = await self.db_pool.simple_select_one(
             "local_media_repository",
             {"media_id": media_id},
             (
@@ -181,11 +200,15 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
                 "created_ts",
                 "quarantined_by",
                 "url_cache",
+                "last_access_ts",
                 "safe_from_quarantine",
             ),
             allow_none=True,
             desc="get_local_media",
         )
+        if row is None:
+            return None
+        return LocalMedia(media_id=media_id, **row)
 
     async def get_local_media_by_user_paginate(
         self,
@@ -236,6 +259,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
                     media_length,
                     upload_name,
                     created_ts,
+                    url_cache,
                     last_access_ts,
                     quarantined_by,
                     safe_from_quarantine
@@ -257,9 +281,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
                     media_length=row[2],
                     upload_name=row[3],
                     created_ts=row[4],
-                    last_access_ts=row[5],
-                    quarantined_by=row[6],
-                    safe_from_quarantine=bool(row[7]),
+                    url_cache=row[5],
+                    last_access_ts=row[6],
+                    quarantined_by=row[7],
+                    safe_from_quarantine=bool(row[8]),
                 )
                 for row in txn
             ]
@@ -390,51 +415,39 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             desc="mark_local_media_as_safe",
         )
 
-    async def get_url_cache(self, url: str, ts: int) -> Optional[Dict[str, Any]]:
+    async def get_url_cache(self, url: str, ts: int) -> Optional[UrlCache]:
         """Get the media_id and ts for a cached URL as of the given timestamp
         Returns:
             None if the URL isn't cached.
         """
 
-        def get_url_cache_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
+        def get_url_cache_txn(txn: LoggingTransaction) -> Optional[UrlCache]:
             # get the most recently cached result (relative to the given ts)
-            sql = (
-                "SELECT response_code, etag, expires_ts, og, media_id, download_ts"
-                " FROM local_media_repository_url_cache"
-                " WHERE url = ? AND download_ts <= ?"
-                " ORDER BY download_ts DESC LIMIT 1"
-            )
+            sql = """
+                SELECT response_code, expires_ts, og
+                FROM local_media_repository_url_cache
+                WHERE url = ? AND download_ts <= ?
+                ORDER BY download_ts DESC LIMIT 1
+            """
             txn.execute(sql, (url, ts))
             row = txn.fetchone()
 
             if not row:
                 # ...or if we've requested a timestamp older than the oldest
                 # copy in the cache, return the oldest copy (if any)
-                sql = (
-                    "SELECT response_code, etag, expires_ts, og, media_id, download_ts"
-                    " FROM local_media_repository_url_cache"
-                    " WHERE url = ? AND download_ts > ?"
-                    " ORDER BY download_ts ASC LIMIT 1"
-                )
+                sql = """
+                    SELECT response_code, expires_ts, og
+                    FROM local_media_repository_url_cache
+                    WHERE url = ? AND download_ts > ?
+                    ORDER BY download_ts ASC LIMIT 1
+                """
                 txn.execute(sql, (url, ts))
                 row = txn.fetchone()
 
             if not row:
                 return None
 
-            return dict(
-                zip(
-                    (
-                        "response_code",
-                        "etag",
-                        "expires_ts",
-                        "og",
-                        "media_id",
-                        "download_ts",
-                    ),
-                    row,
-                )
-            )
+            return UrlCache(response_code=row[0], expires_ts=row[1], og=row[2])
 
         return await self.db_pool.runInteraction("get_url_cache", get_url_cache_txn)
 
@@ -444,7 +457,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         response_code: int,
         etag: Optional[str],
         expires_ts: int,
-        og: Optional[str],
+        og: str,
         media_id: str,
         download_ts: int,
     ) -> None:
@@ -510,8 +523,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
 
     async def get_cached_remote_media(
         self, origin: str, media_id: str
-    ) -> Optional[Dict[str, Any]]:
-        return await self.db_pool.simple_select_one(
+    ) -> Optional[RemoteMedia]:
+        row = await self.db_pool.simple_select_one(
             "remote_media_cache",
             {"media_origin": origin, "media_id": media_id},
             (
@@ -520,11 +533,15 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
                 "upload_name",
                 "created_ts",
                 "filesystem_id",
+                "last_access_ts",
                 "quarantined_by",
             ),
             allow_none=True,
             desc="get_cached_remote_media",
         )
+        if row is None:
+            return row
+        return RemoteMedia(media_origin=origin, media_id=media_id, **row)
 
     async def store_cached_remote_media(
         self,
@@ -623,10 +640,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         t_width: int,
         t_height: int,
         t_type: str,
-    ) -> Optional[Dict[str, Any]]:
+    ) -> Optional[ThumbnailInfo]:
         """Fetch the thumbnail info of given width, height and type."""
 
-        return await self.db_pool.simple_select_one(
+        row = await self.db_pool.simple_select_one(
             table="remote_media_cache_thumbnails",
             keyvalues={
                 "media_origin": origin,
@@ -641,11 +658,19 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
                 "thumbnail_method",
                 "thumbnail_type",
                 "thumbnail_length",
-                "filesystem_id",
             ),
             allow_none=True,
             desc="get_remote_media_thumbnail",
         )
+        if row is None:
+            return None
+        return ThumbnailInfo(
+            width=row["thumbnail_width"],
+            height=row["thumbnail_height"],
+            method=row["thumbnail_method"],
+            type=row["thumbnail_type"],
+            length=row["thumbnail_length"],
+        )
 
     @trace
     async def store_remote_media_thumbnail(