summary refs log tree commit diff
path: root/synapse/http/client.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/http/client.py42
1 files changed, 27 insertions, 15 deletions
diff --git a/synapse/http/client.py b/synapse/http/client.py

index 56ad28eabf..84a510fb42 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py
@@ -31,18 +31,17 @@ from typing import ( List, Mapping, Optional, + Protocol, Tuple, Union, ) import attr -import multipart import treq from canonicaljson import encode_canonical_json from netaddr import AddrFormatError, IPAddress, IPSet from prometheus_client import Counter -from typing_extensions import Protocol -from zope.interface import implementer, provider +from zope.interface import implementer from OpenSSL import SSL from OpenSSL.SSL import VERIFY_NONE @@ -93,6 +92,20 @@ from synapse.util.async_helpers import timeout_deferred if TYPE_CHECKING: from synapse.server import HomeServer +# Support both import names for the `python-multipart` (PyPI) library, +# which renamed its package name from `multipart` to `python_multipart` +# in 0.0.13 (though supports the old import name for compatibility). +# Note that the `multipart` package name conflicts with `multipart` (PyPI) +# so we should prefer importing from `python_multipart` when possible. +try: + from python_multipart import MultipartParser + + if TYPE_CHECKING: + from python_multipart import multipart +except ImportError: + from multipart import MultipartParser # type: ignore[no-redef] + + logger = logging.getLogger(__name__) outgoing_requests_counter = Counter("synapse_http_client_requests", "", ["method"]) @@ -212,7 +225,7 @@ class _IPBlockingResolver: recv.addressResolved(address) recv.resolutionComplete() - @provider(IResolutionReceiver) + @implementer(IResolutionReceiver) class EndpointReceiver: @staticmethod def resolutionBegan(resolutionInProgress: IHostResolution) -> None: @@ -226,8 +239,9 @@ class _IPBlockingResolver: def resolutionComplete() -> None: _callback() + endpoint_receiver_wrapper = EndpointReceiver() self._reactor.nameResolver.resolveHostName( - EndpointReceiver, hostname, portNumber=portNumber + endpoint_receiver_wrapper, hostname, portNumber=portNumber ) return recv @@ -1039,7 +1053,7 @@ class _MultipartParserProtocol(protocol.Protocol): self.deferred = deferred self.boundary = boundary self.max_length = max_length - self.parser = None + self.parser: Optional[MultipartParser] = None self.multipart_response = MultipartResponse() self.has_redirect = False self.in_json = False @@ -1057,11 +1071,11 @@ class _MultipartParserProtocol(protocol.Protocol): if not self.parser: def on_header_field(data: bytes, start: int, end: int) -> None: - if data[start:end] == b"Location": + if data[start:end].lower() == b"location": self.has_redirect = True - if data[start:end] == b"Content-Disposition": + if data[start:end].lower() == b"content-disposition": self.in_disposition = True - if data[start:end] == b"Content-Type": + if data[start:end].lower() == b"content-type": self.in_content_type = True def on_header_value(data: bytes, start: int, end: int) -> None: @@ -1088,7 +1102,6 @@ class _MultipartParserProtocol(protocol.Protocol): return # otherwise we are in the file part else: - logger.info("Writing multipart file data to stream") try: self.stream.write(data[start:end]) except Exception as e: @@ -1098,12 +1111,12 @@ class _MultipartParserProtocol(protocol.Protocol): self.deferred.errback() self.file_length += end - start - callbacks = { + callbacks: "multipart.MultipartCallbacks" = { "on_header_field": on_header_field, "on_header_value": on_header_value, "on_part_data": on_part_data, } - self.parser = multipart.MultipartParser(self.boundary, callbacks) + self.parser = MultipartParser(self.boundary, callbacks) self.total_length += len(incoming_data) if self.max_length is not None and self.total_length >= self.max_length: @@ -1114,7 +1127,7 @@ class _MultipartParserProtocol(protocol.Protocol): self.transport.abortConnection() try: - self.parser.write(incoming_data) # type: ignore[attr-defined] + self.parser.write(incoming_data) except Exception as e: logger.warning(f"Exception writing to multipart parser: {e}") self.deferred.errback() @@ -1314,6 +1327,5 @@ def is_unknown_endpoint( ) ) or ( # Older Synapses returned a 400 error. - e.code == 400 - and synapse_error.errcode == Codes.UNRECOGNIZED + e.code == 400 and synapse_error.errcode == Codes.UNRECOGNIZED )