summary refs log tree commit diff
path: root/synapse/http
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/http')
-rw-r--r--synapse/http/__init__.py2
-rw-r--r--synapse/http/client.py18
-rw-r--r--synapse/http/federation/well_known_resolver.py13
-rw-r--r--synapse/http/matrixfederationclient.py40
-rw-r--r--synapse/http/proxyagent.py14
-rw-r--r--synapse/http/server.py8
-rw-r--r--synapse/http/servlet.py2
-rw-r--r--synapse/http/site.py16
8 files changed, 75 insertions, 38 deletions
diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py
index ed4671b7de..578fc48ef4 100644
--- a/synapse/http/__init__.py
+++ b/synapse/http/__init__.py
@@ -69,7 +69,7 @@ def _get_requested_host(request: IRequest) -> bytes:
         return hostname
 
     # no Host header, use the address/port that the request arrived on
-    host = request.getHost()  # type: Union[address.IPv4Address, address.IPv6Address]
+    host: Union[address.IPv4Address, address.IPv6Address] = request.getHost()
 
     hostname = host.host.encode("ascii")
 
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 1ca6624fd5..2ac76b15c2 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -160,7 +160,7 @@ class _IPBlacklistingResolver:
     def resolveHostName(
         self, recv: IResolutionReceiver, hostname: str, portNumber: int = 0
     ) -> IResolutionReceiver:
-        addresses = []  # type: List[IAddress]
+        addresses: List[IAddress] = []
 
         def _callback() -> None:
             has_bad_ip = False
@@ -333,9 +333,9 @@ class SimpleHttpClient:
         if self._ip_blacklist:
             # If we have an IP blacklist, we need to use a DNS resolver which
             # filters out blacklisted IP addresses, to prevent DNS rebinding.
-            self.reactor = BlacklistingReactorWrapper(
+            self.reactor: ISynapseReactor = BlacklistingReactorWrapper(
                 hs.get_reactor(), self._ip_whitelist, self._ip_blacklist
-            )  # type: ISynapseReactor
+            )
         else:
             self.reactor = hs.get_reactor()
 
@@ -349,14 +349,14 @@ class SimpleHttpClient:
         pool.maxPersistentPerHost = max((100 * hs.config.caches.global_factor, 5))
         pool.cachedConnectionTimeout = 2 * 60
 
-        self.agent = ProxyAgent(
+        self.agent: IAgent = ProxyAgent(
             self.reactor,
             hs.get_reactor(),
             connectTimeout=15,
             contextFactory=self.hs.get_http_client_context_factory(),
             pool=pool,
             use_proxy=use_proxy,
-        )  # type: IAgent
+        )
 
         if self._ip_blacklist:
             # If we have an IP blacklist, we then install the blacklisting Agent
@@ -411,7 +411,7 @@ class SimpleHttpClient:
                         cooperator=self._cooperator,
                     )
 
-                request_deferred = treq.request(
+                request_deferred: defer.Deferred = treq.request(
                     method,
                     uri,
                     agent=self.agent,
@@ -421,7 +421,7 @@ class SimpleHttpClient:
                     # response bodies.
                     unbuffered=True,
                     **self._extra_treq_args,
-                )  # type: defer.Deferred
+                )
 
                 # we use our own timeout mechanism rather than treq's as a workaround
                 # for https://twistedmatrix.com/trac/ticket/9534.
@@ -772,7 +772,7 @@ class BodyExceededMaxSize(Exception):
 class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
     """A protocol which immediately errors upon receiving data."""
 
-    transport = None  # type: Optional[ITCPTransport]
+    transport: Optional[ITCPTransport] = None
 
     def __init__(self, deferred: defer.Deferred):
         self.deferred = deferred
@@ -798,7 +798,7 @@ class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
 class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
     """A protocol which reads body to a stream, erroring if the body exceeds a maximum size."""
 
