diff --git a/changelog.d/17239.misc b/changelog.d/17239.misc
new file mode 100644
index 0000000000..9fca36bb29
--- /dev/null
+++ b/changelog.d/17239.misc
@@ -0,0 +1 @@
+Fix errors in logs about closing incorrect logging contexts when media gets rejected by a module.
diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py
index 0e875132f6..9da8495950 100644
--- a/synapse/media/media_repository.py
+++ b/synapse/media/media_repository.py
@@ -650,7 +650,7 @@ class MediaRepository:
file_info = FileInfo(server_name=server_name, file_id=file_id)
- with self.media_storage.store_into_file(file_info) as (f, fname, finish):
+ async with self.media_storage.store_into_file(file_info) as (f, fname):
try:
length, headers = await self.client.download_media(
server_name,
@@ -693,8 +693,6 @@ class MediaRepository:
)
raise SynapseError(502, "Failed to fetch remote media")
- await finish()
-
if b"Content-Type" in headers:
media_type = headers[b"Content-Type"][0].decode("ascii")
else:
@@ -1045,14 +1043,9 @@ class MediaRepository:
),
)
- with self.media_storage.store_into_file(file_info) as (
- f,
- fname,
- finish,
- ):
+ async with self.media_storage.store_into_file(file_info) as (f, fname):
try:
await self.media_storage.write_to_file(t_byte_source, f)
- await finish()
finally:
t_byte_source.close()
diff --git a/synapse/media/media_storage.py b/synapse/media/media_storage.py
index b45b319f5c..9979c48eac 100644
--- a/synapse/media/media_storage.py
+++ b/synapse/media/media_storage.py
@@ -27,10 +27,9 @@ from typing import (
IO,
TYPE_CHECKING,
Any,
- Awaitable,
+ AsyncIterator,
BinaryIO,
Callable,
- Generator,
Optional,
Sequence,
Tuple,
@@ -97,11 +96,9 @@ class MediaStorage:
the file path written to in the primary media store
"""
- with self.store_into_file(file_info) as (f, fname, finish_cb):
+ async with self.store_into_file(file_info) as (f, fname):
# Write to the main media repository
await self.write_to_file(source, f)
- # Write to the other storage providers
- await finish_cb()
return fname
@@ -111,32 +108,27 @@ class MediaStorage:
await defer_to_thread(self.reactor, _write_file_synchronously, source, output)
@trace_with_opname("MediaStorage.store_into_file")
- @contextlib.contextmanager
- def store_into_file(
+ @contextlib.asynccontextmanager
+ async def store_into_file(
self, file_info: FileInfo
- ) -> Generator[Tuple[BinaryIO, str, Callable[[], Awaitable[None]]], None, None]:
- """Context manager used to get a file like object to write into, as
+ ) -> AsyncIterator[Tuple[BinaryIO, str]]:
+ """Async Context manager used to get a file like object to write into, as
described by file_info.
- Actually yields a 3-tuple (file, fname, finish_cb), where file is a file
- like object that can be written to, fname is the absolute path of file
- on disk, and finish_cb is a function that returns an awaitable.
+ Actually yields a 2-tuple (file, fname,), where file is a file
+ like object that can be written to and fname is the absolute path of file
+ on disk.
fname can be used to read the contents from after upload, e.g. to
generate thumbnails.
- finish_cb must be called and waited on after the file has been successfully been
- written to. Should not be called if there was an error. Checks for spam and
- stores the file into the configured storage providers.
-
Args:
file_info: Info about the file to store
Example:
- with media_storage.store_into_file(info) as (f, fname, finish_cb):
+ async with media_storage.store_into_file(info) as (f, fname,):
# .. write into f ...
- await finish_cb()
"""
path = self._file_info_to_path(file_info)
@@ -145,62 +137,42 @@ class MediaStorage:
dirname = os.path.dirname(fname)
os.makedirs(dirname, exist_ok=True)
- finished_called = [False]
-
main_media_repo_write_trace_scope = start_active_span(
"writing to main media repo"
)
main_media_repo_write_trace_scope.__enter__()
- try:
- with open(fname, "wb") as f:
-
- async def finish() -> None:
- # When someone calls finish, we assume they are done writing to the main media repo
- main_media_repo_write_trace_scope.__exit__(None, None, None)
-
- with start_active_span("writing to other storage providers"):
- # Ensure that all writes have been flushed and close the
- # file.
- f.flush()
- f.close()
-
- spam_check = await self._spam_checker_module_callbacks.check_media_file_for_spam(
- ReadableFileWrapper(self.clock, fname), file_info
- )
- if spam_check != self._spam_checker_module_callbacks.NOT_SPAM:
- logger.info("Blocking media due to spam checker")
- # Note that we'll delete the stored media, due to the
- # try/except below. The media also won't be stored in
- # the DB.
- # We currently ignore any additional field returned by
- # the spam-check API.
- raise SpamMediaException(errcode=spam_check[0])
-
- for provider in self.storage_providers:
- with start_active_span(str(provider)):
- await provider.store_file(path, file_info)
-
- finished_called[0] = True
-
- yield f, fname, finish
- except Exception as e:
+ with main_media_repo_write_trace_scope:
try:
- main_media_repo_write_trace_scope.__exit__(
- type(e), None, e.__traceback__
- )
- os.remove(fname)
- except Exception:
- pass
+ with open(fname, "wb") as f:
+ yield f, fname
- raise e from None
+ except Exception as e:
+ try:
+ os.remove(fname)
+ except Exception:
+ pass
- if not finished_called:
- exc = Exception("Finished callback not called")
- main_media_repo_write_trace_scope.__exit__(
- type(exc), None, exc.__traceback__
+ raise e from None
+
+ with start_active_span("writing to other storage providers"):
+ spam_check = (
+ await self._spam_checker_module_callbacks.check_media_file_for_spam(
+ ReadableFileWrapper(self.clock, fname), file_info
+ )
)
- raise exc
+ if spam_check != self._spam_checker_module_callbacks.NOT_SPAM:
+ logger.info("Blocking media due to spam checker")
+ # Note that we'll delete the stored media, due to the
+ # try/except below. The media also won't be stored in
+ # the DB.
+ # We currently ignore any additional field returned by
+ # the spam-check API.
+ raise SpamMediaException(errcode=spam_check[0])
+
+ for provider in self.storage_providers:
+ with start_active_span(str(provider)):
+ await provider.store_file(path, file_info)
async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]:
"""Attempts to fetch media described by file_info from the local cache
diff --git a/synapse/media/url_previewer.py b/synapse/media/url_previewer.py
index 3897823b35..2e65a04789 100644
--- a/synapse/media/url_previewer.py
+++ b/synapse/media/url_previewer.py
@@ -592,7 +592,7 @@ class UrlPreviewer:
file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True)
- with self.media_storage.store_into_file(file_info) as (f, fname, finish):
+ async with self.media_storage.store_into_file(file_info) as (f, fname):
if url.startswith("data:"):
if not allow_data_urls:
raise SynapseError(
@@ -603,8 +603,6 @@ class UrlPreviewer:
else:
download_result = await self._download_url(url, f)
- await finish()
-
try:
time_now_ms = self.clock.time_msec()
diff --git a/tests/rest/client/test_media.py b/tests/rest/client/test_media.py
index 600cbf8963..be4a289ec1 100644
--- a/tests/rest/client/test_media.py
+++ b/tests/rest/client/test_media.py
@@ -93,13 +93,13 @@ class UnstableMediaDomainBlockingTests(unittest.HomeserverTestCase):
# from a regular 404.
file_id = "abcdefg12345"
file_info = FileInfo(server_name=self.remote_server_name, file_id=file_id)
- with hs.get_media_repository().media_storage.store_into_file(file_info) as (
- f,
- fname,
- finish,
- ):
- f.write(SMALL_PNG)
- self.get_success(finish())
+
+ media_storage = hs.get_media_repository().media_storage
+
+ ctx = media_storage.store_into_file(file_info)
+ (f, fname) = self.get_success(ctx.__aenter__())
+ f.write(SMALL_PNG)
+ self.get_success(ctx.__aexit__(None, None, None))
self.get_success(
self.store.store_cached_remote_media(
diff --git a/tests/rest/media/test_domain_blocking.py b/tests/rest/media/test_domain_blocking.py
index 88988f3a22..72205c6bb3 100644
--- a/tests/rest/media/test_domain_blocking.py
+++ b/tests/rest/media/test_domain_blocking.py
@@ -44,13 +44,13 @@ class MediaDomainBlockingTests(unittest.HomeserverTestCase):
# from a regular 404.
file_id = "abcdefg12345"
file_info = FileInfo(server_name=self.remote_server_name, file_id=file_id)
- with hs.get_media_repository().media_storage.store_into_file(file_info) as (
- f,
- fname,
- finish,
- ):
- f.write(SMALL_PNG)
- self.get_success(finish())
+
+ media_storage = hs.get_media_repository().media_storage
+
+ ctx = media_storage.store_into_file(file_info)
+ (f, fname) = self.get_success(ctx.__aenter__())
+ f.write(SMALL_PNG)
+ self.get_success(ctx.__aexit__(None, None, None))
self.get_success(
self.store.store_cached_remote_media(
|