summary refs log tree commit diff
path: root/synapse/media/media_repository.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/media/media_repository.py70
1 files changed, 60 insertions, 10 deletions
diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py

index 8bc92305fe..18c5a8ecec 100644 --- a/synapse/media/media_repository.py +++ b/synapse/media/media_repository.py
@@ -52,13 +52,18 @@ from synapse.media._base import ( FileInfo, Responder, ThumbnailInfo, + check_for_cached_entry_and_respond, get_filename_from_headers, respond_404, respond_with_multipart_responder, respond_with_responder, ) from synapse.media.filepath import MediaFilePaths -from synapse.media.media_storage import MediaStorage +from synapse.media.media_storage import ( + MediaStorage, + SHA256TransparentIOReader, + SHA256TransparentIOWriter, +) from synapse.media.storage_provider import StorageProviderWrapper from synapse.media.thumbnailer import Thumbnailer, ThumbnailError from synapse.media.url_previewer import UrlPreviewer @@ -259,7 +264,7 @@ class MediaRepository: """ media = await self.store.get_local_media(media_id) if media is None: - raise SynapseError(404, "Unknow media ID", errcode=Codes.NOT_FOUND) + raise NotFoundError("Unknown media ID") if media.user_id != auth_user.to_string(): raise SynapseError( @@ -300,15 +305,26 @@ class MediaRepository: auth_user: The user_id of the uploader """ file_info = FileInfo(server_name=None, file_id=media_id) - fname = await self.media_storage.store_file(content, file_info) + sha256reader = SHA256TransparentIOReader(content) + # This implements all of IO as it has a passthrough + fname = await self.media_storage.store_file(sha256reader.wrap(), file_info) + sha256 = sha256reader.hexdigest() + should_quarantine = await self.store.get_is_hash_quarantined(sha256) logger.info("Stored local media in file %r", fname) + if should_quarantine: + logger.warn( + "Media has been automatically quarantined as it matched existing quarantined media" + ) + await self.store.update_local_media( media_id=media_id, media_type=media_type, upload_name=upload_name, media_length=content_length, user_id=auth_user, + sha256=sha256, + quarantined_by="system" if should_quarantine else None, ) try: @@ -341,11 +357,19 @@ class MediaRepository: media_id = random_string(24) file_info = FileInfo(server_name=None, file_id=media_id) - - fname = await self.media_storage.store_file(content, file_info) + # This implements all of IO as it has a passthrough + sha256reader = SHA256TransparentIOReader(content) + fname = await self.media_storage.store_file(sha256reader.wrap(), file_info) + sha256 = sha256reader.hexdigest() + should_quarantine = await self.store.get_is_hash_quarantined(sha256) logger.info("Stored local media in file %r", fname) + if should_quarantine: + logger.warn( + "Media has been automatically quarantined as it matched existing quarantined media" + ) + await self.store.store_local_media( media_id=media_id, media_type=media_type, @@ -353,6 +377,8 @@ class MediaRepository: upload_name=upload_name, media_length=content_length, user_id=auth_user, + sha256=sha256, + quarantined_by="system" if should_quarantine else None, ) try: @@ -459,6 +485,11 @@ class MediaRepository: self.mark_recently_accessed(None, media_id) + # Once we've checked auth we can return early if the media is cached on + # the client + if check_for_cached_entry_and_respond(request): + return + media_type = media_info.media_type if not media_type: media_type = "application/octet-stream" @@ -471,7 +502,7 @@ class MediaRepository: responder = await self.media_storage.fetch_media(file_info) if federation: await respond_with_multipart_responder( - self.clock, request, responder, media_info + self.clock, request, responder, media_type, media_length, upload_name ) else: await respond_with_responder( @@ -538,6 +569,17 @@ class MediaRepository: allow_authenticated, ) + # Check if the media is cached on the client, if so return 304. We need + # to do this after we have fetched remote media, as we need it to do the + # auth. + if check_for_cached_entry_and_respond(request): + # We always need to use the responder. + if responder: + with responder: + pass + + return + # We deliberately stream the file outside the lock if responder and media_info: upload_name = name if name else media_info.upload_name @@ -739,11 +781,13 @@ class MediaRepository: file_info = FileInfo(server_name=server_name, file_id=file_id) async with self.media_storage.store_into_file(file_info) as (f, fname): + sha256writer = SHA256TransparentIOWriter(f) try: length, headers = await self.client.download_media( server_name, media_id, - output_stream=f, + # This implements all of BinaryIO as it has a passthrough + output_stream=sha256writer.wrap(), max_size=self.max_upload_size, max_timeout_ms=max_timeout_ms, download_ratelimiter=download_ratelimiter, @@ -808,6 +852,7 @@ class MediaRepository: upload_name=upload_name, media_length=length, filesystem_id=file_id, + sha256=sha256writer.hexdigest(), ) logger.info("Stored remote media in file %r", fname) @@ -828,6 +873,7 @@ class MediaRepository: last_access_ts=time_now_ms, quarantined_by=None, authenticated=authenticated, + sha256=sha256writer.hexdigest(), ) async def _federation_download_remote_file( @@ -862,11 +908,13 @@ class MediaRepository: file_info = FileInfo(server_name=server_name, file_id=file_id) async with self.media_storage.store_into_file(file_info) as (f, fname): + sha256writer = SHA256TransparentIOWriter(f) try: res = await self.client.federation_download_media( server_name, media_id, - output_stream=f, + # This implements all of BinaryIO as it has a passthrough + output_stream=sha256writer.wrap(), max_size=self.max_upload_size, max_timeout_ms=max_timeout_ms, download_ratelimiter=download_ratelimiter, @@ -937,6 +985,7 @@ class MediaRepository: upload_name=upload_name, media_length=length, filesystem_id=file_id, + sha256=sha256writer.hexdigest(), ) logger.debug("Stored remote media in file %r", fname) @@ -957,6 +1006,7 @@ class MediaRepository: last_access_ts=time_now_ms, quarantined_by=None, authenticated=authenticated, + sha256=sha256writer.hexdigest(), ) def _get_thumbnail_requirements( @@ -1008,7 +1058,7 @@ class MediaRepository: t_method: str, t_type: str, url_cache: bool, - ) -> Optional[str]: + ) -> Optional[Tuple[str, FileInfo]]: input_path = await self.media_storage.ensure_media_is_in_local_cache( FileInfo(None, media_id, url_cache=url_cache) ) @@ -1070,7 +1120,7 @@ class MediaRepository: t_len, ) - return output_path + return output_path, file_info # Could not generate thumbnail. return None