diff --git a/scripts-dev/check_pydantic_models.py b/scripts-dev/check_pydantic_models.py
index 9f2b7ded5b..d1cfc9a85c 100755
--- a/scripts-dev/check_pydantic_models.py
+++ b/scripts-dev/check_pydantic_models.py
@@ -36,11 +36,41 @@ import textwrap
import traceback
import unittest.mock
from contextlib import contextmanager
-from typing import Any, Callable, Dict, Generator, List, Set, Type, TypeVar
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Dict,
+ Generator,
+ List,
+ Set,
+ Type,
+ TypeVar,
+)
from parameterized import parameterized
-from pydantic import BaseModel as PydanticBaseModel, conbytes, confloat, conint, constr
-from pydantic.typing import get_args
+
+from synapse._pydantic_compat import HAS_PYDANTIC_V2
+
+if TYPE_CHECKING or HAS_PYDANTIC_V2:
+ from pydantic.v1 import (
+ BaseModel as PydanticBaseModel,
+ conbytes,
+ confloat,
+ conint,
+ constr,
+ )
+ from pydantic.v1.typing import get_args
+else:
+ from pydantic import (
+ BaseModel as PydanticBaseModel,
+ conbytes,
+ confloat,
+ conint,
+ constr,
+ )
+ from pydantic.typing import get_args
+
from typing_extensions import ParamSpec
logger = logging.getLogger(__name__)
@@ -251,7 +281,10 @@ class TestConstrainedTypesPatch(unittest.TestCase):
with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
run_test_snippet(
"""
- from pydantic import constr
+ try:
+ from pydantic.v1 import constr
+ except ImportError:
+ from pydantic import constr
constr()
"""
)
@@ -269,7 +302,10 @@ class TestConstrainedTypesPatch(unittest.TestCase):
with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
run_test_snippet(
"""
- from pydantic import *
+ try:
+ from pydantic.v1 import *
+ except ImportError:
+ from pydantic import *
constr()
"""
)
@@ -278,7 +314,10 @@ class TestConstrainedTypesPatch(unittest.TestCase):
with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
run_test_snippet(
"""
- from pydantic.types import constr
+ try:
+ from pydantic.v1.types import constr
+ except ImportError:
+ from pydantic.types import constr
constr()
"""
)
@@ -287,8 +326,11 @@ class TestConstrainedTypesPatch(unittest.TestCase):
with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
run_test_snippet(
"""
- import pydantic.types
- pydantic.types.constr()
+ try:
+ from pydantic.v1 import types as pydantic_types
+ except ImportError:
+ from pydantic import types as pydantic_types
+ pydantic_types.constr()
"""
)
@@ -296,7 +338,10 @@ class TestConstrainedTypesPatch(unittest.TestCase):
with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
run_test_snippet(
"""
- from pydantic import constr
+ try:
+ from pydantic.v1 import constr
+ except ImportError:
+ from pydantic import constr
constr(min_length=10)
"""
)
@@ -305,7 +350,10 @@ class TestConstrainedTypesPatch(unittest.TestCase):
with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
run_test_snippet(
"""
- from pydantic import constr
+ try:
+ from pydantic.v1 import constr
+ except ImportError:
+ from pydantic import constr
constr(strict=False)
"""
)
@@ -314,7 +362,10 @@ class TestConstrainedTypesPatch(unittest.TestCase):
with monkeypatch_pydantic():
run_test_snippet(
"""
- from pydantic import constr
+ try:
+ from pydantic.v1 import constr
+ except ImportError:
+ from pydantic import constr
constr(strict=True)
"""
)
@@ -323,7 +374,10 @@ class TestConstrainedTypesPatch(unittest.TestCase):
with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
run_test_snippet(
"""
- from pydantic import constr
+ try:
+ from pydantic.v1 import constr
+ except ImportError:
+ from pydantic import constr
x: constr()
"""
)
@@ -332,7 +386,10 @@ class TestConstrainedTypesPatch(unittest.TestCase):
with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
run_test_snippet(
"""
- from pydantic import BaseModel, conint
+ try:
+ from pydantic.v1 import BaseModel, conint
+ except ImportError:
+ from pydantic import BaseModel, conint
class C:
x: conint()
"""
@@ -361,7 +418,10 @@ class TestFieldTypeInspection(unittest.TestCase):
run_test_snippet(
f"""
from typing import *
- from pydantic import *
+ try:
+ from pydantic.v1 import *
+ except ImportError:
+ from pydantic import *
class C(BaseModel):
f: {annotation}
"""
@@ -388,7 +448,10 @@ class TestFieldTypeInspection(unittest.TestCase):
run_test_snippet(
f"""
from typing import *
- from pydantic import *
+ try:
+ from pydantic.v1 import *
+ except ImportError:
+ from pydantic import *
class C(BaseModel):
f: {annotation}
"""
@@ -398,7 +461,10 @@ class TestFieldTypeInspection(unittest.TestCase):
with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
run_test_snippet(
"""
- from pydantic.main import BaseModel
+ try:
+ from pydantic.v1.main import BaseModel
+ except ImportError:
+ from pydantic.main import BaseModel
class C(BaseModel):
f: str
"""
|