diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 7566d9eb33..133671e238 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
+from synapse.util.logcontext import PreserveLoggingContext
from twisted.internet import defer, reactor, task
@@ -61,10 +61,8 @@ class Clock(object):
*args: Postional arguments to pass to function.
**kwargs: Key arguments to pass to function.
"""
- current_context = LoggingContext.current_context()
-
def wrapped_callback(*args, **kwargs):
- with PreserveLoggingContext(current_context):
+ with PreserveLoggingContext():
callback(*args, **kwargs)
with PreserveLoggingContext():
diff --git a/synapse/util/async.py b/synapse/util/async.py
index 200edd404c..640fae3890 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -16,13 +16,16 @@
from twisted.internet import defer, reactor
-from .logcontext import preserve_context_over_deferred
+from .logcontext import PreserveLoggingContext
+@defer.inlineCallbacks
def sleep(seconds):
d = defer.Deferred()
- reactor.callLater(seconds, d.callback, seconds)
- return preserve_context_over_deferred(d)
+ with PreserveLoggingContext():
+ reactor.callLater(seconds, d.callback, seconds)
+ res = yield d
+ defer.returnValue(res)
def run_on_reactor():
@@ -54,6 +57,7 @@ class ObservableDeferred(object):
object.__setattr__(self, "_result", (True, r))
while self._observers:
try:
+ # TODO: Handle errors here.
self._observers.pop().callback(r)
except:
pass
@@ -63,6 +67,7 @@ class ObservableDeferred(object):
object.__setattr__(self, "_result", (False, f))
while self._observers:
try:
+ # TODO: Handle errors here.
self._observers.pop().errback(f)
except:
pass
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index e27917c63a..277854ccbc 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -18,6 +18,9 @@ from synapse.util.async import ObservableDeferred
from synapse.util import unwrapFirstError
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache
+from synapse.util.logcontext import (
+ PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn
+)
from . import caches_by_name, DEBUG_CACHES, cache_counter
@@ -190,7 +193,7 @@ class CacheDescriptor(object):
defer.returnValue(cached_result)
observer.addCallback(check_result)
- return observer
+ return preserve_context_over_deferred(observer)
except KeyError:
# Get the sequence number of the cache before reading from the
# database so that we can tell if the cache is invalidated
@@ -198,6 +201,7 @@ class CacheDescriptor(object):
sequence = self.cache.sequence
ret = defer.maybeDeferred(
+ preserve_context_over_fn,
self.function_to_call,
obj, *args, **kwargs
)
@@ -211,7 +215,7 @@ class CacheDescriptor(object):
ret = ObservableDeferred(ret, consumeErrors=True)
self.cache.update(sequence, cache_key, ret)
- return ret.observe()
+ return preserve_context_over_deferred(ret.observe())
wrapped.invalidate = self.cache.invalidate
wrapped.invalidate_all = self.cache.invalidate_all
@@ -299,6 +303,7 @@ class CacheListDescriptor(object):
args_to_call[self.list_name] = missing
ret_d = defer.maybeDeferred(
+ preserve_context_over_fn,
self.function_to_call,
**args_to_call
)
@@ -308,7 +313,8 @@ class CacheListDescriptor(object):
# We need to create deferreds for each arg in the list so that
# we can insert the new deferred into the cache.
for arg in missing:
- observer = ret_d.observe()
+ with PreserveLoggingContext():
+ observer = ret_d.observe()
observer.addCallback(lambda r, arg: r.get(arg, None), arg)
observer = ObservableDeferred(observer)
@@ -327,10 +333,10 @@ class CacheListDescriptor(object):
cached[arg] = res
- return defer.gatherResults(
+ return preserve_context_over_deferred(defer.gatherResults(
cached.values(),
consumeErrors=True,
- ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res))
+ ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res)))
obj.__dict__[self.orig.__name__] = wrapped
diff --git a/synapse/util/caches/snapshot_cache.py b/synapse/util/caches/snapshot_cache.py
index b1e40417fd..d03678b8c8 100644
--- a/synapse/util/caches/snapshot_cache.py
+++ b/synapse/util/caches/snapshot_cache.py
@@ -87,7 +87,8 @@ class SnapshotCache(object):
# expire from the rotation of that cache.
self.next_result_cache[key] = result
self.pending_result_cache.pop(key, None)
+ return r
- result.observe().addBoth(shuffle_along)
+ result.addBoth(shuffle_along)
return result.observe()
diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index 4ebfebf701..8875813de4 100644
--- a/synapse/util/distributor.py
+++ b/synapse/util/distributor.py
@@ -15,9 +15,7 @@
from twisted.internet import defer
-from synapse.util.logcontext import (
- PreserveLoggingContext, preserve_context_over_deferred,
-)
+from synapse.util.logcontext import PreserveLoggingContext
from synapse.util import unwrapFirstError
@@ -97,6 +95,7 @@ class Signal(object):
Each observer callable may return a Deferred."""
self.observers.append(observer)
+ @defer.inlineCallbacks
def fire(self, *args, **kwargs):
"""Invokes every callable in the observer list, passing in the args and
kwargs. Exceptions thrown by observers are logged but ignored. It is
@@ -116,6 +115,7 @@ class Signal(object):
failure.getTracebackObject()))
if not self.suppress_failures:
return failure
+
return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb)
with PreserveLoggingContext():
@@ -124,8 +124,11 @@ class Signal(object):
for observer in self.observers
]
- d = defer.gatherResults(deferreds, consumeErrors=True)
+ res = yield defer.gatherResults(
+ deferreds, consumeErrors=True
+ ).addErrback(unwrapFirstError)
- d.addErrback(unwrapFirstError)
+ defer.returnValue(res)
- return preserve_context_over_deferred(d)
+ def __repr__(self):
+ return "<Signal name=%r>" % (self.name,)
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index e701092cd8..9134e67908 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -48,7 +48,7 @@ class LoggingContext(object):
__slots__ = [
"parent_context", "name", "usage_start", "usage_end", "main_thread",
- "__dict__", "tag",
+ "__dict__", "tag", "alive",
]
thread_local = threading.local()
@@ -88,6 +88,7 @@ class LoggingContext(object):
self.usage_start = None
self.main_thread = threading.current_thread()
self.tag = ""
+ self.alive = True
def __str__(self):
return "%s@%x" % (self.name, id(self))
@@ -106,6 +107,7 @@ class LoggingContext(object):
The context that was previously active
"""
current = cls.current_context()
+
if current is not context:
current.stop()
cls.thread_local.current_context = context
@@ -117,6 +119,7 @@ class LoggingContext(object):
if self.parent_context is not None:
raise Exception("Attempt to enter logging context multiple times")
self.parent_context = self.set_current_context(self)
+ self.alive = True
return self
def __exit__(self, type, value, traceback):
@@ -136,6 +139,7 @@ class LoggingContext(object):
self
)
self.parent_context = None
+ self.alive = False
def __getattr__(self, name):
"""Delegate member lookup to parent context"""
@@ -213,7 +217,7 @@ class PreserveLoggingContext(object):
exited. Used to restore the context after a function using
@defer.inlineCallbacks is resumed by a callback from the reactor."""
- __slots__ = ["current_context", "new_context"]
+ __slots__ = ["current_context", "new_context", "has_parent"]
def __init__(self, new_context=LoggingContext.sentinel):
self.new_context = new_context
@@ -224,11 +228,26 @@ class PreserveLoggingContext(object):
self.new_context
)
+ if self.current_context:
+ self.has_parent = self.current_context.parent_context is not None
+ if not self.current_context.alive:
+ logger.warn(
+ "Entering dead context: %s",
+ self.current_context,
+ )
+
def __exit__(self, type, value, traceback):
"""Restores the current logging context"""
- LoggingContext.set_current_context(self.current_context)
+ context = LoggingContext.set_current_context(self.current_context)
+
+ if context != self.new_context:
+ logger.warn(
+ "Unexpected logging context: %s is not %s",
+ context, self.new_context,
+ )
+
if self.current_context is not LoggingContext.sentinel:
- if self.current_context.parent_context is None:
+ if not self.current_context.alive:
logger.warn(
"Restoring dead context: %s",
self.current_context,
@@ -289,3 +308,74 @@ def preserve_context_over_deferred(deferred):
d = _PreservingContextDeferred(current_context)
deferred.chainDeferred(d)
return d
+
+
+def preserve_fn(f):
+ """Ensures that function is called with correct context and that context is
+ restored after return. Useful for wrapping functions that return a deferred
+ which you don't yield on.
+ """
+ current = LoggingContext.current_context()
+
+ def g(*args, **kwargs):
+ with PreserveLoggingContext(current):
+ return f(*args, **kwargs)
+
+ return g
+
+
+# modules to ignore in `logcontext_tracer`
+_to_ignore = [
+ "synapse.util.logcontext",
+ "synapse.http.server",
+ "synapse.storage._base",
+ "synapse.util.async",
+]
+
+
+def logcontext_tracer(frame, event, arg):
+ """A tracer that logs whenever a logcontext "unexpectedly" changes within
+ a function. Probably inaccurate.
+
+ Use by calling `sys.settrace(logcontext_tracer)` in the main thread.
+ """
+ if event == 'call':
+ name = frame.f_globals["__name__"]
+ if name.startswith("synapse"):
+ if name == "synapse.util.logcontext":
+ if frame.f_code.co_name in ["__enter__", "__exit__"]:
+ tracer = frame.f_back.f_trace
+ if tracer:
+ tracer.just_changed = True
+
+ tracer = frame.f_trace
+ if tracer:
+ return tracer
+
+ if not any(name.startswith(ig) for ig in _to_ignore):
+ return LineTracer()
+
+
+class LineTracer(object):
+ __slots__ = ["context", "just_changed"]
+
+ def __init__(self):
+ self.context = LoggingContext.current_context()
+ self.just_changed = False
+
+ def __call__(self, frame, event, arg):
+ if event in 'line':
+ if self.just_changed:
+ self.context = LoggingContext.current_context()
+ self.just_changed = False
+ else:
+ c = LoggingContext.current_context()
+ if c != self.context:
+ logger.info(
+ "Context changed! %s -> %s, %s, %s",
+ self.context, c,
+ frame.f_code.co_filename, frame.f_lineno
+ )
+ self.context = c
+
+ return self
diff --git a/synapse/util/logutils.py b/synapse/util/logutils.py
index c37a157787..3a83828d25 100644
--- a/synapse/util/logutils.py
+++ b/synapse/util/logutils.py
@@ -168,3 +168,38 @@ def trace_function(f):
wrapped.__name__ = func_name
return wrapped
+
+
+def get_previous_frames():
+ s = inspect.currentframe().f_back.f_back
+ to_return = []
+ while s:
+ if s.f_globals["__name__"].startswith("synapse"):
+ filename, lineno, function, _, _ = inspect.getframeinfo(s)
+ args_string = inspect.formatargvalues(*inspect.getargvalues(s))
+
+ to_return.append("{{ %s:%d %s - Args: %s }}" % (
+ filename, lineno, function, args_string
+ ))
+
+ s = s.f_back
+
+ return ", ". join(to_return)
+
+
+def get_previous_frame(ignore=[]):
+ s = inspect.currentframe().f_back.f_back
+
+ while s:
+ if s.f_globals["__name__"].startswith("synapse"):
+ if not any(s.f_globals["__name__"].startswith(ig) for ig in ignore):
+ filename, lineno, function, _, _ = inspect.getframeinfo(s)
+ args_string = inspect.formatargvalues(*inspect.getargvalues(s))
+
+ return "{{ %s:%d %s - Args: %s }}" % (
+ filename, lineno, function, args_string
+ )
+
+ s = s.f_back
+
+ return None
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index daf6087fe0..ca48007218 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -68,16 +68,18 @@ class Measure(object):
block_timer.inc_by(duration, self.name)
context = LoggingContext.current_context()
- if not context:
- return
if context != self.start_context:
logger.warn(
- "Context have unexpectedly changed %r, %r",
- context, self.start_context
+ "Context have unexpectedly changed from '%s' to '%s'. (%r)",
+ context, self.start_context, self.name
)
return
+ if not context:
+ logger.warn("Expected context. (%r)", self.name)
+ return
+
ru_utime, ru_stime = context.get_resource_usage()
block_ru_utime.inc_by(ru_utime, self.name)
diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py
index ea321bc6a9..4076eed269 100644
--- a/synapse/util/ratelimitutils.py
+++ b/synapse/util/ratelimitutils.py
@@ -18,6 +18,7 @@ from twisted.internet import defer
from synapse.api.errors import LimitExceededError
from synapse.util.async import sleep
+from synapse.util.logcontext import preserve_fn
import collections
import contextlib
@@ -163,7 +164,7 @@ class _PerHostRatelimiter(object):
"Ratelimit [%s]: sleeping req",
id(request_id),
)
- ret_defer = sleep(self.sleep_msec / 1000.0)
+ ret_defer = preserve_fn(sleep)(self.sleep_msec / 1000.0)
self.sleeping_requests.add(request_id)
|