summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/media/media_repository.py11
-rw-r--r--synapse/media/media_storage.py102
-rw-r--r--synapse/media/url_previewer.py4
3 files changed, 40 insertions, 77 deletions
diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py
index 0e875132f6..9da8495950 100644
--- a/synapse/media/media_repository.py
+++ b/synapse/media/media_repository.py
@@ -650,7 +650,7 @@ class MediaRepository:
 
         file_info = FileInfo(server_name=server_name, file_id=file_id)
 
-        with self.media_storage.store_into_file(file_info) as (f, fname, finish):
+        async with self.media_storage.store_into_file(file_info) as (f, fname):
             try:
                 length, headers = await self.client.download_media(
                     server_name,
@@ -693,8 +693,6 @@ class MediaRepository:
                 )
                 raise SynapseError(502, "Failed to fetch remote media")
 
-            await finish()
-
             if b"Content-Type" in headers:
                 media_type = headers[b"Content-Type"][0].decode("ascii")
             else:
@@ -1045,14 +1043,9 @@ class MediaRepository:
                     ),
                 )
 
-                with self.media_storage.store_into_file(file_info) as (
-                    f,
-                    fname,
-                    finish,
-                ):
+                async with self.media_storage.store_into_file(file_info) as (f, fname):
                     try:
                         await self.media_storage.write_to_file(t_byte_source, f)
-                        await finish()
                     finally:
                         t_byte_source.close()
 
diff --git a/synapse/media/media_storage.py b/synapse/media/media_storage.py
index b45b319f5c..9979c48eac 100644
--- a/synapse/media/media_storage.py
+++ b/synapse/media/media_storage.py
@@ -27,10 +27,9 @@ from typing import (
     IO,
     TYPE_CHECKING,
     Any,
-    Awaitable,
+    AsyncIterator,
     BinaryIO,
     Callable,
-    Generator,
     Optional,
     Sequence,
     Tuple,
@@ -97,11 +96,9 @@ class MediaStorage:
             the file path written to in the primary media store
         """
 
-        with self.store_into_file(file_info) as (f, fname, finish_cb):
+        async with self.store_into_file(file_info) as (f, fname):
             # Write to the main media repository
             await self.write_to_file(source, f)
-            # Write to the other storage providers
-            await finish_cb()
 
         return fname
 
@@ -111,32 +108,27 @@ class MediaStorage:
         await defer_to_thread(self.reactor, _write_file_synchronously, source, output)
 
     @trace_with_opname("MediaStorage.store_into_file")
-    @contextlib.contextmanager
-    def store_into_file(
+    @contextlib.asynccontextmanager
+    async def store_into_file(
         self, file_info: FileInfo
-    ) -> Generator[Tuple[BinaryIO, str, Callable[[], Awaitable[None]]], None, None]:
-        """Context manager used to get a file like object to write into, as
+    ) -> AsyncIterator[Tuple[BinaryIO, str]]:
+        """Async Context manager used to get a file like object to write into, as
         described by file_info.
 
-        Actually yields a 3-tuple (file, fname, finish_cb), where file is a file
-        like object that can be written to, fname is the absolute path of file
-        on disk, and finish_cb is a function that returns an awaitable.
+        Actually yields a 2-tuple (file, fname,), where file is a file
+        like object that can be written to and fname is the absolute path of file
+        on disk.
 
         fname can be used to read the contents from after upload, e.g. to
         generate thumbnails.
 
-        finish_cb must be called and waited on after the file has been successfully been
-        written to. Should not be called if there was an error. Checks for spam and
-        stores the file into the configured storage providers.
-
         Args:
             file_info: Info about the file to store
 
         Example:
 
-            with media_storage.store_into_file(info) as (f, fname, finish_cb):
+            async with media_storage.store_into_file(info) as (f, fname,):
                 # .. write into f ...
-                await finish_cb()
         """
 
         path = self._file_info_to_path(file_info)
@@ -145,62 +137,42 @@ class MediaStorage:
         dirname = os.path.dirname(fname)
         os.makedirs(dirname, exist_ok=True)
 
-        finished_called = [False]
-
         main_media_repo_write_trace_scope = start_active_span(
             "writing to main media repo"
         )
         main_media_repo_write_trace_scope.__enter__()
 
-        try:
-            with open(fname, "wb") as f:
-
-                async def finish() -> None:
-                    # When someone calls finish, we assume they are done writing to the main media repo
-                    main_media_repo_write_trace_scope.__exit__(None, None, None)
-
-                    with start_active_span("writing to other storage providers"):
-                        # Ensure that all writes have been flushed and close the
-                        # file.
-                        f.flush()
-                        f.close()
-
-                        spam_check = await self._spam_checker_module_callbacks.check_media_file_for_spam(
-                            ReadableFileWrapper(self.clock, fname), file_info
-                        )
-                        if spam_check != self._spam_checker_module_callbacks.NOT_SPAM:
-                            logger.info("Blocking media due to spam checker")
-                            # Note that we'll delete the stored media, due to the
-                            # try/except below. The media also won't be stored in
-                            # the DB.
-                            # We currently ignore any additional field returned by
-                            # the spam-check API.
-                            raise SpamMediaException(errcode=spam_check[0])
-
-                        for provider in self.storage_providers:
-                            with start_active_span(str(provider)):
-                                await provider.store_file(path, file_info)
-
-                        finished_called[0] = True
-
-                yield f, fname, finish
-        except Exception as e:
+        with main_media_repo_write_trace_scope:
             try:
-                main_media_repo_write_trace_scope.__exit__(
-                    type(e), None, e.__traceback__
-                )
-                os.remove(fname)
-            except Exception:
-                pass
+                with open(fname, "wb") as f:
+                    yield f, fname
 
-            raise e from None
+            except Exception as e:
+                try:
+                    os.remove(fname)
+                except Exception:
+                    pass
 
-        if not finished_called:
-            exc = Exception("Finished callback not called")
-            main_media_repo_write_trace_scope.__exit__(
-                type(exc), None, exc.__traceback__
+                raise e from None
+
+        with start_active_span("writing to other storage providers"):
+            spam_check = (
+                await self._spam_checker_module_callbacks.check_media_file_for_spam(
+                    ReadableFileWrapper(self.clock, fname), file_info
+                )
             )
-            raise exc
+            if spam_check != self._spam_checker_module_callbacks.NOT_SPAM:
+                logger.info("Blocking media due to spam checker")
+                # Note that we'll delete the stored media, due to the
+                # try/except below. The media also won't be stored in
+                # the DB.
+                # We currently ignore any additional field returned by
+                # the spam-check API.
+                raise SpamMediaException(errcode=spam_check[0])
+
+            for provider in self.storage_providers:
+                with start_active_span(str(provider)):
+                    await provider.store_file(path, file_info)
 
     async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]:
         """Attempts to fetch media described by file_info from the local cache
diff --git a/synapse/media/url_previewer.py b/synapse/media/url_previewer.py
index 3897823b35..2e65a04789 100644
--- a/synapse/media/url_previewer.py
+++ b/synapse/media/url_previewer.py
@@ -592,7 +592,7 @@ class UrlPreviewer:
 
         file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True)
 
-        with self.media_storage.store_into_file(file_info) as (f, fname, finish):
+        async with self.media_storage.store_into_file(file_info) as (f, fname):
             if url.startswith("data:"):
                 if not allow_data_urls:
                     raise SynapseError(
@@ -603,8 +603,6 @@ class UrlPreviewer:
             else:
                 download_result = await self._download_url(url, f)
 
-            await finish()
-
         try:
             time_now_ms = self.clock.time_msec()