summary refs log tree commit diff
path: root/scripts-dev
diff options
context:
space:
mode:
Diffstat (limited to 'scripts-dev')
-rwxr-xr-xscripts-dev/check_pydantic_models.py425
-rwxr-xr-xscripts-dev/complement.sh1
-rwxr-xr-xscripts-dev/lint.sh1
-rwxr-xr-xscripts-dev/release.py162
4 files changed, 557 insertions, 32 deletions
diff --git a/scripts-dev/check_pydantic_models.py b/scripts-dev/check_pydantic_models.py
new file mode 100755

index 0000000000..d0fb811bdb --- /dev/null +++ b/scripts-dev/check_pydantic_models.py
@@ -0,0 +1,425 @@ +#! /usr/bin/env python +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +A script which enforces that Synapse always uses strict types when defining a Pydantic +model. + +Pydantic does not yet offer a strict mode, but it is planned for pydantic v2. See + + https://github.com/pydantic/pydantic/issues/1098 + https://pydantic-docs.helpmanual.io/blog/pydantic-v2/#strict-mode + +until then, this script is a best effort to stop us from introducing type coersion bugs +(like the infamous stringy power levels fixed in room version 10). +""" +import argparse +import contextlib +import functools +import importlib +import logging +import os +import pkgutil +import sys +import textwrap +import traceback +import unittest.mock +from contextlib import contextmanager +from typing import Any, Callable, Dict, Generator, List, Set, Type, TypeVar + +from parameterized import parameterized +from pydantic import BaseModel as PydanticBaseModel, conbytes, confloat, conint, constr +from pydantic.typing import get_args +from typing_extensions import ParamSpec + +logger = logging.getLogger(__name__) + +CONSTRAINED_TYPE_FACTORIES_WITH_STRICT_FLAG: List[Callable] = [ + constr, + conbytes, + conint, + confloat, +] + +TYPES_THAT_PYDANTIC_WILL_COERCE_TO = [ + str, + bytes, + int, + float, + bool, +] + + +P = ParamSpec("P") +R = TypeVar("R") + + +class ModelCheckerException(Exception): + """Dummy exception. Allows us to detect unwanted types during a module import.""" + + +class MissingStrictInConstrainedTypeException(ModelCheckerException): + factory_name: str + + def __init__(self, factory_name: str): + self.factory_name = factory_name + + +class FieldHasUnwantedTypeException(ModelCheckerException): + message: str + + def __init__(self, message: str): + self.message = message + + +def make_wrapper(factory: Callable[P, R]) -> Callable[P, R]: + """We patch `constr` and friends with wrappers that enforce strict=True.""" + + @functools.wraps(factory) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + # type-ignore: should be redundant once we can use https://github.com/python/mypy/pull/12668 + if "strict" not in kwargs: # type: ignore[attr-defined] + raise MissingStrictInConstrainedTypeException(factory.__name__) + if not kwargs["strict"]: # type: ignore[index] + raise MissingStrictInConstrainedTypeException(factory.__name__) + return factory(*args, **kwargs) + + return wrapper + + +def field_type_unwanted(type_: Any) -> bool: + """Very rough attempt to detect if a type is unwanted as a Pydantic annotation. + + At present, we exclude types which will coerce, or any generic type involving types + which will coerce.""" + logger.debug("Is %s unwanted?") + if type_ in TYPES_THAT_PYDANTIC_WILL_COERCE_TO: + logger.debug("yes") + return True + logger.debug("Maybe. Subargs are %s", get_args(type_)) + rv = any(field_type_unwanted(t) for t in get_args(type_)) + logger.debug("Conclusion: %s %s unwanted", type_, "is" if rv else "is not") + return rv + + +class PatchedBaseModel(PydanticBaseModel): + """A patched version of BaseModel that inspects fields after models are defined. + + We complain loudly if we see an unwanted type. + + Beware: ModelField.type_ is presumably private; this is likely to be very brittle. + """ + + @classmethod + def __init_subclass__(cls: Type[PydanticBaseModel], **kwargs: object): + for field in cls.__fields__.values(): + # Note that field.type_ and field.outer_type are computed based on the + # annotation type, see pydantic.fields.ModelField._type_analysis + if field_type_unwanted(field.outer_type_): + # TODO: this only reports the first bad field. Can we find all bad ones + # and report them all? + raise FieldHasUnwantedTypeException( + f"{cls.__module__}.{cls.__qualname__} has field '{field.name}' " + f"with unwanted type `{field.outer_type_}`" + ) + + +@contextmanager +def monkeypatch_pydantic() -> Generator[None, None, None]: + """Patch pydantic with our snooping versions of BaseModel and the con* functions. + + If the snooping functions see something they don't like, they'll raise a + ModelCheckingException instance. + """ + with contextlib.ExitStack() as patches: + # Most Synapse code ought to import the patched objects directly from + # `pydantic`. But we also patch their containing modules `pydantic.main` and + # `pydantic.types` for completeness. + patch_basemodel1 = unittest.mock.patch( + "pydantic.BaseModel", new=PatchedBaseModel + ) + patch_basemodel2 = unittest.mock.patch( + "pydantic.main.BaseModel", new=PatchedBaseModel + ) + patches.enter_context(patch_basemodel1) + patches.enter_context(patch_basemodel2) + for factory in CONSTRAINED_TYPE_FACTORIES_WITH_STRICT_FLAG: + wrapper: Callable = make_wrapper(factory) + patch1 = unittest.mock.patch(f"pydantic.{factory.__name__}", new=wrapper) + patch2 = unittest.mock.patch( + f"pydantic.types.{factory.__name__}", new=wrapper + ) + patches.enter_context(patch1) + patches.enter_context(patch2) + yield + + +def format_model_checker_exception(e: ModelCheckerException) -> str: + """Work out which line of code caused e. Format the line in a human-friendly way.""" + # TODO. FieldHasUnwantedTypeException gives better error messages. Can we ditch the + # patches of constr() etc, and instead inspect fields to look for ConstrainedStr + # with strict=False? There is some difficulty with the inheritance hierarchy + # because StrictStr < ConstrainedStr < str. + if isinstance(e, FieldHasUnwantedTypeException): + return e.message + elif isinstance(e, MissingStrictInConstrainedTypeException): + frame_summary = traceback.extract_tb(e.__traceback__)[-2] + return ( + f"Missing `strict=True` from {e.factory_name}() call \n" + + traceback.format_list([frame_summary])[0].lstrip() + ) + else: + raise ValueError(f"Unknown exception {e}") from e + + +def lint() -> int: + """Try to import all of Synapse and see if we spot any Pydantic type coercions. + + Print any problems, then return a status code suitable for sys.exit.""" + failures = do_lint() + if failures: + print(f"Found {len(failures)} problem(s)") + for failure in sorted(failures): + print(failure) + return os.EX_DATAERR if failures else os.EX_OK + + +def do_lint() -> Set[str]: + """Try to import all of Synapse and see if we spot any Pydantic type coercions.""" + failures = set() + + with monkeypatch_pydantic(): + logger.debug("Importing synapse") + try: + # TODO: make "synapse" an argument so we can target this script at + # a subpackage + module = importlib.import_module("synapse") + except ModelCheckerException as e: + logger.warning("Bad annotation found when importing synapse") + failures.add(format_model_checker_exception(e)) + return failures + + try: + logger.debug("Fetching subpackages") + module_infos = list( + pkgutil.walk_packages(module.__path__, f"{module.__name__}.") + ) + except ModelCheckerException as e: + logger.warning("Bad annotation found when looking for modules to import") + failures.add(format_model_checker_exception(e)) + return failures + + for module_info in module_infos: + logger.debug("Importing %s", module_info.name) + try: + importlib.import_module(module_info.name) + except ModelCheckerException as e: + logger.warning( + f"Bad annotation found when importing {module_info.name}" + ) + failures.add(format_model_checker_exception(e)) + + return failures + + +def run_test_snippet(source: str) -> None: + """Exec a snippet of source code in an isolated environment.""" + # To emulate `source` being called at the top level of the module, + # the globals and locals we provide apparently have to be the same mapping. + # + # > Remember that at the module level, globals and locals are the same dictionary. + # > If exec gets two separate objects as globals and locals, the code will be + # > executed as if it were embedded in a class definition. + globals_: Dict[str, object] + locals_: Dict[str, object] + globals_ = locals_ = {} + exec(textwrap.dedent(source), globals_, locals_) + + +class TestConstrainedTypesPatch(unittest.TestCase): + def test_expression_without_strict_raises(self) -> None: + with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException): + run_test_snippet( + """ + from pydantic import constr + constr() + """ + ) + + def test_called_as_module_attribute_raises(self) -> None: + with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException): + run_test_snippet( + """ + import pydantic + pydantic.constr() + """ + ) + + def test_wildcard_import_raises(self) -> None: + with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException): + run_test_snippet( + """ + from pydantic import * + constr() + """ + ) + + def test_alternative_import_raises(self) -> None: + with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException): + run_test_snippet( + """ + from pydantic.types import constr + constr() + """ + ) + + def test_alternative_import_attribute_raises(self) -> None: + with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException): + run_test_snippet( + """ + import pydantic.types + pydantic.types.constr() + """ + ) + + def test_kwarg_but_no_strict_raises(self) -> None: + with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException): + run_test_snippet( + """ + from pydantic import constr + constr(min_length=10) + """ + ) + + def test_kwarg_strict_False_raises(self) -> None: + with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException): + run_test_snippet( + """ + from pydantic import constr + constr(strict=False) + """ + ) + + def test_kwarg_strict_True_doesnt_raise(self) -> None: + with monkeypatch_pydantic(): + run_test_snippet( + """ + from pydantic import constr + constr(strict=True) + """ + ) + + def test_annotation_without_strict_raises(self) -> None: + with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException): + run_test_snippet( + """ + from pydantic import constr + x: constr() + """ + ) + + def test_field_annotation_without_strict_raises(self) -> None: + with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException): + run_test_snippet( + """ + from pydantic import BaseModel, conint + class C: + x: conint() + """ + ) + + +class TestFieldTypeInspection(unittest.TestCase): + @parameterized.expand( + [ + ("str",), + ("bytes"), + ("int",), + ("float",), + ("bool"), + ("Optional[str]",), + ("Union[None, str]",), + ("List[str]",), + ("List[List[str]]",), + ("Dict[StrictStr, str]",), + ("Dict[str, StrictStr]",), + ("TypedDict('D', x=int)",), + ] + ) + def test_field_holding_unwanted_type_raises(self, annotation: str) -> None: + with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException): + run_test_snippet( + f""" + from typing import * + from pydantic import * + class C(BaseModel): + f: {annotation} + """ + ) + + @parameterized.expand( + [ + ("StrictStr",), + ("StrictBytes"), + ("StrictInt",), + ("StrictFloat",), + ("StrictBool"), + ("constr(strict=True, min_length=10)",), + ("Optional[StrictStr]",), + ("Union[None, StrictStr]",), + ("List[StrictStr]",), + ("List[List[StrictStr]]",), + ("Dict[StrictStr, StrictStr]",), + ("TypedDict('D', x=StrictInt)",), + ] + ) + def test_field_holding_accepted_type_doesnt_raise(self, annotation: str) -> None: + with monkeypatch_pydantic(): + run_test_snippet( + f""" + from typing import * + from pydantic import * + class C(BaseModel): + f: {annotation} + """ + ) + + def test_field_holding_str_raises_with_alternative_import(self) -> None: + with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException): + run_test_snippet( + """ + from pydantic.main import BaseModel + class C(BaseModel): + f: str + """ + ) + + +parser = argparse.ArgumentParser() +parser.add_argument("mode", choices=["lint", "test"], default="lint", nargs="?") +parser.add_argument("-v", "--verbose", action="store_true") + + +if __name__ == "__main__": + args = parser.parse_args(sys.argv[1:]) + logging.basicConfig( + format="%(asctime)s %(name)s:%(lineno)d %(levelname)s %(message)s", + level=logging.DEBUG if args.verbose else logging.INFO, + ) + # suppress logs we don't care about + logging.getLogger("xmlschema").setLevel(logging.WARNING) + if args.mode == "lint": + sys.exit(lint()) + elif args.mode == "test": + unittest.main(argv=sys.argv[:1]) diff --git a/scripts-dev/complement.sh b/scripts-dev/complement.sh
index 6381f7092e..eab23f18f1 100755 --- a/scripts-dev/complement.sh +++ b/scripts-dev/complement.sh
@@ -101,6 +101,7 @@ if [ -z "$skip_docker_build" ]; then echo_if_github "::group::Build Docker image: matrixdotorg/synapse" docker build -t matrixdotorg/synapse \ --build-arg TEST_ONLY_SKIP_DEP_HASH_VERIFICATION \ + --build-arg TEST_ONLY_IGNORE_POETRY_LOCKFILE \ -f "docker/Dockerfile" . echo_if_github "::endgroup::" diff --git a/scripts-dev/lint.sh b/scripts-dev/lint.sh
index 377348b107..bf900645b1 100755 --- a/scripts-dev/lint.sh +++ b/scripts-dev/lint.sh
@@ -106,4 +106,5 @@ isort "${files[@]}" python3 -m black "${files[@]}" ./scripts-dev/config-lint.sh flake8 "${files[@]}" +./scripts-dev/check_pydantic_models.py lint mypy diff --git a/scripts-dev/release.py b/scripts-dev/release.py
index 0031ba3e4b..46220c4dd3 100755 --- a/scripts-dev/release.py +++ b/scripts-dev/release.py
@@ -32,6 +32,7 @@ import click import commonmark import git from click.exceptions import ClickException +from git import GitCommandError, Repo from github import Github from packaging import version @@ -55,9 +56,12 @@ def run_until_successful( def cli() -> None: """An interactive script to walk through the parts of creating a release. - Requires the dev dependencies be installed, which can be done via: + Requirements: + - The dev dependencies be installed, which can be done via: - pip install -e .[dev] + pip install -e .[dev] + + - A checkout of the sytest repository at ../sytest Then to use: @@ -75,6 +79,8 @@ def cli() -> None: # Optional: generate some nice links for the announcement + ./scripts-dev/release.py merge-back + ./scripts-dev/release.py announce If the env var GH_TOKEN (or GITHUB_TOKEN) is set, or passed into the @@ -89,10 +95,12 @@ def prepare() -> None: """ # Make sure we're in a git repo. - repo = get_repo_and_check_clean_checkout() + synapse_repo = get_repo_and_check_clean_checkout() + sytest_repo = get_repo_and_check_clean_checkout("../sytest", "sytest") - click.secho("Updating git repo...") - repo.remote().fetch() + click.secho("Updating Synapse and Sytest git repos...") + synapse_repo.remote().fetch() + sytest_repo.remote().fetch() # Get the current version and AST from root Synapse module. current_version = get_package_version() @@ -166,12 +174,12 @@ def prepare() -> None: assert not parsed_new_version.is_postrelease release_branch_name = get_release_branch_name(parsed_new_version) - release_branch = find_ref(repo, release_branch_name) + release_branch = find_ref(synapse_repo, release_branch_name) if release_branch: if release_branch.is_remote(): # If the release branch only exists on the remote we check it out # locally. - repo.git.checkout(release_branch_name) + synapse_repo.git.checkout(release_branch_name) else: # If a branch doesn't exist we create one. We ask which one branch it # should be based off, defaulting to sensible values depending on the @@ -187,25 +195,34 @@ def prepare() -> None: "Which branch should the release be based on?", default=default ) - base_branch = find_ref(repo, branch_name) - if not base_branch: - print(f"Could not find base branch {branch_name}!") - click.get_current_context().abort() + for repo_name, repo in {"synapse": synapse_repo, "sytest": sytest_repo}.items(): + base_branch = find_ref(repo, branch_name) + if not base_branch: + print(f"Could not find base branch {branch_name} for {repo_name}!") + click.get_current_context().abort() + + # Check out the base branch and ensure it's up to date + repo.head.set_reference( + base_branch, f"check out the base branch for {repo_name}" + ) + repo.head.reset(index=True, working_tree=True) + if not base_branch.is_remote(): + update_branch(repo) - # Check out the base branch and ensure it's up to date - repo.head.set_reference(base_branch, "check out the base branch") - repo.head.reset(index=True, working_tree=True) - if not base_branch.is_remote(): - update_branch(repo) + # Create the new release branch + # Type ignore will no longer be needed after GitPython 3.1.28. + # See https://github.com/gitpython-developers/GitPython/pull/1419 + repo.create_head(release_branch_name, commit=base_branch) # type: ignore[arg-type] - # Create the new release branch - # Type ignore will no longer be needed after GitPython 3.1.28. - # See https://github.com/gitpython-developers/GitPython/pull/1419 - repo.create_head(release_branch_name, commit=base_branch) # type: ignore[arg-type] + # Special-case SyTest: we don't actually prepare any files so we may + # as well push it now (and only when we create a release branch; + # not on subsequent RCs or full releases). + if click.confirm("Push new SyTest branch?", default=True): + sytest_repo.git.push("-u", sytest_repo.remote().name, release_branch_name) # Switch to the release branch and ensure it's up to date. - repo.git.checkout(release_branch_name) - update_branch(repo) + synapse_repo.git.checkout(release_branch_name) + update_branch(synapse_repo) # Update the version specified in pyproject.toml. subprocess.check_output(["poetry", "version", new_version]) @@ -230,15 +247,15 @@ def prepare() -> None: run_until_successful('dch -M -r -D stable ""', shell=True) # Show the user the changes and ask if they want to edit the change log. - repo.git.add("-u") + synapse_repo.git.add("-u") subprocess.run("git diff --cached", shell=True) if click.confirm("Edit changelog?", default=False): click.edit(filename="CHANGES.md") # Commit the changes. - repo.git.add("-u") - repo.git.commit("-m", new_version) + synapse_repo.git.add("-u") + synapse_repo.git.commit("-m", new_version) # We give the option to bail here in case the user wants to make sure things # are OK before pushing. @@ -246,17 +263,21 @@ def prepare() -> None: print("") print("Run when ready to push:") print("") - print(f"\tgit push -u {repo.remote().name} {repo.active_branch.name}") + print( + f"\tgit push -u {synapse_repo.remote().name} {synapse_repo.active_branch.name}" + ) print("") sys.exit(0) # Otherwise, push and open the changelog in the browser. - repo.git.push("-u", repo.remote().name, repo.active_branch.name) + synapse_repo.git.push( + "-u", synapse_repo.remote().name, synapse_repo.active_branch.name + ) print("Opening the changelog in your browser...") print("Please ask others to give it a check.") click.launch( - f"https://github.com/matrix-org/synapse/blob/{repo.active_branch.name}/CHANGES.md" + f"https://github.com/matrix-org/synapse/blob/{synapse_repo.active_branch.name}/CHANGES.md" ) @@ -423,6 +444,79 @@ def upload() -> None: ) +def _merge_into(repo: Repo, source: str, target: str) -> None: + """ + Merges branch `source` into branch `target`. + Pulls both before merging and pushes the result. + """ + + # Update our branches and switch to the target branch + for branch in [source, target]: + click.echo(f"Switching to {branch} and pulling...") + repo.heads[branch].checkout() + # Pull so we're up to date + repo.remote().pull() + + assert repo.active_branch.name == target + + try: + # TODO This seemed easier than using GitPython directly + click.echo(f"Merging {source}...") + repo.git.merge(source) + except GitCommandError as exc: + # If a merge conflict occurs, give some context and try to + # make it easy to abort if necessary. + click.echo(exc) + if not click.confirm( + f"Likely merge conflict whilst merging ({source} → {target}). " + f"Have you resolved it?" + ): + repo.git.merge("--abort") + return + + # Push result. + click.echo("Pushing...") + repo.remote().push() + + +@cli.command() +def merge_back() -> None: + """Merge the release branch back into the appropriate branches. + All branches will be automatically pulled from the remote and the results + will be pushed to the remote.""" + + synapse_repo = get_repo_and_check_clean_checkout() + branch_name = synapse_repo.active_branch.name + + if not branch_name.startswith("release-v"): + raise RuntimeError("Not on a release branch. This does not seem sensible.") + + # Pull so we're up to date + synapse_repo.remote().pull() + + current_version = get_package_version() + + if current_version.is_prerelease: + # Release candidate + if click.confirm(f"Merge {branch_name} → develop?", default=True): + _merge_into(synapse_repo, branch_name, "develop") + else: + # Full release + sytest_repo = get_repo_and_check_clean_checkout("../sytest", "sytest") + + if click.confirm(f"Merge {branch_name} → master?", default=True): + _merge_into(synapse_repo, branch_name, "master") + + if click.confirm("Merge master → develop?", default=True): + _merge_into(synapse_repo, "master", "develop") + + if click.confirm(f"On SyTest, merge {branch_name} → master?", default=True): + _merge_into(sytest_repo, branch_name, "master") + + if click.confirm("On SyTest, merge master → develop?", default=True): + _merge_into(sytest_repo, "master", "develop") + + @cli.command() def announce() -> None: """Generate markdown to announce the release.""" @@ -469,14 +563,18 @@ def get_release_branch_name(version_number: version.Version) -> str: return f"release-v{version_number.major}.{version_number.minor}" -def get_repo_and_check_clean_checkout() -> git.Repo: +def get_repo_and_check_clean_checkout( + path: str = ".", name: str = "synapse" +) -> git.Repo: """Get the project repo and check it's not got any uncommitted changes.""" try: - repo = git.Repo() + repo = git.Repo(path=path) except git.InvalidGitRepositoryError: - raise click.ClickException("Not in Synapse repo.") + raise click.ClickException( + f"{path} is not a git repository (expecting a {name} repository)." + ) if repo.is_dirty(): - raise click.ClickException("Uncommitted changes exist.") + raise click.ClickException(f"Uncommitted changes exist in {path}.") return repo