diff --git a/synapse/http/client.py b/synapse/http/client.py
index 9091ae2d38..49737d55da 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -109,7 +109,7 @@ class SimpleHttpClient(object):
bodyProducer=FileBodyProducer(StringIO(query_bytes))
)
- body = yield readBody(response)
+ body = yield preserve_context_over_fn(readBody, response)
defer.returnValue(json.loads(body))
@@ -128,7 +128,7 @@ class SimpleHttpClient(object):
bodyProducer=FileBodyProducer(StringIO(json_str))
)
- body = yield readBody(response)
+ body = yield preserve_context_over_fn(readBody, response)
defer.returnValue(json.loads(body))
@@ -161,7 +161,7 @@ class SimpleHttpClient(object):
})
)
- body = yield readBody(response)
+ body = yield preserve_context_over_fn(readBody, response)
if 200 <= response.code < 300:
defer.returnValue(json.loads(body))
@@ -204,7 +204,7 @@ class SimpleHttpClient(object):
bodyProducer=FileBodyProducer(StringIO(json_str))
)
- body = yield readBody(response)
+ body = yield preserve_context_over_fn(readBody, response)
if 200 <= response.code < 300:
defer.returnValue(json.loads(body))
@@ -238,7 +238,7 @@ class CaptchaServerHttpClient(SimpleHttpClient):
)
try:
- body = yield readBody(response)
+ body = yield preserve_context_over_fn(readBody, response)
defer.returnValue(body)
except PartialDownloadError as e:
# twisted dislikes google's response, no content length.
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 1b90692731..ed47e701e7 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -127,7 +127,7 @@ class MatrixFederationHttpClient(object):
("", "", path_bytes, param_bytes, query_bytes, "",)
)
- txn_id = "%s-%s" % (method, self._next_id)
+ txn_id = "%s-O-%s" % (method, self._next_id)
self._next_id = (self._next_id + 1) % (sys.maxint - 1)
outbound_logger.info(
@@ -139,7 +139,9 @@ class MatrixFederationHttpClient(object):
# (once we have reliable transactions in place)
retries_left = 5
- endpoint = self._getEndpoint(reactor, destination)
+ endpoint = preserve_context_over_fn(
+ self._getEndpoint, reactor, destination
+ )
log_result = None
try:
@@ -149,21 +151,25 @@ class MatrixFederationHttpClient(object):
producer = body_callback(method, url_bytes, headers_dict)
try:
- request_deferred = preserve_context_over_fn(
- self.agent.request,
- destination,
- endpoint,
- method,
- path_bytes,
- param_bytes,
- query_bytes,
- Headers(headers_dict),
- producer
- )
+ def send_request():
+ request_deferred = self.agent.request(
+ destination,
+ endpoint,
+ method,
+ path_bytes,
+ param_bytes,
+ query_bytes,
+ Headers(headers_dict),
+ producer
+ )
+
+ return self.clock.time_bound_deferred(
+ request_deferred,
+ time_out=timeout/1000. if timeout else 60,
+ )
- response = yield self.clock.time_bound_deferred(
- request_deferred,
- time_out=timeout/1000. if timeout else 60,
+ response = yield preserve_context_over_fn(
+ send_request,
)
log_result = "%d %s" % (response.code, response.phrase,)
@@ -212,7 +218,7 @@ class MatrixFederationHttpClient(object):
else:
# :'(
# Update transactions table?
- body = yield readBody(response)
+ body = yield preserve_context_over_fn(readBody, response)
raise HttpResponseException(
response.code, response.phrase, body
)
@@ -292,10 +298,7 @@ class MatrixFederationHttpClient(object):
"Content-Type not application/json"
)
- logger.debug("Getting resp body")
- body = yield readBody(response)
- logger.debug("Got resp body")
-
+ body = yield preserve_context_over_fn(readBody, response)
defer.returnValue(json.loads(body))
@defer.inlineCallbacks
@@ -338,9 +341,7 @@ class MatrixFederationHttpClient(object):
"Content-Type not application/json"
)
- logger.debug("Getting resp body")
- body = yield readBody(response)
- logger.debug("Got resp body")
+ body = yield preserve_context_over_fn(readBody, response)
defer.returnValue(json.loads(body))
@@ -398,9 +399,7 @@ class MatrixFederationHttpClient(object):
"Content-Type not application/json"
)
- logger.debug("Getting resp body")
- body = yield readBody(response)
- logger.debug("Got resp body")
+ body = yield preserve_context_over_fn(readBody, response)
defer.returnValue(json.loads(body))
@@ -443,7 +442,10 @@ class MatrixFederationHttpClient(object):
headers = dict(response.headers.getAllRawHeaders())
try:
- length = yield _readBodyToFile(response, output_stream, max_size)
+ length = yield preserve_context_over_fn(
+ _readBodyToFile,
+ response, output_stream, max_size
+ )
except:
logger.exception("Failed to download body")
raise
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index a92d518b43..7e6062c1b8 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -140,6 +140,37 @@ class PreserveLoggingContext(object):
)
+class _PreservingContextDeferred(defer.Deferred):
+ """A deferred that ensures that all callbacks and errbacks are called with
+ the given logging context.
+ """
+ def __init__(self, context):
+ self._log_context = context
+ defer.Deferred.__init__(self)
+
+ def addCallbacks(self, callback, errback=None,
+ callbackArgs=None, callbackKeywords=None,
+ errbackArgs=None, errbackKeywords=None):
+ callback = self._wrap_callback(callback)
+ errback = self._wrap_callback(errback)
+ return defer.Deferred.addCallbacks(
+ self, callback,
+ errback=errback,
+ callbackArgs=callbackArgs,
+ callbackKeywords=callbackKeywords,
+ errbackArgs=errbackArgs,
+ errbackKeywords=errbackKeywords,
+ )
+
+ def _wrap_callback(self, f):
+ def g(res, *args, **kwargs):
+ with PreserveLoggingContext():
+ LoggingContext.thread_local.current_context = self._log_context
+ res = f(res, *args, **kwargs)
+ return res
+ return g
+
+
def preserve_context_over_fn(fn, *args, **kwargs):
"""Takes a function and invokes it with the given arguments, but removes
and restores the current logging context while doing so.
@@ -160,24 +191,7 @@ def preserve_context_over_deferred(deferred):
"""Given a deferred wrap it such that any callbacks added later to it will
be invoked with the current context.
"""
- d = defer.Deferred()
-
current_context = LoggingContext.current_context()
-
- def cb(res):
- with PreserveLoggingContext():
- LoggingContext.thread_local.current_context = current_context
- res = d.callback(res)
- return res
-
- def eb(failure):
- with PreserveLoggingContext():
- LoggingContext.thread_local.current_context = current_context
- res = d.errback(failure)
- return res
-
- if deferred.called:
- return deferred
-
- deferred.addCallbacks(cb, eb)
+ d = _PreservingContextDeferred(current_context)
+ deferred.chainDeferred(d)
return d
|