summary refs log tree commit diff
path: root/synapse/util/patch_inline_callbacks.py
blob: 812dc888350c550e3d197af26f7ce4d962508bf2 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import functools
import sys
from typing import Any, Callable, List

from twisted.internet import defer
from twisted.internet.defer import Deferred
from twisted.python.failure import Failure

# Tracks if we've already patched inlineCallbacks
_already_patched = False


def do_patch():
    """
    Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit
    """

    from synapse.logging.context import LoggingContext

    global _already_patched

    orig_inline_callbacks = defer.inlineCallbacks
    if _already_patched:
        return

    def new_inline_callbacks(f):
        @functools.wraps(f)
        def wrapped(*args, **kwargs):
            start_context = LoggingContext.current_context()
            changes = []  # type: List[str]
            orig = orig_inline_callbacks(_check_yield_points(f, changes))

            try:
                res = orig(*args, **kwargs)
            except Exception:
                if LoggingContext.current_context() != start_context:
                    for err in changes:
                        print(err, file=sys.stderr)

                    err = "%s changed context from %s to %s on exception" % (
                        f,
                        start_context,
                        LoggingContext.current_context(),
                    )
                    print(err, file=sys.stderr)
                    raise Exception(err)
                raise

            if not isinstance(res, Deferred) or res.called:
                if LoggingContext.current_context() != start_context:
                    for err in changes:
                        print(err, file=sys.stderr)

                    err = "Completed %s changed context from %s to %s" % (
                        f,
                        start_context,
                        LoggingContext.current_context(),
                    )
                    # print the error to stderr because otherwise all we
                    # see in travis-ci is the 500 error
                    print(err, file=sys.stderr)
                    raise Exception(err)
                return res

            if LoggingContext.current_context() != LoggingContext.sentinel:
                err = (
                    "%s returned incomplete deferred in non-sentinel context "
                    "%s (start was %s)"
                ) % (f, LoggingContext.current_context(), start_context)
                print(err, file=sys.stderr)
                raise Exception(err)

            def check_ctx(r):
                if LoggingContext.current_context() != start_context:
                    for err in changes:
                        print(err, file=sys.stderr)
                    err = "%s completion of %s changed context from %s to %s" % (
                        "Failure" if isinstance(r, Failure) else "Success",
                        f,
                        start_context,
                        LoggingContext.current_context(),
                    )
                    print(err, file=sys.stderr)
                    raise Exception(err)
                return r

            res.addBoth(check_ctx)
            return res

        return wrapped

    defer.inlineCallbacks = new_inline_callbacks
    _already_patched = True


def _check_yield_points(f: Callable, changes: List[str]):
    """Wraps a generator that is about to be passed to defer.inlineCallbacks
    checking that after every yield the log contexts are correct.

    Its perfectly valid for log contexts to change within a function, e.g. due
    to new Measure blocks, so such changes are added to the given `changes`
    list instead of triggering an exception.

    Args:
        f: generator function to wrap
        changes: A list of strings detailing how the contexts
            changed within a function.

    Returns:
        function
    """

    from synapse.logging.context import LoggingContext

    @functools.wraps(f)
    def check_yield_points_inner(*args, **kwargs):
        gen = f(*args, **kwargs)

        last_yield_line_no = gen.gi_frame.f_lineno
        result: Any = None
        while True:
            expected_context = LoggingContext.current_context()

            try:
                isFailure = isinstance(result, Failure)
                if isFailure:
                    d = result.throwExceptionIntoGenerator(gen)
                else:
                    d = gen.send(result)
            except (StopIteration, defer._DefGen_Return) as e:
                if LoggingContext.current_context() != expected_context:
                    # This happens when the context is lost sometime *after* the
                    # final yield and returning. E.g. we forgot to yield on a
                    # function that returns a deferred.
                    #
                    # We don't raise here as its perfectly valid for contexts to
                    # change in a function, as long as it sets the correct context
                    # on resolving (which is checked separately).
                    err = (
                        "Function %r returned and changed context from %s to %s,"
                        " in %s between %d and end of func"
                        % (
                            f.__qualname__,
                            expected_context,
                            LoggingContext.current_context(),
                            f.__code__.co_filename,
                            last_yield_line_no,
                        )
                    )
                    changes.append(err)
                return getattr(e, "value", None)

            frame = gen.gi_frame

            if isinstance(d, defer.Deferred) and not d.called:
                # This happens if we yield on a deferred that doesn't follow
                # the log context rules without wrappin in a `make_deferred_yieldable`.
                # We raise here as this should never happen.
                if LoggingContext.current_context() is not LoggingContext.sentinel:
                    err = (
                        "%s yielded with context %s rather than sentinel,"
                        " yielded on line %d in %s"
                        % (
                            frame.f_code.co_name,
                            LoggingContext.current_context(),
                            frame.f_lineno,
                            frame.f_code.co_filename,
                        )
                    )
                    raise Exception(err)

            try:
                result = yield d
            except Exception as e:
                result = Failure(e)

            if LoggingContext.current_context() != expected_context:

                # This happens because the context is lost sometime *after* the
                # previous yield and *after* the current yield. E.g. the
                # deferred we waited on didn't follow the rules, or we forgot to
                # yield on a function between the two yield points.
                #
                # We don't raise here as its perfectly valid for contexts to
                # change in a function, as long as it sets the correct context
                # on resolving (which is checked separately).
                err = (
                    "%s changed context from %s to %s, happened between lines %d and %d in %s"
                    % (
                        frame.f_code.co_name,
                        expected_context,
                        LoggingContext.current_context(),
                        last_yield_line_no,
                        frame.f_lineno,
                        frame.f_code.co_filename,
                    )
                )
                changes.append(err)

            last_yield_line_no = frame.f_lineno

    return check_yield_points_inner