summary refs log tree commit diff
path: root/synapse/http/matrixfederationclient.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/http/matrixfederationclient.py69
1 files changed, 36 insertions, 33 deletions
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 9145405cb0..821aed362b 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -12,17 +13,19 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import synapse.util.retryutils
 from twisted.internet import defer, reactor, protocol
 from twisted.internet.error import DNSLookupError
 from twisted.web.client import readBody, HTTPConnectionPool, Agent
 from twisted.web.http_headers import Headers
 from twisted.web._newclient import ResponseDone
 
+from synapse.http import cancelled_to_request_timed_out_error
 from synapse.http.endpoint import matrix_federation_endpoint
-from synapse.util.async import sleep
-from synapse.util import logcontext
 import synapse.metrics
+from synapse.util.async import sleep, add_timeout_to_deferred
+from synapse.util import logcontext
+from synapse.util.logcontext import make_deferred_yieldable
+import synapse.util.retryutils
 
 from canonicaljson import encode_canonical_json
 
@@ -38,22 +41,19 @@ import logging
 import random
 import sys
 import urllib
-import urlparse
+from six.moves.urllib import parse as urlparse
+from six import string_types
+
 
+from prometheus_client import Counter
 
 logger = logging.getLogger(__name__)
 outbound_logger = logging.getLogger("synapse.http.outbound")
 
-metrics = synapse.metrics.get_metrics_for(__name__)
-
-outgoing_requests_counter = metrics.register_counter(
-    "requests",
-    labels=["method"],
-)
-incoming_responses_counter = metrics.register_counter(
-    "responses",
-    labels=["method", "code"],
-)
+outgoing_requests_counter = Counter("synapse_http_matrixfederationclient_requests",
+                                    "", ["method"])
+incoming_responses_counter = Counter("synapse_http_matrixfederationclient_responses",
+                                     "", ["method", "code"])
 
 
 MAX_LONG_RETRIES = 10
@@ -184,21 +184,20 @@ class MatrixFederationHttpClient(object):
                         producer = body_callback(method, http_url_bytes, headers_dict)
 
                     try:
-                        def send_request():
-                            request_deferred = self.agent.request(
-                                method,
-                                url_bytes,
-                                Headers(headers_dict),
-                                producer
-                            )
-
-                            return self.clock.time_bound_deferred(
-                                request_deferred,
-                                time_out=timeout / 1000. if timeout else 60,
-                            )
-
-                        with logcontext.PreserveLoggingContext():
-                            response = yield send_request()
+                        request_deferred = self.agent.request(
+                            method,
+                            url_bytes,
+                            Headers(headers_dict),
+                            producer
+                        )
+                        add_timeout_to_deferred(
+                            request_deferred,
+                            timeout / 1000. if timeout else 60,
+                            cancelled_to_request_timed_out_error,
+                        )
+                        response = yield make_deferred_yieldable(
+                            request_deferred,
+                        )
 
                         log_result = "%d %s" % (response.code, response.phrase,)
                         break
@@ -286,7 +285,8 @@ class MatrixFederationHttpClient(object):
         headers_dict[b"Authorization"] = auth_headers
 
     @defer.inlineCallbacks
-    def put_json(self, destination, path, data={}, json_data_callback=None,
+    def put_json(self, destination, path, args={}, data={},
+                 json_data_callback=None,
                  long_retries=False, timeout=None,
                  ignore_backoff=False,
                  backoff_on_404=False):
@@ -296,6 +296,7 @@ class MatrixFederationHttpClient(object):
             destination (str): The remote server to send the HTTP request
                 to.
             path (str): The HTTP path.
+            args (dict): query params
             data (dict): A dict containing the data that will be used as
                 the request body. This will be encoded as JSON.
             json_data_callback (callable): A callable returning the dict to
@@ -342,6 +343,7 @@ class MatrixFederationHttpClient(object):
             path,
             body_callback=body_callback,
             headers_dict={"Content-Type": ["application/json"]},
+            query_bytes=encode_query_args(args),
             long_retries=long_retries,
             timeout=timeout,
             ignore_backoff=ignore_backoff,
@@ -373,6 +375,7 @@ class MatrixFederationHttpClient(object):
                 giving up. None indicates no timeout.
             ignore_backoff (bool): true to ignore the historical backoff data and
                 try the request anyway.
+            args (dict): query params
         Returns:
             Deferred: Succeeds when we get a 2xx HTTP response. The result
             will be the decoded JSON body.
@@ -548,7 +551,7 @@ class MatrixFederationHttpClient(object):
 
         encoded_args = {}
         for k, vs in args.items():
-            if isinstance(vs, basestring):
+            if isinstance(vs, string_types):
                 vs = [vs]
             encoded_args[k] = [v.encode("UTF-8") for v in vs]
 
@@ -663,7 +666,7 @@ def check_content_type_is_json(headers):
         RuntimeError if the
 
     """
-    c_type = headers.getRawHeaders("Content-Type")
+    c_type = headers.getRawHeaders(b"Content-Type")
     if c_type is None:
         raise RuntimeError(
             "No Content-Type header"
@@ -680,7 +683,7 @@ def check_content_type_is_json(headers):
 def encode_query_args(args):
     encoded_args = {}
     for k, vs in args.items():
-        if isinstance(vs, basestring):
+        if isinstance(vs, string_types):
             vs = [vs]
         encoded_args[k] = [v.encode("UTF-8") for v in vs]