diff --git a/synapse/config/oidc2.py b/synapse/config/oidc2.py
index 64005deae9..e89879d1f4 100644
--- a/synapse/config/oidc2.py
+++ b/synapse/config/oidc2.py
@@ -1,7 +1,6 @@
-from typing import Optional, Tuple, Any
+from typing import TYPE_CHECKING, Any, Optional, Tuple
-from pydantic import BaseModel, StrictStr, validator, StrictBool
-from synapse.config.validators import string_length_between, string_contains_characters
+from pydantic import BaseModel, StrictBool, StrictStr, constr
class OIDCProviderModel(BaseModel):
@@ -16,11 +15,17 @@ class OIDCProviderModel(BaseModel):
# a unique identifier for this identity provider. Used in the 'user_external_ids'
# table, as well as the query/path parameter used in the login protocol.
# TODO: this is optional in the old-style config, defaulting to "oidc".
- idp_id: StrictStr
- _idp_id_length = validator("idp_id")(string_length_between(1, 250))
- _idp_id_characters = validator("idp_id")(
- string_contains_characters("A-Za-z0-9._~-")
- )
+ # Ugly workaround for https://github.com/samuelcolvin/pydantic/issues/156, see also
+ # https://github.com/samuelcolvin/pydantic/issues/156#issuecomment-1130883884
+ if TYPE_CHECKING:
+ idp_id: str
+ else:
+ idp_id: constr(
+ strict=True,
+ min_length=1,
+ max_length=250,
+ regex="^[A-Za-z0-9._~-]+$", # noqa: F722
+ )
# user-facing name for this identity provider.
# TODO: this is optional in the old-style config, defaulting to "OIDC".
@@ -99,6 +104,4 @@ class OIDCProviderModel(BaseModel):
user_mapping_provider_config: Any
# required attributes to require in userinfo to allow login/registration
- attribute_requirements: Tuple[
- Any, ...
- ] = tuple() # TODO SsoAttributeRequirement] = tuple()
+ attribute_requirements: Tuple[Any, ...] = () # TODO SsoAttributeRequirement] = ()
diff --git a/synapse/config/validators.py b/synapse/config/validators.py
deleted file mode 100644
index 2faae0cf4a..0000000000
--- a/synapse/config/validators.py
+++ /dev/null
@@ -1,31 +0,0 @@
-import re
-from typing import Type
-
-from pydantic import BaseModel
-from pydantic.fields import ModelField
-
-
-def string_length_between(lower: int, upper: int):
- def validator(cls: Type[BaseModel], value: str, field: ModelField) -> str:
- print(f"validate {lower=} {upper=} {value=}")
- if lower <= len(value) <= upper:
- print("ok")
- return value
- print("bad")
- raise ValueError(
- f"{field.name} must be between {lower} and {upper} characters long"
- )
-
- return validator
-
-
-def string_contains_characters(charset: str):
- def validator(cls: Type[BaseModel], value: str, field: ModelField) -> str:
- pattern = f"^[{charset}]*$"
- if re.match(pattern, value):
- return value
- raise ValueError(
- f"{field.name} must be only contain the characters {charset}"
- )
-
- return validator
diff --git a/tests/config/test_oidc2.py b/tests/config/test_oidc2.py
index 021eba9511..102b35f9cc 100644
--- a/tests/config/test_oidc2.py
+++ b/tests/config/test_oidc2.py
@@ -1,14 +1,14 @@
from copy import deepcopy
-from typing import Any, Mapping, Dict
+from typing import Any, Dict
import yaml
from pydantic import ValidationError
from synapse.config.oidc2 import OIDCProviderModel
-from tests.unittest import TestCase
+from tests.unittest import TestCase
-SAMPLE_CONFIG: Mapping[str, Any] = yaml.safe_load(
+SAMPLE_CONFIG = yaml.safe_load(
"""
idp_id: apple
idp_name: Apple
@@ -44,20 +44,21 @@ class PydanticOIDCTestCase(TestCase):
def test_idp_id(self) -> None:
"""Demonstrate that Pydantic validates idp_id correctly."""
- # OIDCProviderModel.parse_obj(self.config)
- #
- # # Enforce that idp_id is required.
- # with self.assertRaises(ValidationError):
- # del self.config["idp_id"]
- # OIDCProviderModel.parse_obj(self.config)
- #
- # # Enforce that idp_id is a string.
- # with self.assertRaises(ValidationError):
- # self.config["idp_id"] = 123
- # OIDCProviderModel.parse_obj(self.config)
- # with self.assertRaises(ValidationError):
- # self.config["idp_id"] = None
- # OIDCProviderModel.parse_obj(self.config)
+ OIDCProviderModel.parse_obj(self.config)
+
+ # Enforce that idp_id is required.
+ with self.assertRaises(ValidationError):
+ del self.config["idp_id"]
+ OIDCProviderModel.parse_obj(self.config)
+
+ # Enforce that idp_id is a string.
+ with self.assertRaises(ValidationError) as e:
+ self.config["idp_id"] = 123
+ OIDCProviderModel.parse_obj(self.config)
+ print(e.exception)
+ with self.assertRaises(ValidationError):
+ self.config["idp_id"] = None
+ OIDCProviderModel.parse_obj(self.config)
# Enforce a length between 1 and 250.
with self.assertRaises(ValidationError):
@@ -71,3 +72,9 @@ class PydanticOIDCTestCase(TestCase):
with self.assertRaises(ValidationError):
self.config["idp_id"] = "$"
OIDCProviderModel.parse_obj(self.config)
+
+ # What happens with a really long string of prohibited characters?
+ with self.assertRaises(ValidationError) as e:
+ self.config["idp_id"] = "$" * 500
+ OIDCProviderModel.parse_obj(self.config)
+ print(e.exception)
diff --git a/tests/config/test_validators.py b/tests/config/test_validators.py
deleted file mode 100644
index 7b57a13d5e..0000000000
--- a/tests/config/test_validators.py
+++ /dev/null
@@ -1,52 +0,0 @@
-from unittest import TestCase
-
-from pydantic import BaseModel, ValidationError, validator, StrictStr
-
-from synapse.config.validators import string_length_between, string_contains_characters
-
-
-class TestValidators(TestCase):
- def test_string_length_between(self) -> None:
- class TestModel(BaseModel):
- x: StrictStr
- _x_length = validator("x")(string_length_between(5, 10))
-
- with self.assertRaises(ValidationError):
- TestModel(x="")
- with self.assertRaises(ValidationError):
- TestModel(x="a" * 4)
-
- # Should not raise:
- TestModel(x="a" * 5)
- TestModel(x="a" * 10)
-
- with self.assertRaises(ValidationError):
- TestModel(x="a" * 11)
- with self.assertRaises(ValidationError):
- TestModel(x="a" * 1000)
-
- def test_string_contains_characters(self) -> None:
- class TestModel(BaseModel):
- x: StrictStr
- _x_characters = validator("x")(string_contains_characters("A-Z0-9"))
-
- # Should not raise
- TestModel(x="")
- TestModel(x="A")
- TestModel(x="B")
- TestModel(x="Z")
- TestModel(x="123456789")
-
- with self.assertRaises(ValidationError):
- TestModel(x="---")
- with self.assertRaises(ValidationError):
- TestModel(x="$")
- with self.assertRaises(ValidationError):
- TestModel(x="A$")
- with self.assertRaises(ValidationError):
- TestModel(x="a")
- with self.assertRaises(ValidationError):
- TestModel(x="\u0000")
- with self.assertRaises(ValidationError):
- TestModel(x="☃")
-
|