diff --git a/synapse/http/client.py b/synapse/http/client.py
index f409368802..e5b13593f2 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -14,9 +14,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-import urllib
+import urllib.parse
from io import BytesIO
from typing import (
+ TYPE_CHECKING,
Any,
BinaryIO,
Dict,
@@ -31,7 +32,7 @@ from typing import (
import treq
from canonicaljson import encode_canonical_json
-from netaddr import IPAddress
+from netaddr import IPAddress, IPSet
from prometheus_client import Counter
from zope.interface import implementer, provider
@@ -39,6 +40,8 @@ from OpenSSL import SSL
from OpenSSL.SSL import VERIFY_NONE
from twisted.internet import defer, error as twisted_error, protocol, ssl
from twisted.internet.interfaces import (
+ IAddress,
+ IHostResolution,
IReactorPluggableNameResolver,
IResolutionReceiver,
)
@@ -53,7 +56,7 @@ from twisted.web.client import (
)
from twisted.web.http import PotentialDataLoss
from twisted.web.http_headers import Headers
-from twisted.web.iweb import IResponse
+from twisted.web.iweb import IAgent, IBodyProducer, IResponse
from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_uri
@@ -63,6 +66,9 @@ from synapse.logging.opentracing import set_tag, start_active_span, tags
from synapse.util import json_decoder
from synapse.util.async_helpers import timeout_deferred
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
outgoing_requests_counter = Counter("synapse_http_client_requests", "", ["method"])
@@ -84,12 +90,19 @@ QueryParamValue = Union[str, bytes, Iterable[Union[str, bytes]]]
QueryParams = Union[Mapping[str, QueryParamValue], Mapping[bytes, QueryParamValue]]
-def check_against_blacklist(ip_address, ip_whitelist, ip_blacklist):
+def check_against_blacklist(
+ ip_address: IPAddress, ip_whitelist: Optional[IPSet], ip_blacklist: IPSet
+) -> bool:
"""
+ Compares an IP address to allowed and disallowed IP sets.
+
Args:
- ip_address (netaddr.IPAddress)
- ip_whitelist (netaddr.IPSet)
- ip_blacklist (netaddr.IPSet)
+ ip_address: The IP address to check
+ ip_whitelist: Allowed IP addresses.
+ ip_blacklist: Disallowed IP addresses.
+
+ Returns:
+ True if the IP address is in the blacklist and not in the whitelist.
"""
if ip_address in ip_blacklist:
if ip_whitelist is None or ip_address not in ip_whitelist:
@@ -118,23 +131,30 @@ class IPBlacklistingResolver:
addresses, preventing DNS rebinding attacks on URL preview.
"""
- def __init__(self, reactor, ip_whitelist, ip_blacklist):
+ def __init__(
+ self,
+ reactor: IReactorPluggableNameResolver,
+ ip_whitelist: Optional[IPSet],
+ ip_blacklist: IPSet,
+ ):
"""
Args:
- reactor (twisted.internet.reactor)
- ip_whitelist (netaddr.IPSet)
- ip_blacklist (netaddr.IPSet)
+ reactor: The twisted reactor.
+ ip_whitelist: IP addresses to allow.
+ ip_blacklist: IP addresses to disallow.
"""
self._reactor = reactor
self._ip_whitelist = ip_whitelist
self._ip_blacklist = ip_blacklist
- def resolveHostName(self, recv, hostname, portNumber=0):
+ def resolveHostName(
+ self, recv: IResolutionReceiver, hostname: str, portNumber: int = 0
+ ) -> IResolutionReceiver:
r = recv()
- addresses = []
+ addresses = [] # type: List[IAddress]
- def _callback():
+ def _callback() -> None:
r.resolutionBegan(None)
has_bad_ip = False
@@ -161,15 +181,15 @@ class IPBlacklistingResolver:
@provider(IResolutionReceiver)
class EndpointReceiver:
@staticmethod
- def resolutionBegan(resolutionInProgress):
+ def resolutionBegan(resolutionInProgress: IHostResolution) -> None:
pass
@staticmethod
- def addressResolved(address):
+ def addressResolved(address: IAddress) -> None:
addresses.append(address)
@staticmethod
- def resolutionComplete():
+ def resolutionComplete() -> None:
_callback()
self._reactor.nameResolver.resolveHostName(
@@ -185,19 +205,29 @@ class BlacklistingAgentWrapper(Agent):
directly (without an IP address lookup).
"""
- def __init__(self, agent, reactor, ip_whitelist=None, ip_blacklist=None):
+ def __init__(
+ self,
+ agent: IAgent,
+ ip_whitelist: Optional[IPSet] = None,
+ ip_blacklist: Optional[IPSet] = None,
+ ):
"""
Args:
- agent (twisted.web.client.Agent): The Agent to wrap.
- reactor (twisted.internet.reactor)
- ip_whitelist (netaddr.IPSet)
- ip_blacklist (netaddr.IPSet)
+ agent: The Agent to wrap.
+ ip_whitelist: IP addresses to allow.
+ ip_blacklist: IP addresses to disallow.
"""
self._agent = agent
self._ip_whitelist = ip_whitelist
self._ip_blacklist = ip_blacklist
- def request(self, method, uri, headers=None, bodyProducer=None):
+ def request(
+ self,
+ method: bytes,
+ uri: bytes,
+ headers: Optional[Headers] = None,
+ bodyProducer: Optional[IBodyProducer] = None,
+ ) -> defer.Deferred:
h = urllib.parse.urlparse(uri.decode("ascii"))
try:
@@ -226,23 +256,23 @@ class SimpleHttpClient:
def __init__(
self,
- hs,
- treq_args={},
- ip_whitelist=None,
- ip_blacklist=None,
- http_proxy=None,
- https_proxy=None,
+ hs: "HomeServer",
+ treq_args: Dict[str, Any] = {},
+ ip_whitelist: Optional[IPSet] = None,
+ ip_blacklist: Optional[IPSet] = None,
+ http_proxy: Optional[bytes] = None,
+ https_proxy: Optional[bytes] = None,
):
"""
Args:
- hs (synapse.server.HomeServer)
- treq_args (dict): Extra keyword arguments to be given to treq.request.
- ip_blacklist (netaddr.IPSet): The IP addresses that are blacklisted that
+ hs
+ treq_args: Extra keyword arguments to be given to treq.request.
+ ip_blacklist: The IP addresses that are blacklisted that
we may not request.
- ip_whitelist (netaddr.IPSet): The whitelisted IP addresses, that we can
+ ip_whitelist: The whitelisted IP addresses, that we can
request if it were otherwise caught in a blacklist.
- http_proxy (bytes): proxy server to use for http connections. host[:port]
- https_proxy (bytes): proxy server to use for https connections. host[:port]
+ http_proxy: proxy server to use for http connections. host[:port]
+ https_proxy: proxy server to use for https connections. host[:port]
"""
self.hs = hs
@@ -306,7 +336,6 @@ class SimpleHttpClient:
# by the DNS resolution.
self.agent = BlacklistingAgentWrapper(
self.agent,
- self.reactor,
ip_whitelist=self._ip_whitelist,
ip_blacklist=self._ip_blacklist,
)
@@ -397,7 +426,7 @@ class SimpleHttpClient:
async def post_urlencoded_get_json(
self,
uri: str,
- args: Mapping[str, Union[str, List[str]]] = {},
+ args: Optional[Mapping[str, Union[str, List[str]]]] = None,
headers: Optional[RawHeaders] = None,
) -> Any:
"""
@@ -422,9 +451,7 @@ class SimpleHttpClient:
# TODO: Do we ever want to log message contents?
logger.debug("post_urlencoded_get_json args: %s", args)
- query_bytes = urllib.parse.urlencode(encode_urlencode_args(args), True).encode(
- "utf8"
- )
+ query_bytes = encode_query_args(args)
actual_headers = {
b"Content-Type": [b"application/x-www-form-urlencoded"],
@@ -432,7 +459,7 @@ class SimpleHttpClient:
b"Accept": [b"application/json"],
}
if headers:
- actual_headers.update(headers)
+ actual_headers.update(headers) # type: ignore
response = await self.request(
"POST", uri, headers=Headers(actual_headers), data=query_bytes
@@ -479,7 +506,7 @@ class SimpleHttpClient:
b"Accept": [b"application/json"],
}
if headers:
- actual_headers.update(headers)
+ actual_headers.update(headers) # type: ignore
response = await self.request(
"POST", uri, headers=Headers(actual_headers), data=json_str
@@ -495,7 +522,10 @@ class SimpleHttpClient:
)
async def get_json(
- self, uri: str, args: QueryParams = {}, headers: Optional[RawHeaders] = None,
+ self,
+ uri: str,
+ args: Optional[QueryParams] = None,
+ headers: Optional[RawHeaders] = None,
) -> Any:
"""Gets some json from the given URI.
@@ -516,7 +546,7 @@ class SimpleHttpClient:
"""
actual_headers = {b"Accept": [b"application/json"]}
if headers:
- actual_headers.update(headers)
+ actual_headers.update(headers) # type: ignore
body = await self.get_raw(uri, args, headers=headers)
return json_decoder.decode(body.decode("utf-8"))
@@ -525,7 +555,7 @@ class SimpleHttpClient:
self,
uri: str,
json_body: Any,
- args: QueryParams = {},
+ args: Optional[QueryParams] = None,
headers: RawHeaders = None,
) -> Any:
"""Puts some json to the given URI.
@@ -546,9 +576,9 @@ class SimpleHttpClient:
ValueError: if the response was not JSON
"""
- if len(args):
- query_bytes = urllib.parse.urlencode(args, True)
- uri = "%s?%s" % (uri, query_bytes)
+ if args:
+ query_str = urllib.parse.urlencode(args, True)
+ uri = "%s?%s" % (uri, query_str)
json_str = encode_canonical_json(json_body)
@@ -558,7 +588,7 @@ class SimpleHttpClient:
b"Accept": [b"application/json"],
}
if headers:
- actual_headers.update(headers)
+ actual_headers.update(headers) # type: ignore
response = await self.request(
"PUT", uri, headers=Headers(actual_headers), data=json_str
@@ -574,7 +604,10 @@ class SimpleHttpClient:
)
async def get_raw(
- self, uri: str, args: QueryParams = {}, headers: Optional[RawHeaders] = None
+ self,
+ uri: str,
+ args: Optional[QueryParams] = None,
+ headers: Optional[RawHeaders] = None,
) -> bytes:
"""Gets raw text from the given URI.
@@ -592,13 +625,13 @@ class SimpleHttpClient:
HttpResponseException on a non-2xx HTTP response.
"""
- if len(args):
- query_bytes = urllib.parse.urlencode(args, True)
- uri = "%s?%s" % (uri, query_bytes)
+ if args:
+ query_str = urllib.parse.urlencode(args, True)
+ uri = "%s?%s" % (uri, query_str)
actual_headers = {b"User-Agent": [self.user_agent]}
if headers:
- actual_headers.update(headers)
+ actual_headers.update(headers) # type: ignore
response = await self.request("GET", uri, headers=Headers(actual_headers))
@@ -641,7 +674,7 @@ class SimpleHttpClient:
actual_headers = {b"User-Agent": [self.user_agent]}
if headers:
- actual_headers.update(headers)
+ actual_headers.update(headers) # type: ignore
response = await self.request("GET", url, headers=Headers(actual_headers))
@@ -649,12 +682,13 @@ class SimpleHttpClient:
if (
b"Content-Length" in resp_headers
+ and max_size
and int(resp_headers[b"Content-Length"][0]) > max_size
):
- logger.warning("Requested URL is too large > %r bytes" % (self.max_size,))
+ logger.warning("Requested URL is too large > %r bytes" % (max_size,))
raise SynapseError(
502,
- "Requested file is too large > %r bytes" % (self.max_size,),
+ "Requested file is too large > %r bytes" % (max_size,),
Codes.TOO_LARGE,
)
@@ -668,7 +702,7 @@ class SimpleHttpClient:
try:
length = await make_deferred_yieldable(
- _readBodyToFile(response, output_stream, max_size)
+ readBodyToFile(response, output_stream, max_size)
)
except SynapseError:
# This can happen e.g. because the body is too large.
@@ -696,18 +730,16 @@ def _timeout_to_request_timed_out_error(f: Failure):
return f
-# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
-# The two should be factored out.
-
-
class _ReadBodyToFileProtocol(protocol.Protocol):
- def __init__(self, stream, deferred, max_size):
+ def __init__(
+ self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int]
+ ):
self.stream = stream
self.deferred = deferred
self.length = 0
self.max_size = max_size
- def dataReceived(self, data):
+ def dataReceived(self, data: bytes) -> None:
self.stream.write(data)
self.length += len(data)
if self.max_size is not None and self.length >= self.max_size:
@@ -721,7 +753,7 @@ class _ReadBodyToFileProtocol(protocol.Protocol):
self.deferred = defer.Deferred()
self.transport.loseConnection()
- def connectionLost(self, reason):
+ def connectionLost(self, reason: Failure) -> None:
if reason.check(ResponseDone):
self.deferred.callback(self.length)
elif reason.check(PotentialDataLoss):
@@ -732,35 +764,48 @@ class _ReadBodyToFileProtocol(protocol.Protocol):
self.deferred.errback(reason)
-# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
-# The two should be factored out.
+def readBodyToFile(
+ response: IResponse, stream: BinaryIO, max_size: Optional[int]
+) -> defer.Deferred:
+ """
+ Read a HTTP response body to a file-object. Optionally enforcing a maximum file size.
+ Args:
+ response: The HTTP response to read from.
+ stream: The file-object to write to.
+ max_size: The maximum file size to allow.
+
+ Returns:
+ A Deferred which resolves to the length of the read body.
+ """
-def _readBodyToFile(response, stream, max_size):
d = defer.Deferred()
response.deliverBody(_ReadBodyToFileProtocol(stream, d, max_size))
return d
-def encode_urlencode_args(args):
- return {k: encode_urlencode_arg(v) for k, v in args.items()}
+def encode_query_args(args: Optional[Mapping[str, Union[str, List[str]]]]) -> bytes:
+ """
+ Encodes a map of query arguments to bytes which can be appended to a URL.
+ Args:
+ args: The query arguments, a mapping of string to string or list of strings.
+
+ Returns:
+ The query arguments encoded as bytes.
+ """
+ if args is None:
+ return b""
-def encode_urlencode_arg(arg):
- if isinstance(arg, str):
- return arg.encode("utf-8")
- elif isinstance(arg, list):
- return [encode_urlencode_arg(i) for i in arg]
- else:
- return arg
+ encoded_args = {}
+ for k, vs in args.items():
+ if isinstance(vs, str):
+ vs = [vs]
+ encoded_args[k] = [v.encode("utf8") for v in vs]
+ query_str = urllib.parse.urlencode(encoded_args, True)
-def _print_ex(e):
- if hasattr(e, "reasons") and e.reasons:
- for ex in e.reasons:
- _print_ex(ex)
- else:
- logger.exception(e)
+ return query_str.encode("utf8")
class InsecureInterceptableContextFactory(ssl.ContextFactory):
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index b2ccae90df..4e27f93b7a 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -19,7 +19,7 @@ import random
import sys
import urllib.parse
from io import BytesIO
-from typing import BinaryIO, Callable, Dict, List, Optional, Tuple, Union
+from typing import Callable, Dict, List, Optional, Tuple, Union
import attr
import treq
@@ -28,26 +28,27 @@ from prometheus_client import Counter
from signedjson.sign import sign_json
from zope.interface import implementer
-from twisted.internet import defer, protocol
+from twisted.internet import defer
from twisted.internet.error import DNSLookupError
from twisted.internet.interfaces import IReactorPluggableNameResolver, IReactorTime
from twisted.internet.task import _EPSILON, Cooperator
-from twisted.python.failure import Failure
-from twisted.web._newclient import ResponseDone
from twisted.web.http_headers import Headers
from twisted.web.iweb import IBodyProducer, IResponse
import synapse.metrics
import synapse.util.retryutils
from synapse.api.errors import (
- Codes,
FederationDeniedError,
HttpResponseException,
RequestSendFailed,
- SynapseError,
)
from synapse.http import QuieterFileBodyProducer
-from synapse.http.client import BlacklistingAgentWrapper, IPBlacklistingResolver
+from synapse.http.client import (
+ BlacklistingAgentWrapper,
+ IPBlacklistingResolver,
+ encode_query_args,
+ readBodyToFile,
+)
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
from synapse.logging.context import make_deferred_yieldable
from synapse.logging.opentracing import (
@@ -250,9 +251,7 @@ class MatrixFederationHttpClient:
# Use a BlacklistingAgentWrapper to prevent circumventing the IP
# blacklist via IP literals in server names
self.agent = BlacklistingAgentWrapper(
- self.agent,
- self.reactor,
- ip_blacklist=hs.config.federation_ip_range_blacklist,
+ self.agent, ip_blacklist=hs.config.federation_ip_range_blacklist,
)
self.clock = hs.get_clock()
@@ -986,7 +985,7 @@ class MatrixFederationHttpClient:
headers = dict(response.headers.getAllRawHeaders())
try:
- d = _readBodyToFile(response, output_stream, max_size)
+ d = readBodyToFile(response, output_stream, max_size)
d.addTimeout(self.default_timeout, self.reactor)
length = await make_deferred_yieldable(d)
except Exception as e:
@@ -1010,44 +1009,6 @@ class MatrixFederationHttpClient:
return (length, headers)
-class _ReadBodyToFileProtocol(protocol.Protocol):
- def __init__(
- self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int]
- ):
- self.stream = stream
- self.deferred = deferred
- self.length = 0
- self.max_size = max_size
-
- def dataReceived(self, data: bytes) -> None:
- self.stream.write(data)
- self.length += len(data)
- if self.max_size is not None and self.length >= self.max_size:
- self.deferred.errback(
- SynapseError(
- 502,
- "Requested file is too large > %r bytes" % (self.max_size,),
- Codes.TOO_LARGE,
- )
- )
- self.deferred = defer.Deferred()
- self.transport.loseConnection()
-
- def connectionLost(self, reason: Failure) -> None:
- if reason.check(ResponseDone):
- self.deferred.callback(self.length)
- else:
- self.deferred.errback(reason)
-
-
-def _readBodyToFile(
- response: IResponse, stream: BinaryIO, max_size: Optional[int]
-) -> defer.Deferred:
- d = defer.Deferred()
- response.deliverBody(_ReadBodyToFileProtocol(stream, d, max_size))
- return d
-
-
def _flatten_response_never_received(e):
if hasattr(e, "reasons"):
reasons = ", ".join(
@@ -1088,18 +1049,3 @@ def check_content_type_is_json(headers: Headers) -> None:
),
can_retry=False,
)
-
-
-def encode_query_args(args: Optional[QueryArgs]) -> bytes:
- if args is None:
- return b""
-
- encoded_args = {}
- for k, vs in args.items():
- if isinstance(vs, str):
- vs = [vs]
- encoded_args[k] = [v.encode("utf8") for v in vs]
-
- query_str = urllib.parse.urlencode(encoded_args, True)
-
- return query_str.encode("utf8")
|