summary refs log tree commit diff
path: root/synapse/storage/media_repository.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/media_repository.py')
-rw-r--r--synapse/storage/media_repository.py117
1 files changed, 104 insertions, 13 deletions
diff --git a/synapse/storage/media_repository.py b/synapse/storage/media_repository.py
index 82bb61b811..e6cdbb0545 100644
--- a/synapse/storage/media_repository.py
+++ b/synapse/storage/media_repository.py
@@ -12,15 +12,22 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from synapse.storage.background_updates import BackgroundUpdateStore
 
-from ._base import SQLBaseStore
 
-
-class MediaRepositoryStore(SQLBaseStore):
+class MediaRepositoryStore(BackgroundUpdateStore):
     """Persistence for attachments and avatars"""
 
-    def get_default_thumbnails(self, top_level_type, sub_type):
-        return []
+    def __init__(self, db_conn, hs):
+        super(MediaRepositoryStore, self).__init__(db_conn, hs)
+
+        self.register_background_index_update(
+            update_name='local_media_repository_url_idx',
+            index_name='local_media_repository_url_idx',
+            table='local_media_repository',
+            columns=['created_ts'],
+            where_clause='url_cache IS NOT NULL',
+        )
 
     def get_local_media(self, media_id):
         """Get the metadata for a local piece of media
@@ -62,7 +69,7 @@ class MediaRepositoryStore(SQLBaseStore):
         def get_url_cache_txn(txn):
             # get the most recently cached result (relative to the given ts)
             sql = (
-                "SELECT response_code, etag, expires, og, media_id, download_ts"
+                "SELECT response_code, etag, expires_ts, og, media_id, download_ts"
                 " FROM local_media_repository_url_cache"
                 " WHERE url = ? AND download_ts <= ?"
                 " ORDER BY download_ts DESC LIMIT 1"
@@ -74,7 +81,7 @@ class MediaRepositoryStore(SQLBaseStore):
                 # ...or if we've requested a timestamp older than the oldest
                 # copy in the cache, return the oldest copy (if any)
                 sql = (
-                    "SELECT response_code, etag, expires, og, media_id, download_ts"
+                    "SELECT response_code, etag, expires_ts, og, media_id, download_ts"
                     " FROM local_media_repository_url_cache"
                     " WHERE url = ? AND download_ts > ?"
                     " ORDER BY download_ts ASC LIMIT 1"
@@ -86,14 +93,14 @@ class MediaRepositoryStore(SQLBaseStore):
                 return None
 
             return dict(zip((
-                'response_code', 'etag', 'expires', 'og', 'media_id', 'download_ts'
+                'response_code', 'etag', 'expires_ts', 'og', 'media_id', 'download_ts'
             ), row))
 
         return self.runInteraction(
             "get_url_cache", get_url_cache_txn
         )
 
-    def store_url_cache(self, url, response_code, etag, expires, og, media_id,
+    def store_url_cache(self, url, response_code, etag, expires_ts, og, media_id,
                         download_ts):
         return self._simple_insert(
             "local_media_repository_url_cache",
@@ -101,7 +108,7 @@ class MediaRepositoryStore(SQLBaseStore):
                 "url": url,
                 "response_code": response_code,
                 "etag": etag,
-                "expires": expires,
+                "expires_ts": expires_ts,
                 "og": og,
                 "media_id": media_id,
                 "download_ts": download_ts,
@@ -166,7 +173,14 @@ class MediaRepositoryStore(SQLBaseStore):
             desc="store_cached_remote_media",
         )
 
-    def update_cached_last_access_time(self, origin_id_tuples, time_ts):
+    def update_cached_last_access_time(self, local_media, remote_media, time_ms):
+        """Updates the last access time of the given media
+
+        Args:
+            local_media (iterable[str]): Set of media_ids
+            remote_media (iterable[(str, str)]): Set of (server_name, media_id)
+            time_ms: Current time in milliseconds
+        """
         def update_cache_txn(txn):
             sql = (
                 "UPDATE remote_media_cache SET last_access_ts = ?"
@@ -174,8 +188,18 @@ class MediaRepositoryStore(SQLBaseStore):
             )
 
             txn.executemany(sql, (
-                (time_ts, media_origin, media_id)
-                for media_origin, media_id in origin_id_tuples
+                (time_ms, media_origin, media_id)
+                for media_origin, media_id in remote_media
+            ))
+
+            sql = (
+                "UPDATE local_media_repository SET last_access_ts = ?"
+                " WHERE media_id = ?"
+            )
+
+            txn.executemany(sql, (
+                (time_ms, media_id)
+                for media_id in local_media
             ))
 
         return self.runInteraction("update_cached_last_access_time", update_cache_txn)
@@ -238,3 +262,70 @@ class MediaRepositoryStore(SQLBaseStore):
                 },
             )
         return self.runInteraction("delete_remote_media", delete_remote_media_txn)
+
+    def get_expired_url_cache(self, now_ts):
+        sql = (
+            "SELECT media_id FROM local_media_repository_url_cache"
+            " WHERE expires_ts < ?"
+            " ORDER BY expires_ts ASC"
+            " LIMIT 500"
+        )
+
+        def _get_expired_url_cache_txn(txn):
+            txn.execute(sql, (now_ts,))
+            return [row[0] for row in txn]
+
+        return self.runInteraction("get_expired_url_cache", _get_expired_url_cache_txn)
+
+    def delete_url_cache(self, media_ids):
+        if len(media_ids) == 0:
+            return
+
+        sql = (
+            "DELETE FROM local_media_repository_url_cache"
+            " WHERE media_id = ?"
+        )
+
+        def _delete_url_cache_txn(txn):
+            txn.executemany(sql, [(media_id,) for media_id in media_ids])
+
+        return self.runInteraction("delete_url_cache", _delete_url_cache_txn)
+
+    def get_url_cache_media_before(self, before_ts):
+        sql = (
+            "SELECT media_id FROM local_media_repository"
+            " WHERE created_ts < ? AND url_cache IS NOT NULL"
+            " ORDER BY created_ts ASC"
+            " LIMIT 500"
+        )
+
+        def _get_url_cache_media_before_txn(txn):
+            txn.execute(sql, (before_ts,))
+            return [row[0] for row in txn]
+
+        return self.runInteraction(
+            "get_url_cache_media_before", _get_url_cache_media_before_txn,
+        )
+
+    def delete_url_cache_media(self, media_ids):
+        if len(media_ids) == 0:
+            return
+
+        def _delete_url_cache_media_txn(txn):
+            sql = (
+                "DELETE FROM local_media_repository"
+                " WHERE media_id = ?"
+            )
+
+            txn.executemany(sql, [(media_id,) for media_id in media_ids])
+
+            sql = (
+                "DELETE FROM local_media_repository_thumbnails"
+                " WHERE media_id = ?"
+            )
+
+            txn.executemany(sql, [(media_id,) for media_id in media_ids])
+
+        return self.runInteraction(
+            "delete_url_cache_media", _delete_url_cache_media_txn,
+        )