diff --git a/synapse/config/oidc2.py b/synapse/config/oidc2.py
new file mode 100644
index 0000000000..64005deae9
--- /dev/null
+++ b/synapse/config/oidc2.py
@@ -0,0 +1,104 @@
+from typing import Optional, Tuple, Any
+
+from pydantic import BaseModel, StrictStr, validator, StrictBool
+from synapse.config.validators import string_length_between, string_contains_characters
+
+
+class OIDCProviderModel(BaseModel):
+ """
+ Notes on Pydantic:
+ - I've used StrictStr because a plain `str` accepts integers and calls str() on them
+ - I've factored out the validators here to demonstrate that we can avoid some duplication
+ if there are common patterns. Otherwise one could use @validator("field_name") and
+ define the validator function inline.
+ """
+
+ # a unique identifier for this identity provider. Used in the 'user_external_ids'
+ # table, as well as the query/path parameter used in the login protocol.
+ # TODO: this is optional in the old-style config, defaulting to "oidc".
+ idp_id: StrictStr
+ _idp_id_length = validator("idp_id")(string_length_between(1, 250))
+ _idp_id_characters = validator("idp_id")(
+ string_contains_characters("A-Za-z0-9._~-")
+ )
+
+ # user-facing name for this identity provider.
+ # TODO: this is optional in the old-style config, defaulting to "OIDC".
+ idp_name: StrictStr
+
+ # Optional MXC URI for icon for this IdP.
+ # TODO: validate that this is an MXC URI.
+ idp_icon: Optional[StrictStr]
+
+ # Optional brand identifier for this IdP.
+ idp_brand: Optional[StrictStr]
+
+ # whether the OIDC discovery mechanism is used to discover endpoints
+ discover: StrictBool = True
+
+ # the OIDC issuer. Used to validate tokens and (if discovery is enabled) to
+ # discover the provider's endpoints.
+ issuer: StrictStr
+
+ # oauth2 client id to use
+ client_id: StrictStr
+
+ # oauth2 client secret to use. if `None`, use client_secret_jwt_key to generate
+ # a secret.
+ client_secret: Optional[StrictStr]
+
+ # key to use to construct a JWT to use as a client secret. May be `None` if
+ # `client_secret` is set.
+ # TODO
+ client_secret_jwt_key: Optional[Any] # OidcProviderClientSecretJwtKey]
+
+ # auth method to use when exchanging the token.
+ # Valid values are 'client_secret_basic', 'client_secret_post' and
+ # 'none'.
+ client_auth_method: StrictStr = "client_secret_basic"
+
+ # list of scopes to request
+ scopes: Tuple[StrictStr, ...] = ("openid",)
+
+ # the oauth2 authorization endpoint. Required if discovery is disabled.
+ # TODO: required if discovery is disabled
+ authorization_endpoint: Optional[StrictStr]
+
+ # the oauth2 token endpoint. Required if discovery is disabled.
+ # TODO: required if discovery is disabled
+ token_endpoint: Optional[StrictStr]
+
+ # the OIDC userinfo endpoint. Required if discovery is disabled and the
+ # "openid" scope is not requested.
+ # TODO: required if discovery is disabled and the openid scope isn't requested
+ userinfo_endpoint: Optional[StrictStr]
+
+ # URI where to fetch the JWKS. Required if discovery is disabled and the
+ # "openid" scope is used.
+ # TODO: required if discovery is disabled and the openid scope IS requested
+ jwks_uri: Optional[StrictStr]
+
+ # Whether to skip metadata verification
+ skip_verification: StrictBool = False
+
+ # Whether to fetch the user profile from the userinfo endpoint. Valid
+ # values are: "auto" or "userinfo_endpoint".
+ # TODO enum
+ user_profile_method: StrictStr = "auto"
+
+ # whether to allow a user logging in via OIDC to match a pre-existing account
+ # instead of failing
+ allow_existing_users: StrictBool = False
+
+ # the class of the user mapping provider
+ # TODO
+ user_mapping_provider_class: Any # TODO: Type
+
+ # the config of the user mapping provider
+ # TODO
+ user_mapping_provider_config: Any
+
+ # required attributes to require in userinfo to allow login/registration
+ attribute_requirements: Tuple[
+ Any, ...
+ ] = tuple() # TODO SsoAttributeRequirement] = tuple()
diff --git a/synapse/config/validators.py b/synapse/config/validators.py
new file mode 100644
index 0000000000..2faae0cf4a
--- /dev/null
+++ b/synapse/config/validators.py
@@ -0,0 +1,31 @@
+import re
+from typing import Type
+
+from pydantic import BaseModel
+from pydantic.fields import ModelField
+
+
+def string_length_between(lower: int, upper: int):
+ def validator(cls: Type[BaseModel], value: str, field: ModelField) -> str:
+ print(f"validate {lower=} {upper=} {value=}")
+ if lower <= len(value) <= upper:
+ print("ok")
+ return value
+ print("bad")
+ raise ValueError(
+ f"{field.name} must be between {lower} and {upper} characters long"
+ )
+
+ return validator
+
+
+def string_contains_characters(charset: str):
+ def validator(cls: Type[BaseModel], value: str, field: ModelField) -> str:
+ pattern = f"^[{charset}]*$"
+ if re.match(pattern, value):
+ return value
+ raise ValueError(
+ f"{field.name} must be only contain the characters {charset}"
+ )
+
+ return validator
diff --git a/tests/config/test_oidc2.py b/tests/config/test_oidc2.py
new file mode 100644
index 0000000000..021eba9511
--- /dev/null
+++ b/tests/config/test_oidc2.py
@@ -0,0 +1,73 @@
+from copy import deepcopy
+from typing import Any, Mapping, Dict
+
+import yaml
+from pydantic import ValidationError
+
+from synapse.config.oidc2 import OIDCProviderModel
+from tests.unittest import TestCase
+
+
+SAMPLE_CONFIG: Mapping[str, Any] = yaml.safe_load(
+ """
+idp_id: apple
+idp_name: Apple
+idp_icon: "mxc://matrix.org/blahblahblah"
+idp_brand: "apple"
+issuer: "https://appleid.apple.com"
+client_id: "org.matrix.synapse.sso.service"
+client_secret_jwt_key:
+ key: DUMMY_PRIVATE_KEY
+ jwt_header:
+ alg: ES256
+ kid: potato123
+ jwt_payload:
+ iss: issuer456
+client_auth_method: "client_secret_post"
+scopes: ["name", "email", "openid"]
+authorization_endpoint: https://appleid.apple.com/auth/authorize?response_mode=form_post
+user_mapping_provider:
+ config:
+ email_template: "{{ user.email }}"
+ localpart_template: "{{ user.email|localpart_from_email }}"
+ confirm_localpart: true
+"""
+)
+
+
+class PydanticOIDCTestCase(TestCase):
+ # Each test gets a dummy config it can change as it sees fit
+ config: Dict[str, Any]
+
+ def setUp(self) -> None:
+ self.config = deepcopy(SAMPLE_CONFIG)
+
+ def test_idp_id(self) -> None:
+ """Demonstrate that Pydantic validates idp_id correctly."""
+ # OIDCProviderModel.parse_obj(self.config)
+ #
+ # # Enforce that idp_id is required.
+ # with self.assertRaises(ValidationError):
+ # del self.config["idp_id"]
+ # OIDCProviderModel.parse_obj(self.config)
+ #
+ # # Enforce that idp_id is a string.
+ # with self.assertRaises(ValidationError):
+ # self.config["idp_id"] = 123
+ # OIDCProviderModel.parse_obj(self.config)
+ # with self.assertRaises(ValidationError):
+ # self.config["idp_id"] = None
+ # OIDCProviderModel.parse_obj(self.config)
+
+ # Enforce a length between 1 and 250.
+ with self.assertRaises(ValidationError):
+ self.config["idp_id"] = ""
+ OIDCProviderModel.parse_obj(self.config)
+ with self.assertRaises(ValidationError):
+ self.config["idp_id"] = "a" * 251
+ OIDCProviderModel.parse_obj(self.config)
+
+ # Enforce the character set
+ with self.assertRaises(ValidationError):
+ self.config["idp_id"] = "$"
+ OIDCProviderModel.parse_obj(self.config)
diff --git a/tests/config/test_validators.py b/tests/config/test_validators.py
new file mode 100644
index 0000000000..7b57a13d5e
--- /dev/null
+++ b/tests/config/test_validators.py
@@ -0,0 +1,52 @@
+from unittest import TestCase
+
+from pydantic import BaseModel, ValidationError, validator, StrictStr
+
+from synapse.config.validators import string_length_between, string_contains_characters
+
+
+class TestValidators(TestCase):
+ def test_string_length_between(self) -> None:
+ class TestModel(BaseModel):
+ x: StrictStr
+ _x_length = validator("x")(string_length_between(5, 10))
+
+ with self.assertRaises(ValidationError):
+ TestModel(x="")
+ with self.assertRaises(ValidationError):
+ TestModel(x="a" * 4)
+
+ # Should not raise:
+ TestModel(x="a" * 5)
+ TestModel(x="a" * 10)
+
+ with self.assertRaises(ValidationError):
+ TestModel(x="a" * 11)
+ with self.assertRaises(ValidationError):
+ TestModel(x="a" * 1000)
+
+ def test_string_contains_characters(self) -> None:
+ class TestModel(BaseModel):
+ x: StrictStr
+ _x_characters = validator("x")(string_contains_characters("A-Z0-9"))
+
+ # Should not raise
+ TestModel(x="")
+ TestModel(x="A")
+ TestModel(x="B")
+ TestModel(x="Z")
+ TestModel(x="123456789")
+
+ with self.assertRaises(ValidationError):
+ TestModel(x="---")
+ with self.assertRaises(ValidationError):
+ TestModel(x="$")
+ with self.assertRaises(ValidationError):
+ TestModel(x="A$")
+ with self.assertRaises(ValidationError):
+ TestModel(x="a")
+ with self.assertRaises(ValidationError):
+ TestModel(x="\u0000")
+ with self.assertRaises(ValidationError):
+ TestModel(x="☃")
+
|