diff --git a/synapse/http/client.py b/synapse/http/client.py
index 3cef747a4d..13fcab3378 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -15,13 +15,11 @@
# limitations under the License.
import logging
+import urllib
from io import BytesIO
-from six import raise_from, text_type
-from six.moves import urllib
-
import treq
-from canonicaljson import encode_canonical_json, json
+from canonicaljson import encode_canonical_json
from netaddr import IPAddress
from prometheus_client import Counter
from zope.interface import implementer, provider
@@ -33,6 +31,7 @@ from twisted.internet.interfaces import (
IReactorPluggableNameResolver,
IResolutionReceiver,
)
+from twisted.internet.task import Cooperator
from twisted.python.failure import Failure
from twisted.web._newclient import ResponseDone
from twisted.web.client import Agent, HTTPConnectionPool, readBody
@@ -48,6 +47,7 @@ from synapse.http import (
from synapse.http.proxyagent import ProxyAgent
from synapse.logging.context import make_deferred_yieldable
from synapse.logging.opentracing import set_tag, start_active_span, tags
+from synapse.util import json_decoder
from synapse.util.async_helpers import timeout_deferred
logger = logging.getLogger(__name__)
@@ -71,7 +71,22 @@ def check_against_blacklist(ip_address, ip_whitelist, ip_blacklist):
return False
-class IPBlacklistingResolver(object):
+_EPSILON = 0.00000001
+
+
+def _make_scheduler(reactor):
+ """Makes a schedular suitable for a Cooperator using the given reactor.
+
+ (This is effectively just a copy from `twisted.internet.task`)
+ """
+
+ def _scheduler(x):
+ return reactor.callLater(_EPSILON, x)
+
+ return _scheduler
+
+
+class IPBlacklistingResolver:
"""
A proxy for reactor.nameResolver which only produces non-blacklisted IP
addresses, preventing DNS rebinding attacks on URL preview.
@@ -118,7 +133,7 @@ class IPBlacklistingResolver(object):
r.resolutionComplete()
@provider(IResolutionReceiver)
- class EndpointReceiver(object):
+ class EndpointReceiver:
@staticmethod
def resolutionBegan(resolutionInProgress):
pass
@@ -177,7 +192,7 @@ class BlacklistingAgentWrapper(Agent):
)
-class SimpleHttpClient(object):
+class SimpleHttpClient:
"""
A simple, no-frills HTTP client with methods that wrap up common ways of
using HTTP in Matrix
@@ -214,6 +229,10 @@ class SimpleHttpClient(object):
if hs.config.user_agent_suffix:
self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix)
+ # We use this for our body producers to ensure that they use the correct
+ # reactor.
+ self._cooperator = Cooperator(scheduler=_make_scheduler(hs.get_reactor()))
+
self.user_agent = self.user_agent.encode("ascii")
if self._ip_blacklist:
@@ -225,7 +244,7 @@ class SimpleHttpClient(object):
)
@implementer(IReactorPluggableNameResolver)
- class Reactor(object):
+ class Reactor:
def __getattr__(_self, attr):
if attr == "nameResolver":
return nameResolver
@@ -266,8 +285,7 @@ class SimpleHttpClient(object):
ip_blacklist=self._ip_blacklist,
)
- @defer.inlineCallbacks
- def request(self, method, uri, data=None, headers=None):
+ async def request(self, method, uri, data=None, headers=None):
"""
Args:
method (str): HTTP method to use.
@@ -280,7 +298,7 @@ class SimpleHttpClient(object):
outgoing_requests_counter.labels(method).inc()
# log request but strip `access_token` (AS requests for example include this)
- logger.info("Sending request %s %s", method, redact_uri(uri))
+ logger.debug("Sending request %s %s", method, redact_uri(uri))
with start_active_span(
"outgoing-client-request",
@@ -294,7 +312,9 @@ class SimpleHttpClient(object):
try:
body_producer = None
if data is not None:
- body_producer = QuieterFileBodyProducer(BytesIO(data))
+ body_producer = QuieterFileBodyProducer(
+ BytesIO(data), cooperator=self._cooperator,
+ )
request_deferred = treq.request(
method,
@@ -310,7 +330,7 @@ class SimpleHttpClient(object):
self.hs.get_reactor(),
cancelled_to_request_timed_out_error,
)
- response = yield make_deferred_yieldable(request_deferred)
+ response = await make_deferred_yieldable(request_deferred)
incoming_responses_counter.labels(method, response.code).inc()
logger.info(
@@ -333,8 +353,7 @@ class SimpleHttpClient(object):
set_tag("error_reason", e.args[0])
raise
- @defer.inlineCallbacks
- def post_urlencoded_get_json(self, uri, args={}, headers=None):
+ async def post_urlencoded_get_json(self, uri, args={}, headers=None):
"""
Args:
uri (str):
@@ -343,7 +362,7 @@ class SimpleHttpClient(object):
header name to a list of values for that header
Returns:
- Deferred[object]: parsed json
+ object: parsed json
Raises:
HttpResponseException: On a non-2xx HTTP response.
@@ -366,19 +385,20 @@ class SimpleHttpClient(object):
if headers:
actual_headers.update(headers)
- response = yield self.request(
+ response = await self.request(
"POST", uri, headers=Headers(actual_headers), data=query_bytes
)
- body = yield make_deferred_yieldable(readBody(response))
+ body = await make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300:
- return json.loads(body)
+ return json_decoder.decode(body.decode("utf-8"))
else:
- raise HttpResponseException(response.code, response.phrase, body)
+ raise HttpResponseException(
+ response.code, response.phrase.decode("ascii", errors="replace"), body
+ )
- @defer.inlineCallbacks
- def post_json_get_json(self, uri, post_json, headers=None):
+ async def post_json_get_json(self, uri, post_json, headers=None):
"""
Args:
@@ -388,7 +408,7 @@ class SimpleHttpClient(object):
header name to a list of values for that header
Returns:
- Deferred[object]: parsed json
+ object: parsed json
Raises:
HttpResponseException: On a non-2xx HTTP response.
@@ -407,19 +427,20 @@ class SimpleHttpClient(object):
if headers:
actual_headers.update(headers)
- response = yield self.request(
+ response = await self.request(
"POST", uri, headers=Headers(actual_headers), data=json_str
)
- body = yield make_deferred_yieldable(readBody(response))
+ body = await make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300:
- return json.loads(body)
+ return json_decoder.decode(body.decode("utf-8"))
else:
- raise HttpResponseException(response.code, response.phrase, body)
+ raise HttpResponseException(
+ response.code, response.phrase.decode("ascii", errors="replace"), body
+ )
- @defer.inlineCallbacks
- def get_json(self, uri, args={}, headers=None):
+ async def get_json(self, uri, args={}, headers=None):
""" Gets some json from the given URI.
Args:
@@ -431,7 +452,7 @@ class SimpleHttpClient(object):
headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
header name to a list of values for that header
Returns:
- Deferred: Succeeds when we get *any* 2xx HTTP response, with the
+ Succeeds when we get *any* 2xx HTTP response, with the
HTTP body as JSON.
Raises:
HttpResponseException On a non-2xx HTTP response.
@@ -442,11 +463,10 @@ class SimpleHttpClient(object):
if headers:
actual_headers.update(headers)
- body = yield self.get_raw(uri, args, headers=headers)
- return json.loads(body)
+ body = await self.get_raw(uri, args, headers=headers)
+ return json_decoder.decode(body.decode("utf-8"))
- @defer.inlineCallbacks
- def put_json(self, uri, json_body, args={}, headers=None):
+ async def put_json(self, uri, json_body, args={}, headers=None):
""" Puts some json to the given URI.
Args:
@@ -459,7 +479,7 @@ class SimpleHttpClient(object):
headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
header name to a list of values for that header
Returns:
- Deferred: Succeeds when we get *any* 2xx HTTP response, with the
+ Succeeds when we get *any* 2xx HTTP response, with the
HTTP body as JSON.
Raises:
HttpResponseException On a non-2xx HTTP response.
@@ -480,19 +500,20 @@ class SimpleHttpClient(object):
if headers:
actual_headers.update(headers)
- response = yield self.request(
+ response = await self.request(
"PUT", uri, headers=Headers(actual_headers), data=json_str
)
- body = yield make_deferred_yieldable(readBody(response))
+ body = await make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300:
- return json.loads(body)
+ return json_decoder.decode(body.decode("utf-8"))
else:
- raise HttpResponseException(response.code, response.phrase, body)
+ raise HttpResponseException(
+ response.code, response.phrase.decode("ascii", errors="replace"), body
+ )
- @defer.inlineCallbacks
- def get_raw(self, uri, args={}, headers=None):
+ async def get_raw(self, uri, args={}, headers=None):
""" Gets raw text from the given URI.
Args:
@@ -504,8 +525,8 @@ class SimpleHttpClient(object):
headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
header name to a list of values for that header
Returns:
- Deferred: Succeeds when we get *any* 2xx HTTP response, with the
- HTTP body at text.
+ Succeeds when we get *any* 2xx HTTP response, with the
+ HTTP body as bytes.
Raises:
HttpResponseException on a non-2xx HTTP response.
"""
@@ -517,20 +538,21 @@ class SimpleHttpClient(object):
if headers:
actual_headers.update(headers)
- response = yield self.request("GET", uri, headers=Headers(actual_headers))
+ response = await self.request("GET", uri, headers=Headers(actual_headers))
- body = yield make_deferred_yieldable(readBody(response))
+ body = await make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300:
return body
else:
- raise HttpResponseException(response.code, response.phrase, body)
+ raise HttpResponseException(
+ response.code, response.phrase.decode("ascii", errors="replace"), body
+ )
# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
# The two should be factored out.
- @defer.inlineCallbacks
- def get_file(self, url, output_stream, max_size=None, headers=None):
+ async def get_file(self, url, output_stream, max_size=None, headers=None):
"""GETs a file from a given URL
Args:
url (str): The URL to GET
@@ -546,7 +568,7 @@ class SimpleHttpClient(object):
if headers:
actual_headers.update(headers)
- response = yield self.request("GET", url, headers=Headers(actual_headers))
+ response = await self.request("GET", url, headers=Headers(actual_headers))
resp_headers = dict(response.headers.getAllRawHeaders())
@@ -570,14 +592,14 @@ class SimpleHttpClient(object):
# straight back in again
try:
- length = yield make_deferred_yieldable(
+ length = await make_deferred_yieldable(
_readBodyToFile(response, output_stream, max_size)
)
except SynapseError:
# This can happen e.g. because the body is too large.
raise
except Exception as e:
- raise_from(SynapseError(502, ("Failed to download remote body: %s" % e)), e)
+ raise SynapseError(502, ("Failed to download remote body: %s" % e)) from e
return (
length,
@@ -638,7 +660,7 @@ def encode_urlencode_args(args):
def encode_urlencode_arg(arg):
- if isinstance(arg, text_type):
+ if isinstance(arg, str):
return arg.encode("utf-8")
elif isinstance(arg, list):
return [encode_urlencode_arg(i) for i in arg]
|