diff --git a/synapse/media/media_storage.py b/synapse/media/media_storage.py
index 2a106bb0eb..afd33c02a1 100644
--- a/synapse/media/media_storage.py
+++ b/synapse/media/media_storage.py
@@ -19,6 +19,7 @@
#
#
import contextlib
+import hashlib
import json
import logging
import os
@@ -49,15 +50,11 @@ from zope.interface import implementer
from twisted.internet import interfaces
from twisted.internet.defer import Deferred
from twisted.internet.interfaces import IConsumer
-from twisted.protocols.basic import FileSender
from synapse.api.errors import NotFoundError
-from synapse.logging.context import (
- defer_to_thread,
- make_deferred_yieldable,
- run_in_background,
-)
+from synapse.logging.context import defer_to_thread, run_in_background
from synapse.logging.opentracing import start_active_span, trace, trace_with_opname
+from synapse.media._base import ThreadedFileSender
from synapse.util import Clock
from synapse.util.file_consumer import BackgroundFileConsumer
@@ -74,6 +71,88 @@ logger = logging.getLogger(__name__)
CRLF = b"\r\n"
+class SHA256TransparentIOWriter:
+ """Will generate a SHA256 hash from a source stream transparently.
+
+ Args:
+ source: Source stream.
+ """
+
+ def __init__(self, source: BinaryIO):
+ self._hash = hashlib.sha256()
+ self._source = source
+
+ def write(self, buffer: Union[bytes, bytearray]) -> int:
+ """Wrapper for source.write()
+
+ Args:
+ buffer
+
+ Returns:
+ the value of source.write()
+ """
+ res = self._source.write(buffer)
+ self._hash.update(buffer)
+ return res
+
+ def hexdigest(self) -> str:
+ """The digest of the written or read value.
+
+ Returns:
+ The digest in hex formaat.
+ """
+ return self._hash.hexdigest()
+
+ def wrap(self) -> BinaryIO:
+ # This class implements a subset the IO interface and passes through everything else via __getattr__
+ return cast(BinaryIO, self)
+
+ # Passthrough any other calls
+ def __getattr__(self, attr_name: str) -> Any:
+ return getattr(self._source, attr_name)
+
+
+class SHA256TransparentIOReader:
+ """Will generate a SHA256 hash from a source stream transparently.
+
+ Args:
+ source: Source IO stream.
+ """
+
+ def __init__(self, source: IO):
+ self._hash = hashlib.sha256()
+ self._source = source
+
+ def read(self, n: int = -1) -> bytes:
+ """Wrapper for source.read()
+
+ Args:
+ n
+
+ Returns:
+ the value of source.read()
+ """
+ bytes = self._source.read(n)
+ self._hash.update(bytes)
+ return bytes
+
+ def hexdigest(self) -> str:
+ """The digest of the written or read value.
+
+ Returns:
+ The digest in hex formaat.
+ """
+ return self._hash.hexdigest()
+
+ def wrap(self) -> IO:
+ # This class implements a subset the IO interface and passes through everything else via __getattr__
+ return cast(IO, self)
+
+ # Passthrough any other calls
+ def __getattr__(self, attr_name: str) -> Any:
+ return getattr(self._source, attr_name)
+
+
class MediaStorage:
"""Responsible for storing/fetching files from local sources.
@@ -111,7 +190,6 @@ class MediaStorage:
Returns:
the file path written to in the primary media store
"""
-
async with self.store_into_file(file_info) as (f, fname):
# Write to the main media repository
await self.write_to_file(source, f)
@@ -213,7 +291,7 @@ class MediaStorage:
local_path = os.path.join(self.local_media_directory, path)
if os.path.exists(local_path):
logger.debug("responding with local file %s", local_path)
- return FileResponder(open(local_path, "rb"))
+ return FileResponder(self.hs, open(local_path, "rb"))
logger.debug("local file %s did not exist", local_path)
for provider in self.storage_providers:
@@ -336,13 +414,12 @@ class FileResponder(Responder):
is closed when finished streaming.
"""
- def __init__(self, open_file: IO):
+ def __init__(self, hs: "HomeServer", open_file: BinaryIO):
+ self.hs = hs
self.open_file = open_file
def write_to_consumer(self, consumer: IConsumer) -> Deferred:
- return make_deferred_yieldable(
- FileSender().beginFileTransfer(self.open_file, consumer)
- )
+ return ThreadedFileSender(self.hs).beginFileTransfer(self.open_file, consumer)
def __exit__(
self,
@@ -549,7 +626,7 @@ class MultipartFileConsumer:
Calculate the content length of the multipart response
in bytes.
"""
- if not self.length:
+ if self.length is None:
return None
# calculate length of json field and content-type, disposition headers
json_field = json.dumps(self.json_field)
|