diff options
author | David Robertson <davidr@element.io> | 2022-05-19 02:01:02 +0100 |
---|---|---|
committer | David Robertson <davidr@element.io> | 2022-05-19 10:29:27 +0100 |
commit | 9b9b51be6a77fc96c63dff6d0c20039a1c8b554a (patch) | |
tree | fa50d614a3c6e9b752204129195d6d7ed3c5b4dd | |
parent | WIP trying out validators (diff) | |
download | synapse-9b9b51be6a77fc96c63dff6d0c20039a1c8b554a.tar.xz |
It seems what I want is `constr`
but this interacts poorly with mypy :(
-rw-r--r-- | synapse/config/oidc2.py | 25 | ||||
-rw-r--r-- | synapse/config/validators.py | 31 | ||||
-rw-r--r-- | tests/config/test_oidc2.py | 41 | ||||
-rw-r--r-- | tests/config/test_validators.py | 52 |
4 files changed, 38 insertions, 111 deletions
diff --git a/synapse/config/oidc2.py b/synapse/config/oidc2.py index 64005deae9..e89879d1f4 100644 --- a/synapse/config/oidc2.py +++ b/synapse/config/oidc2.py @@ -1,7 +1,6 @@ -from typing import Optional, Tuple, Any +from typing import TYPE_CHECKING, Any, Optional, Tuple -from pydantic import BaseModel, StrictStr, validator, StrictBool -from synapse.config.validators import string_length_between, string_contains_characters +from pydantic import BaseModel, StrictBool, StrictStr, constr class OIDCProviderModel(BaseModel): @@ -16,11 +15,17 @@ class OIDCProviderModel(BaseModel): # a unique identifier for this identity provider. Used in the 'user_external_ids' # table, as well as the query/path parameter used in the login protocol. # TODO: this is optional in the old-style config, defaulting to "oidc". - idp_id: StrictStr - _idp_id_length = validator("idp_id")(string_length_between(1, 250)) - _idp_id_characters = validator("idp_id")( - string_contains_characters("A-Za-z0-9._~-") - ) + # Ugly workaround for https://github.com/samuelcolvin/pydantic/issues/156, see also + # https://github.com/samuelcolvin/pydantic/issues/156#issuecomment-1130883884 + if TYPE_CHECKING: + idp_id: str + else: + idp_id: constr( + strict=True, + min_length=1, + max_length=250, + regex="^[A-Za-z0-9._~-]+$", # noqa: F722 + ) # user-facing name for this identity provider. # TODO: this is optional in the old-style config, defaulting to "OIDC". @@ -99,6 +104,4 @@ class OIDCProviderModel(BaseModel): user_mapping_provider_config: Any # required attributes to require in userinfo to allow login/registration - attribute_requirements: Tuple[ - Any, ... - ] = tuple() # TODO SsoAttributeRequirement] = tuple() + attribute_requirements: Tuple[Any, ...] = () # TODO SsoAttributeRequirement] = () diff --git a/synapse/config/validators.py b/synapse/config/validators.py deleted file mode 100644 index 2faae0cf4a..0000000000 --- a/synapse/config/validators.py +++ /dev/null @@ -1,31 +0,0 @@ -import re -from typing import Type - -from pydantic import BaseModel -from pydantic.fields import ModelField - - -def string_length_between(lower: int, upper: int): - def validator(cls: Type[BaseModel], value: str, field: ModelField) -> str: - print(f"validate {lower=} {upper=} {value=}") - if lower <= len(value) <= upper: - print("ok") - return value - print("bad") - raise ValueError( - f"{field.name} must be between {lower} and {upper} characters long" - ) - - return validator - - -def string_contains_characters(charset: str): - def validator(cls: Type[BaseModel], value: str, field: ModelField) -> str: - pattern = f"^[{charset}]*$" - if re.match(pattern, value): - return value - raise ValueError( - f"{field.name} must be only contain the characters {charset}" - ) - - return validator diff --git a/tests/config/test_oidc2.py b/tests/config/test_oidc2.py index 021eba9511..102b35f9cc 100644 --- a/tests/config/test_oidc2.py +++ b/tests/config/test_oidc2.py @@ -1,14 +1,14 @@ from copy import deepcopy -from typing import Any, Mapping, Dict +from typing import Any, Dict import yaml from pydantic import ValidationError from synapse.config.oidc2 import OIDCProviderModel -from tests.unittest import TestCase +from tests.unittest import TestCase -SAMPLE_CONFIG: Mapping[str, Any] = yaml.safe_load( +SAMPLE_CONFIG = yaml.safe_load( """ idp_id: apple idp_name: Apple @@ -44,20 +44,21 @@ class PydanticOIDCTestCase(TestCase): def test_idp_id(self) -> None: """Demonstrate that Pydantic validates idp_id correctly.""" - # OIDCProviderModel.parse_obj(self.config) - # - # # Enforce that idp_id is required. - # with self.assertRaises(ValidationError): - # del self.config["idp_id"] - # OIDCProviderModel.parse_obj(self.config) - # - # # Enforce that idp_id is a string. - # with self.assertRaises(ValidationError): - # self.config["idp_id"] = 123 - # OIDCProviderModel.parse_obj(self.config) - # with self.assertRaises(ValidationError): - # self.config["idp_id"] = None - # OIDCProviderModel.parse_obj(self.config) + OIDCProviderModel.parse_obj(self.config) + + # Enforce that idp_id is required. + with self.assertRaises(ValidationError): + del self.config["idp_id"] + OIDCProviderModel.parse_obj(self.config) + + # Enforce that idp_id is a string. + with self.assertRaises(ValidationError) as e: + self.config["idp_id"] = 123 + OIDCProviderModel.parse_obj(self.config) + print(e.exception) + with self.assertRaises(ValidationError): + self.config["idp_id"] = None + OIDCProviderModel.parse_obj(self.config) # Enforce a length between 1 and 250. with self.assertRaises(ValidationError): @@ -71,3 +72,9 @@ class PydanticOIDCTestCase(TestCase): with self.assertRaises(ValidationError): self.config["idp_id"] = "$" OIDCProviderModel.parse_obj(self.config) + + # What happens with a really long string of prohibited characters? + with self.assertRaises(ValidationError) as e: + self.config["idp_id"] = "$" * 500 + OIDCProviderModel.parse_obj(self.config) + print(e.exception) diff --git a/tests/config/test_validators.py b/tests/config/test_validators.py deleted file mode 100644 index 7b57a13d5e..0000000000 --- a/tests/config/test_validators.py +++ /dev/null @@ -1,52 +0,0 @@ -from unittest import TestCase - -from pydantic import BaseModel, ValidationError, validator, StrictStr - -from synapse.config.validators import string_length_between, string_contains_characters - - -class TestValidators(TestCase): - def test_string_length_between(self) -> None: - class TestModel(BaseModel): - x: StrictStr - _x_length = validator("x")(string_length_between(5, 10)) - - with self.assertRaises(ValidationError): - TestModel(x="") - with self.assertRaises(ValidationError): - TestModel(x="a" * 4) - - # Should not raise: - TestModel(x="a" * 5) - TestModel(x="a" * 10) - - with self.assertRaises(ValidationError): - TestModel(x="a" * 11) - with self.assertRaises(ValidationError): - TestModel(x="a" * 1000) - - def test_string_contains_characters(self) -> None: - class TestModel(BaseModel): - x: StrictStr - _x_characters = validator("x")(string_contains_characters("A-Z0-9")) - - # Should not raise - TestModel(x="") - TestModel(x="A") - TestModel(x="B") - TestModel(x="Z") - TestModel(x="123456789") - - with self.assertRaises(ValidationError): - TestModel(x="---") - with self.assertRaises(ValidationError): - TestModel(x="$") - with self.assertRaises(ValidationError): - TestModel(x="A$") - with self.assertRaises(ValidationError): - TestModel(x="a") - with self.assertRaises(ValidationError): - TestModel(x="\u0000") - with self.assertRaises(ValidationError): - TestModel(x="☃") - |