summary refs log tree commit diff
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2020-07-24 09:39:02 -0400
committerGitHub <noreply@github.com>2020-07-24 09:39:02 -0400
commit5ea29d7f850b6d2acbbfaf2e81bc5f0625411320 (patch)
tree02322e5d5fefaab399b4277d77009f119bae8d90
parentReturn an empty body for OPTIONS requests. (#7886) (diff)
downloadsynapse-5ea29d7f850b6d2acbbfaf2e81bc5f0625411320.tar.xz
Convert more of the media code to async/await (#7873)
-rw-r--r--changelog.d/7873.misc1
-rw-r--r--synapse/rest/media/v1/_base.py15
-rw-r--r--synapse/rest/media/v1/media_storage.py60
-rw-r--r--tests/rest/media/v1/test_media_storage.py5
4 files changed, 47 insertions, 34 deletions
diff --git a/changelog.d/7873.misc b/changelog.d/7873.misc
new file mode 100644
index 0000000000..58260764e7
--- /dev/null
+++ b/changelog.d/7873.misc
@@ -0,0 +1 @@
+Convert more media code to async/await.
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 595849f9d5..9a847130c0 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -18,7 +18,6 @@ import logging
 import os
 import urllib
 
-from twisted.internet import defer
 from twisted.protocols.basic import FileSender
 
 from synapse.api.errors import Codes, SynapseError, cs_error
@@ -77,8 +76,9 @@ def respond_404(request):
     )
 
 
-@defer.inlineCallbacks
-def respond_with_file(request, media_type, file_path, file_size=None, upload_name=None):
+async def respond_with_file(
+    request, media_type, file_path, file_size=None, upload_name=None
+):
     logger.debug("Responding with %r", file_path)
 
     if os.path.isfile(file_path):
@@ -89,7 +89,7 @@ def respond_with_file(request, media_type, file_path, file_size=None, upload_nam
         add_file_headers(request, media_type, file_size, upload_name)
 
         with open(file_path, "rb") as f:
-            yield make_deferred_yieldable(FileSender().beginFileTransfer(f, request))
+            await make_deferred_yieldable(FileSender().beginFileTransfer(f, request))
 
         finish_request(request)
     else:
@@ -198,8 +198,9 @@ def _can_encode_filename_as_token(x):
     return True
 
 
-@defer.inlineCallbacks
-def respond_with_responder(request, responder, media_type, file_size, upload_name=None):
+async def respond_with_responder(
+    request, responder, media_type, file_size, upload_name=None
+):
     """Responds to the request with given responder. If responder is None then
     returns 404.
 
@@ -218,7 +219,7 @@ def respond_with_responder(request, responder, media_type, file_size, upload_nam
     add_file_headers(request, media_type, file_size, upload_name)
     try:
         with responder:
-            yield responder.write_to_consumer(request)
+            await responder.write_to_consumer(request)
     except Exception as e:
         # The majority of the time this will be due to the client having gone
         # away. Unfortunately, Twisted simply throws a generic exception at us
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index 79cb0dddbe..66bc1c3360 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -14,17 +14,18 @@
 # limitations under the License.
 
 import contextlib
+import inspect
 import logging
 import os
 import shutil
+from typing import Optional
 
-from twisted.internet import defer
 from twisted.protocols.basic import FileSender
 
 from synapse.logging.context import defer_to_thread, make_deferred_yieldable
 from synapse.util.file_consumer import BackgroundFileConsumer
 
-from ._base import Responder
+from ._base import FileInfo, Responder
 
 logger = logging.getLogger(__name__)
 
@@ -46,25 +47,24 @@ class MediaStorage(object):
         self.filepaths = filepaths
         self.storage_providers = storage_providers
 
-    @defer.inlineCallbacks
-    def store_file(self, source, file_info):
+    async def store_file(self, source, file_info: FileInfo) -> str:
         """Write `source` to the on disk media store, and also any other
         configured storage providers
 
         Args:
             source: A file like object that should be written
-            file_info (FileInfo): Info about the file to store
+            file_info: Info about the file to store
 
         Returns:
-            Deferred[str]: the file path written to in the primary media store
+            the file path written to in the primary media store
         """
 
         with self.store_into_file(file_info) as (f, fname, finish_cb):
             # Write to the main repository
-            yield defer_to_thread(
+            await defer_to_thread(
                 self.hs.get_reactor(), _write_file_synchronously, source, f
             )
-            yield finish_cb()
+            await finish_cb()
 
         return fname
 
@@ -75,7 +75,7 @@ class MediaStorage(object):
 
         Actually yields a 3-tuple (file, fname, finish_cb), where file is a file
         like object that can be written to, fname is the absolute path of file
-        on disk, and finish_cb is a function that returns a Deferred.
+        on disk, and finish_cb is a function that returns an awaitable.
 
         fname can be used to read the contents from after upload, e.g. to
         generate thumbnails.
@@ -91,7 +91,7 @@ class MediaStorage(object):
 
             with media_storage.store_into_file(info) as (f, fname, finish_cb):
                 # .. write into f ...
-                yield finish_cb()
+                await finish_cb()
         """
 
         path = self._file_info_to_path(file_info)
@@ -103,10 +103,13 @@ class MediaStorage(object):
 
         finished_called = [False]
 
-        @defer.inlineCallbacks
-        def finish():
+        async def finish():
             for provider in self.storage_providers:
-                yield provider.store_file(path, file_info)
+                # store_file is supposed to return an Awaitable, but guard
+                # against improper implementations.
+                result = provider.store_file(path, file_info)
+                if inspect.isawaitable(result):
+                    await result
 
             finished_called[0] = True
 
@@ -123,17 +126,15 @@ class MediaStorage(object):
         if not finished_called:
             raise Exception("Finished callback not called")
 
-    @defer.inlineCallbacks
-    def fetch_media(self, file_info):
+    async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]:
         """Attempts to fetch media described by file_info from the local cache
         and configured storage providers.
 
         Args:
-            file_info (FileInfo)
+            file_info
 
         Returns:
-            Deferred[Responder|None]: Returns a Responder if the file was found,
-                otherwise None.
+            Returns a Responder if the file was found, otherwise None.
         """
 
         path = self._file_info_to_path(file_info)
@@ -142,23 +143,26 @@ class MediaStorage(object):
             return FileResponder(open(local_path, "rb"))
 
         for provider in self.storage_providers:
-            res = yield provider.fetch(path, file_info)
+            res = provider.fetch(path, file_info)
+            # Fetch is supposed to return an Awaitable, but guard against
+            # improper implementations.
+            if inspect.isawaitable(res):
+                res = await res
             if res:
                 logger.debug("Streaming %s from %s", path, provider)
                 return res
 
         return None
 
-    @defer.inlineCallbacks
-    def ensure_media_is_in_local_cache(self, file_info):
+    async def ensure_media_is_in_local_cache(self, file_info: FileInfo) -> str:
         """Ensures that the given file is in the local cache. Attempts to
         download it from storage providers if it isn't.
 
         Args:
-            file_info (FileInfo)
+            file_info
 
         Returns:
-            Deferred[str]: Full path to local file
+            Full path to local file
         """
         path = self._file_info_to_path(file_info)
         local_path = os.path.join(self.local_media_directory, path)
@@ -170,14 +174,18 @@ class MediaStorage(object):
             os.makedirs(dirname)
 
         for provider in self.storage_providers:
-            res = yield provider.fetch(path, file_info)
+            res = provider.fetch(path, file_info)
+            # Fetch is supposed to return an Awaitable, but guard against
+            # improper implementations.
+            if inspect.isawaitable(res):
+                res = await res
             if res:
                 with res:
                     consumer = BackgroundFileConsumer(
                         open(local_path, "wb"), self.hs.get_reactor()
                     )
-                    yield res.write_to_consumer(consumer)
-                    yield consumer.wait()
+                    await res.write_to_consumer(consumer)
+                    await consumer.wait()
                 return local_path
 
         raise Exception("file could not be found")
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index 66fa5978b2..f4f3e56777 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -26,6 +26,7 @@ import attr
 from parameterized import parameterized_class
 from PIL import Image as Image
 
+from twisted.internet import defer
 from twisted.internet.defer import Deferred
 
 from synapse.logging.context import make_deferred_yieldable
@@ -77,7 +78,9 @@ class MediaStorageTests(unittest.HomeserverTestCase):
 
         # This uses a real blocking threadpool so we have to wait for it to be
         # actually done :/
-        x = self.media_storage.ensure_media_is_in_local_cache(file_info)
+        x = defer.ensureDeferred(
+            self.media_storage.ensure_media_is_in_local_cache(file_info)
+        )
 
         # Hotloop until the threadpool does its job...
         self.wait_on_thread(x)