diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 8cebeb5189..c8d7c9fd32 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -26,8 +26,11 @@ from typing import (
cast,
)
+import attr
+
from synapse.api.constants import Direction
from synapse.logging.opentracing import trace
+from synapse.media._base import ThumbnailInfo
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
@@ -44,6 +47,18 @@ BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2 = (
)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class LocalMedia:
+ media_id: str
+ media_type: str
+ media_length: int
+ upload_name: str
+ created_ts: int
+ last_access_ts: int
+ quarantined_by: Optional[str]
+ safe_from_quarantine: bool
+
+
class MediaSortOrder(Enum):
"""
Enum to define the sorting method used when returning media with
@@ -179,7 +194,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
user_id: str,
order_by: str = MediaSortOrder.CREATED_TS.value,
direction: Direction = Direction.FORWARDS,
- ) -> Tuple[List[Dict[str, Any]], int]:
+ ) -> Tuple[List[LocalMedia], int]:
"""Get a paginated list of metadata for a local piece of media
which an user_id has uploaded
@@ -196,7 +211,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
def get_local_media_by_user_paginate_txn(
txn: LoggingTransaction,
- ) -> Tuple[List[Dict[str, Any]], int]:
+ ) -> Tuple[List[LocalMedia], int]:
# Set ordering
order_by_column = MediaSortOrder(order_by).value
@@ -216,14 +231,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
sql = """
SELECT
- "media_id",
- "media_type",
- "media_length",
- "upload_name",
- "created_ts",
- "last_access_ts",
- "quarantined_by",
- "safe_from_quarantine"
+ media_id,
+ media_type,
+ media_length,
+ upload_name,
+ created_ts,
+ last_access_ts,
+ quarantined_by,
+ safe_from_quarantine
FROM local_media_repository
WHERE user_id = ?
ORDER BY {order_by_column} {order}, media_id ASC
@@ -235,7 +250,19 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
args += [limit, start]
txn.execute(sql, args)
- media = self.db_pool.cursor_to_dict(txn)
+ media = [
+ LocalMedia(
+ media_id=row[0],
+ media_type=row[1],
+ media_length=row[2],
+ upload_name=row[3],
+ created_ts=row[4],
+ last_access_ts=row[5],
+ quarantined_by=row[6],
+ safe_from_quarantine=bool(row[7]),
+ )
+ for row in txn
+ ]
return media, count
return await self.db_pool.runInteraction(
@@ -435,19 +462,28 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="store_url_cache",
)
- async def get_local_media_thumbnails(self, media_id: str) -> List[Dict[str, Any]]:
- return await self.db_pool.simple_select_list(
- "local_media_repository_thumbnails",
- {"media_id": media_id},
- (
- "thumbnail_width",
- "thumbnail_height",
- "thumbnail_method",
- "thumbnail_type",
- "thumbnail_length",
+ async def get_local_media_thumbnails(self, media_id: str) -> List[ThumbnailInfo]:
+ rows = cast(
+ List[Tuple[int, int, str, str, int]],
+ await self.db_pool.simple_select_list(
+ "local_media_repository_thumbnails",
+ {"media_id": media_id},
+ (
+ "thumbnail_width",
+ "thumbnail_height",
+ "thumbnail_method",
+ "thumbnail_type",
+ "thumbnail_length",
+ ),
+ desc="get_local_media_thumbnails",
),
- desc="get_local_media_thumbnails",
)
+ return [
+ ThumbnailInfo(
+ width=row[0], height=row[1], method=row[2], type=row[3], length=row[4]
+ )
+ for row in rows
+ ]
@trace
async def store_local_thumbnail(
@@ -556,20 +592,28 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
async def get_remote_media_thumbnails(
self, origin: str, media_id: str
- ) -> List[Dict[str, Any]]:
- return await self.db_pool.simple_select_list(
- "remote_media_cache_thumbnails",
- {"media_origin": origin, "media_id": media_id},
- (
- "thumbnail_width",
- "thumbnail_height",
- "thumbnail_method",
- "thumbnail_type",
- "thumbnail_length",
- "filesystem_id",
+ ) -> List[ThumbnailInfo]:
+ rows = cast(
+ List[Tuple[int, int, str, str, int]],
+ await self.db_pool.simple_select_list(
+ "remote_media_cache_thumbnails",
+ {"media_origin": origin, "media_id": media_id},
+ (
+ "thumbnail_width",
+ "thumbnail_height",
+ "thumbnail_method",
+ "thumbnail_type",
+ "thumbnail_length",
+ ),
+ desc="get_remote_media_thumbnails",
),
- desc="get_remote_media_thumbnails",
)
+ return [
+ ThumbnailInfo(
+ width=row[0], height=row[1], method=row[2], type=row[3], length=row[4]
+ )
+ for row in rows
+ ]
@trace
async def get_remote_media_thumbnail(
@@ -632,7 +676,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
async def get_remote_media_ids(
self, before_ts: int, include_quarantined_media: bool
- ) -> List[Dict[str, str]]:
+ ) -> List[Tuple[str, str, str]]:
"""
Retrieve a list of server name, media ID tuples from the remote media cache.
@@ -646,12 +690,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
A list of tuples containing:
* The server name of homeserver where the media originates from,
* The ID of the media.
+ * The filesystem ID.
+ """
+
+ sql = """
+ SELECT media_origin, media_id, filesystem_id
+ FROM remote_media_cache
+ WHERE last_access_ts < ?
"""
- sql = (
- "SELECT media_origin, media_id, filesystem_id"
- " FROM remote_media_cache"
- " WHERE last_access_ts < ?"
- )
if include_quarantined_media is False:
# Only include media that has not been quarantined
@@ -659,8 +705,9 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
AND quarantined_by IS NULL
"""
- return await self.db_pool.execute(
- "get_remote_media_ids", self.db_pool.cursor_to_dict, sql, before_ts
+ return cast(
+ List[Tuple[str, str, str]],
+ await self.db_pool.execute("get_remote_media_ids", sql, before_ts),
)
async def delete_remote_media(self, media_origin: str, media_id: str) -> None:
|