2 files changed, 125 insertions, 0 deletions
diff --git a/tests/config/test_oidc2.py b/tests/config/test_oidc2.py
new file mode 100644
index 0000000000..021eba9511
--- /dev/null
+++ b/tests/config/test_oidc2.py
@@ -0,0 +1,73 @@
+from copy import deepcopy
+from typing import Any, Mapping, Dict
+
+import yaml
+from pydantic import ValidationError
+
+from synapse.config.oidc2 import OIDCProviderModel
+from tests.unittest import TestCase
+
+
+SAMPLE_CONFIG: Mapping[str, Any] = yaml.safe_load(
+ """
+idp_id: apple
+idp_name: Apple
+idp_icon: "mxc://matrix.org/blahblahblah"
+idp_brand: "apple"
+issuer: "https://appleid.apple.com"
+client_id: "org.matrix.synapse.sso.service"
+client_secret_jwt_key:
+ key: DUMMY_PRIVATE_KEY
+ jwt_header:
+ alg: ES256
+ kid: potato123
+ jwt_payload:
+ iss: issuer456
+client_auth_method: "client_secret_post"
+scopes: ["name", "email", "openid"]
+authorization_endpoint: https://appleid.apple.com/auth/authorize?response_mode=form_post
+user_mapping_provider:
+ config:
+ email_template: "{{ user.email }}"
+ localpart_template: "{{ user.email|localpart_from_email }}"
+ confirm_localpart: true
+"""
+)
+
+
+class PydanticOIDCTestCase(TestCase):
+ # Each test gets a dummy config it can change as it sees fit
+ config: Dict[str, Any]
+
+ def setUp(self) -> None:
+ self.config = deepcopy(SAMPLE_CONFIG)
+
+ def test_idp_id(self) -> None:
+ """Demonstrate that Pydantic validates idp_id correctly."""
+ # OIDCProviderModel.parse_obj(self.config)
+ #
+ # # Enforce that idp_id is required.
+ # with self.assertRaises(ValidationError):
+ # del self.config["idp_id"]
+ # OIDCProviderModel.parse_obj(self.config)
+ #
+ # # Enforce that idp_id is a string.
+ # with self.assertRaises(ValidationError):
+ # self.config["idp_id"] = 123
+ # OIDCProviderModel.parse_obj(self.config)
+ # with self.assertRaises(ValidationError):
+ # self.config["idp_id"] = None
+ # OIDCProviderModel.parse_obj(self.config)
+
+ # Enforce a length between 1 and 250.
+ with self.assertRaises(ValidationError):
+ self.config["idp_id"] = ""
+ OIDCProviderModel.parse_obj(self.config)
+ with self.assertRaises(ValidationError):
+ self.config["idp_id"] = "a" * 251
+ OIDCProviderModel.parse_obj(self.config)
+
+ # Enforce the character set
+ with self.assertRaises(ValidationError):
+ self.config["idp_id"] = "$"
+ OIDCProviderModel.parse_obj(self.config)
diff --git a/tests/config/test_validators.py b/tests/config/test_validators.py
new file mode 100644
index 0000000000..7b57a13d5e
--- /dev/null
+++ b/tests/config/test_validators.py
@@ -0,0 +1,52 @@
+from unittest import TestCase
+
+from pydantic import BaseModel, ValidationError, validator, StrictStr
+
+from synapse.config.validators import string_length_between, string_contains_characters
+
+
+class TestValidators(TestCase):
+ def test_string_length_between(self) -> None:
+ class TestModel(BaseModel):
+ x: StrictStr
+ _x_length = validator("x")(string_length_between(5, 10))
+
+ with self.assertRaises(ValidationError):
+ TestModel(x="")
+ with self.assertRaises(ValidationError):
+ TestModel(x="a" * 4)
+
+ # Should not raise:
+ TestModel(x="a" * 5)
+ TestModel(x="a" * 10)
+
+ with self.assertRaises(ValidationError):
+ TestModel(x="a" * 11)
+ with self.assertRaises(ValidationError):
+ TestModel(x="a" * 1000)
+
+ def test_string_contains_characters(self) -> None:
+ class TestModel(BaseModel):
+ x: StrictStr
+ _x_characters = validator("x")(string_contains_characters("A-Z0-9"))
+
+ # Should not raise
+ TestModel(x="")
+ TestModel(x="A")
+ TestModel(x="B")
+ TestModel(x="Z")
+ TestModel(x="123456789")
+
+ with self.assertRaises(ValidationError):
+ TestModel(x="---")
+ with self.assertRaises(ValidationError):
+ TestModel(x="$")
+ with self.assertRaises(ValidationError):
+ TestModel(x="A$")
+ with self.assertRaises(ValidationError):
+ TestModel(x="a")
+ with self.assertRaises(ValidationError):
+ TestModel(x="\u0000")
+ with self.assertRaises(ValidationError):
+ TestModel(x="☃")
+
|