summary refs log tree commit diff
path: root/synapse/rest/media
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/rest/media')
-rw-r--r--synapse/rest/media/v1/_base.py85
1 files changed, 65 insertions, 20 deletions
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index d16a30acd8..fece1ef0b8 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2019 New Vector Ltd.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -213,8 +214,7 @@ def get_filename_from_headers(headers):
     Content-Disposition HTTP header.
 
     Args:
-        headers (twisted.web.http_headers.Headers): The HTTP
-            request headers.
+        headers (dict[bytes, list[bytes]]): The HTTP request headers.
 
     Returns:
         A Unicode string of the filename, or None.
@@ -225,23 +225,12 @@ def get_filename_from_headers(headers):
     if not content_disposition[0]:
         return
 
-    # dict of unicode: bytes, corresponding to the key value sections of the
-    # Content-Disposition header.
-    params = {}
-    parts = content_disposition[0].split(b";")
-    for i in parts:
-        # Split into key-value pairs, if able
-        # We don't care about things like `inline`, so throw it out
-        if b"=" not in i:
-            continue
-
-        key, value = i.strip().split(b"=")
-        params[key.decode('ascii')] = value
+    _, params = _parse_header(content_disposition[0])
 
     upload_name = None
 
     # First check if there is a valid UTF-8 filename
-    upload_name_utf8 = params.get("filename*", None)
+    upload_name_utf8 = params.get(b"filename*", None)
     if upload_name_utf8:
         if upload_name_utf8.lower().startswith(b"utf-8''"):
             upload_name_utf8 = upload_name_utf8[7:]
@@ -267,12 +256,68 @@ def get_filename_from_headers(headers):
 
     # If there isn't check for an ascii name.
     if not upload_name:
-        upload_name_ascii = params.get("filename", None)
+        upload_name_ascii = params.get(b"filename", None)
         if upload_name_ascii and is_ascii(upload_name_ascii):
-            # Make sure there's no %-quoted bytes. If there is, reject it as
-            # non-valid ASCII.
-            if b"%" not in upload_name_ascii:
-                upload_name = upload_name_ascii.decode('ascii')
+            upload_name = upload_name_ascii.decode('ascii')
 
     # This may be None here, indicating we did not find a matching name.
     return upload_name
+
+
+def _parse_header(line):
+    """Parse a Content-type like header.
+
+    Cargo-culted from `cgi`, but works on bytes rather than strings.
+
+    Args:
+        line (bytes): header to be parsed
+
+    Returns:
+        Tuple[bytes, dict[bytes, bytes]]:
+            the main content-type, followed by the parameter dictionary
+    """
+    parts = _parseparam(b';' + line)
+    key = next(parts)
+    pdict = {}
+    for p in parts:
+        i = p.find(b'=')
+        if i >= 0:
+            name = p[:i].strip().lower()
+            value = p[i + 1:].strip()
+
+            # strip double-quotes
+            if len(value) >= 2 and value[0:1] == value[-1:] == b'"':
+                value = value[1:-1]
+                value = value.replace(b'\\\\', b'\\').replace(b'\\"', b'"')
+            pdict[name] = value
+
+    return key, pdict
+
+
+def _parseparam(s):
+    """Generator which splits the input on ;, respecting double-quoted sequences
+
+    Cargo-culted from `cgi`, but works on bytes rather than strings.
+
+    Args:
+        s (bytes): header to be parsed
+
+    Returns:
+        Iterable[bytes]: the split input
+    """
+    while s[:1] == b';':
+        s = s[1:]
+
+        # look for the next ;
+        end = s.find(b';')
+
+        # if there is an odd number of " marks between here and the next ;, skip to the
+        # next ; instead
+        while end > 0 and (s.count(b'"', 0, end) - s.count(b'\\"', 0, end)) % 2:
+            end = s.find(b';', end + 1)
+
+        if end < 0:
+            end = len(s)
+        f = s[:end]
+        yield f.strip()
+        s = s[end:]