summary refs log tree commit diff
path: root/synapse/rest/media/v1/media_storage.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/rest/media/v1/media_storage.py')
-rw-r--r--synapse/rest/media/v1/media_storage.py52
1 files changed, 28 insertions, 24 deletions
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index 66bc1c3360..858b6d3005 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -12,13 +12,12 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import contextlib
 import inspect
 import logging
 import os
 import shutil
-from typing import Optional
+from typing import IO, TYPE_CHECKING, Any, Optional, Sequence
 
 from twisted.protocols.basic import FileSender
 
@@ -26,6 +25,12 @@ from synapse.logging.context import defer_to_thread, make_deferred_yieldable
 from synapse.util.file_consumer import BackgroundFileConsumer
 
 from ._base import FileInfo, Responder
+from .filepath import MediaFilePaths
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+    from .storage_provider import StorageProvider
 
 logger = logging.getLogger(__name__)
 
@@ -34,20 +39,25 @@ class MediaStorage(object):
     """Responsible for storing/fetching files from local sources.
 
     Args:
-        hs (synapse.server.Homeserver)
-        local_media_directory (str): Base path where we store media on disk
-        filepaths (MediaFilePaths)
-        storage_providers ([StorageProvider]): List of StorageProvider that are
-            used to fetch and store files.
+        hs
+        local_media_directory: Base path where we store media on disk
+        filepaths
+        storage_providers: List of StorageProvider that are used to fetch and store files.
     """
 
-    def __init__(self, hs, local_media_directory, filepaths, storage_providers):
+    def __init__(
+        self,
+        hs: "HomeServer",
+        local_media_directory: str,
+        filepaths: MediaFilePaths,
+        storage_providers: Sequence["StorageProvider"],
+    ):
         self.hs = hs
         self.local_media_directory = local_media_directory
         self.filepaths = filepaths
         self.storage_providers = storage_providers
 
-    async def store_file(self, source, file_info: FileInfo) -> str:
+    async def store_file(self, source: IO, file_info: FileInfo) -> str:
         """Write `source` to the on disk media store, and also any other
         configured storage providers
 
@@ -69,7 +79,7 @@ class MediaStorage(object):
         return fname
 
     @contextlib.contextmanager
-    def store_into_file(self, file_info):
+    def store_into_file(self, file_info: FileInfo):
         """Context manager used to get a file like object to write into, as
         described by file_info.
 
@@ -85,7 +95,7 @@ class MediaStorage(object):
         error.
 
         Args:
-            file_info (FileInfo): Info about the file to store
+            file_info: Info about the file to store
 
         Example:
 
@@ -143,9 +153,9 @@ class MediaStorage(object):
             return FileResponder(open(local_path, "rb"))
 
         for provider in self.storage_providers:
-            res = provider.fetch(path, file_info)
-            # Fetch is supposed to return an Awaitable, but guard against
-            # improper implementations.
+            res = provider.fetch(path, file_info)  # type: Any
+            # Fetch is supposed to return an Awaitable[Responder], but guard
+            # against improper implementations.
             if inspect.isawaitable(res):
                 res = await res
             if res:
@@ -174,9 +184,9 @@ class MediaStorage(object):
             os.makedirs(dirname)
 
         for provider in self.storage_providers:
-            res = provider.fetch(path, file_info)
-            # Fetch is supposed to return an Awaitable, but guard against
-            # improper implementations.
+            res = provider.fetch(path, file_info)  # type: Any
+            # Fetch is supposed to return an Awaitable[Responder], but guard
+            # against improper implementations.
             if inspect.isawaitable(res):
                 res = await res
             if res:
@@ -190,17 +200,11 @@ class MediaStorage(object):
 
         raise Exception("file could not be found")
 
-    def _file_info_to_path(self, file_info):
+    def _file_info_to_path(self, file_info: FileInfo) -> str:
         """Converts file_info into a relative path.
 
         The path is suitable for storing files under a directory, e.g. used to
         store files on local FS under the base media repository directory.
-
-        Args:
-            file_info (FileInfo)
-
-        Returns:
-            str
         """
         if file_info.url_cache:
             if file_info.thumbnail: