diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index deffdc19ce..3c585c555a 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -251,12 +251,36 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"get_local_media_by_user_paginate_txn", get_local_media_by_user_paginate_txn
)
- async def get_local_media_before(
+ async def get_local_media_ids(
self,
before_ts: int,
size_gt: int,
keep_profiles: bool,
+ include_quarantined_media: bool,
+ include_protected_media: bool,
) -> List[str]:
+ """
+ Retrieve a list of media IDs from the local media store.
+
+ Args:
+ before_ts: Only retrieve IDs from media that was either last accessed
+ (or if never accessed, created) before the given UNIX timestamp in ms.
+ size_gt: Only retrieve IDs from media that has a size (in bytes) greater than
+ the given integer.
+ keep_profiles: If True, exclude media IDs from the results that are used in the
+ following situations:
+ * global profile user avatar
+ * per-room profile user avatar
+ * room avatar
+ * a user's avatar in the user directory
+ include_quarantined_media: If False, exclude media IDs from the results that have
+ been marked as quarantined.
+ include_protected_media: If False, exclude media IDs from the results that have
+ been marked as protected from quarantine.
+
+ Returns:
+ A list of local media IDs.
+ """
# to find files that have never been accessed (last_access_ts IS NULL)
# compare with `created_ts`
@@ -294,12 +318,24 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
)
sql += sql_keep
- def _get_local_media_before_txn(txn: LoggingTransaction) -> List[str]:
+ if include_quarantined_media is False:
+ # Do not include media that has been quarantined
+ sql += """
+ AND quarantined_by IS NULL
+ """
+
+ if include_protected_media is False:
+ # Do not include media that has been protected from quarantine
+ sql += """
+ AND safe_from_quarantine = false
+ """
+
+ def _get_local_media_ids_txn(txn: LoggingTransaction) -> List[str]:
txn.execute(sql, (before_ts, before_ts, size_gt))
return [row[0] for row in txn]
return await self.db_pool.runInteraction(
- "get_local_media_before", _get_local_media_before_txn
+ "get_local_media_ids", _get_local_media_ids_txn
)
async def store_local_media(
@@ -599,15 +635,37 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="store_remote_media_thumbnail",
)
- async def get_remote_media_before(self, before_ts: int) -> List[Dict[str, str]]:
+ async def get_remote_media_ids(
+ self, before_ts: int, include_quarantined_media: bool
+ ) -> List[Dict[str, str]]:
+ """
+ Retrieve a list of server name, media ID tuples from the remote media cache.
+
+ Args:
+ before_ts: Only retrieve IDs from media that was either last accessed
+ (or if never accessed, created) before the given UNIX timestamp in ms.
+ include_quarantined_media: If False, exclude media IDs from the results that have
+ been marked as quarantined.
+
+ Returns:
+ A list of tuples containing:
+ * The server name of homeserver where the media originates from,
+ * The ID of the media.
+ """
sql = (
"SELECT media_origin, media_id, filesystem_id"
" FROM remote_media_cache"
" WHERE last_access_ts < ?"
)
+ if include_quarantined_media is False:
+ # Only include media that has not been quarantined
+ sql += """
+ AND quarantined_by IS NULL
+ """
+
return await self.db_pool.execute(
- "get_remote_media_before", self.db_pool.cursor_to_dict, sql, before_ts
+ "get_remote_media_ids", self.db_pool.cursor_to_dict, sql, before_ts
)
async def delete_remote_media(self, media_origin: str, media_id: str) -> None:
|