diff --git a/changelog.d/6127.misc b/changelog.d/6127.misc
new file mode 100644
index 0000000000..7bfbcfc252
--- /dev/null
+++ b/changelog.d/6127.misc
@@ -0,0 +1 @@
+Add env var to turn on tracking of log context changes.
diff --git a/synapse/__init__.py b/synapse/__init__.py
index 2d52d26af5..56df3f5ac6 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -17,6 +17,7 @@
""" This is a reference implementation of a Matrix home server.
"""
+import os
import sys
# Check that we're not running on an unsupported Python version.
@@ -36,3 +37,10 @@ except ImportError:
pass
__version__ = "1.4.0"
+
+if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
+ # We import here so that we don't have to install a bunch of deps when
+ # running the packaging tox test.
+ from synapse.util.patch_inline_callbacks import do_patch
+
+ do_patch()
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index abe16334ec..06cc14fcd1 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -30,7 +30,7 @@ from prometheus_client import Histogram
from twisted.internet import defer
from synapse.api.errors import StoreError
-from synapse.logging.context import LoggingContext, PreserveLoggingContext
+from synapse.logging.context import LoggingContext, make_deferred_yieldable
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.types import get_domain_from_id
@@ -550,8 +550,9 @@ class SQLBaseStore(object):
return func(conn, *args, **kwargs)
- with PreserveLoggingContext():
- result = yield self._db_pool.runWithConnection(inner_func, *args, **kwargs)
+ result = yield make_deferred_yieldable(
+ self._db_pool.runWithConnection(inner_func, *args, **kwargs)
+ )
return result
diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py
new file mode 100644
index 0000000000..3925927f9f
--- /dev/null
+++ b/synapse/util/patch_inline_callbacks.py
@@ -0,0 +1,219 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function
+
+import functools
+import sys
+from typing import Any, Callable, List
+
+from twisted.internet import defer
+from twisted.internet.defer import Deferred
+from twisted.python.failure import Failure
+
+# Tracks if we've already patched inlineCallbacks
+_already_patched = False
+
+
+def do_patch():
+ """
+ Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit
+ """
+
+ from synapse.logging.context import LoggingContext
+
+ global _already_patched
+
+ orig_inline_callbacks = defer.inlineCallbacks
+ if _already_patched:
+ return
+
+ def new_inline_callbacks(f):
+ @functools.wraps(f)
+ def wrapped(*args, **kwargs):
+ start_context = LoggingContext.current_context()
+ changes = [] # type: List[str]
+ orig = orig_inline_callbacks(_check_yield_points(f, changes))
+
+ try:
+ res = orig(*args, **kwargs)
+ except Exception:
+ if LoggingContext.current_context() != start_context:
+ for err in changes:
+ print(err, file=sys.stderr)
+
+ err = "%s changed context from %s to %s on exception" % (
+ f,
+ start_context,
+ LoggingContext.current_context(),
+ )
+ print(err, file=sys.stderr)
+ raise Exception(err)
+ raise
+
+ if not isinstance(res, Deferred) or res.called:
+ if LoggingContext.current_context() != start_context:
+ for err in changes:
+ print(err, file=sys.stderr)
+
+ err = "Completed %s changed context from %s to %s" % (
+ f,
+ start_context,
+ LoggingContext.current_context(),
+ )
+ # print the error to stderr because otherwise all we
+ # see in travis-ci is the 500 error
+ print(err, file=sys.stderr)
+ raise Exception(err)
+ return res
+
+ if LoggingContext.current_context() != LoggingContext.sentinel:
+ err = (
+ "%s returned incomplete deferred in non-sentinel context "
+ "%s (start was %s)"
+ ) % (f, LoggingContext.current_context(), start_context)
+ print(err, file=sys.stderr)
+ raise Exception(err)
+
+ def check_ctx(r):
+ if LoggingContext.current_context() != start_context:
+ for err in changes:
+ print(err, file=sys.stderr)
+ err = "%s completion of %s changed context from %s to %s" % (
+ "Failure" if isinstance(r, Failure) else "Success",
+ f,
+ start_context,
+ LoggingContext.current_context(),
+ )
+ print(err, file=sys.stderr)
+ raise Exception(err)
+ return r
+
+ res.addBoth(check_ctx)
+ return res
+
+ return wrapped
+
+ defer.inlineCallbacks = new_inline_callbacks
+ _already_patched = True
+
+
+def _check_yield_points(f: Callable, changes: List[str]):
+ """Wraps a generator that is about to be passed to defer.inlineCallbacks
+ checking that after every yield the log contexts are correct.
+
+ It's perfectly valid for log contexts to change within a function, e.g. due
+ to new Measure blocks, so such changes are added to the given `changes`
+ list instead of triggering an exception.
+
+ Args:
+ f: generator function to wrap
+ changes: A list of strings detailing how the contexts
+ changed within a function.
+
+ Returns:
+ function
+ """
+
+ from synapse.logging.context import LoggingContext
+
+ @functools.wraps(f)
+ def check_yield_points_inner(*args, **kwargs):
+ gen = f(*args, **kwargs)
+
+ last_yield_line_no = gen.gi_frame.f_lineno
+ result = None # type: Any
+ while True:
+ expected_context = LoggingContext.current_context()
+
+ try:
+ isFailure = isinstance(result, Failure)
+ if isFailure:
+ d = result.throwExceptionIntoGenerator(gen)
+ else:
+ d = gen.send(result)
+ except (StopIteration, defer._DefGen_Return) as e:
+ if LoggingContext.current_context() != expected_context:
+ # This happens when the context is lost sometime *after* the
+ # final yield and returning. E.g. we forgot to yield on a
+ # function that returns a deferred.
+ #
+ # We don't raise here as it's perfectly valid for contexts to
+ # change in a function, as long as it sets the correct context
+ # on resolving (which is checked separately).
+ err = (
+ "Function %r returned and changed context from %s to %s,"
+ " in %s between %d and end of func"
+ % (
+ f.__qualname__,
+ expected_context,
+ LoggingContext.current_context(),
+ f.__code__.co_filename,
+ last_yield_line_no,
+ )
+ )
+ changes.append(err)
+ return getattr(e, "value", None)
+
+ frame = gen.gi_frame
+
+ if isinstance(d, defer.Deferred) and not d.called:
+ # This happens if we yield on a deferred that doesn't follow
+ # the log context rules without wrapping in a `make_deferred_yieldable`.
+ # We raise here as this should never happen.
+ if LoggingContext.current_context() is not LoggingContext.sentinel:
+ err = (
+ "%s yielded with context %s rather than sentinel,"
+ " yielded on line %d in %s"
+ % (
+ frame.f_code.co_name,
+ LoggingContext.current_context(),
+ frame.f_lineno,
+ frame.f_code.co_filename,
+ )
+ )
+ raise Exception(err)
+
+ try:
+ result = yield d
+ except Exception as e:
+ result = Failure(e)
+
+ if LoggingContext.current_context() != expected_context:
+
+ # This happens because the context is lost sometime *after* the
+ # previous yield and *after* the current yield. E.g. the
+ # deferred we waited on didn't follow the rules, or we forgot to
+ # yield on a function between the two yield points.
+ #
+ # We don't raise here as its perfectly valid for contexts to
+ # change in a function, as long as it sets the correct context
+ # on resolving (which is checked separately).
+ err = (
+ "%s changed context from %s to %s, happened between lines %d and %d in %s"
+ % (
+ frame.f_code.co_name,
+ expected_context,
+ LoggingContext.current_context(),
+ last_yield_line_no,
+ frame.f_lineno,
+ frame.f_code.co_filename,
+ )
+ )
+ changes.append(err)
+
+ last_yield_line_no = frame.f_lineno
+
+ return check_yield_points_inner
diff --git a/tests/__init__.py b/tests/__init__.py
index f7fc502f01..ed805db1c2 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -16,9 +16,9 @@
from twisted.trial import util
-import tests.patch_inline_callbacks
+from synapse.util.patch_inline_callbacks import do_patch
# attempt to do the patch before we load any synapse code
-tests.patch_inline_callbacks.do_patch()
+do_patch()
util.DEFAULT_TIMEOUT_DURATION = 20
diff --git a/tests/patch_inline_callbacks.py b/tests/patch_inline_callbacks.py
deleted file mode 100644
index 220884311c..0000000000
--- a/tests/patch_inline_callbacks.py
+++ /dev/null
@@ -1,94 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2018 New Vector Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from __future__ import print_function
-
-import functools
-import sys
-
-from twisted.internet import defer
-from twisted.internet.defer import Deferred
-from twisted.python.failure import Failure
-
-
-def do_patch():
- """
- Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit
- """
-
- from synapse.logging.context import LoggingContext
-
- orig_inline_callbacks = defer.inlineCallbacks
-
- def new_inline_callbacks(f):
-
- orig = orig_inline_callbacks(f)
-
- @functools.wraps(f)
- def wrapped(*args, **kwargs):
- start_context = LoggingContext.current_context()
-
- try:
- res = orig(*args, **kwargs)
- except Exception:
- if LoggingContext.current_context() != start_context:
- err = "%s changed context from %s to %s on exception" % (
- f,
- start_context,
- LoggingContext.current_context(),
- )
- print(err, file=sys.stderr)
- raise Exception(err)
- raise
-
- if not isinstance(res, Deferred) or res.called:
- if LoggingContext.current_context() != start_context:
- err = "%s changed context from %s to %s" % (
- f,
- start_context,
- LoggingContext.current_context(),
- )
- # print the error to stderr because otherwise all we
- # see in travis-ci is the 500 error
- print(err, file=sys.stderr)
- raise Exception(err)
- return res
-
- if LoggingContext.current_context() != LoggingContext.sentinel:
- err = (
- "%s returned incomplete deferred in non-sentinel context "
- "%s (start was %s)"
- ) % (f, LoggingContext.current_context(), start_context)
- print(err, file=sys.stderr)
- raise Exception(err)
-
- def check_ctx(r):
- if LoggingContext.current_context() != start_context:
- err = "%s completion of %s changed context from %s to %s" % (
- "Failure" if isinstance(r, Failure) else "Success",
- f,
- start_context,
- LoggingContext.current_context(),
- )
- print(err, file=sys.stderr)
- raise Exception(err)
- return r
-
- res.addBoth(check_ctx)
- return res
-
- return wrapped
-
- defer.inlineCallbacks = new_inline_callbacks
|