summary refs log tree commit diff
diff options
context:
space:
mode:
authorDavid Robertson <davidr@element.io>2022-05-19 02:01:02 +0100
committerDavid Robertson <davidr@element.io>2022-05-19 10:29:27 +0100
commit9b9b51be6a77fc96c63dff6d0c20039a1c8b554a (patch)
treefa50d614a3c6e9b752204129195d6d7ed3c5b4dd
parentWIP trying out validators (diff)
downloadsynapse-9b9b51be6a77fc96c63dff6d0c20039a1c8b554a.tar.xz
It seems what I want is `constr`
but this interacts poorly with mypy :(
-rw-r--r--synapse/config/oidc2.py25
-rw-r--r--synapse/config/validators.py31
-rw-r--r--tests/config/test_oidc2.py41
-rw-r--r--tests/config/test_validators.py52
4 files changed, 38 insertions, 111 deletions
diff --git a/synapse/config/oidc2.py b/synapse/config/oidc2.py
index 64005deae9..e89879d1f4 100644
--- a/synapse/config/oidc2.py
+++ b/synapse/config/oidc2.py
@@ -1,7 +1,6 @@
-from typing import Optional, Tuple, Any
+from typing import TYPE_CHECKING, Any, Optional, Tuple
 
-from pydantic import BaseModel, StrictStr, validator, StrictBool
-from synapse.config.validators import string_length_between, string_contains_characters
+from pydantic import BaseModel, StrictBool, StrictStr, constr
 
 
 class OIDCProviderModel(BaseModel):
@@ -16,11 +15,17 @@ class OIDCProviderModel(BaseModel):
     # a unique identifier for this identity provider. Used in the 'user_external_ids'
     # table, as well as the query/path parameter used in the login protocol.
     # TODO: this is optional in the old-style config, defaulting to "oidc".
-    idp_id: StrictStr
-    _idp_id_length = validator("idp_id")(string_length_between(1, 250))
-    _idp_id_characters = validator("idp_id")(
-        string_contains_characters("A-Za-z0-9._~-")
-    )
+    # Ugly workaround for https://github.com/samuelcolvin/pydantic/issues/156, see also
+    # https://github.com/samuelcolvin/pydantic/issues/156#issuecomment-1130883884
+    if TYPE_CHECKING:
+        idp_id: str
+    else:
+        idp_id: constr(
+            strict=True,
+            min_length=1,
+            max_length=250,
+            regex="^[A-Za-z0-9._~-]+$",  # noqa: F722
+        )
 
     # user-facing name for this identity provider.
     # TODO: this is optional in the old-style config, defaulting to "OIDC".
@@ -99,6 +104,4 @@ class OIDCProviderModel(BaseModel):
     user_mapping_provider_config: Any
 
     # required attributes to require in userinfo to allow login/registration
-    attribute_requirements: Tuple[
-        Any, ...
-    ] = tuple()  # TODO SsoAttributeRequirement] = tuple()
+    attribute_requirements: Tuple[Any, ...] = ()  # TODO SsoAttributeRequirement] = ()
diff --git a/synapse/config/validators.py b/synapse/config/validators.py
deleted file mode 100644
index 2faae0cf4a..0000000000
--- a/synapse/config/validators.py
+++ /dev/null
@@ -1,31 +0,0 @@
-import re
-from typing import Type
-
-from pydantic import BaseModel
-from pydantic.fields import ModelField
-
-
-def string_length_between(lower: int, upper: int):
-    def validator(cls: Type[BaseModel], value: str, field: ModelField) -> str:
-        print(f"validate {lower=} {upper=} {value=}")
-        if lower <= len(value) <= upper:
-            print("ok")
-            return value
-        print("bad")
-        raise ValueError(
-            f"{field.name} must be between {lower} and {upper} characters long"
-        )
-
-    return validator
-
-
-def string_contains_characters(charset: str):
-    def validator(cls: Type[BaseModel], value: str, field: ModelField) -> str:
-        pattern = f"^[{charset}]*$"
-        if re.match(pattern, value):
-            return value
-        raise ValueError(
-            f"{field.name} must be only contain the characters {charset}"
-        )
-
-    return validator
diff --git a/tests/config/test_oidc2.py b/tests/config/test_oidc2.py
index 021eba9511..102b35f9cc 100644
--- a/tests/config/test_oidc2.py
+++ b/tests/config/test_oidc2.py
@@ -1,14 +1,14 @@
 from copy import deepcopy
-from typing import Any, Mapping, Dict
+from typing import Any, Dict
 
 import yaml
 from pydantic import ValidationError
 
 from synapse.config.oidc2 import OIDCProviderModel
-from tests.unittest import TestCase
 
+from tests.unittest import TestCase
 
-SAMPLE_CONFIG: Mapping[str, Any] = yaml.safe_load(
+SAMPLE_CONFIG = yaml.safe_load(
     """
 idp_id: apple
 idp_name: Apple
@@ -44,20 +44,21 @@ class PydanticOIDCTestCase(TestCase):
 
     def test_idp_id(self) -> None:
         """Demonstrate that Pydantic validates idp_id correctly."""
-        # OIDCProviderModel.parse_obj(self.config)
-        #
-        # # Enforce that idp_id is required.
-        # with self.assertRaises(ValidationError):
-        #     del self.config["idp_id"]
-        #     OIDCProviderModel.parse_obj(self.config)
-        #
-        # # Enforce that idp_id is a string.
-        # with self.assertRaises(ValidationError):
-        #     self.config["idp_id"] = 123
-        #     OIDCProviderModel.parse_obj(self.config)
-        # with self.assertRaises(ValidationError):
-        #     self.config["idp_id"] = None
-        #     OIDCProviderModel.parse_obj(self.config)
+        OIDCProviderModel.parse_obj(self.config)
+
+        # Enforce that idp_id is required.
+        with self.assertRaises(ValidationError):
+            del self.config["idp_id"]
+            OIDCProviderModel.parse_obj(self.config)
+
+        # Enforce that idp_id is a string.
+        with self.assertRaises(ValidationError) as e:
+            self.config["idp_id"] = 123
+            OIDCProviderModel.parse_obj(self.config)
+        print(e.exception)
+        with self.assertRaises(ValidationError):
+            self.config["idp_id"] = None
+            OIDCProviderModel.parse_obj(self.config)
 
         # Enforce a length between 1 and 250.
         with self.assertRaises(ValidationError):
@@ -71,3 +72,9 @@ class PydanticOIDCTestCase(TestCase):
         with self.assertRaises(ValidationError):
             self.config["idp_id"] = "$"
             OIDCProviderModel.parse_obj(self.config)
+
+        # What happens with a really long string of prohibited characters?
+        with self.assertRaises(ValidationError) as e:
+            self.config["idp_id"] = "$" * 500
+            OIDCProviderModel.parse_obj(self.config)
+        print(e.exception)
diff --git a/tests/config/test_validators.py b/tests/config/test_validators.py
deleted file mode 100644
index 7b57a13d5e..0000000000
--- a/tests/config/test_validators.py
+++ /dev/null
@@ -1,52 +0,0 @@
-from unittest import TestCase
-
-from pydantic import BaseModel, ValidationError, validator, StrictStr
-
-from synapse.config.validators import string_length_between, string_contains_characters
-
-
-class TestValidators(TestCase):
-    def test_string_length_between(self) -> None:
-        class TestModel(BaseModel):
-            x: StrictStr
-            _x_length = validator("x")(string_length_between(5, 10))
-
-        with self.assertRaises(ValidationError):
-            TestModel(x="")
-        with self.assertRaises(ValidationError):
-            TestModel(x="a" * 4)
-
-        # Should not raise:
-        TestModel(x="a" * 5)
-        TestModel(x="a" * 10)
-
-        with self.assertRaises(ValidationError):
-            TestModel(x="a" * 11)
-        with self.assertRaises(ValidationError):
-            TestModel(x="a" * 1000)
-
-    def test_string_contains_characters(self) -> None:
-        class TestModel(BaseModel):
-            x: StrictStr
-            _x_characters = validator("x")(string_contains_characters("A-Z0-9"))
-
-        # Should not raise
-        TestModel(x="")
-        TestModel(x="A")
-        TestModel(x="B")
-        TestModel(x="Z")
-        TestModel(x="123456789")
-
-        with self.assertRaises(ValidationError):
-            TestModel(x="---")
-        with self.assertRaises(ValidationError):
-            TestModel(x="$")
-        with self.assertRaises(ValidationError):
-            TestModel(x="A$")
-        with self.assertRaises(ValidationError):
-            TestModel(x="a")
-        with self.assertRaises(ValidationError):
-            TestModel(x="\u0000")
-        with self.assertRaises(ValidationError):
-            TestModel(x="☃")
-