summary refs log tree commit diff
path: root/synapse/storage/databases/main/media_repository.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/media_repository.py')
-rw-r--r--synapse/storage/databases/main/media_repository.py131
1 files changed, 89 insertions, 42 deletions
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py

index 8cebeb5189..c8d7c9fd32 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py
@@ -26,8 +26,11 @@ from typing import ( cast, ) +import attr + from synapse.api.constants import Direction from synapse.logging.opentracing import trace +from synapse.media._base import ThumbnailInfo from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( DatabasePool, @@ -44,6 +47,18 @@ BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2 = ( ) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class LocalMedia: + media_id: str + media_type: str + media_length: int + upload_name: str + created_ts: int + last_access_ts: int + quarantined_by: Optional[str] + safe_from_quarantine: bool + + class MediaSortOrder(Enum): """ Enum to define the sorting method used when returning media with @@ -179,7 +194,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): user_id: str, order_by: str = MediaSortOrder.CREATED_TS.value, direction: Direction = Direction.FORWARDS, - ) -> Tuple[List[Dict[str, Any]], int]: + ) -> Tuple[List[LocalMedia], int]: """Get a paginated list of metadata for a local piece of media which an user_id has uploaded @@ -196,7 +211,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): def get_local_media_by_user_paginate_txn( txn: LoggingTransaction, - ) -> Tuple[List[Dict[str, Any]], int]: + ) -> Tuple[List[LocalMedia], int]: # Set ordering order_by_column = MediaSortOrder(order_by).value @@ -216,14 +231,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): sql = """ SELECT - "media_id", - "media_type", - "media_length", - "upload_name", - "created_ts", - "last_access_ts", - "quarantined_by", - "safe_from_quarantine" + media_id, + media_type, + media_length, + upload_name, + created_ts, + last_access_ts, + quarantined_by, + safe_from_quarantine FROM local_media_repository WHERE user_id = ? ORDER BY {order_by_column} {order}, media_id ASC @@ -235,7 +250,19 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): args += [limit, start] txn.execute(sql, args) - media = self.db_pool.cursor_to_dict(txn) + media = [ + LocalMedia( + media_id=row[0], + media_type=row[1], + media_length=row[2], + upload_name=row[3], + created_ts=row[4], + last_access_ts=row[5], + quarantined_by=row[6], + safe_from_quarantine=bool(row[7]), + ) + for row in txn + ] return media, count return await self.db_pool.runInteraction( @@ -435,19 +462,28 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): desc="store_url_cache", ) - async def get_local_media_thumbnails(self, media_id: str) -> List[Dict[str, Any]]: - return await self.db_pool.simple_select_list( - "local_media_repository_thumbnails", - {"media_id": media_id}, - ( - "thumbnail_width", - "thumbnail_height", - "thumbnail_method", - "thumbnail_type", - "thumbnail_length", + async def get_local_media_thumbnails(self, media_id: str) -> List[ThumbnailInfo]: + rows = cast( + List[Tuple[int, int, str, str, int]], + await self.db_pool.simple_select_list( + "local_media_repository_thumbnails", + {"media_id": media_id}, + ( + "thumbnail_width", + "thumbnail_height", + "thumbnail_method", + "thumbnail_type", + "thumbnail_length", + ), + desc="get_local_media_thumbnails", ), - desc="get_local_media_thumbnails", ) + return [ + ThumbnailInfo( + width=row[0], height=row[1], method=row[2], type=row[3], length=row[4] + ) + for row in rows + ] @trace async def store_local_thumbnail( @@ -556,20 +592,28 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): async def get_remote_media_thumbnails( self, origin: str, media_id: str - ) -> List[Dict[str, Any]]: - return await self.db_pool.simple_select_list( - "remote_media_cache_thumbnails", - {"media_origin": origin, "media_id": media_id}, - ( - "thumbnail_width", - "thumbnail_height", - "thumbnail_method", - "thumbnail_type", - "thumbnail_length", - "filesystem_id", + ) -> List[ThumbnailInfo]: + rows = cast( + List[Tuple[int, int, str, str, int]], + await self.db_pool.simple_select_list( + "remote_media_cache_thumbnails", + {"media_origin": origin, "media_id": media_id}, + ( + "thumbnail_width", + "thumbnail_height", + "thumbnail_method", + "thumbnail_type", + "thumbnail_length", + ), + desc="get_remote_media_thumbnails", ), - desc="get_remote_media_thumbnails", ) + return [ + ThumbnailInfo( + width=row[0], height=row[1], method=row[2], type=row[3], length=row[4] + ) + for row in rows + ] @trace async def get_remote_media_thumbnail( @@ -632,7 +676,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): async def get_remote_media_ids( self, before_ts: int, include_quarantined_media: bool - ) -> List[Dict[str, str]]: + ) -> List[Tuple[str, str, str]]: """ Retrieve a list of server name, media ID tuples from the remote media cache. @@ -646,12 +690,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): A list of tuples containing: * The server name of homeserver where the media originates from, * The ID of the media. + * The filesystem ID. + """ + + sql = """ + SELECT media_origin, media_id, filesystem_id + FROM remote_media_cache + WHERE last_access_ts < ? """ - sql = ( - "SELECT media_origin, media_id, filesystem_id" - " FROM remote_media_cache" - " WHERE last_access_ts < ?" - ) if include_quarantined_media is False: # Only include media that has not been quarantined @@ -659,8 +705,9 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): AND quarantined_by IS NULL """ - return await self.db_pool.execute( - "get_remote_media_ids", self.db_pool.cursor_to_dict, sql, before_ts + return cast( + List[Tuple[str, str, str]], + await self.db_pool.execute("get_remote_media_ids", sql, before_ts), ) async def delete_remote_media(self, media_origin: str, media_id: str) -> None: