summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorErik Johnston <erikj@element.io>2024-05-29 11:14:42 +0100
committerGitHub <noreply@github.com>2024-05-29 11:14:42 +0100
commitbb5a692946e69c7f3686f1cb3fc0833b736f066a (patch)
tree6fccac75d6c8b4205d5eb1129a2e17003d60aeaf /synapse
parentMerge branch 'master' into develop (diff)
downloadsynapse-bb5a692946e69c7f3686f1cb3fc0833b736f066a.tar.xz
Fix slipped logging context when media rejected (#17239)
When a module rejects a piece of media we end up trying to close the
same logging context twice.

Instead of fixing the existing code we refactor to use an async context
manager, which is easier to write correctly.
Diffstat (limited to 'synapse')
-rw-r--r--synapse/media/media_repository.py11
-rw-r--r--synapse/media/media_storage.py102
-rw-r--r--synapse/media/url_previewer.py4
3 files changed, 40 insertions, 77 deletions
diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py
index 0e875132f6..9da8495950 100644
--- a/synapse/media/media_repository.py
+++ b/synapse/media/media_repository.py
@@ -650,7 +650,7 @@ class MediaRepository:
 
         file_info = FileInfo(server_name=server_name, file_id=file_id)
 
-        with self.media_storage.store_into_file(file_info) as (f, fname, finish):
+        async with self.media_storage.store_into_file(file_info) as (f, fname):
             try:
                 length, headers = await self.client.download_media(
                     server_name,
@@ -693,8 +693,6 @@ class MediaRepository:
                 )
                 raise SynapseError(502, "Failed to fetch remote media")
 
-            await finish()
-
             if b"Content-Type" in headers:
                 media_type = headers[b"Content-Type"][0].decode("ascii")
             else:
@@ -1045,14 +1043,9 @@ class MediaRepository:
                     ),
                 )
 
-                with self.media_storage.store_into_file(file_info) as (
-                    f,
-                    fname,
-                    finish,
-                ):
+                async with self.media_storage.store_into_file(file_info) as (f, fname):
                     try:
                         await self.media_storage.write_to_file(t_byte_source, f)
-                        await finish()
                     finally:
                         t_byte_source.close()
 
diff --git a/synapse/media/media_storage.py b/synapse/media/media_storage.py
index b45b319f5c..9979c48eac 100644
--- a/synapse/media/media_storage.py
+++ b/synapse/media/media_storage.py
@@ -27,10 +27,9 @@ from typing import (
     IO,
     TYPE_CHECKING,
     Any,
-    Awaitable,
+    AsyncIterator,
     BinaryIO,
     Callable,
-    Generator,
     Optional,
     Sequence,
     Tuple,
@@ -97,11 +96,9 @@ class MediaStorage:
             the file path written to in the primary media store
         """
 
-        with self.store_into_file(file_info) as (f, fname, finish_cb):
+        async with self.store_into_file(file_info) as (f, fname):
             # Write to the main media repository
             await self.write_to_file(source, f)
-            # Write to the other storage providers
-            await finish_cb()
 
         return fname
 
@@ -111,32 +108,27 @@ class MediaStorage:
         await defer_to_thread(self.reactor, _write_file_synchronously, source, output)
 
     @trace_with_opname("MediaStorage.store_into_file")
-    @contextlib.contextmanager
-    def store_into_file(
+    @contextlib.asynccontextmanager
+    async def store_into_file(
         self, file_info: FileInfo
-    ) -> Generator[Tuple[BinaryIO, str, Callable[[], Awaitable[None]]], None, None]:
-        """Context manager used to get a file like object to write into, as
+    ) -> AsyncIterator[Tuple[BinaryIO, str]]:
+        """Async Context manager used to get a file like object to write into, as
         described by file_info.
 
-        Actually yields a 3-tuple (file, fname, finish_cb), where file is a file
-        like object that can be written to, fname is the absolute path of file
-        on disk, and finish_cb is a function that returns an awaitable.
+        Actually yields a 2-tuple (file, fname,), where file is a file
+        like object that can be written to and fname is the absolute path of file
+        on disk.
 
         fname can be used to read the contents from after upload, e.g. to
         generate thumbnails.
 
-        finish_cb must be called and waited on after the file has been successfully been
-        written to. Should not be called if there was an error. Checks for spam and
-        stores the file into the configured storage providers.
-
         Args:
             file_info: Info about the file to store
 
         Example:
 
-            with media_storage.store_into_file(info) as (f, fname, finish_cb):
+            async with media_storage.store_into_file(info) as (f, fname,):
                 # .. write into f ...
-                await finish_cb()
         """
 
         path = self._file_info_to_path(file_info)
@@ -145,62 +137,42 @@ class MediaStorage:
         dirname = os.path.dirname(fname)
         os.makedirs(dirname, exist_ok=True)
 
-        finished_called = [False]
-
         main_media_repo_write_trace_scope = start_active_span(
             "writing to main media repo"
         )
         main_media_repo_write_trace_scope.__enter__()
 
-        try:
-            with open(fname, "wb") as f:
-
-                async def finish() -> None:
-                    # When someone calls finish, we assume they are done writing to the main media repo
-                    main_media_repo_write_trace_scope.__exit__(None, None, None)
-
-                    with start_active_span("writing to other storage providers"):
-                        # Ensure that all writes have been flushed and close the
-                        # file.
-                        f.flush()
-                        f.close()
-
-                        spam_check = await self._spam_checker_module_callbacks.check_media_file_for_spam(
-                            ReadableFileWrapper(self.clock, fname), file_info
-                        )
-                        if spam_check != self._spam_checker_module_callbacks.NOT_SPAM:
-                            logger.info("Blocking media due to spam checker")
-                            # Note that we'll delete the stored media, due to the
-                            # try/except below. The media also won't be stored in
-                            # the DB.
-                            # We currently ignore any additional field returned by
-                            # the spam-check API.
-                            raise SpamMediaException(errcode=spam_check[0])
-
-                        for provider in self.storage_providers:
-                            with start_active_span(str(provider)):
-                                await provider.store_file(path, file_info)
-
-                        finished_called[0] = True
-
-                yield f, fname, finish
-        except Exception as e:
+        with main_media_repo_write_trace_scope:
             try:
-                main_media_repo_write_trace_scope.__exit__(
-                    type(e), None, e.__traceback__
-                )
-                os.remove(fname)
-            except Exception:
-                pass
+                with open(fname, "wb") as f:
+                    yield f, fname
 
-            raise e from None
+            except Exception as e:
+                try:
+                    os.remove(fname)
+                except Exception:
+                    pass
 
-        if not finished_called:
-            exc = Exception("Finished callback not called")
-            main_media_repo_write_trace_scope.__exit__(
-                type(exc), None, exc.__traceback__
+                raise e from None
+
+        with start_active_span("writing to other storage providers"):
+            spam_check = (
+                await self._spam_checker_module_callbacks.check_media_file_for_spam(
+                    ReadableFileWrapper(self.clock, fname), file_info
+                )
             )
-            raise exc
+            if spam_check != self._spam_checker_module_callbacks.NOT_SPAM:
+                logger.info("Blocking media due to spam checker")
+                # Note that we'll delete the stored media, due to the
+                # try/except below. The media also won't be stored in
+                # the DB.
+                # We currently ignore any additional field returned by
+                # the spam-check API.
+                raise SpamMediaException(errcode=spam_check[0])
+
+            for provider in self.storage_providers:
+                with start_active_span(str(provider)):
+                    await provider.store_file(path, file_info)
 
     async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]:
         """Attempts to fetch media described by file_info from the local cache
diff --git a/synapse/media/url_previewer.py b/synapse/media/url_previewer.py
index 3897823b35..2e65a04789 100644
--- a/synapse/media/url_previewer.py
+++ b/synapse/media/url_previewer.py
@@ -592,7 +592,7 @@ class UrlPreviewer:
 
         file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True)
 
-        with self.media_storage.store_into_file(file_info) as (f, fname, finish):
+        async with self.media_storage.store_into_file(file_info) as (f, fname):
             if url.startswith("data:"):
                 if not allow_data_urls:
                     raise SynapseError(
@@ -603,8 +603,6 @@ class UrlPreviewer:
             else:
                 download_result = await self._download_url(url, f)
 
-            await finish()
-
         try:
             time_now_ms = self.clock.time_msec()