-    transport = None  # type: Optional[ITCPTransport]
+    transport: Optional[ITCPTransport] = None
 
     def __init__(
         self, stream: ByteWriteable, deferred: defer.Deferred, max_size: Optional[int]
diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py
index 20d39a4ea6..43f2140429 100644
--- a/synapse/http/federation/well_known_resolver.py
+++ b/synapse/http/federation/well_known_resolver.py
@@ -70,10 +70,8 @@ WELL_KNOWN_RETRY_ATTEMPTS = 3
 logger = logging.getLogger(__name__)
 
 
-_well_known_cache = TTLCache("well-known")  # type: TTLCache[bytes, Optional[bytes]]
-_had_valid_well_known_cache = TTLCache(
-    "had-valid-well-known"
-)  # type: TTLCache[bytes, bool]
+_well_known_cache: TTLCache[bytes, Optional[bytes]] = TTLCache("well-known")
+_had_valid_well_known_cache: TTLCache[bytes, bool] = TTLCache("had-valid-well-known")
 
 
 @attr.s(slots=True, frozen=True)
@@ -130,9 +128,10 @@ class WellKnownResolver:
         # requests for the same server in parallel?
         try:
             with Measure(self._clock, "get_well_known"):
-                result, cache_period = await self._fetch_well_known(
-                    server_name
-                )  # type: Optional[bytes], float
+                result: Optional[bytes]
+                cache_period: float
+
+                result, cache_period = await self._fetch_well_known(server_name)
 
         except _FetchWellKnownFailure as e:
             if prev_result and e.temporary:
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index b8849c0150..2efa15bf04 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -43,6 +43,7 @@ from twisted.internet import defer
 from twisted.internet.error import DNSLookupError
 from twisted.internet.interfaces import IReactorTime
 from twisted.internet.task import _EPSILON, Cooperator
+from twisted.web.client import ResponseFailed
 from twisted.web.http_headers import Headers
 from twisted.web.iweb import IBodyProducer, IResponse
 
@@ -105,7 +106,7 @@ class ByteParser(ByteWriteable, Generic[T], abc.ABC):
     the parsed data.
     """
 
-    CONTENT_TYPE = abc.abstractproperty()  # type: str  # type: ignore
+    CONTENT_TYPE: str = abc.abstractproperty()  # type: ignore
     """The expected content type of the response, e.g. `application/json`. If
     the content type doesn't match we fail the request.
     """
@@ -262,6 +263,15 @@ async def _handle_response(
             request.uri.decode("ascii"),
         )
         raise RequestSendFailed(e, can_retry=True) from e
+    except ResponseFailed as e:
+        logger.warning(
+            "{%s} [%s] Failed to read response - %s %s",
+            request.txn_id,
+            request.destination,
+            request.method,
+            request.uri.decode("ascii"),
+        )
+        raise RequestSendFailed(e, can_retry=True) from e
     except Exception as e:
         logger.warning(
             "{%s} [%s] Error reading response %s %s: %s",
@@ -317,11 +327,11 @@ class MatrixFederationHttpClient:
 
         # We need to use a DNS resolver which filters out blacklisted IP
         # addresses, to prevent DNS rebinding.
-        self.reactor = BlacklistingReactorWrapper(
+        self.reactor: ISynapseReactor = BlacklistingReactorWrapper(
             hs.get_reactor(),
             hs.config.federation_ip_range_whitelist,
             hs.config.federation_ip_range_blacklist,
-        )  # type: ISynapseReactor
+        )
 
         user_agent = hs.version_string
         if hs.config.user_agent_suffix:
@@ -494,7 +504,7 @@ class MatrixFederationHttpClient:
         )
 
         # Inject the span into the headers
-        headers_dict = {}  # type: Dict[bytes, List[bytes]]
+        headers_dict: Dict[bytes, List[bytes]] = {}
         opentracing.inject_header_dict(headers_dict, request.destination)
 
         headers_dict[b"User-Agent"] = [self.version_string_bytes]
@@ -523,9 +533,9 @@ class MatrixFederationHttpClient:
                             destination_bytes, method_bytes, url_to_sign_bytes, json
                         )
                         data = encode_canonical_json(json)
-                        producer = QuieterFileBodyProducer(
+                        producer: Optional[IBodyProducer] = QuieterFileBodyProducer(
                             BytesIO(data), cooperator=self._cooperator
-                        )  # type: Optional[IBodyProducer]
+                        )
                     else:
                         producer = None
                         auth_headers = self.build_auth_headers(
@@ -1137,6 +1147,24 @@ class MatrixFederationHttpClient:
                 msg,
             )
             raise SynapseError(502, msg, Codes.TOO_LARGE)
+        except defer.TimeoutError as e:
+            logger.warning(
+                "{%s} [%s] Timed out reading response - %s %s",
+                request.txn_id,
+                request.destination,
+                request.method,
+                request.uri.decode("ascii"),
+            )
+            raise RequestSendFailed(e, can_retry=True) from e
+        except ResponseFailed as e:
+            logger.warning(
+                "{%s} [%s] Failed to read response - %s %s",
+                request.txn_id,
+                request.destination,
+                request.method,
+                request.uri.decode("ascii"),
+            )
+            raise RequestSendFailed(e, can_retry=True) from e
         except Exception as e:
             logger.warning(
                 "{%s} [%s] Error reading response: %s",
diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py
index 7dfae8b786..f7193e60bd 100644
--- a/synapse/http/proxyagent.py
+++ b/synapse/http/proxyagent.py
@@ -117,7 +117,8 @@ class ProxyAgent(_AgentBase):
             https_proxy = proxies["https"].encode() if "https" in proxies else None
             no_proxy = proxies["no"] if "no" in proxies else None
 
-        # Parse credentials from https proxy connection string if present
+        # Parse credentials from http and https proxy connection string if present
+        self.http_proxy_creds, http_proxy = parse_username_password(http_proxy)
         self.https_proxy_creds, https_proxy = parse_username_password(https_proxy)
 
         self.http_proxy_endpoint = _http_proxy_endpoint(
@@ -171,7 +172,7 @@ class ProxyAgent(_AgentBase):
         """
         uri = uri.strip()
         if not _VALID_URI.match(uri):
-            raise ValueError("Invalid URI {!r}".format(uri))
+            raise ValueError(f"Invalid URI {uri!r}")
 
         parsed_uri = URI.fromBytes(uri)
         pool_key = (parsed_uri.scheme, parsed_uri.host, parsed_uri.port)
@@ -189,6 +190,15 @@ class ProxyAgent(_AgentBase):
             and self.http_proxy_endpoint
             and not should_skip_proxy
         ):
