summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tests/config/test_oidc2.py73
-rw-r--r--tests/config/test_validators.py52
2 files changed, 125 insertions, 0 deletions
diff --git a/tests/config/test_oidc2.py b/tests/config/test_oidc2.py
new file mode 100644

index 0000000000..021eba9511 --- /dev/null +++ b/tests/config/test_oidc2.py
@@ -0,0 +1,73 @@ +from copy import deepcopy +from typing import Any, Mapping, Dict + +import yaml +from pydantic import ValidationError + +from synapse.config.oidc2 import OIDCProviderModel +from tests.unittest import TestCase + + +SAMPLE_CONFIG: Mapping[str, Any] = yaml.safe_load( + """ +idp_id: apple +idp_name: Apple +idp_icon: "mxc://matrix.org/blahblahblah" +idp_brand: "apple" +issuer: "https://appleid.apple.com" +client_id: "org.matrix.synapse.sso.service" +client_secret_jwt_key: + key: DUMMY_PRIVATE_KEY + jwt_header: + alg: ES256 + kid: potato123 + jwt_payload: + iss: issuer456 +client_auth_method: "client_secret_post" +scopes: ["name", "email", "openid"] +authorization_endpoint: https://appleid.apple.com/auth/authorize?response_mode=form_post +user_mapping_provider: + config: + email_template: "{{ user.email }}" + localpart_template: "{{ user.email|localpart_from_email }}" + confirm_localpart: true +""" +) + + +class PydanticOIDCTestCase(TestCase): + # Each test gets a dummy config it can change as it sees fit + config: Dict[str, Any] + + def setUp(self) -> None: + self.config = deepcopy(SAMPLE_CONFIG) + + def test_idp_id(self) -> None: + """Demonstrate that Pydantic validates idp_id correctly.""" + # OIDCProviderModel.parse_obj(self.config) + # + # # Enforce that idp_id is required. + # with self.assertRaises(ValidationError): + # del self.config["idp_id"] + # OIDCProviderModel.parse_obj(self.config) + # + # # Enforce that idp_id is a string. + # with self.assertRaises(ValidationError): + # self.config["idp_id"] = 123 + # OIDCProviderModel.parse_obj(self.config) + # with self.assertRaises(ValidationError): + # self.config["idp_id"] = None + # OIDCProviderModel.parse_obj(self.config) + + # Enforce a length between 1 and 250. + with self.assertRaises(ValidationError): + self.config["idp_id"] = "" + OIDCProviderModel.parse_obj(self.config) + with self.assertRaises(ValidationError): + self.config["idp_id"] = "a" * 251 + OIDCProviderModel.parse_obj(self.config) + + # Enforce the character set + with self.assertRaises(ValidationError): + self.config["idp_id"] = "$" + OIDCProviderModel.parse_obj(self.config) diff --git a/tests/config/test_validators.py b/tests/config/test_validators.py new file mode 100644
index 0000000000..7b57a13d5e --- /dev/null +++ b/tests/config/test_validators.py
@@ -0,0 +1,52 @@ +from unittest import TestCase + +from pydantic import BaseModel, ValidationError, validator, StrictStr + +from synapse.config.validators import string_length_between, string_contains_characters + + +class TestValidators(TestCase): + def test_string_length_between(self) -> None: + class TestModel(BaseModel): + x: StrictStr + _x_length = validator("x")(string_length_between(5, 10)) + + with self.assertRaises(ValidationError): + TestModel(x="") + with self.assertRaises(ValidationError): + TestModel(x="a" * 4) + + # Should not raise: + TestModel(x="a" * 5) + TestModel(x="a" * 10) + + with self.assertRaises(ValidationError): + TestModel(x="a" * 11) + with self.assertRaises(ValidationError): + TestModel(x="a" * 1000) + + def test_string_contains_characters(self) -> None: + class TestModel(BaseModel): + x: StrictStr + _x_characters = validator("x")(string_contains_characters("A-Z0-9")) + + # Should not raise + TestModel(x="") + TestModel(x="A") + TestModel(x="B") + TestModel(x="Z") + TestModel(x="123456789") + + with self.assertRaises(ValidationError): + TestModel(x="---") + with self.assertRaises(ValidationError): + TestModel(x="$") + with self.assertRaises(ValidationError): + TestModel(x="A$") + with self.assertRaises(ValidationError): + TestModel(x="a") + with self.assertRaises(ValidationError): + TestModel(x="\u0000") + with self.assertRaises(ValidationError): + TestModel(x="☃") +