diff options
Diffstat (limited to 'synapse/rest/media/v1/_base.py')
-rw-r--r-- | synapse/rest/media/v1/_base.py | 96 |
1 files changed, 74 insertions, 22 deletions
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index efe42a429d..fece1ef0b8 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2019 New Vector Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -133,8 +134,15 @@ def respond_with_responder(request, responder, media_type, file_size, upload_nam logger.debug("Responding to media request with responder %s") add_file_headers(request, media_type, file_size, upload_name) - with responder: - yield responder.write_to_consumer(request) + try: + with responder: + yield responder.write_to_consumer(request) + except Exception as e: + # The majority of the time this will be due to the client having gone + # away. Unfortunately, Twisted simply throws a generic exception at us + # in that case. + logger.warning("Failed to write to consumer: %s %s", type(e), e) + finish_request(request) @@ -206,8 +214,7 @@ def get_filename_from_headers(headers): Content-Disposition HTTP header. Args: - headers (twisted.web.http_headers.Headers): The HTTP - request headers. + headers (dict[bytes, list[bytes]]): The HTTP request headers. Returns: A Unicode string of the filename, or None. @@ -218,23 +225,12 @@ def get_filename_from_headers(headers): if not content_disposition[0]: return - # dict of unicode: bytes, corresponding to the key value sections of the - # Content-Disposition header. - params = {} - parts = content_disposition[0].split(b";") - for i in parts: - # Split into key-value pairs, if able - # We don't care about things like `inline`, so throw it out - if b"=" not in i: - continue - - key, value = i.strip().split(b"=") - params[key.decode('ascii')] = value + _, params = _parse_header(content_disposition[0]) upload_name = None # First check if there is a valid UTF-8 filename - upload_name_utf8 = params.get("filename*", None) + upload_name_utf8 = params.get(b"filename*", None) if upload_name_utf8: if upload_name_utf8.lower().startswith(b"utf-8''"): upload_name_utf8 = upload_name_utf8[7:] @@ -260,12 +256,68 @@ def get_filename_from_headers(headers): # If there isn't check for an ascii name. if not upload_name: - upload_name_ascii = params.get("filename", None) + upload_name_ascii = params.get(b"filename", None) if upload_name_ascii and is_ascii(upload_name_ascii): - # Make sure there's no %-quoted bytes. If there is, reject it as - # non-valid ASCII. - if b"%" not in upload_name_ascii: - upload_name = upload_name_ascii.decode('ascii') + upload_name = upload_name_ascii.decode('ascii') # This may be None here, indicating we did not find a matching name. return upload_name + + +def _parse_header(line): + """Parse a Content-type like header. + + Cargo-culted from `cgi`, but works on bytes rather than strings. + + Args: + line (bytes): header to be parsed + + Returns: + Tuple[bytes, dict[bytes, bytes]]: + the main content-type, followed by the parameter dictionary + """ + parts = _parseparam(b';' + line) + key = next(parts) + pdict = {} + for p in parts: + i = p.find(b'=') + if i >= 0: + name = p[:i].strip().lower() + value = p[i + 1:].strip() + + # strip double-quotes + if len(value) >= 2 and value[0:1] == value[-1:] == b'"': + value = value[1:-1] + value = value.replace(b'\\\\', b'\\').replace(b'\\"', b'"') + pdict[name] = value + + return key, pdict + + +def _parseparam(s): + """Generator which splits the input on ;, respecting double-quoted sequences + + Cargo-culted from `cgi`, but works on bytes rather than strings. + + Args: + s (bytes): header to be parsed + + Returns: + Iterable[bytes]: the split input + """ + while s[:1] == b';': + s = s[1:] + + # look for the next ; + end = s.find(b';') + + # if there is an odd number of " marks between here and the next ;, skip to the + # next ; instead + while end > 0 and (s.count(b'"', 0, end) - s.count(b'\\"', 0, end)) % 2: + end = s.find(b';', end + 1) + + if end < 0: + end = len(s) + f = s[:end] + yield f.strip() + s = s[end:] |