diff --git a/synapse/http/client.py b/synapse/http/client.py
index ca2f770f5d..70a19d9b74 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,9 +17,12 @@ from OpenSSL import SSL
from OpenSSL.SSL import VERIFY_NONE
from synapse.api.errors import (
- CodeMessageException, SynapseError, Codes,
+ CodeMessageException, MatrixCodeMessageException, SynapseError, Codes,
)
-from synapse.util.logcontext import preserve_context_over_fn
+from synapse.http import cancelled_to_request_timed_out_error
+from synapse.util.async import add_timeout_to_deferred
+from synapse.util.caches import CACHE_SIZE_FACTOR
+from synapse.util.logcontext import make_deferred_yieldable
import synapse.metrics
from synapse.http.endpoint import SpiderEndpoint
@@ -29,13 +33,14 @@ from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.web.client import (
BrowserLikeRedirectAgent, ContentDecoderAgent, GzipDecoder, Agent,
readBody, PartialDownloadError,
+ HTTPConnectionPool,
)
from twisted.web.client import FileBodyProducer as TwistedFileBodyProducer
from twisted.web.http import PotentialDataLoss
from twisted.web.http_headers import Headers
from twisted.web._newclient import ResponseDone
-from StringIO import StringIO
+from six import StringIO
import simplejson as json
import logging
@@ -63,92 +68,139 @@ class SimpleHttpClient(object):
"""
def __init__(self, hs):
self.hs = hs
+
+ pool = HTTPConnectionPool(reactor)
+
+ # the pusher makes lots of concurrent SSL connections to sygnal, and
+ # tends to do so in batches, so we need to allow the pool to keep lots
+ # of idle connections around.
+ pool.maxPersistentPerHost = max((100 * CACHE_SIZE_FACTOR, 5))
+ pool.cachedConnectionTimeout = 2 * 60
+
# The default context factory in Twisted 14.0.0 (which we require) is
# BrowserLikePolicyForHTTPS which will do regular cert validation
# 'like a browser'
self.agent = Agent(
reactor,
connectTimeout=15,
- contextFactory=hs.get_http_client_context_factory()
+ contextFactory=hs.get_http_client_context_factory(),
+ pool=pool,
)
self.user_agent = hs.version_string
+ self.clock = hs.get_clock()
if hs.config.user_agent_suffix:
self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix,)
+ @defer.inlineCallbacks
def request(self, method, uri, *args, **kwargs):
# A small wrapper around self.agent.request() so we can easily attach
# counters to it
outgoing_requests_counter.inc(method)
- d = preserve_context_over_fn(
- self.agent.request,
- method, uri, *args, **kwargs
- )
logger.info("Sending request %s %s", method, uri)
- def _cb(response):
+ try:
+ request_deferred = self.agent.request(
+ method, uri, *args, **kwargs
+ )
+ add_timeout_to_deferred(
+ request_deferred,
+ 60, cancelled_to_request_timed_out_error,
+ )
+ response = yield make_deferred_yieldable(request_deferred)
+
incoming_responses_counter.inc(method, response.code)
logger.info(
"Received response to %s %s: %s",
method, uri, response.code
)
- return response
-
- def _eb(failure):
+ defer.returnValue(response)
+ except Exception as e:
incoming_responses_counter.inc(method, "ERR")
logger.info(
"Error sending request to %s %s: %s %s",
- method, uri, failure.type, failure.getErrorMessage()
+ method, uri, type(e).__name__, e.message
)
- return failure
+ raise e
- d.addCallbacks(_cb, _eb)
+ @defer.inlineCallbacks
+ def post_urlencoded_get_json(self, uri, args={}, headers=None):
+ """
+ Args:
+ uri (str):
+ args (dict[str, str|List[str]]): query params
+ headers (dict[str, List[str]]|None): If not None, a map from
+ header name to a list of values for that header
- return d
+ Returns:
+ Deferred[object]: parsed json
+ """
- @defer.inlineCallbacks
- def post_urlencoded_get_json(self, uri, args={}):
# TODO: Do we ever want to log message contents?
logger.debug("post_urlencoded_get_json args: %s", args)
query_bytes = urllib.urlencode(encode_urlencode_args(args), True)
+ actual_headers = {
+ b"Content-Type": [b"application/x-www-form-urlencoded"],
+ b"User-Agent": [self.user_agent],
+ }
+ if headers:
+ actual_headers.update(headers)
+
response = yield self.request(
"POST",
uri.encode("ascii"),
- headers=Headers({
- b"Content-Type": [b"application/x-www-form-urlencoded"],
- b"User-Agent": [self.user_agent],
- }),
+ headers=Headers(actual_headers),
bodyProducer=FileBodyProducer(StringIO(query_bytes))
)
- body = yield preserve_context_over_fn(readBody, response)
+ body = yield make_deferred_yieldable(readBody(response))
defer.returnValue(json.loads(body))
@defer.inlineCallbacks
- def post_json_get_json(self, uri, post_json):
+ def post_json_get_json(self, uri, post_json, headers=None):
+ """
+
+ Args:
+ uri (str):
+ post_json (object):
+ headers (dict[str, List[str]]|None): If not None, a map from
+ header name to a list of values for that header
+
+ Returns:
+ Deferred[object]: parsed json
+ """
json_str = encode_canonical_json(post_json)
logger.debug("HTTP POST %s -> %s", json_str, uri)
+ actual_headers = {
+ b"Content-Type": [b"application/json"],
+ b"User-Agent": [self.user_agent],
+ }
+ if headers:
+ actual_headers.update(headers)
+
response = yield self.request(
"POST",
uri.encode("ascii"),
- headers=Headers({
- b"Content-Type": [b"application/json"],
- b"User-Agent": [self.user_agent],
- }),
+ headers=Headers(actual_headers),
bodyProducer=FileBodyProducer(StringIO(json_str))
)
- body = yield preserve_context_over_fn(readBody, response)
+ body = yield make_deferred_yieldable(readBody(response))
+
+ if 200 <= response.code < 300:
+ defer.returnValue(json.loads(body))
+ else:
+ raise self._exceptionFromFailedRequest(response, body)
defer.returnValue(json.loads(body))
@defer.inlineCallbacks
- def get_json(self, uri, args={}):
+ def get_json(self, uri, args={}, headers=None):
""" Gets some json from the given URI.
Args:
@@ -157,6 +209,8 @@ class SimpleHttpClient(object):
None.
**Note**: The value of each key is assumed to be an iterable
and *not* a string.
+ headers (dict[str, List[str]]|None): If not None, a map from
+ header name to a list of values for that header
Returns:
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
HTTP body as JSON.
@@ -164,11 +218,14 @@ class SimpleHttpClient(object):
On a non-2xx HTTP response. The response body will be used as the
error message.
"""
- body = yield self.get_raw(uri, args)
- defer.returnValue(json.loads(body))
+ try:
+ body = yield self.get_raw(uri, args, headers=headers)
+ defer.returnValue(json.loads(body))
+ except CodeMessageException as e:
+ raise self._exceptionFromFailedRequest(e.code, e.msg)
@defer.inlineCallbacks
- def put_json(self, uri, json_body, args={}):
+ def put_json(self, uri, json_body, args={}, headers=None):
""" Puts some json to the given URI.
Args:
@@ -178,6 +235,8 @@ class SimpleHttpClient(object):
None.
**Note**: The value of each key is assumed to be an iterable
and *not* a string.
+ headers (dict[str, List[str]]|None): If not None, a map from
+ header name to a list of values for that header
Returns:
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
HTTP body as JSON.
@@ -190,17 +249,21 @@ class SimpleHttpClient(object):
json_str = encode_canonical_json(json_body)
+ actual_headers = {
+ b"Content-Type": [b"application/json"],
+ b"User-Agent": [self.user_agent],
+ }
+ if headers:
+ actual_headers.update(headers)
+
response = yield self.request(
"PUT",
uri.encode("ascii"),
- headers=Headers({
- b"User-Agent": [self.user_agent],
- "Content-Type": ["application/json"]
- }),
+ headers=Headers(actual_headers),
bodyProducer=FileBodyProducer(StringIO(json_str))
)
- body = yield preserve_context_over_fn(readBody, response)
+ body = yield make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300:
defer.returnValue(json.loads(body))
@@ -211,7 +274,7 @@ class SimpleHttpClient(object):
raise CodeMessageException(response.code, body)
@defer.inlineCallbacks
- def get_raw(self, uri, args={}):
+ def get_raw(self, uri, args={}, headers=None):
""" Gets raw text from the given URI.
Args:
@@ -220,6 +283,8 @@ class SimpleHttpClient(object):
None.
**Note**: The value of each key is assumed to be an iterable
and *not* a string.
+ headers (dict[str, List[str]]|None): If not None, a map from
+ header name to a list of values for that header
Returns:
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
HTTP body at text.
@@ -231,46 +296,65 @@ class SimpleHttpClient(object):
query_bytes = urllib.urlencode(args, True)
uri = "%s?%s" % (uri, query_bytes)
+ actual_headers = {
+ b"User-Agent": [self.user_agent],
+ }
+ if headers:
+ actual_headers.update(headers)
+
response = yield self.request(
"GET",
uri.encode("ascii"),
- headers=Headers({
- b"User-Agent": [self.user_agent],
- })
+ headers=Headers(actual_headers),
)
- body = yield preserve_context_over_fn(readBody, response)
+ body = yield make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300:
defer.returnValue(body)
else:
raise CodeMessageException(response.code, body)
+ def _exceptionFromFailedRequest(self, response, body):
+ try:
+ jsonBody = json.loads(body)
+ errcode = jsonBody['errcode']
+ error = jsonBody['error']
+ return MatrixCodeMessageException(response.code, error, errcode)
+ except (ValueError, KeyError):
+ return CodeMessageException(response.code, body)
+
# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
# The two should be factored out.
@defer.inlineCallbacks
- def get_file(self, url, output_stream, max_size=None):
+ def get_file(self, url, output_stream, max_size=None, headers=None):
"""GETs a file from a given URL
Args:
url (str): The URL to GET
output_stream (file): File to write the response body to.
+ headers (dict[str, List[str]]|None): If not None, a map from
+ header name to a list of values for that header
Returns:
A (int,dict,string,int) tuple of the file length, dict of the response
headers, absolute URI of the response and HTTP response code.
"""
+ actual_headers = {
+ b"User-Agent": [self.user_agent],
+ }
+ if headers:
+ actual_headers.update(headers)
+
response = yield self.request(
"GET",
url.encode("ascii"),
- headers=Headers({
- b"User-Agent": [self.user_agent],
- })
+ headers=Headers(actual_headers),
)
- headers = dict(response.headers.getAllRawHeaders())
+ resp_headers = dict(response.headers.getAllRawHeaders())
- if 'Content-Length' in headers and headers['Content-Length'] > max_size:
+ if 'Content-Length' in resp_headers and resp_headers['Content-Length'] > max_size:
logger.warn("Requested URL is too large > %r bytes" % (self.max_size,))
raise SynapseError(
502,
@@ -291,10 +375,9 @@ class SimpleHttpClient(object):
# straight back in again
try:
- length = yield preserve_context_over_fn(
- _readBodyToFile,
- response, output_stream, max_size
- )
+ length = yield make_deferred_yieldable(_readBodyToFile(
+ response, output_stream, max_size,
+ ))
except Exception as e:
logger.exception("Failed to download body")
raise SynapseError(
@@ -303,7 +386,9 @@ class SimpleHttpClient(object):
Codes.UNKNOWN,
)
- defer.returnValue((length, headers, response.request.absoluteURI, response.code))
+ defer.returnValue(
+ (length, resp_headers, response.request.absoluteURI, response.code),
+ )
# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
@@ -371,7 +456,7 @@ class CaptchaServerHttpClient(SimpleHttpClient):
)
try:
- body = yield preserve_context_over_fn(readBody, response)
+ body = yield make_deferred_yieldable(readBody(response))
defer.returnValue(body)
except PartialDownloadError as e:
# twisted dislikes google's response, no content length.
@@ -422,7 +507,7 @@ class SpiderHttpClient(SimpleHttpClient):
reactor,
SpiderEndpointFactory(hs)
)
- ), [('gzip', GzipDecoder)]
+ ), [(b'gzip', GzipDecoder)]
)
# We could look like Chrome:
# self.user_agent = ("Mozilla/5.0 (%s) (KHTML, like Gecko)
|