+            # Determine whether we need to set Proxy-Authorization headers
+            if self.http_proxy_creds:
+                # Set a Proxy-Authorization header
+                if headers is None:
+                    headers = Headers()
+                headers.addRawHeader(
+                    b"Proxy-Authorization",
+                    self.http_proxy_creds.as_proxy_authorization_value(),
+                )
             # Cache *all* connections under the same key, since we are only
             # connecting to a single destination, the proxy:
             pool_key = ("http-proxy", self.http_proxy_endpoint)
diff --git a/synapse/http/server.py b/synapse/http/server.py
index efbc6d5b25..b79fa722e9 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -81,7 +81,7 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
 
     if f.check(SynapseError):
         # mypy doesn't understand that f.check asserts the type.
-        exc = f.value  # type: SynapseError  # type: ignore
+        exc: SynapseError = f.value  # type: ignore
         error_code = exc.code
         error_dict = exc.error_dict()
 
@@ -132,7 +132,7 @@ def return_html_error(
     """
     if f.check(CodeMessageException):
         # mypy doesn't understand that f.check asserts the type.
-        cme = f.value  # type: CodeMessageException  # type: ignore
+        cme: CodeMessageException = f.value  # type: ignore
         code = cme.code
         msg = cme.msg
 
@@ -404,7 +404,7 @@ class JsonResource(DirectServeJsonResource):
             key word arguments to pass to the callback
         """
         # At this point the path must be bytes.
-        request_path_bytes = request.path  # type: bytes  # type: ignore
+        request_path_bytes: bytes = request.path  # type: ignore
         request_path = request_path_bytes.decode("ascii")
         # Treat HEAD requests as GET requests.
         request_method = request.method
@@ -557,7 +557,7 @@ class _ByteProducer:
         request: Request,
         iterator: Iterator[bytes],
     ):
-        self._request = request  # type: Optional[Request]
+        self._request: Optional[Request] = request
         self._iterator = iterator
         self._paused = False
 
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 6ba2ce1e53..04560fb589 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -205,7 +205,7 @@ def parse_string(
             parameter is present, must be one of a list of allowed values and
             is not one of those allowed values.
     """
-    args = request.args  # type: Dict[bytes, List[bytes]]  # type: ignore
+    args: Dict[bytes, List[bytes]] = request.args  # type: ignore
     return parse_string_from_args(
         args,
         name,
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 40754b7bea..190084e8aa 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -64,16 +64,16 @@ class SynapseRequest(Request):
     def __init__(self, channel, *args, max_request_body_size=1024, **kw):
         Request.__init__(self, channel, *args, **kw)
         self._max_request_body_size = max_request_body_size
-        self.site = channel.site  # type: SynapseSite
+        self.site: SynapseSite = channel.site
         self._channel = channel  # this is used by the tests
         self.start_time = 0.0
 
         # The requester, if authenticated. For federation requests this is the
         # server name, for client requests this is the Requester object.
-        self._requester = None  # type: Optional[Union[Requester, str]]
+        self._requester: Optional[Union[Requester, str]] = None
 
         # we can't yet create the logcontext, as we don't know the method.
-        self.logcontext = None  # type: Optional[LoggingContext]
+        self.logcontext: Optional[LoggingContext] = None
 
         global _next_request_seq
         self.request_seq = _next_request_seq
@@ -152,7 +152,7 @@ class SynapseRequest(Request):
         Returns:
             The redacted URI as a string.
         """
-        uri = self.uri  # type: Union[bytes, str]
+        uri: Union[bytes, str] = self.uri
         if isinstance(uri, bytes):
             uri = uri.decode("ascii", errors="replace")
         return redact_uri(uri)
@@ -167,7 +167,7 @@ class SynapseRequest(Request):
         Returns:
             The request method as a string.
         """
-        method = self.method  # type: Union[bytes, str]
+        method: Union[bytes, str] = self.method
         if isinstance(method, bytes):
             return self.method.decode("ascii")
         return method
@@ -384,7 +384,7 @@ class SynapseRequest(Request):
         # authenticated (e.g. and admin is puppetting a user) then we log both.
         requester, authenticated_entity = self.get_authenticated_entity()
         if authenticated_entity:
-            requester = "{}.{}".format(authenticated_entity, requester)
+            requester = f"{authenticated_entity}.{requester}"
 
         self.site.access_logger.log(
             log_level,
@@ -434,8 +434,8 @@ class XForwardedForRequest(SynapseRequest):
     """
 
     # the client IP and ssl flag, as extracted from the headers.
-    _forwarded_for = None  # type: Optional[_XForwardedForAddress]
-    _forwarded_https = False  # type: bool
+    _forwarded_for: "Optional[_XForwardedForAddress]" = None
+    _forwarded_https: bool = False
 
     def requestReceived(self, command, path, version):
         # this method is called by the Channel once the full request has been