diff options
Diffstat (limited to 'synapse/storage')
-rw-r--r-- | synapse/storage/databases/main/media_repository.py | 105 |
1 files changed, 65 insertions, 40 deletions
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index c8d7c9fd32..7f99c64f1b 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -15,9 +15,7 @@ from enum import Enum from typing import ( TYPE_CHECKING, - Any, Collection, - Dict, Iterable, List, Optional, @@ -54,11 +52,32 @@ class LocalMedia: media_length: int upload_name: str created_ts: int + url_cache: Optional[str] last_access_ts: int quarantined_by: Optional[str] safe_from_quarantine: bool +@attr.s(slots=True, frozen=True, auto_attribs=True) +class RemoteMedia: + media_origin: str + media_id: str + media_type: str + media_length: int + upload_name: Optional[str] + filesystem_id: str + created_ts: int + last_access_ts: int + quarantined_by: Optional[str] + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class UrlCache: + response_code: int + expires_ts: int + og: Union[str, bytes] + + class MediaSortOrder(Enum): """ Enum to define the sorting method used when returning media with @@ -165,13 +184,13 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): super().__init__(database, db_conn, hs) self.server_name: str = hs.hostname - async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]: + async def get_local_media(self, media_id: str) -> Optional[LocalMedia]: """Get the metadata for a local piece of media Returns: None if the media_id doesn't exist. """ - return await self.db_pool.simple_select_one( + row = await self.db_pool.simple_select_one( "local_media_repository", {"media_id": media_id}, ( @@ -181,11 +200,15 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): "created_ts", "quarantined_by", "url_cache", + "last_access_ts", "safe_from_quarantine", ), allow_none=True, desc="get_local_media", ) + if row is None: + return None + return LocalMedia(media_id=media_id, **row) async def get_local_media_by_user_paginate( self, @@ -236,6 +259,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): media_length, upload_name, created_ts, + url_cache, last_access_ts, quarantined_by, safe_from_quarantine @@ -257,9 +281,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): media_length=row[2], upload_name=row[3], created_ts=row[4], - last_access_ts=row[5], - quarantined_by=row[6], - safe_from_quarantine=bool(row[7]), + url_cache=row[5], + last_access_ts=row[6], + quarantined_by=row[7], + safe_from_quarantine=bool(row[8]), ) for row in txn ] @@ -390,51 +415,39 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): desc="mark_local_media_as_safe", ) - async def get_url_cache(self, url: str, ts: int) -> Optional[Dict[str, Any]]: + async def get_url_cache(self, url: str, ts: int) -> Optional[UrlCache]: """Get the media_id and ts for a cached URL as of the given timestamp Returns: None if the URL isn't cached. """ - def get_url_cache_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]: + def get_url_cache_txn(txn: LoggingTransaction) -> Optional[UrlCache]: # get the most recently cached result (relative to the given ts) - sql = ( - "SELECT response_code, etag, expires_ts, og, media_id, download_ts" - " FROM local_media_repository_url_cache" - " WHERE url = ? AND download_ts <= ?" - " ORDER BY download_ts DESC LIMIT 1" - ) + sql = """ + SELECT response_code, expires_ts, og + FROM local_media_repository_url_cache + WHERE url = ? AND download_ts <= ? + ORDER BY download_ts DESC LIMIT 1 + """ txn.execute(sql, (url, ts)) row = txn.fetchone() if not row: # ...or if we've requested a timestamp older than the oldest # copy in the cache, return the oldest copy (if any) - sql = ( - "SELECT response_code, etag, expires_ts, og, media_id, download_ts" - " FROM local_media_repository_url_cache" - " WHERE url = ? AND download_ts > ?" - " ORDER BY download_ts ASC LIMIT 1" - ) + sql = """ + SELECT response_code, expires_ts, og + FROM local_media_repository_url_cache + WHERE url = ? AND download_ts > ? + ORDER BY download_ts ASC LIMIT 1 + """ txn.execute(sql, (url, ts)) row = txn.fetchone() if not row: return None - return dict( - zip( - ( - "response_code", - "etag", - "expires_ts", - "og", - "media_id", - "download_ts", - ), - row, - ) - ) + return UrlCache(response_code=row[0], expires_ts=row[1], og=row[2]) return await self.db_pool.runInteraction("get_url_cache", get_url_cache_txn) @@ -444,7 +457,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): response_code: int, etag: Optional[str], expires_ts: int, - og: Optional[str], + og: str, media_id: str, download_ts: int, ) -> None: @@ -510,8 +523,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): async def get_cached_remote_media( self, origin: str, media_id: str - ) -> Optional[Dict[str, Any]]: - return await self.db_pool.simple_select_one( + ) -> Optional[RemoteMedia]: + row = await self.db_pool.simple_select_one( "remote_media_cache", {"media_origin": origin, "media_id": media_id}, ( @@ -520,11 +533,15 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): "upload_name", "created_ts", "filesystem_id", + "last_access_ts", "quarantined_by", ), allow_none=True, desc="get_cached_remote_media", ) + if row is None: + return row + return RemoteMedia(media_origin=origin, media_id=media_id, **row) async def store_cached_remote_media( self, @@ -623,10 +640,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): t_width: int, t_height: int, t_type: str, - ) -> Optional[Dict[str, Any]]: + ) -> Optional[ThumbnailInfo]: """Fetch the thumbnail info of given width, height and type.""" - return await self.db_pool.simple_select_one( + row = await self.db_pool.simple_select_one( table="remote_media_cache_thumbnails", keyvalues={ "media_origin": origin, @@ -641,11 +658,19 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): "thumbnail_method", "thumbnail_type", "thumbnail_length", - "filesystem_id", ), allow_none=True, desc="get_remote_media_thumbnail", ) + if row is None: + return None + return ThumbnailInfo( + width=row["thumbnail_width"], + height=row["thumbnail_height"], + method=row["thumbnail_method"], + type=row["thumbnail_type"], + length=row["thumbnail_length"], + ) @trace async def store_remote_media_thumbnail( |