summary refs log tree commit diff
path: root/synapse/rest/media
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2020-08-04 09:44:25 -0400
committerGitHub <noreply@github.com>2020-08-04 09:44:25 -0400
commit8ff2deda720b46995868127f85f82fc6ba852d82 (patch)
treee11b3e6e34f86ca7ce5a342dd58f19542c8f21b1 /synapse/rest/media
parentConvert the SimpleHttpClient to async. (#8016) (diff)
downloadsynapse-8ff2deda720b46995868127f85f82fc6ba852d82.tar.xz
Fix async/await calls for broken media providers. (#8027)
Diffstat (limited to 'synapse/rest/media')
-rw-r--r--synapse/rest/media/v1/media_storage.py23
-rw-r--r--synapse/rest/media/v1/storage_provider.py19
2 files changed, 20 insertions, 22 deletions
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index 858b6d3005..ab1fa705bf 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -13,7 +13,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import contextlib
-import inspect
 import logging
 import os
 import shutil
@@ -30,7 +29,7 @@ from .filepath import MediaFilePaths
 if TYPE_CHECKING:
     from synapse.server import HomeServer
 
-    from .storage_provider import StorageProvider
+    from .storage_provider import StorageProviderWrapper
 
 logger = logging.getLogger(__name__)
 
@@ -50,7 +49,7 @@ class MediaStorage(object):
         hs: "HomeServer",
         local_media_directory: str,
         filepaths: MediaFilePaths,
-        storage_providers: Sequence["StorageProvider"],
+        storage_providers: Sequence["StorageProviderWrapper"],
     ):
         self.hs = hs
         self.local_media_directory = local_media_directory
@@ -115,11 +114,7 @@ class MediaStorage(object):
 
         async def finish():
             for provider in self.storage_providers:
-                # store_file is supposed to return an Awaitable, but guard
-                # against improper implementations.
-                result = provider.store_file(path, file_info)
-                if inspect.isawaitable(result):
-                    await result
+                await provider.store_file(path, file_info)
 
             finished_called[0] = True
 
@@ -153,11 +148,7 @@ class MediaStorage(object):
             return FileResponder(open(local_path, "rb"))
 
         for provider in self.storage_providers:
-            res = provider.fetch(path, file_info)  # type: Any
-            # Fetch is supposed to return an Awaitable[Responder], but guard
-            # against improper implementations.
-            if inspect.isawaitable(res):
-                res = await res
+            res = await provider.fetch(path, file_info)  # type: Any
             if res:
                 logger.debug("Streaming %s from %s", path, provider)
                 return res
@@ -184,11 +175,7 @@ class MediaStorage(object):
             os.makedirs(dirname)
 
         for provider in self.storage_providers:
-            res = provider.fetch(path, file_info)  # type: Any
-            # Fetch is supposed to return an Awaitable[Responder], but guard
-            # against improper implementations.
-            if inspect.isawaitable(res):
-                res = await res
+            res = await provider.fetch(path, file_info)  # type: Any
             if res:
                 with res:
                     consumer = BackgroundFileConsumer(
diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py
index a33f56e806..18c9ed48d6 100644
--- a/synapse/rest/media/v1/storage_provider.py
+++ b/synapse/rest/media/v1/storage_provider.py
@@ -13,6 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import inspect
 import logging
 import os
 import shutil
@@ -88,12 +89,18 @@ class StorageProviderWrapper(StorageProvider):
             return None
 
         if self.store_synchronous:
-            return await self.backend.store_file(path, file_info)
+            # store_file is supposed to return an Awaitable, but guard
+            # against improper implementations.
+            result = self.backend.store_file(path, file_info)
+            if inspect.isawaitable(result):
+                return await result
         else:
             # TODO: Handle errors.
-            def store():
+            async def store():
                 try:
-                    return self.backend.store_file(path, file_info)
+                    result = self.backend.store_file(path, file_info)
+                    if inspect.isawaitable(result):
+                        return await result
                 except Exception:
                     logger.exception("Error storing file")
 
@@ -101,7 +108,11 @@ class StorageProviderWrapper(StorageProvider):
             return None
 
     async def fetch(self, path, file_info):
-        return await self.backend.fetch(path, file_info)
+        # store_file is supposed to return an Awaitable, but guard
+        # against improper implementations.
+        result = self.backend.fetch(path, file_info)
+        if inspect.isawaitable(result):
+            return await result
 
 
 class FileStorageProviderBackend(StorageProvider):