diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py
index 8bc92305fe..18c5a8ecec 100644
--- a/synapse/media/media_repository.py
+++ b/synapse/media/media_repository.py
@@ -52,13 +52,18 @@ from synapse.media._base import (
FileInfo,
Responder,
ThumbnailInfo,
+ check_for_cached_entry_and_respond,
get_filename_from_headers,
respond_404,
respond_with_multipart_responder,
respond_with_responder,
)
from synapse.media.filepath import MediaFilePaths
-from synapse.media.media_storage import MediaStorage
+from synapse.media.media_storage import (
+ MediaStorage,
+ SHA256TransparentIOReader,
+ SHA256TransparentIOWriter,
+)
from synapse.media.storage_provider import StorageProviderWrapper
from synapse.media.thumbnailer import Thumbnailer, ThumbnailError
from synapse.media.url_previewer import UrlPreviewer
@@ -259,7 +264,7 @@ class MediaRepository:
"""
media = await self.store.get_local_media(media_id)
if media is None:
- raise SynapseError(404, "Unknow media ID", errcode=Codes.NOT_FOUND)
+ raise NotFoundError("Unknown media ID")
if media.user_id != auth_user.to_string():
raise SynapseError(
@@ -300,15 +305,26 @@ class MediaRepository:
auth_user: The user_id of the uploader
"""
file_info = FileInfo(server_name=None, file_id=media_id)
- fname = await self.media_storage.store_file(content, file_info)
+ sha256reader = SHA256TransparentIOReader(content)
+ # This implements all of IO as it has a passthrough
+ fname = await self.media_storage.store_file(sha256reader.wrap(), file_info)
+ sha256 = sha256reader.hexdigest()
+ should_quarantine = await self.store.get_is_hash_quarantined(sha256)
logger.info("Stored local media in file %r", fname)
+ if should_quarantine:
+ logger.warn(
+ "Media has been automatically quarantined as it matched existing quarantined media"
+ )
+
await self.store.update_local_media(
media_id=media_id,
media_type=media_type,
upload_name=upload_name,
media_length=content_length,
user_id=auth_user,
+ sha256=sha256,
+ quarantined_by="system" if should_quarantine else None,
)
try:
@@ -341,11 +357,19 @@ class MediaRepository:
media_id = random_string(24)
file_info = FileInfo(server_name=None, file_id=media_id)
-
- fname = await self.media_storage.store_file(content, file_info)
+ # This implements all of IO as it has a passthrough
+ sha256reader = SHA256TransparentIOReader(content)
+ fname = await self.media_storage.store_file(sha256reader.wrap(), file_info)
+ sha256 = sha256reader.hexdigest()
+ should_quarantine = await self.store.get_is_hash_quarantined(sha256)
logger.info("Stored local media in file %r", fname)
+ if should_quarantine:
+ logger.warn(
+ "Media has been automatically quarantined as it matched existing quarantined media"
+ )
+
await self.store.store_local_media(
media_id=media_id,
media_type=media_type,
@@ -353,6 +377,8 @@ class MediaRepository:
upload_name=upload_name,
media_length=content_length,
user_id=auth_user,
+ sha256=sha256,
+ quarantined_by="system" if should_quarantine else None,
)
try:
@@ -459,6 +485,11 @@ class MediaRepository:
self.mark_recently_accessed(None, media_id)
+ # Once we've checked auth we can return early if the media is cached on
+ # the client
+ if check_for_cached_entry_and_respond(request):
+ return
+
media_type = media_info.media_type
if not media_type:
media_type = "application/octet-stream"
@@ -471,7 +502,7 @@ class MediaRepository:
responder = await self.media_storage.fetch_media(file_info)
if federation:
await respond_with_multipart_responder(
- self.clock, request, responder, media_info
+ self.clock, request, responder, media_type, media_length, upload_name
)
else:
await respond_with_responder(
@@ -538,6 +569,17 @@ class MediaRepository:
allow_authenticated,
)
+ # Check if the media is cached on the client, if so return 304. We need
+ # to do this after we have fetched remote media, as we need it to do the
+ # auth.
+ if check_for_cached_entry_and_respond(request):
+ # We always need to use the responder.
+ if responder:
+ with responder:
+ pass
+
+ return
+
# We deliberately stream the file outside the lock
if responder and media_info:
upload_name = name if name else media_info.upload_name
@@ -739,11 +781,13 @@ class MediaRepository:
file_info = FileInfo(server_name=server_name, file_id=file_id)
async with self.media_storage.store_into_file(file_info) as (f, fname):
+ sha256writer = SHA256TransparentIOWriter(f)
try:
length, headers = await self.client.download_media(
server_name,
media_id,
- output_stream=f,
+ # This implements all of BinaryIO as it has a passthrough
+ output_stream=sha256writer.wrap(),
max_size=self.max_upload_size,
max_timeout_ms=max_timeout_ms,
download_ratelimiter=download_ratelimiter,
@@ -808,6 +852,7 @@ class MediaRepository:
upload_name=upload_name,
media_length=length,
filesystem_id=file_id,
+ sha256=sha256writer.hexdigest(),
)
logger.info("Stored remote media in file %r", fname)
@@ -828,6 +873,7 @@ class MediaRepository:
last_access_ts=time_now_ms,
quarantined_by=None,
authenticated=authenticated,
+ sha256=sha256writer.hexdigest(),
)
async def _federation_download_remote_file(
@@ -862,11 +908,13 @@ class MediaRepository:
file_info = FileInfo(server_name=server_name, file_id=file_id)
async with self.media_storage.store_into_file(file_info) as (f, fname):
+ sha256writer = SHA256TransparentIOWriter(f)
try:
res = await self.client.federation_download_media(
server_name,
media_id,
- output_stream=f,
+ # This implements all of BinaryIO as it has a passthrough
+ output_stream=sha256writer.wrap(),
max_size=self.max_upload_size,
max_timeout_ms=max_timeout_ms,
download_ratelimiter=download_ratelimiter,
@@ -937,6 +985,7 @@ class MediaRepository:
upload_name=upload_name,
media_length=length,
filesystem_id=file_id,
+ sha256=sha256writer.hexdigest(),
)
logger.debug("Stored remote media in file %r", fname)
@@ -957,6 +1006,7 @@ class MediaRepository:
last_access_ts=time_now_ms,
quarantined_by=None,
authenticated=authenticated,
+ sha256=sha256writer.hexdigest(),
)
def _get_thumbnail_requirements(
@@ -1008,7 +1058,7 @@ class MediaRepository:
t_method: str,
t_type: str,
url_cache: bool,
- ) -> Optional[str]:
+ ) -> Optional[Tuple[str, FileInfo]]:
input_path = await self.media_storage.ensure_media_is_in_local_cache(
FileInfo(None, media_id, url_cache=url_cache)
)
@@ -1070,7 +1120,7 @@ class MediaRepository:
t_len,
)
- return output_path
+ return output_path, file_info
# Could not generate thumbnail.
return None
|