diff --git a/synapse/http/client.py b/synapse/http/client.py
index 56ad28eabf..84a510fb42 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -31,18 +31,17 @@ from typing import (
List,
Mapping,
Optional,
+ Protocol,
Tuple,
Union,
)
import attr
-import multipart
import treq
from canonicaljson import encode_canonical_json
from netaddr import AddrFormatError, IPAddress, IPSet
from prometheus_client import Counter
-from typing_extensions import Protocol
-from zope.interface import implementer, provider
+from zope.interface import implementer
from OpenSSL import SSL
from OpenSSL.SSL import VERIFY_NONE
@@ -93,6 +92,20 @@ from synapse.util.async_helpers import timeout_deferred
if TYPE_CHECKING:
from synapse.server import HomeServer
+# Support both import names for the `python-multipart` (PyPI) library,
+# which renamed its package name from `multipart` to `python_multipart`
+# in 0.0.13 (though supports the old import name for compatibility).
+# Note that the `multipart` package name conflicts with `multipart` (PyPI)
+# so we should prefer importing from `python_multipart` when possible.
+try:
+ from python_multipart import MultipartParser
+
+ if TYPE_CHECKING:
+ from python_multipart import multipart
+except ImportError:
+ from multipart import MultipartParser # type: ignore[no-redef]
+
+
logger = logging.getLogger(__name__)
outgoing_requests_counter = Counter("synapse_http_client_requests", "", ["method"])
@@ -212,7 +225,7 @@ class _IPBlockingResolver:
recv.addressResolved(address)
recv.resolutionComplete()
- @provider(IResolutionReceiver)
+ @implementer(IResolutionReceiver)
class EndpointReceiver:
@staticmethod
def resolutionBegan(resolutionInProgress: IHostResolution) -> None:
@@ -226,8 +239,9 @@ class _IPBlockingResolver:
def resolutionComplete() -> None:
_callback()
+ endpoint_receiver_wrapper = EndpointReceiver()
self._reactor.nameResolver.resolveHostName(
- EndpointReceiver, hostname, portNumber=portNumber
+ endpoint_receiver_wrapper, hostname, portNumber=portNumber
)
return recv
@@ -1039,7 +1053,7 @@ class _MultipartParserProtocol(protocol.Protocol):
self.deferred = deferred
self.boundary = boundary
self.max_length = max_length
- self.parser = None
+ self.parser: Optional[MultipartParser] = None
self.multipart_response = MultipartResponse()
self.has_redirect = False
self.in_json = False
@@ -1057,11 +1071,11 @@ class _MultipartParserProtocol(protocol.Protocol):
if not self.parser:
def on_header_field(data: bytes, start: int, end: int) -> None:
- if data[start:end] == b"Location":
+ if data[start:end].lower() == b"location":
self.has_redirect = True
- if data[start:end] == b"Content-Disposition":
+ if data[start:end].lower() == b"content-disposition":
self.in_disposition = True
- if data[start:end] == b"Content-Type":
+ if data[start:end].lower() == b"content-type":
self.in_content_type = True
def on_header_value(data: bytes, start: int, end: int) -> None:
@@ -1088,7 +1102,6 @@ class _MultipartParserProtocol(protocol.Protocol):
return
# otherwise we are in the file part
else:
- logger.info("Writing multipart file data to stream")
try:
self.stream.write(data[start:end])
except Exception as e:
@@ -1098,12 +1111,12 @@ class _MultipartParserProtocol(protocol.Protocol):
self.deferred.errback()
self.file_length += end - start
- callbacks = {
+ callbacks: "multipart.MultipartCallbacks" = {
"on_header_field": on_header_field,
"on_header_value": on_header_value,
"on_part_data": on_part_data,
}
- self.parser = multipart.MultipartParser(self.boundary, callbacks)
+ self.parser = MultipartParser(self.boundary, callbacks)
self.total_length += len(incoming_data)
if self.max_length is not None and self.total_length >= self.max_length:
@@ -1114,7 +1127,7 @@ class _MultipartParserProtocol(protocol.Protocol):
self.transport.abortConnection()
try:
- self.parser.write(incoming_data) # type: ignore[attr-defined]
+ self.parser.write(incoming_data)
except Exception as e:
logger.warning(f"Exception writing to multipart parser: {e}")
self.deferred.errback()
@@ -1314,6 +1327,5 @@ def is_unknown_endpoint(
)
) or (
# Older Synapses returned a 400 error.
- e.code == 400
- and synapse_error.errcode == Codes.UNRECOGNIZED
+ e.code == 400 and synapse_error.errcode == Codes.UNRECOGNIZED
)
|