summary refs log tree commit diff
path: root/synapse/rest
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2020-10-30 10:55:24 +0000
committerGitHub <noreply@github.com>2020-10-30 10:55:24 +0000
commit46f4be94b410776ef3f922af2f437eb17631d2fa (patch)
treec686c30dac2bac97c2b0c05dc2c7e224748de375 /synapse/rest
parentFix optional parameter in stripped state storage method (#8688) (diff)
downloadsynapse-46f4be94b410776ef3f922af2f437eb17631d2fa.tar.xz
Fix race for concurrent downloads of remote media. (#8682)
Fixes #6755
Diffstat (limited to 'synapse/rest')
-rw-r--r--synapse/rest/media/v1/media_repository.py165
-rw-r--r--synapse/rest/media/v1/media_storage.py30
2 files changed, 125 insertions, 70 deletions
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 5cce7237a0..9cac74ebd8 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -305,15 +305,12 @@ class MediaRepository:
         # file_id is the ID we use to track the file locally. If we've already
         # seen the file then reuse the existing ID, otherwise genereate a new
         # one.
-        if media_info:
-            file_id = media_info["filesystem_id"]
-        else:
-            file_id = random_string(24)
-
-        file_info = FileInfo(server_name, file_id)
 
         # If we have an entry in the DB, try and look for it
         if media_info:
+            file_id = media_info["filesystem_id"]
+            file_info = FileInfo(server_name, file_id)
+
             if media_info["quarantined_by"]:
                 logger.info("Media is quarantined")
                 raise NotFoundError()
@@ -324,14 +321,34 @@ class MediaRepository:
 
         # Failed to find the file anywhere, lets download it.
 
-        media_info = await self._download_remote_file(server_name, media_id, file_id)
+        try:
+            media_info = await self._download_remote_file(server_name, media_id,)
+        except SynapseError:
+            raise
+        except Exception as e:
+            # An exception may be because we downloaded media in another
+            # process, so let's check if we magically have the media.
+            media_info = await self.store.get_cached_remote_media(server_name, media_id)
+            if not media_info:
+                raise e
+
+        file_id = media_info["filesystem_id"]
+        file_info = FileInfo(server_name, file_id)
+
+        # We generate thumbnails even if another process downloaded the media
+        # as a) it's conceivable that the other download request dies before it
+        # generates thumbnails, but mainly b) we want to be sure the thumbnails
+        # have finished being generated before responding to the client,
+        # otherwise they'll request thumbnails and get a 404 if they're not
+        # ready yet.
+        await self._generate_thumbnails(
+            server_name, media_id, file_id, media_info["media_type"]
+        )
 
         responder = await self.media_storage.fetch_media(file_info)
         return responder, media_info
 
-    async def _download_remote_file(
-        self, server_name: str, media_id: str, file_id: str
-    ) -> dict:
+    async def _download_remote_file(self, server_name: str, media_id: str,) -> dict:
         """Attempt to download the remote file from the given server name,
         using the given file_id as the local id.
 
@@ -346,6 +363,8 @@ class MediaRepository:
             The media info of the file.
         """
 
+        file_id = random_string(24)
+
         file_info = FileInfo(server_name=server_name, file_id=file_id)
 
         with self.media_storage.store_into_file(file_info) as (f, fname, finish):
@@ -401,22 +420,32 @@ class MediaRepository:
 
             await finish()
 
-        media_type = headers[b"Content-Type"][0].decode("ascii")
-        upload_name = get_filename_from_headers(headers)
-        time_now_ms = self.clock.time_msec()
+            media_type = headers[b"Content-Type"][0].decode("ascii")
+            upload_name = get_filename_from_headers(headers)
+            time_now_ms = self.clock.time_msec()
+
+            # Multiple remote media download requests can race (when using
+            # multiple media repos), so this may throw a violation constraint
+            # exception. If it does we'll delete the newly downloaded file from
+            # disk (as we're in the ctx manager).
+            #
+            # However: we've already called `finish()` so we may have also
+            # written to the storage providers. This is preferable to the
+            # alternative where we call `finish()` *after* this, where we could
+            # end up having an entry in the DB but fail to write the files to
+            # the storage providers.
+            await self.store.store_cached_remote_media(
+                origin=server_name,
+                media_id=media_id,
+                media_type=media_type,
+                time_now_ms=self.clock.time_msec(),
+                upload_name=upload_name,
+                media_length=length,
+                filesystem_id=file_id,
+            )
 
         logger.info("Stored remote media in file %r", fname)
 
-        await self.store.store_cached_remote_media(
-            origin=server_name,
-            media_id=media_id,
-            media_type=media_type,
-            time_now_ms=self.clock.time_msec(),
-            upload_name=upload_name,
-            media_length=length,
-            filesystem_id=file_id,
-        )
-
         media_info = {
             "media_type": media_type,
             "media_length": length,
@@ -425,8 +454,6 @@ class MediaRepository:
             "filesystem_id": file_id,
         }
 
-        await self._generate_thumbnails(server_name, media_id, file_id, media_type)
-
         return media_info
 
     def _get_thumbnail_requirements(self, media_type):
@@ -692,42 +719,60 @@ class MediaRepository:
             if not t_byte_source:
                 continue
 
-            try:
-                file_info = FileInfo(
-                    server_name=server_name,
-                    file_id=file_id,
-                    thumbnail=True,
-                    thumbnail_width=t_width,
-                    thumbnail_height=t_height,
-                    thumbnail_method=t_method,
-                    thumbnail_type=t_type,
-                    url_cache=url_cache,
-                )
-
-                output_path = await self.media_storage.store_file(
-                    t_byte_source, file_info
-                )
-            finally:
-                t_byte_source.close()
-
-            t_len = os.path.getsize(output_path)
+            file_info = FileInfo(
+                server_name=server_name,
+                file_id=file_id,
+                thumbnail=True,
+                thumbnail_width=t_width,
+                thumbnail_height=t_height,
+                thumbnail_method=t_method,
+                thumbnail_type=t_type,
+                url_cache=url_cache,
+            )
 
-            # Write to database
-            if server_name:
-                await self.store.store_remote_media_thumbnail(
-                    server_name,
-                    media_id,
-                    file_id,
-                    t_width,
-                    t_height,
-                    t_type,
-                    t_method,
-                    t_len,
-                )
-            else:
-                await self.store.store_local_thumbnail(
-                    media_id, t_width, t_height, t_type, t_method, t_len
-                )
+            with self.media_storage.store_into_file(file_info) as (f, fname, finish):
+                try:
+                    await self.media_storage.write_to_file(t_byte_source, f)
+                    await finish()
+                finally:
+                    t_byte_source.close()
+
+                t_len = os.path.getsize(fname)
+
+                # Write to database
+                if server_name:
+                    # Multiple remote media download requests can race (when
+                    # using multiple media repos), so this may throw a violation
+                    # constraint exception. If it does we'll delete the newly
+                    # generated thumbnail from disk (as we're in the ctx
+                    # manager).
+                    #
+                    # However: we've already called `finish()` so we may have
+                    # also written to the storage providers. This is preferable
+                    # to the alternative where we call `finish()` *after* this,
+                    # where we could end up having an entry in the DB but fail
+                    # to write the files to the storage providers.
+                    try:
+                        await self.store.store_remote_media_thumbnail(
+                            server_name,
+                            media_id,
+                            file_id,
+                            t_width,
+                            t_height,
+                            t_type,
+                            t_method,
+                            t_len,
+                        )
+                    except Exception as e:
+                        thumbnail_exists = await self.store.get_remote_media_thumbnail(
+                            server_name, media_id, t_width, t_height, t_type,
+                        )
+                        if not thumbnail_exists:
+                            raise e
+                else:
+                    await self.store.store_local_thumbnail(
+                        media_id, t_width, t_height, t_type, t_method, t_len
+                    )
 
         return {"width": m_width, "height": m_height}
 
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index a9586fb0b7..268e0c8f50 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -52,6 +52,7 @@ class MediaStorage:
         storage_providers: Sequence["StorageProviderWrapper"],
     ):
         self.hs = hs
+        self.reactor = hs.get_reactor()
         self.local_media_directory = local_media_directory
         self.filepaths = filepaths
         self.storage_providers = storage_providers
@@ -70,13 +71,16 @@ class MediaStorage:
 
         with self.store_into_file(file_info) as (f, fname, finish_cb):
             # Write to the main repository
-            await defer_to_thread(
-                self.hs.get_reactor(), _write_file_synchronously, source, f
-            )
+            await self.write_to_file(source, f)
             await finish_cb()
 
         return fname
 
+    async def write_to_file(self, source: IO, output: IO):
+        """Asynchronously write the `source` to `output`.
+        """
+        await defer_to_thread(self.reactor, _write_file_synchronously, source, output)
+
     @contextlib.contextmanager
     def store_into_file(self, file_info: FileInfo):
         """Context manager used to get a file like object to write into, as
@@ -112,14 +116,20 @@ class MediaStorage:
 
         finished_called = [False]
 
-        async def finish():
-            for provider in self.storage_providers:
-                await provider.store_file(path, file_info)
-
-            finished_called[0] = True
-
         try:
             with open(fname, "wb") as f:
+
+                async def finish():
+                    # Ensure that all writes have been flushed and close the
+                    # file.
+                    f.flush()
+                    f.close()
+
+                    for provider in self.storage_providers:
+                        await provider.store_file(path, file_info)
+
+                    finished_called[0] = True
+
                 yield f, fname, finish
         except Exception:
             try:
@@ -210,7 +220,7 @@ class MediaStorage:
             if res:
                 with res:
                     consumer = BackgroundFileConsumer(
-                        open(local_path, "wb"), self.hs.get_reactor()
+                        open(local_path, "wb"), self.reactor
                     )
                     await res.write_to_consumer(consumer)
                     await consumer.wait()