summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/databases/main/media_repository.py68
1 files changed, 63 insertions, 5 deletions
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index deffdc19ce..3c585c555a 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -251,12 +251,36 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             "get_local_media_by_user_paginate_txn", get_local_media_by_user_paginate_txn
         )
 
-    async def get_local_media_before(
+    async def get_local_media_ids(
         self,
         before_ts: int,
         size_gt: int,
         keep_profiles: bool,
+        include_quarantined_media: bool,
+        include_protected_media: bool,
     ) -> List[str]:
+        """
+        Retrieve a list of media IDs from the local media store.
+
+        Args:
+            before_ts: Only retrieve IDs from media that was either last accessed
+                (or if never accessed, created) before the given UNIX timestamp in ms.
+            size_gt: Only retrieve IDs from media that has a size (in bytes) greater than
+                the given integer.
+            keep_profiles: If True, exclude media IDs from the results that are used in the
+                following situations:
+                    * global profile user avatar
+                    * per-room profile user avatar
+                    * room avatar
+                    * a user's avatar in the user directory
+            include_quarantined_media: If False, exclude media IDs from the results that have
+                been marked as quarantined.
+            include_protected_media: If False, exclude media IDs from the results that have
+                been marked as protected from quarantine.
+
+        Returns:
+            A list of local media IDs.
+        """
 
         # to find files that have never been accessed (last_access_ts IS NULL)
         # compare with `created_ts`
@@ -294,12 +318,24 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             )
             sql += sql_keep
 
-        def _get_local_media_before_txn(txn: LoggingTransaction) -> List[str]:
+        if include_quarantined_media is False:
+            # Do not include media that has been quarantined
+            sql += """
+                AND quarantined_by IS NULL
+            """
+
+        if include_protected_media is False:
+            # Do not include media that has been protected from quarantine
+            sql += """
+                AND safe_from_quarantine = false
+            """
+
+        def _get_local_media_ids_txn(txn: LoggingTransaction) -> List[str]:
             txn.execute(sql, (before_ts, before_ts, size_gt))
             return [row[0] for row in txn]
 
         return await self.db_pool.runInteraction(
-            "get_local_media_before", _get_local_media_before_txn
+            "get_local_media_ids", _get_local_media_ids_txn
         )
 
     async def store_local_media(
@@ -599,15 +635,37 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             desc="store_remote_media_thumbnail",
         )
 
-    async def get_remote_media_before(self, before_ts: int) -> List[Dict[str, str]]:
+    async def get_remote_media_ids(
+        self, before_ts: int, include_quarantined_media: bool
+    ) -> List[Dict[str, str]]:
+        """
+        Retrieve a list of server name, media ID tuples from the remote media cache.
+
+        Args:
+            before_ts: Only retrieve IDs from media that was either last accessed
+                (or if never accessed, created) before the given UNIX timestamp in ms.
+            include_quarantined_media: If False, exclude media IDs from the results that have
+                been marked as quarantined.
+
+        Returns:
+            A list of tuples containing:
+                * The server name of homeserver where the media originates from,
+                * The ID of the media.
+        """
         sql = (
             "SELECT media_origin, media_id, filesystem_id"
             " FROM remote_media_cache"
             " WHERE last_access_ts < ?"
         )
 
+        if include_quarantined_media is False:
+            # Only include media that has not been quarantined
+            sql += """
+            AND quarantined_by IS NULL
+            """
+
         return await self.db_pool.execute(
-            "get_remote_media_before", self.db_pool.cursor_to_dict, sql, before_ts
+            "get_remote_media_ids", self.db_pool.cursor_to_dict, sql, before_ts
         )
 
     async def delete_remote_media(self, media_origin: str, media_id: str) -> None: