summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2019-10-10 13:47:50 +0100
committerGitHub <noreply@github.com>2019-10-10 13:47:50 +0100
commit933034e2feaae56640e4bbb7f58c85e4bbf720c9 (patch)
treea23b9fa75cfb4c2d754c785a2ef0ab1660d2e6dc
parentAdd domain validation when creating room with list of invitees (#6121) (diff)
parentFixup comments (diff)
downloadsynapse-933034e2feaae56640e4bbb7f58c85e4bbf720c9.tar.xz
Merge pull request #6127 from matrix-org/erikj/patch_inner
Add more log context checks when patching inlineCallbacks
-rw-r--r--changelog.d/6127.misc1
-rw-r--r--synapse/__init__.py8
-rw-r--r--synapse/storage/_base.py7
-rw-r--r--synapse/util/patch_inline_callbacks.py219
-rw-r--r--tests/__init__.py4
-rw-r--r--tests/patch_inline_callbacks.py94
6 files changed, 234 insertions, 99 deletions
diff --git a/changelog.d/6127.misc b/changelog.d/6127.misc
new file mode 100644
index 0000000000..7bfbcfc252
--- /dev/null
+++ b/changelog.d/6127.misc
@@ -0,0 +1 @@
+Add env var to turn on tracking of log context changes.
diff --git a/synapse/__init__.py b/synapse/__init__.py
index 2d52d26af5..56df3f5ac6 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -17,6 +17,7 @@
 """ This is a reference implementation of a Matrix home server.
 """
 
+import os
 import sys
 
 # Check that we're not running on an unsupported Python version.
@@ -36,3 +37,10 @@ except ImportError:
     pass
 
 __version__ = "1.4.0"
+
+if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
+    # We import here so that we don't have to install a bunch of deps when
+    # running the packaging tox test.
+    from synapse.util.patch_inline_callbacks import do_patch
+
+    do_patch()
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index abe16334ec..06cc14fcd1 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -30,7 +30,7 @@ from prometheus_client import Histogram
 from twisted.internet import defer
 
 from synapse.api.errors import StoreError
-from synapse.logging.context import LoggingContext, PreserveLoggingContext
+from synapse.logging.context import LoggingContext, make_deferred_yieldable
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 from synapse.types import get_domain_from_id
@@ -550,8 +550,9 @@ class SQLBaseStore(object):
 
                 return func(conn, *args, **kwargs)
 
-        with PreserveLoggingContext():
-            result = yield self._db_pool.runWithConnection(inner_func, *args, **kwargs)
+        result = yield make_deferred_yieldable(
+            self._db_pool.runWithConnection(inner_func, *args, **kwargs)
+        )
 
         return result
 
diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py
new file mode 100644
index 0000000000..3925927f9f
--- /dev/null
+++ b/synapse/util/patch_inline_callbacks.py
@@ -0,0 +1,219 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function
+
+import functools
+import sys
+from typing import Any, Callable, List
+
+from twisted.internet import defer
+from twisted.internet.defer import Deferred
+from twisted.python.failure import Failure
+
+# Tracks if we've already patched inlineCallbacks
+_already_patched = False
+
+
+def do_patch():
+    """
+    Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit
+    """
+
+    from synapse.logging.context import LoggingContext
+
+    global _already_patched
+
+    orig_inline_callbacks = defer.inlineCallbacks
+    if _already_patched:
+        return
+
+    def new_inline_callbacks(f):
+        @functools.wraps(f)
+        def wrapped(*args, **kwargs):
+            start_context = LoggingContext.current_context()
+            changes = []  # type: List[str]
+            orig = orig_inline_callbacks(_check_yield_points(f, changes))
+
+            try:
+                res = orig(*args, **kwargs)
+            except Exception:
+                if LoggingContext.current_context() != start_context:
+                    for err in changes:
+                        print(err, file=sys.stderr)
+
+                    err = "%s changed context from %s to %s on exception" % (
+                        f,
+                        start_context,
+                        LoggingContext.current_context(),
+                    )
+                    print(err, file=sys.stderr)
+                    raise Exception(err)
+                raise
+
+            if not isinstance(res, Deferred) or res.called:
+                if LoggingContext.current_context() != start_context:
+                    for err in changes:
+                        print(err, file=sys.stderr)
+
+                    err = "Completed %s changed context from %s to %s" % (
+                        f,
+                        start_context,
+                        LoggingContext.current_context(),
+                    )
+                    # print the error to stderr because otherwise all we
+                    # see in travis-ci is the 500 error
+                    print(err, file=sys.stderr)
+                    raise Exception(err)
+                return res
+
+            if LoggingContext.current_context() != LoggingContext.sentinel:
+                err = (
+                    "%s returned incomplete deferred in non-sentinel context "
+                    "%s (start was %s)"
+                ) % (f, LoggingContext.current_context(), start_context)
+                print(err, file=sys.stderr)
+                raise Exception(err)
+
+            def check_ctx(r):
+                if LoggingContext.current_context() != start_context:
+                    for err in changes:
+                        print(err, file=sys.stderr)
+                    err = "%s completion of %s changed context from %s to %s" % (
+                        "Failure" if isinstance(r, Failure) else "Success",
+                        f,
+                        start_context,
+                        LoggingContext.current_context(),
+                    )
+                    print(err, file=sys.stderr)
+                    raise Exception(err)
+                return r
+
+            res.addBoth(check_ctx)
+            return res
+
+        return wrapped
+
+    defer.inlineCallbacks = new_inline_callbacks
+    _already_patched = True
+
+
+def _check_yield_points(f: Callable, changes: List[str]):
+    """Wraps a generator that is about to be passed to defer.inlineCallbacks
+    checking that after every yield the log contexts are correct.
+
+    It's perfectly valid for log contexts to change within a function, e.g. due
+    to new Measure blocks, so such changes are added to the given `changes`
+    list instead of triggering an exception.
+
+    Args:
+        f: generator function to wrap
+        changes: A list of strings detailing how the contexts
+            changed within a function.
+
+    Returns:
+        function
+    """
+
+    from synapse.logging.context import LoggingContext
+
+    @functools.wraps(f)
+    def check_yield_points_inner(*args, **kwargs):
+        gen = f(*args, **kwargs)
+
+        last_yield_line_no = gen.gi_frame.f_lineno
+        result = None  # type: Any
+        while True:
+            expected_context = LoggingContext.current_context()
+
+            try:
+                isFailure = isinstance(result, Failure)
+                if isFailure:
+                    d = result.throwExceptionIntoGenerator(gen)
+                else:
+                    d = gen.send(result)
+            except (StopIteration, defer._DefGen_Return) as e:
+                if LoggingContext.current_context() != expected_context:
+                    # This happens when the context is lost sometime *after* the
+                    # final yield and returning. E.g. we forgot to yield on a
+                    # function that returns a deferred.
+                    #
+                    # We don't raise here as it's perfectly valid for contexts to
+                    # change in a function, as long as it sets the correct context
+                    # on resolving (which is checked separately).
+                    err = (
+                        "Function %r returned and changed context from %s to %s,"
+                        " in %s between %d and end of func"
+                        % (
+                            f.__qualname__,
+                            expected_context,
+                            LoggingContext.current_context(),
+                            f.__code__.co_filename,
+                            last_yield_line_no,
+                        )
+                    )
+                    changes.append(err)
+                return getattr(e, "value", None)
+
+            frame = gen.gi_frame
+
+            if isinstance(d, defer.Deferred) and not d.called:
+                # This happens if we yield on a deferred that doesn't follow
+                # the log context rules without wrapping in a `make_deferred_yieldable`.
+                # We raise here as this should never happen.
+                if LoggingContext.current_context() is not LoggingContext.sentinel:
+                    err = (
+                        "%s yielded with context %s rather than sentinel,"
+                        " yielded on line %d in %s"
+                        % (
+                            frame.f_code.co_name,
+                            LoggingContext.current_context(),
+                            frame.f_lineno,
+                            frame.f_code.co_filename,
+                        )
+                    )
+                    raise Exception(err)
+
+            try:
+                result = yield d
+            except Exception as e:
+                result = Failure(e)
+
+            if LoggingContext.current_context() != expected_context:
+
+                # This happens because the context is lost sometime *after* the
+                # previous yield and *after* the current yield. E.g. the
+                # deferred we waited on didn't follow the rules, or we forgot to
+                # yield on a function between the two yield points.
+                #
+                # We don't raise here as its perfectly valid for contexts to
+                # change in a function, as long as it sets the correct context
+                # on resolving (which is checked separately).
+                err = (
+                    "%s changed context from %s to %s, happened between lines %d and %d in %s"
+                    % (
+                        frame.f_code.co_name,
+                        expected_context,
+                        LoggingContext.current_context(),
+                        last_yield_line_no,
+                        frame.f_lineno,
+                        frame.f_code.co_filename,
+                    )
+                )
+                changes.append(err)
+
+            last_yield_line_no = frame.f_lineno
+
+    return check_yield_points_inner
diff --git a/tests/__init__.py b/tests/__init__.py
index f7fc502f01..ed805db1c2 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -16,9 +16,9 @@
 
 from twisted.trial import util
 
-import tests.patch_inline_callbacks
+from synapse.util.patch_inline_callbacks import do_patch
 
 # attempt to do the patch before we load any synapse code
-tests.patch_inline_callbacks.do_patch()
+do_patch()
 
 util.DEFAULT_TIMEOUT_DURATION = 20
diff --git a/tests/patch_inline_callbacks.py b/tests/patch_inline_callbacks.py
deleted file mode 100644
index 220884311c..0000000000
--- a/tests/patch_inline_callbacks.py
+++ /dev/null
@@ -1,94 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2018 New Vector Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from __future__ import print_function
-
-import functools
-import sys
-
-from twisted.internet import defer
-from twisted.internet.defer import Deferred
-from twisted.python.failure import Failure
-
-
-def do_patch():
-    """
-    Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit
-    """
-
-    from synapse.logging.context import LoggingContext
-
-    orig_inline_callbacks = defer.inlineCallbacks
-
-    def new_inline_callbacks(f):
-
-        orig = orig_inline_callbacks(f)
-
-        @functools.wraps(f)
-        def wrapped(*args, **kwargs):
-            start_context = LoggingContext.current_context()
-
-            try:
-                res = orig(*args, **kwargs)
-            except Exception:
-                if LoggingContext.current_context() != start_context:
-                    err = "%s changed context from %s to %s on exception" % (
-                        f,
-                        start_context,
-                        LoggingContext.current_context(),
-                    )
-                    print(err, file=sys.stderr)
-                    raise Exception(err)
-                raise
-
-            if not isinstance(res, Deferred) or res.called:
-                if LoggingContext.current_context() != start_context:
-                    err = "%s changed context from %s to %s" % (
-                        f,
-                        start_context,
-                        LoggingContext.current_context(),
-                    )
-                    # print the error to stderr because otherwise all we
-                    # see in travis-ci is the 500 error
-                    print(err, file=sys.stderr)
-                    raise Exception(err)
-                return res
-
-            if LoggingContext.current_context() != LoggingContext.sentinel:
-                err = (
-                    "%s returned incomplete deferred in non-sentinel context "
-                    "%s (start was %s)"
-                ) % (f, LoggingContext.current_context(), start_context)
-                print(err, file=sys.stderr)
-                raise Exception(err)
-
-            def check_ctx(r):
-                if LoggingContext.current_context() != start_context:
-                    err = "%s completion of %s changed context from %s to %s" % (
-                        "Failure" if isinstance(r, Failure) else "Success",
-                        f,
-                        start_context,
-                        LoggingContext.current_context(),
-                    )
-                    print(err, file=sys.stderr)
-                    raise Exception(err)
-                return r
-
-            res.addBoth(check_ctx)
-            return res
-
-        return wrapped
-
-    defer.inlineCallbacks = new_inline_callbacks