summary refs log tree commit diff
path: root/scripts-dev/check_pydantic_models.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts-dev/check_pydantic_models.py')
-rwxr-xr-xscripts-dev/check_pydantic_models.py98
1 files changed, 82 insertions, 16 deletions
diff --git a/scripts-dev/check_pydantic_models.py b/scripts-dev/check_pydantic_models.py
index 9f2b7ded5b..d1cfc9a85c 100755
--- a/scripts-dev/check_pydantic_models.py
+++ b/scripts-dev/check_pydantic_models.py
@@ -36,11 +36,41 @@ import textwrap
 import traceback
 import unittest.mock
 from contextlib import contextmanager
-from typing import Any, Callable, Dict, Generator, List, Set, Type, TypeVar
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Callable,
+    Dict,
+    Generator,
+    List,
+    Set,
+    Type,
+    TypeVar,
+)
 
 from parameterized import parameterized
-from pydantic import BaseModel as PydanticBaseModel, conbytes, confloat, conint, constr
-from pydantic.typing import get_args
+
+from synapse._pydantic_compat import HAS_PYDANTIC_V2
+
+if TYPE_CHECKING or HAS_PYDANTIC_V2:
+    from pydantic.v1 import (
+        BaseModel as PydanticBaseModel,
+        conbytes,
+        confloat,
+        conint,
+        constr,
+    )
+    from pydantic.v1.typing import get_args
+else:
+    from pydantic import (
+        BaseModel as PydanticBaseModel,
+        conbytes,
+        confloat,
+        conint,
+        constr,
+    )
+    from pydantic.typing import get_args
+
 from typing_extensions import ParamSpec
 
 logger = logging.getLogger(__name__)
@@ -251,7 +281,10 @@ class TestConstrainedTypesPatch(unittest.TestCase):
         with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
             run_test_snippet(
                 """
-                from pydantic import constr
+                try:
+                    from pydantic.v1 import constr
+                except ImportError:
+                    from pydantic import constr
                 constr()
                 """
             )
@@ -269,7 +302,10 @@ class TestConstrainedTypesPatch(unittest.TestCase):
         with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
             run_test_snippet(
                 """
-                from pydantic import *
+                try:
+                    from pydantic.v1 import *
+                except ImportError:
+                    from pydantic import *
                 constr()
                 """
             )
@@ -278,7 +314,10 @@ class TestConstrainedTypesPatch(unittest.TestCase):
         with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
             run_test_snippet(
                 """
-                from pydantic.types import constr
+                try:
+                    from pydantic.v1.types import constr
+                except ImportError:
+                    from pydantic.types import constr
                 constr()
                 """
             )
@@ -287,8 +326,11 @@ class TestConstrainedTypesPatch(unittest.TestCase):
         with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
             run_test_snippet(
                 """
-                import pydantic.types
-                pydantic.types.constr()
+                try:
+                    from pydantic.v1 import types as pydantic_types
+                except ImportError:
+                    from pydantic import types as pydantic_types
+                pydantic_types.constr()
                 """
             )
 
@@ -296,7 +338,10 @@ class TestConstrainedTypesPatch(unittest.TestCase):
         with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
             run_test_snippet(
                 """
-                from pydantic import constr
+                try:
+                    from pydantic.v1 import constr
+                except ImportError:
+                    from pydantic import constr
                 constr(min_length=10)
                 """
             )
@@ -305,7 +350,10 @@ class TestConstrainedTypesPatch(unittest.TestCase):
         with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
             run_test_snippet(
                 """
-                from pydantic import constr
+                try:
+                    from pydantic.v1 import constr
+                except ImportError:
+                    from pydantic import constr
                 constr(strict=False)
                 """
             )
@@ -314,7 +362,10 @@ class TestConstrainedTypesPatch(unittest.TestCase):
         with monkeypatch_pydantic():
             run_test_snippet(
                 """
-                from pydantic import constr
+                try:
+                    from pydantic.v1 import constr
+                except ImportError:
+                    from pydantic import constr
                 constr(strict=True)
                 """
             )
@@ -323,7 +374,10 @@ class TestConstrainedTypesPatch(unittest.TestCase):
         with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
             run_test_snippet(
                 """
-                from pydantic import constr
+                try:
+                    from pydantic.v1 import constr
+                except ImportError:
+                    from pydantic import constr
                 x: constr()
                 """
             )
@@ -332,7 +386,10 @@ class TestConstrainedTypesPatch(unittest.TestCase):
         with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
             run_test_snippet(
                 """
-                from pydantic import BaseModel, conint
+                try:
+                    from pydantic.v1 import BaseModel, conint
+                except ImportError:
+                    from pydantic import BaseModel, conint
                 class C:
                     x: conint()
                 """
@@ -361,7 +418,10 @@ class TestFieldTypeInspection(unittest.TestCase):
             run_test_snippet(
                 f"""
                 from typing import *
-                from pydantic import *
+                try:
+                    from pydantic.v1 import *
+                except ImportError:
+                    from pydantic import *
                 class C(BaseModel):
                     f: {annotation}
                 """
@@ -388,7 +448,10 @@ class TestFieldTypeInspection(unittest.TestCase):
             run_test_snippet(
                 f"""
                 from typing import *
-                from pydantic import *
+                try:
+                    from pydantic.v1 import *
+                except ImportError:
+                    from pydantic import *
                 class C(BaseModel):
                     f: {annotation}
                 """
@@ -398,7 +461,10 @@ class TestFieldTypeInspection(unittest.TestCase):
         with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
             run_test_snippet(
                 """
-                from pydantic.main import BaseModel
+                try:
+                    from pydantic.v1.main import BaseModel
+                except ImportError:
+                    from pydantic.main import BaseModel
                 class C(BaseModel):
                     f: str
                 """