summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/http/client.py10
-rw-r--r--synapse/http/matrixfederationclient.py58
-rw-r--r--synapse/util/logcontext.py52
3 files changed, 68 insertions, 52 deletions
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 9091ae2d38..49737d55da 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -109,7 +109,7 @@ class SimpleHttpClient(object):
             bodyProducer=FileBodyProducer(StringIO(query_bytes))
         )
 
-        body = yield readBody(response)
+        body = yield preserve_context_over_fn(readBody, response)
 
         defer.returnValue(json.loads(body))
 
@@ -128,7 +128,7 @@ class SimpleHttpClient(object):
             bodyProducer=FileBodyProducer(StringIO(json_str))
         )
 
-        body = yield readBody(response)
+        body = yield preserve_context_over_fn(readBody, response)
 
         defer.returnValue(json.loads(body))
 
@@ -161,7 +161,7 @@ class SimpleHttpClient(object):
             })
         )
 
-        body = yield readBody(response)
+        body = yield preserve_context_over_fn(readBody, response)
 
         if 200 <= response.code < 300:
             defer.returnValue(json.loads(body))
@@ -204,7 +204,7 @@ class SimpleHttpClient(object):
             bodyProducer=FileBodyProducer(StringIO(json_str))
         )
 
-        body = yield readBody(response)
+        body = yield preserve_context_over_fn(readBody, response)
 
         if 200 <= response.code < 300:
             defer.returnValue(json.loads(body))
@@ -238,7 +238,7 @@ class CaptchaServerHttpClient(SimpleHttpClient):
         )
 
         try:
-            body = yield readBody(response)
+            body = yield preserve_context_over_fn(readBody, response)
             defer.returnValue(body)
         except PartialDownloadError as e:
             # twisted dislikes google's response, no content length.
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 1b90692731..ed47e701e7 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -127,7 +127,7 @@ class MatrixFederationHttpClient(object):
             ("", "", path_bytes, param_bytes, query_bytes, "",)
         )
 
-        txn_id = "%s-%s" % (method, self._next_id)
+        txn_id = "%s-O-%s" % (method, self._next_id)
         self._next_id = (self._next_id + 1) % (sys.maxint - 1)
 
         outbound_logger.info(
@@ -139,7 +139,9 @@ class MatrixFederationHttpClient(object):
         # (once we have reliable transactions in place)
         retries_left = 5
 
-        endpoint = self._getEndpoint(reactor, destination)
+        endpoint = preserve_context_over_fn(
+            self._getEndpoint, reactor, destination
+        )
 
         log_result = None
         try:
@@ -149,21 +151,25 @@ class MatrixFederationHttpClient(object):
                     producer = body_callback(method, url_bytes, headers_dict)
 
                 try:
-                    request_deferred = preserve_context_over_fn(
-                        self.agent.request,
-                        destination,
-                        endpoint,
-                        method,
-                        path_bytes,
-                        param_bytes,
-                        query_bytes,
-                        Headers(headers_dict),
-                        producer
-                    )
+                    def send_request():
+                        request_deferred = self.agent.request(
+                            destination,
+                            endpoint,
+                            method,
+                            path_bytes,
+                            param_bytes,
+                            query_bytes,
+                            Headers(headers_dict),
+                            producer
+                        )
+
+                        return self.clock.time_bound_deferred(
+                            request_deferred,
+                            time_out=timeout/1000. if timeout else 60,
+                        )
 
-                    response = yield self.clock.time_bound_deferred(
-                        request_deferred,
-                        time_out=timeout/1000. if timeout else 60,
+                    response = yield preserve_context_over_fn(
+                        send_request,
                     )
 
                     log_result = "%d %s" % (response.code, response.phrase,)
@@ -212,7 +218,7 @@ class MatrixFederationHttpClient(object):
         else:
             # :'(
             # Update transactions table?
-            body = yield readBody(response)
+            body = yield preserve_context_over_fn(readBody, response)
             raise HttpResponseException(
                 response.code, response.phrase, body
             )
@@ -292,10 +298,7 @@ class MatrixFederationHttpClient(object):
                     "Content-Type not application/json"
                 )
 
-        logger.debug("Getting resp body")
-        body = yield readBody(response)
-        logger.debug("Got resp body")
-
+        body = yield preserve_context_over_fn(readBody, response)
         defer.returnValue(json.loads(body))
 
     @defer.inlineCallbacks
@@ -338,9 +341,7 @@ class MatrixFederationHttpClient(object):
                     "Content-Type not application/json"
                 )
 
-        logger.debug("Getting resp body")
-        body = yield readBody(response)
-        logger.debug("Got resp body")
+        body = yield preserve_context_over_fn(readBody, response)
 
         defer.returnValue(json.loads(body))
 
@@ -398,9 +399,7 @@ class MatrixFederationHttpClient(object):
                     "Content-Type not application/json"
                 )
 
-        logger.debug("Getting resp body")
-        body = yield readBody(response)
-        logger.debug("Got resp body")
+        body = yield preserve_context_over_fn(readBody, response)
 
         defer.returnValue(json.loads(body))
 
@@ -443,7 +442,10 @@ class MatrixFederationHttpClient(object):
         headers = dict(response.headers.getAllRawHeaders())
 
         try:
-            length = yield _readBodyToFile(response, output_stream, max_size)
+            length = yield preserve_context_over_fn(
+                _readBodyToFile,
+                response, output_stream, max_size
+            )
         except:
             logger.exception("Failed to download body")
             raise
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index a92d518b43..7e6062c1b8 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -140,6 +140,37 @@ class PreserveLoggingContext(object):
                 )
 
 
+class _PreservingContextDeferred(defer.Deferred):
+    """A deferred that ensures that all callbacks and errbacks are called with
+    the given logging context.
+    """
+    def __init__(self, context):
+        self._log_context = context
+        defer.Deferred.__init__(self)
+
+    def addCallbacks(self, callback, errback=None,
+                     callbackArgs=None, callbackKeywords=None,
+                     errbackArgs=None, errbackKeywords=None):
+        callback = self._wrap_callback(callback)
+        errback = self._wrap_callback(errback)
+        return defer.Deferred.addCallbacks(
+            self, callback,
+            errback=errback,
+            callbackArgs=callbackArgs,
+            callbackKeywords=callbackKeywords,
+            errbackArgs=errbackArgs,
+            errbackKeywords=errbackKeywords,
+        )
+
+    def _wrap_callback(self, f):
+        def g(res, *args, **kwargs):
+            with PreserveLoggingContext():
+                LoggingContext.thread_local.current_context = self._log_context
+                res = f(res, *args, **kwargs)
+            return res
+        return g
+
+
 def preserve_context_over_fn(fn, *args, **kwargs):
     """Takes a function and invokes it with the given arguments, but removes
     and restores the current logging context while doing so.
@@ -160,24 +191,7 @@ def preserve_context_over_deferred(deferred):
     """Given a deferred wrap it such that any callbacks added later to it will
     be invoked with the current context.
     """
-    d = defer.Deferred()
-
     current_context = LoggingContext.current_context()
-
-    def cb(res):
-        with PreserveLoggingContext():
-            LoggingContext.thread_local.current_context = current_context
-            res = d.callback(res)
-        return res
-
-    def eb(failure):
-        with PreserveLoggingContext():
-            LoggingContext.thread_local.current_context = current_context
-            res = d.errback(failure)
-        return res
-
-    if deferred.called:
-        return deferred
-
-    deferred.addCallbacks(cb, eb)
+    d = _PreservingContextDeferred(current_context)
+    deferred.chainDeferred(d)
     return d