summary refs log tree commit diff
path: root/synapse/rest/media/v1/media_storage.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/rest/media/v1/media_storage.py')
-rw-r--r--synapse/rest/media/v1/media_storage.py96
1 files changed, 54 insertions, 42 deletions
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py

index 79cb0dddbe..858b6d3005 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py
@@ -12,19 +12,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import contextlib +import inspect import logging import os import shutil +from typing import IO, TYPE_CHECKING, Any, Optional, Sequence -from twisted.internet import defer from twisted.protocols.basic import FileSender from synapse.logging.context import defer_to_thread, make_deferred_yieldable from synapse.util.file_consumer import BackgroundFileConsumer -from ._base import Responder +from ._base import FileInfo, Responder +from .filepath import MediaFilePaths + +if TYPE_CHECKING: + from synapse.server import HomeServer + + from .storage_provider import StorageProvider logger = logging.getLogger(__name__) @@ -33,49 +39,53 @@ class MediaStorage(object): """Responsible for storing/fetching files from local sources. Args: - hs (synapse.server.Homeserver) - local_media_directory (str): Base path where we store media on disk - filepaths (MediaFilePaths) - storage_providers ([StorageProvider]): List of StorageProvider that are - used to fetch and store files. + hs + local_media_directory: Base path where we store media on disk + filepaths + storage_providers: List of StorageProvider that are used to fetch and store files. """ - def __init__(self, hs, local_media_directory, filepaths, storage_providers): + def __init__( + self, + hs: "HomeServer", + local_media_directory: str, + filepaths: MediaFilePaths, + storage_providers: Sequence["StorageProvider"], + ): self.hs = hs self.local_media_directory = local_media_directory self.filepaths = filepaths self.storage_providers = storage_providers - @defer.inlineCallbacks - def store_file(self, source, file_info): + async def store_file(self, source: IO, file_info: FileInfo) -> str: """Write `source` to the on disk media store, and also any other configured storage providers Args: source: A file like object that should be written - file_info (FileInfo): Info about the file to store + file_info: Info about the file to store Returns: - Deferred[str]: the file path written to in the primary media store + the file path written to in the primary media store """ with self.store_into_file(file_info) as (f, fname, finish_cb): # Write to the main repository - yield defer_to_thread( + await defer_to_thread( self.hs.get_reactor(), _write_file_synchronously, source, f ) - yield finish_cb() + await finish_cb() return fname @contextlib.contextmanager - def store_into_file(self, file_info): + def store_into_file(self, file_info: FileInfo): """Context manager used to get a file like object to write into, as described by file_info. Actually yields a 3-tuple (file, fname, finish_cb), where file is a file like object that can be written to, fname is the absolute path of file - on disk, and finish_cb is a function that returns a Deferred. + on disk, and finish_cb is a function that returns an awaitable. fname can be used to read the contents from after upload, e.g. to generate thumbnails. @@ -85,13 +95,13 @@ class MediaStorage(object): error. Args: - file_info (FileInfo): Info about the file to store + file_info: Info about the file to store Example: with media_storage.store_into_file(info) as (f, fname, finish_cb): # .. write into f ... - yield finish_cb() + await finish_cb() """ path = self._file_info_to_path(file_info) @@ -103,10 +113,13 @@ class MediaStorage(object): finished_called = [False] - @defer.inlineCallbacks - def finish(): + async def finish(): for provider in self.storage_providers: - yield provider.store_file(path, file_info) + # store_file is supposed to return an Awaitable, but guard + # against improper implementations. + result = provider.store_file(path, file_info) + if inspect.isawaitable(result): + await result finished_called[0] = True @@ -123,17 +136,15 @@ class MediaStorage(object): if not finished_called: raise Exception("Finished callback not called") - @defer.inlineCallbacks - def fetch_media(self, file_info): + async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]: """Attempts to fetch media described by file_info from the local cache and configured storage providers. Args: - file_info (FileInfo) + file_info Returns: - Deferred[Responder|None]: Returns a Responder if the file was found, - otherwise None. + Returns a Responder if the file was found, otherwise None. """ path = self._file_info_to_path(file_info) @@ -142,23 +153,26 @@ class MediaStorage(object): return FileResponder(open(local_path, "rb")) for provider in self.storage_providers: - res = yield provider.fetch(path, file_info) + res = provider.fetch(path, file_info) # type: Any + # Fetch is supposed to return an Awaitable[Responder], but guard + # against improper implementations. + if inspect.isawaitable(res): + res = await res if res: logger.debug("Streaming %s from %s", path, provider) return res return None - @defer.inlineCallbacks - def ensure_media_is_in_local_cache(self, file_info): + async def ensure_media_is_in_local_cache(self, file_info: FileInfo) -> str: """Ensures that the given file is in the local cache. Attempts to download it from storage providers if it isn't. Args: - file_info (FileInfo) + file_info Returns: - Deferred[str]: Full path to local file + Full path to local file """ path = self._file_info_to_path(file_info) local_path = os.path.join(self.local_media_directory, path) @@ -170,29 +184,27 @@ class MediaStorage(object): os.makedirs(dirname) for provider in self.storage_providers: - res = yield provider.fetch(path, file_info) + res = provider.fetch(path, file_info) # type: Any + # Fetch is supposed to return an Awaitable[Responder], but guard + # against improper implementations. + if inspect.isawaitable(res): + res = await res if res: with res: consumer = BackgroundFileConsumer( open(local_path, "wb"), self.hs.get_reactor() ) - yield res.write_to_consumer(consumer) - yield consumer.wait() + await res.write_to_consumer(consumer) + await consumer.wait() return local_path raise Exception("file could not be found") - def _file_info_to_path(self, file_info): + def _file_info_to_path(self, file_info: FileInfo) -> str: """Converts file_info into a relative path. The path is suitable for storing files under a directory, e.g. used to store files on local FS under the base media repository directory. - - Args: - file_info (FileInfo) - - Returns: - str """ if file_info.url_cache: if file_info.thumbnail: