summary refs log tree commit diff
path: root/synapse/media/media_storage.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/media/media_storage.py103
1 files changed, 90 insertions, 13 deletions
diff --git a/synapse/media/media_storage.py b/synapse/media/media_storage.py

index 2a106bb0eb..afd33c02a1 100644 --- a/synapse/media/media_storage.py +++ b/synapse/media/media_storage.py
@@ -19,6 +19,7 @@ # # import contextlib +import hashlib import json import logging import os @@ -49,15 +50,11 @@ from zope.interface import implementer from twisted.internet import interfaces from twisted.internet.defer import Deferred from twisted.internet.interfaces import IConsumer -from twisted.protocols.basic import FileSender from synapse.api.errors import NotFoundError -from synapse.logging.context import ( - defer_to_thread, - make_deferred_yieldable, - run_in_background, -) +from synapse.logging.context import defer_to_thread, run_in_background from synapse.logging.opentracing import start_active_span, trace, trace_with_opname +from synapse.media._base import ThreadedFileSender from synapse.util import Clock from synapse.util.file_consumer import BackgroundFileConsumer @@ -74,6 +71,88 @@ logger = logging.getLogger(__name__) CRLF = b"\r\n" +class SHA256TransparentIOWriter: + """Will generate a SHA256 hash from a source stream transparently. + + Args: + source: Source stream. + """ + + def __init__(self, source: BinaryIO): + self._hash = hashlib.sha256() + self._source = source + + def write(self, buffer: Union[bytes, bytearray]) -> int: + """Wrapper for source.write() + + Args: + buffer + + Returns: + the value of source.write() + """ + res = self._source.write(buffer) + self._hash.update(buffer) + return res + + def hexdigest(self) -> str: + """The digest of the written or read value. + + Returns: + The digest in hex formaat. + """ + return self._hash.hexdigest() + + def wrap(self) -> BinaryIO: + # This class implements a subset the IO interface and passes through everything else via __getattr__ + return cast(BinaryIO, self) + + # Passthrough any other calls + def __getattr__(self, attr_name: str) -> Any: + return getattr(self._source, attr_name) + + +class SHA256TransparentIOReader: + """Will generate a SHA256 hash from a source stream transparently. + + Args: + source: Source IO stream. + """ + + def __init__(self, source: IO): + self._hash = hashlib.sha256() + self._source = source + + def read(self, n: int = -1) -> bytes: + """Wrapper for source.read() + + Args: + n + + Returns: + the value of source.read() + """ + bytes = self._source.read(n) + self._hash.update(bytes) + return bytes + + def hexdigest(self) -> str: + """The digest of the written or read value. + + Returns: + The digest in hex formaat. + """ + return self._hash.hexdigest() + + def wrap(self) -> IO: + # This class implements a subset the IO interface and passes through everything else via __getattr__ + return cast(IO, self) + + # Passthrough any other calls + def __getattr__(self, attr_name: str) -> Any: + return getattr(self._source, attr_name) + + class MediaStorage: """Responsible for storing/fetching files from local sources. @@ -111,7 +190,6 @@ class MediaStorage: Returns: the file path written to in the primary media store """ - async with self.store_into_file(file_info) as (f, fname): # Write to the main media repository await self.write_to_file(source, f) @@ -213,7 +291,7 @@ class MediaStorage: local_path = os.path.join(self.local_media_directory, path) if os.path.exists(local_path): logger.debug("responding with local file %s", local_path) - return FileResponder(open(local_path, "rb")) + return FileResponder(self.hs, open(local_path, "rb")) logger.debug("local file %s did not exist", local_path) for provider in self.storage_providers: @@ -336,13 +414,12 @@ class FileResponder(Responder): is closed when finished streaming. """ - def __init__(self, open_file: IO): + def __init__(self, hs: "HomeServer", open_file: BinaryIO): + self.hs = hs self.open_file = open_file def write_to_consumer(self, consumer: IConsumer) -> Deferred: - return make_deferred_yieldable( - FileSender().beginFileTransfer(self.open_file, consumer) - ) + return ThreadedFileSender(self.hs).beginFileTransfer(self.open_file, consumer) def __exit__( self, @@ -549,7 +626,7 @@ class MultipartFileConsumer: Calculate the content length of the multipart response in bytes. """ - if not self.length: + if self.length is None: return None # calculate length of json field and content-type, disposition headers json_field = json.dumps(self.json_field)