diff --git a/scripts-dev/check_pydantic_models.py b/scripts-dev/check_pydantic_models.py
index 26d667aba0..5eb1f0a9df 100755
--- a/scripts-dev/check_pydantic_models.py
+++ b/scripts-dev/check_pydantic_models.py
@@ -45,7 +45,6 @@ import traceback
import unittest.mock
from contextlib import contextmanager
from typing import (
- TYPE_CHECKING,
Any,
Callable,
Dict,
@@ -57,30 +56,17 @@ from typing import (
)
from parameterized import parameterized
-
-from synapse._pydantic_compat import HAS_PYDANTIC_V2
-
-if TYPE_CHECKING or HAS_PYDANTIC_V2:
- from pydantic.v1 import (
- BaseModel as PydanticBaseModel,
- conbytes,
- confloat,
- conint,
- constr,
- )
- from pydantic.v1.typing import get_args
-else:
- from pydantic import (
- BaseModel as PydanticBaseModel,
- conbytes,
- confloat,
- conint,
- constr,
- )
- from pydantic.typing import get_args
-
from typing_extensions import ParamSpec
+from synapse._pydantic_compat import (
+ BaseModel as PydanticBaseModel,
+ conbytes,
+ confloat,
+ conint,
+ constr,
+ get_args,
+)
+
logger = logging.getLogger(__name__)
CONSTRAINED_TYPE_FACTORIES_WITH_STRICT_FLAG: List[Callable] = [
@@ -183,22 +169,16 @@ def monkeypatch_pydantic() -> Generator[None, None, None]:
# Most Synapse code ought to import the patched objects directly from
# `pydantic`. But we also patch their containing modules `pydantic.main` and
# `pydantic.types` for completeness.
- patch_basemodel1 = unittest.mock.patch(
- "pydantic.BaseModel", new=PatchedBaseModel
- )
- patch_basemodel2 = unittest.mock.patch(
- "pydantic.main.BaseModel", new=PatchedBaseModel
+ patch_basemodel = unittest.mock.patch(
+ "synapse._pydantic_compat.BaseModel", new=PatchedBaseModel
)
- patches.enter_context(patch_basemodel1)
- patches.enter_context(patch_basemodel2)
+ patches.enter_context(patch_basemodel)
for factory in CONSTRAINED_TYPE_FACTORIES_WITH_STRICT_FLAG:
wrapper: Callable = make_wrapper(factory)
- patch1 = unittest.mock.patch(f"pydantic.{factory.__name__}", new=wrapper)
- patch2 = unittest.mock.patch(
- f"pydantic.types.{factory.__name__}", new=wrapper
+ patch = unittest.mock.patch(
+ f"synapse._pydantic_compat.{factory.__name__}", new=wrapper
)
- patches.enter_context(patch1)
- patches.enter_context(patch2)
+ patches.enter_context(patch)
yield
|