summary refs log tree commit diff
path: root/synapse/http/client.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/http/client.py')
-rw-r--r--synapse/http/client.py29
1 files changed, 23 insertions, 6 deletions
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 743a7ffcb1..c01d2326cf 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -20,6 +20,7 @@ from typing import (
     TYPE_CHECKING,
     Any,
     BinaryIO,
+    Callable,
     Dict,
     Iterable,
     List,
@@ -321,21 +322,20 @@ class SimpleHttpClient:
         self._ip_whitelist = ip_whitelist
         self._ip_blacklist = ip_blacklist
         self._extra_treq_args = treq_args or {}
-
-        self.user_agent = hs.version_string
         self.clock = hs.get_clock()
+
+        user_agent = hs.version_string
         if hs.config.server.user_agent_suffix:
-            self.user_agent = "%s %s" % (
-                self.user_agent,
+            user_agent = "%s %s" % (
+                user_agent,
                 hs.config.server.user_agent_suffix,
             )
+        self.user_agent = user_agent.encode("ascii")
 
         # We use this for our body producers to ensure that they use the correct
         # reactor.
         self._cooperator = Cooperator(scheduler=_make_scheduler(hs.get_reactor()))
 
-        self.user_agent = self.user_agent.encode("ascii")
-
         if self._ip_blacklist:
             # If we have an IP blacklist, we need to use a DNS resolver which
             # filters out blacklisted IP addresses, to prevent DNS rebinding.
@@ -693,12 +693,18 @@ class SimpleHttpClient:
         output_stream: BinaryIO,
         max_size: Optional[int] = None,
         headers: Optional[RawHeaders] = None,
+        is_allowed_content_type: Optional[Callable[[str], bool]] = None,
     ) -> Tuple[int, Dict[bytes, List[bytes]], str, int]:
         """GETs a file from a given URL
         Args:
             url: The URL to GET
             output_stream: File to write the response body to.
             headers: A map from header name to a list of values for that header
+            is_allowed_content_type: A predicate to determine whether the
+                content type of the file we're downloading is allowed. If set and
+                it evaluates to False when called with the content type, the
+                request will be terminated before completing the download by
+                raising SynapseError.
         Returns:
             A tuple of the file length, dict of the response
             headers, absolute URI of the response and HTTP response code.
@@ -726,6 +732,17 @@ class SimpleHttpClient:
                 HTTPStatus.BAD_GATEWAY, "Got error %d" % (response.code,), Codes.UNKNOWN
             )
 
+        if is_allowed_content_type and b"Content-Type" in resp_headers:
+            content_type = resp_headers[b"Content-Type"][0].decode("ascii")
+            if not is_allowed_content_type(content_type):
+                raise SynapseError(
+                    HTTPStatus.BAD_GATEWAY,
+                    (
+                        "Requested file's content type not allowed for this operation: %s"
+                        % content_type
+                    ),
+                )
+
         # TODO: if our Content-Type is HTML or something, just read the first
         # N bytes into RAM rather than saving it all to disk only to read it
         # straight back in again