summary refs log tree commit diff
path: root/synapse/util/patch_inline_callbacks.py
blob: 46dad32156b139f3edffb2e21be16d308cda2143 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
#

import functools
import sys
from typing import Any, Callable, Generator, List, TypeVar, cast

from typing_extensions import ParamSpec

from twisted.internet import defer
from twisted.internet.defer import Deferred
from twisted.python.failure import Failure

# Tracks if we've already patched inlineCallbacks
_already_patched = False


T = TypeVar("T")
P = ParamSpec("P")


def do_patch() -> None:
    """
    Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit
    """

    from synapse.logging.context import current_context

    global _already_patched

    orig_inline_callbacks = defer.inlineCallbacks
    if _already_patched:
        return

    def new_inline_callbacks(
        f: Callable[P, Generator["Deferred[object]", object, T]]
    ) -> Callable[P, "Deferred[T]"]:
        @functools.wraps(f)
        def wrapped(*args: P.args, **kwargs: P.kwargs) -> "Deferred[T]":
            start_context = current_context()
            changes: List[str] = []
            orig: Callable[P, "Deferred[T]"] = orig_inline_callbacks(
                _check_yield_points(f, changes)
            )

            try:
                res: "Deferred[T]" = orig(*args, **kwargs)
            except Exception:
                if current_context() != start_context:
                    for err in changes:
                        print(err, file=sys.stderr)

                    err = "%s changed context from %s to %s on exception" % (
                        f,
                        start_context,
                        current_context(),
                    )
                    print(err, file=sys.stderr)
                    raise Exception(err)
                raise

            if not isinstance(res, Deferred) or res.called:
                if current_context() != start_context:
                    for err in changes:
                        print(err, file=sys.stderr)

                    err = "Completed %s changed context from %s to %s" % (
                        f,
                        start_context,
                        current_context(),
                    )
                    # print the error to stderr because otherwise all we
                    # see in travis-ci is the 500 error
                    print(err, file=sys.stderr)
                    raise Exception(err)
                return res

            if current_context():
                err = (
                    "%s returned incomplete deferred in non-sentinel context "
                    "%s (start was %s)"
                ) % (f, current_context(), start_context)
                print(err, file=sys.stderr)
                raise Exception(err)

            def check_ctx(r: T) -> T:
                if current_context() != start_context:
                    for err in changes:
                        print(err, file=sys.stderr)
                    err = "%s completion of %s changed context from %s to %s" % (
                        "Failure" if isinstance(r, Failure) else "Success",
                        f,
                        start_context,
                        current_context(),
                    )
                    print(err, file=sys.stderr)
                    raise Exception(err)
                return r

            res.addBoth(check_ctx)
            return res

        return wrapped

    defer.inlineCallbacks = new_inline_callbacks
    _already_patched = True


def _check_yield_points(
    f: Callable[P, Generator["Deferred[object]", object, T]],
    changes: List[str],
) -> Callable:
    """Wraps a generator that is about to be passed to defer.inlineCallbacks
    checking that after every yield the log contexts are correct.

    It's perfectly valid for log contexts to change within a function, e.g. due
    to new Measure blocks, so such changes are added to the given `changes`
    list instead of triggering an exception.

    Args:
        f: generator function to wrap
        changes: A list of strings detailing how the contexts
            changed within a function.

    Returns:
        function
    """

    from synapse.logging.context import current_context

    @functools.wraps(f)
    def check_yield_points_inner(
        *args: P.args, **kwargs: P.kwargs
    ) -> Generator["Deferred[object]", object, T]:
        gen = f(*args, **kwargs)

        last_yield_line_no = gen.gi_frame.f_lineno
        result: Any = None
        while True:
            expected_context = current_context()

            try:
                isFailure = isinstance(result, Failure)
                if isFailure:
                    d = result.throwExceptionIntoGenerator(gen)
                else:
                    d = gen.send(result)
            except (StopIteration, defer._DefGen_Return) as e:
                if current_context() != expected_context:
                    # This happens when the context is lost sometime *after* the
                    # final yield and returning. E.g. we forgot to yield on a
                    # function that returns a deferred.
                    #
                    # We don't raise here as it's perfectly valid for contexts to
                    # change in a function, as long as it sets the correct context
                    # on resolving (which is checked separately).
                    err = (
                        "Function %r returned and changed context from %s to %s,"
                        " in %s between %d and end of func"
                        % (
                            f.__qualname__,
                            expected_context,
                            current_context(),
                            f.__code__.co_filename,
                            last_yield_line_no,
                        )
                    )
                    changes.append(err)
                # The `StopIteration` or `_DefGen_Return` contains the return value from the
                # generator.
                return cast(T, e.value)

            frame = gen.gi_frame

            if isinstance(d, defer.Deferred) and not d.called:
                # This happens if we yield on a deferred that doesn't follow
                # the log context rules without wrapping in a `make_deferred_yieldable`.
                # We raise here as this should never happen.
                if current_context():
                    err = (
                        "%s yielded with context %s rather than sentinel,"
                        " yielded on line %d in %s"
                        % (
                            frame.f_code.co_name,
                            current_context(),
                            frame.f_lineno,
                            frame.f_code.co_filename,
                        )
                    )
                    raise Exception(err)

            # the wrapped function yielded a Deferred: yield it back up to the parent
            # inlineCallbacks().
            try:
                result = yield d
            except Exception:
                # this will fish an earlier Failure out of the stack where possible, and
                # thus is preferable to passing in an exception to the Failure
                # constructor, since it results in less stack-mangling.
                result = Failure()

            if current_context() != expected_context:
                # This happens because the context is lost sometime *after* the
                # previous yield and *after* the current yield. E.g. the
                # deferred we waited on didn't follow the rules, or we forgot to
                # yield on a function between the two yield points.
                #
                # We don't raise here as its perfectly valid for contexts to
                # change in a function, as long as it sets the correct context
                # on resolving (which is checked separately).
                err = (
                    "%s changed context from %s to %s, happened between lines %d and %d in %s"
                    % (
                        frame.f_code.co_name,
                        expected_context,
                        current_context(),
                        last_yield_line_no,
                        frame.f_lineno,
                        frame.f_code.co_filename,
                    )
                )
                changes.append(err)

            last_yield_line_no = frame.f_lineno

    return check_yield_points_inner