diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index c8d7c9fd32..7f99c64f1b 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -15,9 +15,7 @@
from enum import Enum
from typing import (
TYPE_CHECKING,
- Any,
Collection,
- Dict,
Iterable,
List,
Optional,
@@ -54,11 +52,32 @@ class LocalMedia:
media_length: int
upload_name: str
created_ts: int
+ url_cache: Optional[str]
last_access_ts: int
quarantined_by: Optional[str]
safe_from_quarantine: bool
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class RemoteMedia:
+ media_origin: str
+ media_id: str
+ media_type: str
+ media_length: int
+ upload_name: Optional[str]
+ filesystem_id: str
+ created_ts: int
+ last_access_ts: int
+ quarantined_by: Optional[str]
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class UrlCache:
+ response_code: int
+ expires_ts: int
+ og: Union[str, bytes]
+
+
class MediaSortOrder(Enum):
"""
Enum to define the sorting method used when returning media with
@@ -165,13 +184,13 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
super().__init__(database, db_conn, hs)
self.server_name: str = hs.hostname
- async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]:
+ async def get_local_media(self, media_id: str) -> Optional[LocalMedia]:
"""Get the metadata for a local piece of media
Returns:
None if the media_id doesn't exist.
"""
- return await self.db_pool.simple_select_one(
+ row = await self.db_pool.simple_select_one(
"local_media_repository",
{"media_id": media_id},
(
@@ -181,11 +200,15 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"created_ts",
"quarantined_by",
"url_cache",
+ "last_access_ts",
"safe_from_quarantine",
),
allow_none=True,
desc="get_local_media",
)
+ if row is None:
+ return None
+ return LocalMedia(media_id=media_id, **row)
async def get_local_media_by_user_paginate(
self,
@@ -236,6 +259,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
media_length,
upload_name,
created_ts,
+ url_cache,
last_access_ts,
quarantined_by,
safe_from_quarantine
@@ -257,9 +281,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
media_length=row[2],
upload_name=row[3],
created_ts=row[4],
- last_access_ts=row[5],
- quarantined_by=row[6],
- safe_from_quarantine=bool(row[7]),
+ url_cache=row[5],
+ last_access_ts=row[6],
+ quarantined_by=row[7],
+ safe_from_quarantine=bool(row[8]),
)
for row in txn
]
@@ -390,51 +415,39 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="mark_local_media_as_safe",
)
- async def get_url_cache(self, url: str, ts: int) -> Optional[Dict[str, Any]]:
+ async def get_url_cache(self, url: str, ts: int) -> Optional[UrlCache]:
"""Get the media_id and ts for a cached URL as of the given timestamp
Returns:
None if the URL isn't cached.
"""
- def get_url_cache_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
+ def get_url_cache_txn(txn: LoggingTransaction) -> Optional[UrlCache]:
# get the most recently cached result (relative to the given ts)
- sql = (
- "SELECT response_code, etag, expires_ts, og, media_id, download_ts"
- " FROM local_media_repository_url_cache"
- " WHERE url = ? AND download_ts <= ?"
- " ORDER BY download_ts DESC LIMIT 1"
- )
+ sql = """
+ SELECT response_code, expires_ts, og
+ FROM local_media_repository_url_cache
+ WHERE url = ? AND download_ts <= ?
+ ORDER BY download_ts DESC LIMIT 1
+ """
txn.execute(sql, (url, ts))
row = txn.fetchone()
if not row:
# ...or if we've requested a timestamp older than the oldest
# copy in the cache, return the oldest copy (if any)
- sql = (
- "SELECT response_code, etag, expires_ts, og, media_id, download_ts"
- " FROM local_media_repository_url_cache"
- " WHERE url = ? AND download_ts > ?"
- " ORDER BY download_ts ASC LIMIT 1"
- )
+ sql = """
+ SELECT response_code, expires_ts, og
+ FROM local_media_repository_url_cache
+ WHERE url = ? AND download_ts > ?
+ ORDER BY download_ts ASC LIMIT 1
+ """
txn.execute(sql, (url, ts))
row = txn.fetchone()
if not row:
return None
- return dict(
- zip(
- (
- "response_code",
- "etag",
- "expires_ts",
- "og",
- "media_id",
- "download_ts",
- ),
- row,
- )
- )
+ return UrlCache(response_code=row[0], expires_ts=row[1], og=row[2])
return await self.db_pool.runInteraction("get_url_cache", get_url_cache_txn)
@@ -444,7 +457,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
response_code: int,
etag: Optional[str],
expires_ts: int,
- og: Optional[str],
+ og: str,
media_id: str,
download_ts: int,
) -> None:
@@ -510,8 +523,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
async def get_cached_remote_media(
self, origin: str, media_id: str
- ) -> Optional[Dict[str, Any]]:
- return await self.db_pool.simple_select_one(
+ ) -> Optional[RemoteMedia]:
+ row = await self.db_pool.simple_select_one(
"remote_media_cache",
{"media_origin": origin, "media_id": media_id},
(
@@ -520,11 +533,15 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"upload_name",
"created_ts",
"filesystem_id",
+ "last_access_ts",
"quarantined_by",
),
allow_none=True,
desc="get_cached_remote_media",
)
+ if row is None:
+ return row
+ return RemoteMedia(media_origin=origin, media_id=media_id, **row)
async def store_cached_remote_media(
self,
@@ -623,10 +640,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
t_width: int,
t_height: int,
t_type: str,
- ) -> Optional[Dict[str, Any]]:
+ ) -> Optional[ThumbnailInfo]:
"""Fetch the thumbnail info of given width, height and type."""
- return await self.db_pool.simple_select_one(
+ row = await self.db_pool.simple_select_one(
table="remote_media_cache_thumbnails",
keyvalues={
"media_origin": origin,
@@ -641,11 +658,19 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"thumbnail_method",
"thumbnail_type",
"thumbnail_length",
- "filesystem_id",
),
allow_none=True,
desc="get_remote_media_thumbnail",
)
+ if row is None:
+ return None
+ return ThumbnailInfo(
+ width=row["thumbnail_width"],
+ height=row["thumbnail_height"],
+ method=row["thumbnail_method"],
+ type=row["thumbnail_type"],
+ length=row["thumbnail_length"],
+ )
@trace
async def store_remote_media_thumbnail(
|