diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 6568e61829..67aa993f19 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -213,6 +213,12 @@ async def respond_with_responder(
file_size (int|None): Size in bytes of the media. If not known it should be None
upload_name (str|None): The name of the requested file, if any.
"""
+ if request._disconnected:
+ logger.warning(
+ "Not sending response to request %s, already disconnected.", request
+ )
+ return
+
if not responder:
respond_404(request)
return
diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py
index 7447eeaebe..9e079f672f 100644
--- a/synapse/rest/media/v1/filepath.py
+++ b/synapse/rest/media/v1/filepath.py
@@ -69,6 +69,23 @@ class MediaFilePaths:
local_media_thumbnail = _wrap_in_base_path(local_media_thumbnail_rel)
+ def local_media_thumbnail_dir(self, media_id: str) -> str:
+ """
+ Retrieve the local store path of thumbnails of a given media_id
+
+ Args:
+ media_id: The media ID to query.
+ Returns:
+ Path of local_thumbnails from media_id
+ """
+ return os.path.join(
+ self.base_path,
+ "local_thumbnails",
+ media_id[0:2],
+ media_id[2:4],
+ media_id[4:],
+ )
+
def remote_media_filepath_rel(self, server_name, file_id):
return os.path.join(
"remote_content", server_name, file_id[0:2], file_id[2:4], file_id[4:]
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index e1192b47cd..9cac74ebd8 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -18,7 +18,7 @@ import errno
import logging
import os
import shutil
-from typing import IO, Dict, Optional, Tuple
+from typing import IO, Dict, List, Optional, Tuple
import twisted.internet.error
import twisted.web.http
@@ -305,15 +305,12 @@ class MediaRepository:
# file_id is the ID we use to track the file locally. If we've already
# seen the file then reuse the existing ID, otherwise genereate a new
# one.
- if media_info:
- file_id = media_info["filesystem_id"]
- else:
- file_id = random_string(24)
-
- file_info = FileInfo(server_name, file_id)
# If we have an entry in the DB, try and look for it
if media_info:
+ file_id = media_info["filesystem_id"]
+ file_info = FileInfo(server_name, file_id)
+
if media_info["quarantined_by"]:
logger.info("Media is quarantined")
raise NotFoundError()
@@ -324,14 +321,34 @@ class MediaRepository:
# Failed to find the file anywhere, lets download it.
- media_info = await self._download_remote_file(server_name, media_id, file_id)
+ try:
+ media_info = await self._download_remote_file(server_name, media_id,)
+ except SynapseError:
+ raise
+ except Exception as e:
+ # An exception may be because we downloaded media in another
+ # process, so let's check if we magically have the media.
+ media_info = await self.store.get_cached_remote_media(server_name, media_id)
+ if not media_info:
+ raise e
+
+ file_id = media_info["filesystem_id"]
+ file_info = FileInfo(server_name, file_id)
+
+ # We generate thumbnails even if another process downloaded the media
+ # as a) it's conceivable that the other download request dies before it
+ # generates thumbnails, but mainly b) we want to be sure the thumbnails
+ # have finished being generated before responding to the client,
+ # otherwise they'll request thumbnails and get a 404 if they're not
+ # ready yet.
+ await self._generate_thumbnails(
+ server_name, media_id, file_id, media_info["media_type"]
+ )
responder = await self.media_storage.fetch_media(file_info)
return responder, media_info
- async def _download_remote_file(
- self, server_name: str, media_id: str, file_id: str
- ) -> dict:
+ async def _download_remote_file(self, server_name: str, media_id: str,) -> dict:
"""Attempt to download the remote file from the given server name,
using the given file_id as the local id.
@@ -346,6 +363,8 @@ class MediaRepository:
The media info of the file.
"""
+ file_id = random_string(24)
+
file_info = FileInfo(server_name=server_name, file_id=file_id)
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
@@ -401,22 +420,32 @@ class MediaRepository:
await finish()
- media_type = headers[b"Content-Type"][0].decode("ascii")
- upload_name = get_filename_from_headers(headers)
- time_now_ms = self.clock.time_msec()
+ media_type = headers[b"Content-Type"][0].decode("ascii")
+ upload_name = get_filename_from_headers(headers)
+ time_now_ms = self.clock.time_msec()
+
+ # Multiple remote media download requests can race (when using
+ # multiple media repos), so this may throw a violation constraint
+ # exception. If it does we'll delete the newly downloaded file from
+ # disk (as we're in the ctx manager).
+ #
+ # However: we've already called `finish()` so we may have also
+ # written to the storage providers. This is preferable to the
+ # alternative where we call `finish()` *after* this, where we could
+ # end up having an entry in the DB but fail to write the files to
+ # the storage providers.
+ await self.store.store_cached_remote_media(
+ origin=server_name,
+ media_id=media_id,
+ media_type=media_type,
+ time_now_ms=self.clock.time_msec(),
+ upload_name=upload_name,
+ media_length=length,
+ filesystem_id=file_id,
+ )
logger.info("Stored remote media in file %r", fname)
- await self.store.store_cached_remote_media(
- origin=server_name,
- media_id=media_id,
- media_type=media_type,
- time_now_ms=self.clock.time_msec(),
- upload_name=upload_name,
- media_length=length,
- filesystem_id=file_id,
- )
-
media_info = {
"media_type": media_type,
"media_length": length,
@@ -425,8 +454,6 @@ class MediaRepository:
"filesystem_id": file_id,
}
- await self._generate_thumbnails(server_name, media_id, file_id, media_type)
-
return media_info
def _get_thumbnail_requirements(self, media_type):
@@ -692,42 +719,60 @@ class MediaRepository:
if not t_byte_source:
continue
- try:
- file_info = FileInfo(
- server_name=server_name,
- file_id=file_id,
- thumbnail=True,
- thumbnail_width=t_width,
- thumbnail_height=t_height,
- thumbnail_method=t_method,
- thumbnail_type=t_type,
- url_cache=url_cache,
- )
-
- output_path = await self.media_storage.store_file(
- t_byte_source, file_info
- )
- finally:
- t_byte_source.close()
-
- t_len = os.path.getsize(output_path)
+ file_info = FileInfo(
+ server_name=server_name,
+ file_id=file_id,
+ thumbnail=True,
+ thumbnail_width=t_width,
+ thumbnail_height=t_height,
+ thumbnail_method=t_method,
+ thumbnail_type=t_type,
+ url_cache=url_cache,
+ )
- # Write to database
- if server_name:
- await self.store.store_remote_media_thumbnail(
- server_name,
- media_id,
- file_id,
- t_width,
- t_height,
- t_type,
- t_method,
- t_len,
- )
- else:
- await self.store.store_local_thumbnail(
- media_id, t_width, t_height, t_type, t_method, t_len
- )
+ with self.media_storage.store_into_file(file_info) as (f, fname, finish):
+ try:
+ await self.media_storage.write_to_file(t_byte_source, f)
+ await finish()
+ finally:
+ t_byte_source.close()
+
+ t_len = os.path.getsize(fname)
+
+ # Write to database
+ if server_name:
+ # Multiple remote media download requests can race (when
+ # using multiple media repos), so this may throw a violation
+ # constraint exception. If it does we'll delete the newly
+ # generated thumbnail from disk (as we're in the ctx
+ # manager).
+ #
+ # However: we've already called `finish()` so we may have
+ # also written to the storage providers. This is preferable
+ # to the alternative where we call `finish()` *after* this,
+ # where we could end up having an entry in the DB but fail
+ # to write the files to the storage providers.
+ try:
+ await self.store.store_remote_media_thumbnail(
+ server_name,
+ media_id,
+ file_id,
+ t_width,
+ t_height,
+ t_type,
+ t_method,
+ t_len,
+ )
+ except Exception as e:
+ thumbnail_exists = await self.store.get_remote_media_thumbnail(
+ server_name, media_id, t_width, t_height, t_type,
+ )
+ if not thumbnail_exists:
+ raise e
+ else:
+ await self.store.store_local_thumbnail(
+ media_id, t_width, t_height, t_type, t_method, t_len
+ )
return {"width": m_width, "height": m_height}
@@ -767,6 +812,76 @@ class MediaRepository:
return {"deleted": deleted}
+ async def delete_local_media(self, media_id: str) -> Tuple[List[str], int]:
+ """
+ Delete the given local or remote media ID from this server
+
+ Args:
+ media_id: The media ID to delete.
+ Returns:
+ A tuple of (list of deleted media IDs, total deleted media IDs).
+ """
+ return await self._remove_local_media_from_disk([media_id])
+
+ async def delete_old_local_media(
+ self, before_ts: int, size_gt: int = 0, keep_profiles: bool = True,
+ ) -> Tuple[List[str], int]:
+ """
+ Delete local or remote media from this server by size and timestamp. Removes
+ media files, any thumbnails and cached URLs.
+
+ Args:
+ before_ts: Unix timestamp in ms.
+ Files that were last used before this timestamp will be deleted
+ size_gt: Size of the media in bytes. Files that are larger will be deleted
+ keep_profiles: Switch to delete also files that are still used in image data
+ (e.g user profile, room avatar)
+ If false these files will be deleted
+ Returns:
+ A tuple of (list of deleted media IDs, total deleted media IDs).
+ """
+ old_media = await self.store.get_local_media_before(
+ before_ts, size_gt, keep_profiles,
+ )
+ return await self._remove_local_media_from_disk(old_media)
+
+ async def _remove_local_media_from_disk(
+ self, media_ids: List[str]
+ ) -> Tuple[List[str], int]:
+ """
+ Delete local or remote media from this server. Removes media files,
+ any thumbnails and cached URLs.
+
+ Args:
+ media_ids: List of media_id to delete
+ Returns:
+ A tuple of (list of deleted media IDs, total deleted media IDs).
+ """
+ removed_media = []
+ for media_id in media_ids:
+ logger.info("Deleting media with ID '%s'", media_id)
+ full_path = self.filepaths.local_media_filepath(media_id)
+ try:
+ os.remove(full_path)
+ except OSError as e:
+ logger.warning("Failed to remove file: %r: %s", full_path, e)
+ if e.errno == errno.ENOENT:
+ pass
+ else:
+ continue
+
+ thumbnail_dir = self.filepaths.local_media_thumbnail_dir(media_id)
+ shutil.rmtree(thumbnail_dir, ignore_errors=True)
+
+ await self.store.delete_remote_media(self.server_name, media_id)
+
+ await self.store.delete_url_cache((media_id,))
+ await self.store.delete_url_cache_media((media_id,))
+
+ removed_media.append(media_id)
+
+ return removed_media, len(removed_media)
+
class MediaRepositoryResource(Resource):
"""File uploading and downloading.
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index a9586fb0b7..268e0c8f50 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -52,6 +52,7 @@ class MediaStorage:
storage_providers: Sequence["StorageProviderWrapper"],
):
self.hs = hs
+ self.reactor = hs.get_reactor()
self.local_media_directory = local_media_directory
self.filepaths = filepaths
self.storage_providers = storage_providers
@@ -70,13 +71,16 @@ class MediaStorage:
with self.store_into_file(file_info) as (f, fname, finish_cb):
# Write to the main repository
- await defer_to_thread(
- self.hs.get_reactor(), _write_file_synchronously, source, f
- )
+ await self.write_to_file(source, f)
await finish_cb()
return fname
+ async def write_to_file(self, source: IO, output: IO):
+ """Asynchronously write the `source` to `output`.
+ """
+ await defer_to_thread(self.reactor, _write_file_synchronously, source, output)
+
@contextlib.contextmanager
def store_into_file(self, file_info: FileInfo):
"""Context manager used to get a file like object to write into, as
@@ -112,14 +116,20 @@ class MediaStorage:
finished_called = [False]
- async def finish():
- for provider in self.storage_providers:
- await provider.store_file(path, file_info)
-
- finished_called[0] = True
-
try:
with open(fname, "wb") as f:
+
+ async def finish():
+ # Ensure that all writes have been flushed and close the
+ # file.
+ f.flush()
+ f.close()
+
+ for provider in self.storage_providers:
+ await provider.store_file(path, file_info)
+
+ finished_called[0] = True
+
yield f, fname, finish
except Exception:
try:
@@ -210,7 +220,7 @@ class MediaStorage:
if res:
with res:
consumer = BackgroundFileConsumer(
- open(local_path, "wb"), self.hs.get_reactor()
+ open(local_path, "wb"), self.reactor
)
await res.write_to_consumer(consumer)
await consumer.wait()
|