diff options
Diffstat (limited to 'synapse/storage')
-rw-r--r-- | synapse/storage/databases/main/media_repository.py | 90 |
1 files changed, 88 insertions, 2 deletions
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 3f80a64dc5..149135b8b5 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -49,13 +49,14 @@ BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2 = ( class LocalMedia: media_id: str media_type: str - media_length: int + media_length: Optional[int] upload_name: str created_ts: int url_cache: Optional[str] last_access_ts: int quarantined_by: Optional[str] safe_from_quarantine: bool + user_id: Optional[str] @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -149,6 +150,13 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): self._drop_media_index_without_method, ) + if hs.config.media.can_load_media_repo: + self.unused_expiration_time: Optional[ + int + ] = hs.config.media.unused_expiration_time + else: + self.unused_expiration_time = None + async def _drop_media_index_without_method( self, progress: JsonDict, batch_size: int ) -> int: @@ -202,6 +210,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): "url_cache", "last_access_ts", "safe_from_quarantine", + "user_id", ), allow_none=True, desc="get_local_media", @@ -218,6 +227,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): url_cache=row[5], last_access_ts=row[6], safe_from_quarantine=row[7], + user_id=row[8], ) async def get_local_media_by_user_paginate( @@ -272,7 +282,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): url_cache, last_access_ts, quarantined_by, - safe_from_quarantine + safe_from_quarantine, + user_id FROM local_media_repository WHERE user_id = ? ORDER BY {order_by_column} {order}, media_id ASC @@ -295,6 +306,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): last_access_ts=row[6], quarantined_by=row[7], safe_from_quarantine=bool(row[8]), + user_id=row[9], ) for row in txn ] @@ -392,6 +404,23 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): ) @trace + async def store_local_media_id( + self, + media_id: str, + time_now_ms: int, + user_id: UserID, + ) -> None: + await self.db_pool.simple_insert( + "local_media_repository", + { + "media_id": media_id, + "created_ts": time_now_ms, + "user_id": user_id.to_string(), + }, + desc="store_local_media_id", + ) + + @trace async def store_local_media( self, media_id: str, @@ -416,6 +445,30 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): desc="store_local_media", ) + async def update_local_media( + self, + media_id: str, + media_type: str, + upload_name: Optional[str], + media_length: int, + user_id: UserID, + url_cache: Optional[str] = None, + ) -> None: + await self.db_pool.simple_update_one( + "local_media_repository", + keyvalues={ + "user_id": user_id.to_string(), + "media_id": media_id, + }, + updatevalues={ + "media_type": media_type, + "upload_name": upload_name, + "media_length": media_length, + "url_cache": url_cache, + }, + desc="update_local_media", + ) + async def mark_local_media_as_safe(self, media_id: str, safe: bool = True) -> None: """Mark a local media as safe or unsafe from quarantining.""" await self.db_pool.simple_update_one( @@ -425,6 +478,39 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): desc="mark_local_media_as_safe", ) + async def count_pending_media(self, user_id: UserID) -> Tuple[int, int]: + """Count the number of pending media for a user. + + Returns: + A tuple of two integers: the total pending media requests and the earliest + expiration timestamp. + """ + + def get_pending_media_txn(txn: LoggingTransaction) -> Tuple[int, int]: + sql = """ + SELECT COUNT(*), MIN(created_ts) + FROM local_media_repository + WHERE user_id = ? + AND created_ts > ? + AND media_length IS NULL + """ + assert self.unused_expiration_time is not None + txn.execute( + sql, + ( + user_id.to_string(), + self._clock.time_msec() - self.unused_expiration_time, + ), + ) + row = txn.fetchone() + if not row: + return 0, 0 + return row[0], (row[1] + self.unused_expiration_time if row[1] else 0) + + return await self.db_pool.runInteraction( + "get_pending_media", get_pending_media_txn + ) + async def get_url_cache(self, url: str, ts: int) -> Optional[UrlCache]: """Get the media_id and ts for a cached URL as of the given timestamp Returns: |