summary refs log tree commit diff
path: root/synapse/storage/databases/main/media_repository.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/media_repository.py')
-rw-r--r--synapse/storage/databases/main/media_repository.py398
1 files changed, 398 insertions, 0 deletions
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
new file mode 100644

index 0000000000..80fc1cd009 --- /dev/null +++ b/synapse/storage/databases/main/media_repository.py
@@ -0,0 +1,398 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.storage._base import SQLBaseStore +from synapse.storage.database import DatabasePool + + +class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): + def __init__(self, database: DatabasePool, db_conn, hs): + super(MediaRepositoryBackgroundUpdateStore, self).__init__( + database, db_conn, hs + ) + + self.db_pool.updates.register_background_index_update( + update_name="local_media_repository_url_idx", + index_name="local_media_repository_url_idx", + table="local_media_repository", + columns=["created_ts"], + where_clause="url_cache IS NOT NULL", + ) + + +class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): + """Persistence for attachments and avatars""" + + def __init__(self, database: DatabasePool, db_conn, hs): + super(MediaRepositoryStore, self).__init__(database, db_conn, hs) + + def get_local_media(self, media_id): + """Get the metadata for a local piece of media + Returns: + None if the media_id doesn't exist. + """ + return self.db_pool.simple_select_one( + "local_media_repository", + {"media_id": media_id}, + ( + "media_type", + "media_length", + "upload_name", + "created_ts", + "quarantined_by", + "url_cache", + ), + allow_none=True, + desc="get_local_media", + ) + + def store_local_media( + self, + media_id, + media_type, + time_now_ms, + upload_name, + media_length, + user_id, + url_cache=None, + ): + return self.db_pool.simple_insert( + "local_media_repository", + { + "media_id": media_id, + "media_type": media_type, + "created_ts": time_now_ms, + "upload_name": upload_name, + "media_length": media_length, + "user_id": user_id.to_string(), + "url_cache": url_cache, + }, + desc="store_local_media", + ) + + def mark_local_media_as_safe(self, media_id: str): + """Mark a local media as safe from quarantining.""" + return self.db_pool.simple_update_one( + table="local_media_repository", + keyvalues={"media_id": media_id}, + updatevalues={"safe_from_quarantine": True}, + desc="mark_local_media_as_safe", + ) + + def get_url_cache(self, url, ts): + """Get the media_id and ts for a cached URL as of the given timestamp + Returns: + None if the URL isn't cached. + """ + + def get_url_cache_txn(txn): + # get the most recently cached result (relative to the given ts) + sql = ( + "SELECT response_code, etag, expires_ts, og, media_id, download_ts" + " FROM local_media_repository_url_cache" + " WHERE url = ? AND download_ts <= ?" + " ORDER BY download_ts DESC LIMIT 1" + ) + txn.execute(sql, (url, ts)) + row = txn.fetchone() + + if not row: + # ...or if we've requested a timestamp older than the oldest + # copy in the cache, return the oldest copy (if any) + sql = ( + "SELECT response_code, etag, expires_ts, og, media_id, download_ts" + " FROM local_media_repository_url_cache" + " WHERE url = ? AND download_ts > ?" + " ORDER BY download_ts ASC LIMIT 1" + ) + txn.execute(sql, (url, ts)) + row = txn.fetchone() + + if not row: + return None + + return dict( + zip( + ( + "response_code", + "etag", + "expires_ts", + "og", + "media_id", + "download_ts", + ), + row, + ) + ) + + return self.db_pool.runInteraction("get_url_cache", get_url_cache_txn) + + def store_url_cache( + self, url, response_code, etag, expires_ts, og, media_id, download_ts + ): + return self.db_pool.simple_insert( + "local_media_repository_url_cache", + { + "url": url, + "response_code": response_code, + "etag": etag, + "expires_ts": expires_ts, + "og": og, + "media_id": media_id, + "download_ts": download_ts, + }, + desc="store_url_cache", + ) + + def get_local_media_thumbnails(self, media_id): + return self.db_pool.simple_select_list( + "local_media_repository_thumbnails", + {"media_id": media_id}, + ( + "thumbnail_width", + "thumbnail_height", + "thumbnail_method", + "thumbnail_type", + "thumbnail_length", + ), + desc="get_local_media_thumbnails", + ) + + def store_local_thumbnail( + self, + media_id, + thumbnail_width, + thumbnail_height, + thumbnail_type, + thumbnail_method, + thumbnail_length, + ): + return self.db_pool.simple_insert( + "local_media_repository_thumbnails", + { + "media_id": media_id, + "thumbnail_width": thumbnail_width, + "thumbnail_height": thumbnail_height, + "thumbnail_method": thumbnail_method, + "thumbnail_type": thumbnail_type, + "thumbnail_length": thumbnail_length, + }, + desc="store_local_thumbnail", + ) + + def get_cached_remote_media(self, origin, media_id): + return self.db_pool.simple_select_one( + "remote_media_cache", + {"media_origin": origin, "media_id": media_id}, + ( + "media_type", + "media_length", + "upload_name", + "created_ts", + "filesystem_id", + "quarantined_by", + ), + allow_none=True, + desc="get_cached_remote_media", + ) + + def store_cached_remote_media( + self, + origin, + media_id, + media_type, + media_length, + time_now_ms, + upload_name, + filesystem_id, + ): + return self.db_pool.simple_insert( + "remote_media_cache", + { + "media_origin": origin, + "media_id": media_id, + "media_type": media_type, + "media_length": media_length, + "created_ts": time_now_ms, + "upload_name": upload_name, + "filesystem_id": filesystem_id, + "last_access_ts": time_now_ms, + }, + desc="store_cached_remote_media", + ) + + def update_cached_last_access_time(self, local_media, remote_media, time_ms): + """Updates the last access time of the given media + + Args: + local_media (iterable[str]): Set of media_ids + remote_media (iterable[(str, str)]): Set of (server_name, media_id) + time_ms: Current time in milliseconds + """ + + def update_cache_txn(txn): + sql = ( + "UPDATE remote_media_cache SET last_access_ts = ?" + " WHERE media_origin = ? AND media_id = ?" + ) + + txn.executemany( + sql, + ( + (time_ms, media_origin, media_id) + for media_origin, media_id in remote_media + ), + ) + + sql = ( + "UPDATE local_media_repository SET last_access_ts = ?" + " WHERE media_id = ?" + ) + + txn.executemany(sql, ((time_ms, media_id) for media_id in local_media)) + + return self.db_pool.runInteraction( + "update_cached_last_access_time", update_cache_txn + ) + + def get_remote_media_thumbnails(self, origin, media_id): + return self.db_pool.simple_select_list( + "remote_media_cache_thumbnails", + {"media_origin": origin, "media_id": media_id}, + ( + "thumbnail_width", + "thumbnail_height", + "thumbnail_method", + "thumbnail_type", + "thumbnail_length", + "filesystem_id", + ), + desc="get_remote_media_thumbnails", + ) + + def store_remote_media_thumbnail( + self, + origin, + media_id, + filesystem_id, + thumbnail_width, + thumbnail_height, + thumbnail_type, + thumbnail_method, + thumbnail_length, + ): + return self.db_pool.simple_insert( + "remote_media_cache_thumbnails", + { + "media_origin": origin, + "media_id": media_id, + "thumbnail_width": thumbnail_width, + "thumbnail_height": thumbnail_height, + "thumbnail_method": thumbnail_method, + "thumbnail_type": thumbnail_type, + "thumbnail_length": thumbnail_length, + "filesystem_id": filesystem_id, + }, + desc="store_remote_media_thumbnail", + ) + + def get_remote_media_before(self, before_ts): + sql = ( + "SELECT media_origin, media_id, filesystem_id" + " FROM remote_media_cache" + " WHERE last_access_ts < ?" + ) + + return self.db_pool.execute( + "get_remote_media_before", self.db_pool.cursor_to_dict, sql, before_ts + ) + + def delete_remote_media(self, media_origin, media_id): + def delete_remote_media_txn(txn): + self.db_pool.simple_delete_txn( + txn, + "remote_media_cache", + keyvalues={"media_origin": media_origin, "media_id": media_id}, + ) + self.db_pool.simple_delete_txn( + txn, + "remote_media_cache_thumbnails", + keyvalues={"media_origin": media_origin, "media_id": media_id}, + ) + + return self.db_pool.runInteraction( + "delete_remote_media", delete_remote_media_txn + ) + + def get_expired_url_cache(self, now_ts): + sql = ( + "SELECT media_id FROM local_media_repository_url_cache" + " WHERE expires_ts < ?" + " ORDER BY expires_ts ASC" + " LIMIT 500" + ) + + def _get_expired_url_cache_txn(txn): + txn.execute(sql, (now_ts,)) + return [row[0] for row in txn] + + return self.db_pool.runInteraction( + "get_expired_url_cache", _get_expired_url_cache_txn + ) + + async def delete_url_cache(self, media_ids): + if len(media_ids) == 0: + return + + sql = "DELETE FROM local_media_repository_url_cache WHERE media_id = ?" + + def _delete_url_cache_txn(txn): + txn.executemany(sql, [(media_id,) for media_id in media_ids]) + + return await self.db_pool.runInteraction( + "delete_url_cache", _delete_url_cache_txn + ) + + def get_url_cache_media_before(self, before_ts): + sql = ( + "SELECT media_id FROM local_media_repository" + " WHERE created_ts < ? AND url_cache IS NOT NULL" + " ORDER BY created_ts ASC" + " LIMIT 500" + ) + + def _get_url_cache_media_before_txn(txn): + txn.execute(sql, (before_ts,)) + return [row[0] for row in txn] + + return self.db_pool.runInteraction( + "get_url_cache_media_before", _get_url_cache_media_before_txn + ) + + async def delete_url_cache_media(self, media_ids): + if len(media_ids) == 0: + return + + def _delete_url_cache_media_txn(txn): + sql = "DELETE FROM local_media_repository WHERE media_id = ?" + + txn.executemany(sql, [(media_id,) for media_id in media_ids]) + + sql = "DELETE FROM local_media_repository_thumbnails WHERE media_id = ?" + + txn.executemany(sql, [(media_id,) for media_id in media_ids]) + + return await self.db_pool.runInteraction( + "delete_url_cache_media", _delete_url_cache_media_txn + )