diff --git a/synapse/media/_base.py b/synapse/media/_base.py
index 1b268ce4d4..21f334339b 100644
--- a/synapse/media/_base.py
+++ b/synapse/media/_base.py
@@ -22,12 +22,14 @@
import logging
import os
+import threading
import urllib
from abc import ABC, abstractmethod
from types import TracebackType
from typing import (
TYPE_CHECKING,
Awaitable,
+ BinaryIO,
Dict,
Generator,
List,
@@ -37,15 +39,19 @@ from typing import (
)
import attr
+from zope.interface import implementer
+from twisted.internet import interfaces
+from twisted.internet.defer import Deferred
from twisted.internet.interfaces import IConsumer
-from twisted.protocols.basic import FileSender
+from twisted.python.failure import Failure
from twisted.web.server import Request
from synapse.api.errors import Codes, cs_error
from synapse.http.server import finish_request, respond_with_json
from synapse.http.site import SynapseRequest
-from synapse.logging.context import make_deferred_yieldable
+from synapse.logging.context import defer_to_thread, make_deferred_yieldable
+from synapse.types import ISynapseReactor
from synapse.util import Clock
from synapse.util.stringutils import is_ascii
@@ -138,7 +144,7 @@ async def respond_with_file(
add_file_headers(request, media_type, file_size, upload_name)
with open(file_path, "rb") as f:
- await make_deferred_yieldable(FileSender().beginFileTransfer(f, request))
+ await ThreadedFileSender(request.reactor).beginFileTransfer(f, request)
finish_request(request)
else:
@@ -601,3 +607,132 @@ def _parseparam(s: bytes) -> Generator[bytes, None, None]:
f = s[:end]
yield f.strip()
s = s[end:]
+
+
+@implementer(interfaces.IPushProducer)
+class ThreadedFileSender:
+ """
+ A producer that sends the contents of a file to a consumer, reading from the
+ file on a thread.
+
+ This works by spawning a loop in a threadpool that repeatedly reads from the
+ file and sends it to the consumer. The main thread communicates with the
+ loop via two `threading.Event`, which controls when to start/pause reading
+ and when to terminate.
+ """
+
+ # How much data to read in one go.
+ CHUNK_SIZE = 2**14
+
+ # How long we wait for the consumer to be ready again before aborting the
+ # read.
+ TIMEOUT_SECONDS = 90.0
+
+ def __init__(self, reactor: ISynapseReactor) -> None:
+ self.reactor = reactor
+
+ self.file: Optional[BinaryIO] = None
+ self.deferred: "Deferred[None]" = Deferred()
+ self.consumer: Optional[interfaces.IConsumer] = None
+
+ # Signals if the thread should keep reading/sending data. Set means
+ # continue, clear means pause.
+ self.wakeup_event = threading.Event()
+
+ # Signals if the thread should terminate, e.g. because the consumer has
+ # gone away. Both this and `wakeup_event` should be set to terminate the
+ # loop (otherwise the thread will block on `wakeup_event`).
+ self.stop_event = threading.Event()
+
+ def beginFileTransfer(
+ self, file: BinaryIO, consumer: interfaces.IConsumer
+ ) -> "Deferred[None]":
+ """
+ Begin transferring a file
+ """
+ self.file = file
+ self.consumer = consumer
+
+ self.consumer.registerProducer(self, True)
+
+ # We set the wakeup signal as we should start producing immediately.
+ self.wakeup_event.set()
+ defer_to_thread(self.reactor, self._on_thread_read_loop)
+
+ return make_deferred_yieldable(self.deferred)
+
+ def resumeProducing(self) -> None:
+ """interfaces.IPushProducer"""
+ self.wakeup_event.set()
+
+ def pauseProducing(self) -> None:
+ """interfaces.IPushProducer"""
+ self.wakeup_event.clear()
+
+ def stopProducing(self) -> None:
+ """interfaces.IPushProducer"""
+
+ # Terminate the thread loop.
+ self.wakeup_event.set()
+ self.stop_event.set()
+
+ if not self.deferred.called:
+ self.deferred.errback(Exception("Consumer asked us to stop producing"))
+
+ def _on_thread_read_loop(self) -> None:
+ """This is the loop that happens on a thread."""
+
+ try:
+ while not self.stop_event.is_set():
+ # We wait for the producer to signal that the consumer wants
+ # more data (or we should abort)
+ if not self.wakeup_event.is_set():
+ ret = self.wakeup_event.wait(self.TIMEOUT_SECONDS)
+ if not ret:
+ raise Exception("Timed out waiting to resume")
+
+ # Check if we were woken up so that we abort the download
+ if self.stop_event.is_set():
+ return
+
+ # The file should always have been set before we get here.
+ assert self.file is not None
+
+ chunk = self.file.read(self.CHUNK_SIZE)
+ if not chunk:
+ return
+
+ self.reactor.callFromThread(self._write, chunk)
+
+ except Exception:
+ self.reactor.callFromThread(self._error, Failure())
+ finally:
+ self.reactor.callFromThread(self._finish)
+
+ def _write(self, chunk: bytes) -> None:
+ """Called from the thread to write a chunk of data"""
+ if self.consumer:
+ self.consumer.write(chunk)
+
+ def _error(self, failure: Failure) -> None:
+ """Called from the thread when there was a fatal error"""
+ if self.consumer:
+ self.consumer.unregisterProducer()
+ self.consumer = None
+
+ if not self.deferred.called:
+ self.deferred.errback(failure)
+
+ def _finish(self) -> None:
+ """Called from the thread when it finishes (either on success or
+ failure)."""
+ if self.file:
+ self.file.close()
+ self.file = None
+
+ if self.consumer:
+ self.consumer.unregisterProducer()
+ self.consumer = None
+
+ if not self.deferred.called:
+ self.deferred.callback(None)
|