summary refs log tree commit diff
path: root/scripts-dev/check_pydantic_models.py
blob: 9f2b7ded5bd51930fc85c936e9b351b04f225d6b (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
#! /usr/bin/env python
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
A script which enforces that Synapse always uses strict types when defining a Pydantic
model.

Pydantic does not yet offer a strict mode, but it is planned for pydantic v2. See

    https://github.com/pydantic/pydantic/issues/1098
    https://pydantic-docs.helpmanual.io/blog/pydantic-v2/#strict-mode

until then, this script is a best effort to stop us from introducing type coersion bugs
(like the infamous stringy power levels fixed in room version 10).
"""
import argparse
import contextlib
import functools
import importlib
import logging
import os
import pkgutil
import sys
import textwrap
import traceback
import unittest.mock
from contextlib import contextmanager
from typing import Any, Callable, Dict, Generator, List, Set, Type, TypeVar

from parameterized import parameterized
from pydantic import BaseModel as PydanticBaseModel, conbytes, confloat, conint, constr
from pydantic.typing import get_args
from typing_extensions import ParamSpec

logger = logging.getLogger(__name__)

CONSTRAINED_TYPE_FACTORIES_WITH_STRICT_FLAG: List[Callable] = [
    constr,
    conbytes,
    conint,
    confloat,
]

TYPES_THAT_PYDANTIC_WILL_COERCE_TO = [
    str,
    bytes,
    int,
    float,
    bool,
]


P = ParamSpec("P")
R = TypeVar("R")


class ModelCheckerException(Exception):
    """Dummy exception. Allows us to detect unwanted types during a module import."""


class MissingStrictInConstrainedTypeException(ModelCheckerException):
    factory_name: str

    def __init__(self, factory_name: str):
        self.factory_name = factory_name


class FieldHasUnwantedTypeException(ModelCheckerException):
    message: str

    def __init__(self, message: str):
        self.message = message


def make_wrapper(factory: Callable[P, R]) -> Callable[P, R]:
    """We patch `constr` and friends with wrappers that enforce strict=True."""

    @functools.wraps(factory)
    def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
        if "strict" not in kwargs:
            raise MissingStrictInConstrainedTypeException(factory.__name__)
        if not kwargs["strict"]:
            raise MissingStrictInConstrainedTypeException(factory.__name__)
        return factory(*args, **kwargs)

    return wrapper


def field_type_unwanted(type_: Any) -> bool:
    """Very rough attempt to detect if a type is unwanted as a Pydantic annotation.

    At present, we exclude types which will coerce, or any generic type involving types
    which will coerce."""
    logger.debug("Is %s unwanted?")
    if type_ in TYPES_THAT_PYDANTIC_WILL_COERCE_TO:
        logger.debug("yes")
        return True
    logger.debug("Maybe. Subargs are %s", get_args(type_))
    rv = any(field_type_unwanted(t) for t in get_args(type_))
    logger.debug("Conclusion: %s %s unwanted", type_, "is" if rv else "is not")
    return rv


class PatchedBaseModel(PydanticBaseModel):
    """A patched version of BaseModel that inspects fields after models are defined.

    We complain loudly if we see an unwanted type.

    Beware: ModelField.type_ is presumably private; this is likely to be very brittle.
    """

    @classmethod
    def __init_subclass__(cls: Type[PydanticBaseModel], **kwargs: object):
        for field in cls.__fields__.values():
            # Note that field.type_ and field.outer_type are computed based on the
            # annotation type, see pydantic.fields.ModelField._type_analysis
            if field_type_unwanted(field.outer_type_):
                # TODO: this only reports the first bad field. Can we find all bad ones
                #  and report them all?
                raise FieldHasUnwantedTypeException(
                    f"{cls.__module__}.{cls.__qualname__} has field '{field.name}' "
                    f"with unwanted type `{field.outer_type_}`"
                )


@contextmanager
def monkeypatch_pydantic() -> Generator[None, None, None]:
    """Patch pydantic with our snooping versions of BaseModel and the con* functions.

    If the snooping functions see something they don't like, they'll raise a
    ModelCheckingException instance.
    """
    with contextlib.ExitStack() as patches:
        # Most Synapse code ought to import the patched objects directly from
        # `pydantic`. But we also patch their containing modules `pydantic.main` and
        # `pydantic.types` for completeness.
        patch_basemodel1 = unittest.mock.patch(
            "pydantic.BaseModel", new=PatchedBaseModel
        )
        patch_basemodel2 = unittest.mock.patch(
            "pydantic.main.BaseModel", new=PatchedBaseModel
        )
        patches.enter_context(patch_basemodel1)
        patches.enter_context(patch_basemodel2)
        for factory in CONSTRAINED_TYPE_FACTORIES_WITH_STRICT_FLAG:
            wrapper: Callable = make_wrapper(factory)
            patch1 = unittest.mock.patch(f"pydantic.{factory.__name__}", new=wrapper)
            patch2 = unittest.mock.patch(
                f"pydantic.types.{factory.__name__}", new=wrapper
            )
            patches.enter_context(patch1)
            patches.enter_context(patch2)
        yield


def format_model_checker_exception(e: ModelCheckerException) -> str:
    """Work out which line of code caused e. Format the line in a human-friendly way."""
    # TODO. FieldHasUnwantedTypeException gives better error messages. Can we ditch the
    #   patches of constr() etc, and instead inspect fields to look for ConstrainedStr
    #   with strict=False? There is some difficulty with the inheritance hierarchy
    #   because StrictStr < ConstrainedStr < str.
    if isinstance(e, FieldHasUnwantedTypeException):
        return e.message
    elif isinstance(e, MissingStrictInConstrainedTypeException):
        frame_summary = traceback.extract_tb(e.__traceback__)[-2]
        return (
            f"Missing `strict=True` from {e.factory_name}() call \n"
            + traceback.format_list([frame_summary])[0].lstrip()
        )
    else:
        raise ValueError(f"Unknown exception {e}") from e


def lint() -> int:
    """Try to import all of Synapse and see if we spot any Pydantic type coercions.

    Print any problems, then return a status code suitable for sys.exit."""
    failures = do_lint()
    if failures:
        print(f"Found {len(failures)} problem(s)")
    for failure in sorted(failures):
        print(failure)
    return os.EX_DATAERR if failures else os.EX_OK


def do_lint() -> Set[str]:
    """Try to import all of Synapse and see if we spot any Pydantic type coercions."""
    failures = set()

    with monkeypatch_pydantic():
        logger.debug("Importing synapse")
        try:
            # TODO: make "synapse" an argument so we can target this script at
            # a subpackage
            module = importlib.import_module("synapse")
        except ModelCheckerException as e:
            logger.warning("Bad annotation found when importing synapse")
            failures.add(format_model_checker_exception(e))
            return failures

        try:
            logger.debug("Fetching subpackages")
            module_infos = list(
                pkgutil.walk_packages(module.__path__, f"{module.__name__}.")
            )
        except ModelCheckerException as e:
            logger.warning("Bad annotation found when looking for modules to import")
            failures.add(format_model_checker_exception(e))
            return failures

        for module_info in module_infos:
            logger.debug("Importing %s", module_info.name)
            try:
                importlib.import_module(module_info.name)
            except ModelCheckerException as e:
                logger.warning(
                    f"Bad annotation found when importing {module_info.name}"
                )
                failures.add(format_model_checker_exception(e))

    return failures


def run_test_snippet(source: str) -> None:
    """Exec a snippet of source code in an isolated environment."""
    # To emulate `source` being called at the top level of the module,
    # the globals and locals we provide apparently have to be the same mapping.
    #
    # > Remember that at the module level, globals and locals are the same dictionary.
    # > If exec gets two separate objects as globals and locals, the code will be
    # > executed as if it were embedded in a class definition.
    globals_: Dict[str, object]
    locals_: Dict[str, object]
    globals_ = locals_ = {}
    exec(textwrap.dedent(source), globals_, locals_)


class TestConstrainedTypesPatch(unittest.TestCase):
    def test_expression_without_strict_raises(self) -> None:
        with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
            run_test_snippet(
                """
                from pydantic import constr
                constr()
                """
            )

    def test_called_as_module_attribute_raises(self) -> None:
        with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
            run_test_snippet(
                """
                import pydantic
                pydantic.constr()
                """
            )

    def test_wildcard_import_raises(self) -> None:
        with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
            run_test_snippet(
                """
                from pydantic import *
                constr()
                """
            )

    def test_alternative_import_raises(self) -> None:
        with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
            run_test_snippet(
                """
                from pydantic.types import constr
                constr()
                """
            )

    def test_alternative_import_attribute_raises(self) -> None:
        with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
            run_test_snippet(
                """
                import pydantic.types
                pydantic.types.constr()
                """
            )

    def test_kwarg_but_no_strict_raises(self) -> None:
        with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
            run_test_snippet(
                """
                from pydantic import constr
                constr(min_length=10)
                """
            )

    def test_kwarg_strict_False_raises(self) -> None:
        with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
            run_test_snippet(
                """
                from pydantic import constr
                constr(strict=False)
                """
            )

    def test_kwarg_strict_True_doesnt_raise(self) -> None:
        with monkeypatch_pydantic():
            run_test_snippet(
                """
                from pydantic import constr
                constr(strict=True)
                """
            )

    def test_annotation_without_strict_raises(self) -> None:
        with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
            run_test_snippet(
                """
                from pydantic import constr
                x: constr()
                """
            )

    def test_field_annotation_without_strict_raises(self) -> None:
        with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
            run_test_snippet(
                """
                from pydantic import BaseModel, conint
                class C:
                    x: conint()
                """
            )


class TestFieldTypeInspection(unittest.TestCase):
    @parameterized.expand(
        [
            ("str",),
            ("bytes"),
            ("int",),
            ("float",),
            ("bool"),
            ("Optional[str]",),
            ("Union[None, str]",),
            ("List[str]",),
            ("List[List[str]]",),
            ("Dict[StrictStr, str]",),
            ("Dict[str, StrictStr]",),
            ("TypedDict('D', x=int)",),
        ]
    )
    def test_field_holding_unwanted_type_raises(self, annotation: str) -> None:
        with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
            run_test_snippet(
                f"""
                from typing import *
                from pydantic import *
                class C(BaseModel):
                    f: {annotation}
                """
            )

    @parameterized.expand(
        [
            ("StrictStr",),
            ("StrictBytes"),
            ("StrictInt",),
            ("StrictFloat",),
            ("StrictBool"),
            ("constr(strict=True, min_length=10)",),
            ("Optional[StrictStr]",),
            ("Union[None, StrictStr]",),
            ("List[StrictStr]",),
            ("List[List[StrictStr]]",),
            ("Dict[StrictStr, StrictStr]",),
            ("TypedDict('D', x=StrictInt)",),
        ]
    )
    def test_field_holding_accepted_type_doesnt_raise(self, annotation: str) -> None:
        with monkeypatch_pydantic():
            run_test_snippet(
                f"""
                from typing import *
                from pydantic import *
                class C(BaseModel):
                    f: {annotation}
                """
            )

    def test_field_holding_str_raises_with_alternative_import(self) -> None:
        with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
            run_test_snippet(
                """
                from pydantic.main import BaseModel
                class C(BaseModel):
                    f: str
                """
            )


parser = argparse.ArgumentParser()
parser.add_argument("mode", choices=["lint", "test"], default="lint", nargs="?")
parser.add_argument("-v", "--verbose", action="store_true")


if __name__ == "__main__":
    args = parser.parse_args(sys.argv[1:])
    logging.basicConfig(
        format="%(asctime)s %(name)s:%(lineno)d %(levelname)s %(message)s",
        level=logging.DEBUG if args.verbose else logging.INFO,
    )
    # suppress logs we don't care about
    logging.getLogger("xmlschema").setLevel(logging.WARNING)
    if args.mode == "lint":
        sys.exit(lint())
    elif args.mode == "test":
        unittest.main(argv=sys.argv[:1])