diff --git a/synapse/media/_base.py b/synapse/media/_base.py
index 1b268ce4d4..2e48d2fdc7 100644
--- a/synapse/media/_base.py
+++ b/synapse/media/_base.py
@@ -28,6 +28,7 @@ from types import TracebackType
from typing import (
TYPE_CHECKING,
Awaitable,
+ BinaryIO,
Dict,
Generator,
List,
@@ -37,21 +38,28 @@ from typing import (
)
import attr
+from zope.interface import implementer
+from twisted.internet import interfaces
+from twisted.internet.defer import Deferred
from twisted.internet.interfaces import IConsumer
-from twisted.protocols.basic import FileSender
+from twisted.python.failure import Failure
from twisted.web.server import Request
from synapse.api.errors import Codes, cs_error
from synapse.http.server import finish_request, respond_with_json
from synapse.http.site import SynapseRequest
-from synapse.logging.context import make_deferred_yieldable
+from synapse.logging.context import (
+ defer_to_threadpool,
+ make_deferred_yieldable,
+ run_in_background,
+)
from synapse.util import Clock
+from synapse.util.async_helpers import DeferredEvent
from synapse.util.stringutils import is_ascii
if TYPE_CHECKING:
- from synapse.storage.databases.main.media_repository import LocalMedia
-
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -110,6 +118,9 @@ DEFAULT_MAX_TIMEOUT_MS = 20_000
# Maximum allowed timeout_ms for download and thumbnail requests
MAXIMUM_ALLOWED_MAX_TIMEOUT_MS = 60_000
+# The ETag header value to use for immutable media. This can be anything.
+_IMMUTABLE_ETAG = "1"
+
def respond_404(request: SynapseRequest) -> None:
assert request.path is not None
@@ -122,6 +133,7 @@ def respond_404(request: SynapseRequest) -> None:
async def respond_with_file(
+ hs: "HomeServer",
request: SynapseRequest,
media_type: str,
file_path: str,
@@ -138,7 +150,7 @@ async def respond_with_file(
add_file_headers(request, media_type, file_size, upload_name)
with open(file_path, "rb") as f:
- await make_deferred_yieldable(FileSender().beginFileTransfer(f, request))
+ await ThreadedFileSender(hs).beginFileTransfer(f, request)
finish_request(request)
else:
@@ -215,12 +227,7 @@ def add_file_headers(
request.setHeader(b"Content-Disposition", disposition.encode("ascii"))
- # cache for at least a day.
- # XXX: we might want to turn this off for data we don't want to
- # recommend caching as it's sensitive or private - or at least
- # select private. don't bother setting Expires as all our
- # clients are smart enough to be happy with Cache-Control
- request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400")
+ _add_cache_headers(request)
if file_size is not None:
request.setHeader(b"Content-Length", b"%d" % (file_size,))
@@ -231,6 +238,26 @@ def add_file_headers(
request.setHeader(b"X-Robots-Tag", "noindex, nofollow, noarchive, noimageindex")
+def _add_cache_headers(request: Request) -> None:
+ """Adds the appropriate cache headers to the response"""
+
+ # Cache on the client for at least a day.
+ #
+ # We set this to "public,s-maxage=0,proxy-revalidate" to allow CDNs to cache
+ # the media, so long as they "revalidate" the media on every request. By
+ # revalidate, we mean send the request to Synapse with a `If-None-Match`
+ # header, to which Synapse can either respond with a 304 if the user is
+ # authenticated/authorized, or a 401/403 if they're not.
+ request.setHeader(
+ b"Cache-Control", b"public,max-age=86400,s-maxage=0,proxy-revalidate"
+ )
+
+ # Set an ETag header to allow requesters to use it in requests to check if
+ # the cache is still valid. Since media is immutable (though may be
+ # deleted), we just set this to a constant.
+ request.setHeader(b"ETag", _IMMUTABLE_ETAG)
+
+
# separators as defined in RFC2616. SP and HT are handled separately.
# see _can_encode_filename_as_token.
_FILENAME_SEPARATOR_CHARS = {
@@ -279,7 +306,9 @@ async def respond_with_multipart_responder(
clock: Clock,
request: SynapseRequest,
responder: "Optional[Responder]",
- media_info: "LocalMedia",
+ media_type: str,
+ media_length: Optional[int],
+ upload_name: Optional[str],
) -> None:
"""
Responds to requests originating from the federation media `/download` endpoint by
@@ -303,7 +332,7 @@ async def respond_with_multipart_responder(
)
return
- if media_info.media_type.lower().split(";", 1)[0] in INLINE_CONTENT_TYPES:
+ if media_type.lower().split(";", 1)[0] in INLINE_CONTENT_TYPES:
disposition = "inline"
else:
disposition = "attachment"
@@ -311,33 +340,35 @@ async def respond_with_multipart_responder(
def _quote(x: str) -> str:
return urllib.parse.quote(x.encode("utf-8"))
- if media_info.upload_name:
- if _can_encode_filename_as_token(media_info.upload_name):
+ if upload_name:
+ if _can_encode_filename_as_token(upload_name):
disposition = "%s; filename=%s" % (
disposition,
- media_info.upload_name,
+ upload_name,
)
else:
disposition = "%s; filename*=utf-8''%s" % (
disposition,
- _quote(media_info.upload_name),
+ _quote(upload_name),
)
from synapse.media.media_storage import MultipartFileConsumer
+ _add_cache_headers(request)
+
# note that currently the json_object is just {}, this will change when linked media
# is implemented
multipart_consumer = MultipartFileConsumer(
clock,
request,
- media_info.media_type,
- {},
+ media_type,
+ {}, # Note: if we change this we need to change the returned ETag.
disposition,
- media_info.media_length,
+ media_length,
)
logger.debug("Responding to media request with responder %s", responder)
- if media_info.media_length is not None:
+ if media_length is not None:
content_length = multipart_consumer.content_length()
assert content_length is not None
request.setHeader(b"Content-Length", b"%d" % (content_length,))
@@ -408,6 +439,46 @@ async def respond_with_responder(
finish_request(request)
+def respond_with_304(request: SynapseRequest) -> None:
+ request.setResponseCode(304)
+
+ # could alternatively use request.notifyFinish() and flip a flag when
+ # the Deferred fires, but since the flag is RIGHT THERE it seems like
+ # a waste.
+ if request._disconnected:
+ logger.warning(
+ "Not sending response to request %s, already disconnected.", request
+ )
+ return None
+
+ _add_cache_headers(request)
+
+ request.finish()
+
+
+def check_for_cached_entry_and_respond(request: SynapseRequest) -> bool:
+ """Check if the request has a conditional header that allows us to return a
+ 304 Not Modified response, and if it does, return a 304 response.
+
+ This handles clients and intermediary proxies caching media.
+ This method assumes that the user has already been
+ authorised to request the media.
+
+ Returns True if we have responded."""
+
+ # We've checked the user has access to the media, so we now check if it
+ # is a "conditional request" and we can just return a `304 Not Modified`
+ # response. Since media is immutable (though may be deleted), we just
+ # check this is the expected constant.
+ etag = request.getHeader("If-None-Match")
+ if etag == _IMMUTABLE_ETAG:
+ # Return a `304 Not modified`.
+ respond_with_304(request)
+ return True
+
+ return False
+
+
class Responder(ABC):
"""Represents a response that can be streamed to the requester.
@@ -601,3 +672,151 @@ def _parseparam(s: bytes) -> Generator[bytes, None, None]:
f = s[:end]
yield f.strip()
s = s[end:]
+
+
+@implementer(interfaces.IPushProducer)
+class ThreadedFileSender:
+ """
+ A producer that sends the contents of a file to a consumer, reading from the
+ file on a thread.
+
+ This works by having a loop in a threadpool repeatedly reading from the
+ file, until the consumer pauses the producer. There is then a loop in the
+ main thread that waits until the consumer resumes the producer and then
+ starts reading in the threadpool again.
+
+ This is done to ensure that we're never waiting in the threadpool, as
+ otherwise its easy to starve it of threads.
+ """
+
+ # How much data to read in one go.
+ CHUNK_SIZE = 2**14
+
+ # How long we wait for the consumer to be ready again before aborting the
+ # read.
+ TIMEOUT_SECONDS = 90.0
+
+ def __init__(self, hs: "HomeServer") -> None:
+ self.reactor = hs.get_reactor()
+ self.thread_pool = hs.get_media_sender_thread_pool()
+
+ self.file: Optional[BinaryIO] = None
+ self.deferred: "Deferred[None]" = Deferred()
+ self.consumer: Optional[interfaces.IConsumer] = None
+
+ # Signals if the thread should keep reading/sending data. Set means
+ # continue, clear means pause.
+ self.wakeup_event = DeferredEvent(self.reactor)
+
+ # Signals if the thread should terminate, e.g. because the consumer has
+ # gone away.
+ self.stop_writing = False
+
+ def beginFileTransfer(
+ self, file: BinaryIO, consumer: interfaces.IConsumer
+ ) -> "Deferred[None]":
+ """
+ Begin transferring a file
+ """
+ self.file = file
+ self.consumer = consumer
+
+ self.consumer.registerProducer(self, True)
+
+ # We set the wakeup signal as we should start producing immediately.
+ self.wakeup_event.set()
+ run_in_background(self.start_read_loop)
+
+ return make_deferred_yieldable(self.deferred)
+
+ def resumeProducing(self) -> None:
+ """interfaces.IPushProducer"""
+ self.wakeup_event.set()
+
+ def pauseProducing(self) -> None:
+ """interfaces.IPushProducer"""
+ self.wakeup_event.clear()
+
+ def stopProducing(self) -> None:
+ """interfaces.IPushProducer"""
+
+ # Unregister the consumer so we don't try and interact with it again.
+ if self.consumer:
+ self.consumer.unregisterProducer()
+
+ self.consumer = None
+
+ # Terminate the loop.
+ self.stop_writing = True
+ self.wakeup_event.set()
+
+ if not self.deferred.called:
+ self.deferred.errback(Exception("Consumer asked us to stop producing"))
+
+ async def start_read_loop(self) -> None:
+ """This is the loop that drives reading/writing"""
+ try:
+ while not self.stop_writing:
+ # Start the loop in the threadpool to read data.
+ more_data = await defer_to_threadpool(
+ self.reactor, self.thread_pool, self._on_thread_read_loop
+ )
+ if not more_data:
+ # Reached EOF, we can just return.
+ return
+
+ if not self.wakeup_event.is_set():
+ ret = await self.wakeup_event.wait(self.TIMEOUT_SECONDS)
+ if not ret:
+ raise Exception("Timed out waiting to resume")
+ except Exception:
+ self._error(Failure())
+ finally:
+ self._finish()
+
+ def _on_thread_read_loop(self) -> bool:
+ """This is the loop that happens on a thread.
+
+ Returns:
+ Whether there is more data to send.
+ """
+
+ while not self.stop_writing and self.wakeup_event.is_set():
+ # The file should always have been set before we get here.
+ assert self.file is not None
+
+ chunk = self.file.read(self.CHUNK_SIZE)
+ if not chunk:
+ return False
+
+ self.reactor.callFromThread(self._write, chunk)
+
+ return True
+
+ def _write(self, chunk: bytes) -> None:
+ """Called from the thread to write a chunk of data"""
+ if self.consumer:
+ self.consumer.write(chunk)
+
+ def _error(self, failure: Failure) -> None:
+ """Called when there was a fatal error"""
+ if self.consumer:
+ self.consumer.unregisterProducer()
+ self.consumer = None
+
+ if not self.deferred.called:
+ self.deferred.errback(failure)
+
+ def _finish(self) -> None:
+ """Called when we have finished writing (either on success or
+ failure)."""
+ if self.file:
+ self.file.close()
+ self.file = None
+
+ if self.consumer:
+ self.consumer.unregisterProducer()
+ self.consumer = None
+
+ if not self.deferred.called:
+ self.deferred.callback(None)
|