diff options
author | David Robertson <davidr@element.io> | 2022-05-18 23:12:29 +0100 |
---|---|---|
committer | David Robertson <davidr@element.io> | 2022-05-19 10:29:27 +0100 |
commit | 348b53fe9c16068f86b0afe37bd31e959de81ae6 (patch) | |
tree | 8b5f4dd3e6ec96c010d24c70d04114f6be6b31a0 | |
parent | Require and lock `pydantic` (diff) | |
download | synapse-348b53fe9c16068f86b0afe37bd31e959de81ae6.tar.xz |
WIP trying out validators
-rw-r--r-- | synapse/config/oidc2.py | 104 | ||||
-rw-r--r-- | synapse/config/validators.py | 31 | ||||
-rw-r--r-- | tests/config/test_oidc2.py | 73 | ||||
-rw-r--r-- | tests/config/test_validators.py | 52 |
4 files changed, 260 insertions, 0 deletions
diff --git a/synapse/config/oidc2.py b/synapse/config/oidc2.py new file mode 100644 index 0000000000..64005deae9 --- /dev/null +++ b/synapse/config/oidc2.py @@ -0,0 +1,104 @@ +from typing import Optional, Tuple, Any + +from pydantic import BaseModel, StrictStr, validator, StrictBool +from synapse.config.validators import string_length_between, string_contains_characters + + +class OIDCProviderModel(BaseModel): + """ + Notes on Pydantic: + - I've used StrictStr because a plain `str` accepts integers and calls str() on them + - I've factored out the validators here to demonstrate that we can avoid some duplication + if there are common patterns. Otherwise one could use @validator("field_name") and + define the validator function inline. + """ + + # a unique identifier for this identity provider. Used in the 'user_external_ids' + # table, as well as the query/path parameter used in the login protocol. + # TODO: this is optional in the old-style config, defaulting to "oidc". + idp_id: StrictStr + _idp_id_length = validator("idp_id")(string_length_between(1, 250)) + _idp_id_characters = validator("idp_id")( + string_contains_characters("A-Za-z0-9._~-") + ) + + # user-facing name for this identity provider. + # TODO: this is optional in the old-style config, defaulting to "OIDC". + idp_name: StrictStr + + # Optional MXC URI for icon for this IdP. + # TODO: validate that this is an MXC URI. + idp_icon: Optional[StrictStr] + + # Optional brand identifier for this IdP. + idp_brand: Optional[StrictStr] + + # whether the OIDC discovery mechanism is used to discover endpoints + discover: StrictBool = True + + # the OIDC issuer. Used to validate tokens and (if discovery is enabled) to + # discover the provider's endpoints. + issuer: StrictStr + + # oauth2 client id to use + client_id: StrictStr + + # oauth2 client secret to use. if `None`, use client_secret_jwt_key to generate + # a secret. + client_secret: Optional[StrictStr] + + # key to use to construct a JWT to use as a client secret. May be `None` if + # `client_secret` is set. + # TODO + client_secret_jwt_key: Optional[Any] # OidcProviderClientSecretJwtKey] + + # auth method to use when exchanging the token. + # Valid values are 'client_secret_basic', 'client_secret_post' and + # 'none'. + client_auth_method: StrictStr = "client_secret_basic" + + # list of scopes to request + scopes: Tuple[StrictStr, ...] = ("openid",) + + # the oauth2 authorization endpoint. Required if discovery is disabled. + # TODO: required if discovery is disabled + authorization_endpoint: Optional[StrictStr] + + # the oauth2 token endpoint. Required if discovery is disabled. + # TODO: required if discovery is disabled + token_endpoint: Optional[StrictStr] + + # the OIDC userinfo endpoint. Required if discovery is disabled and the + # "openid" scope is not requested. + # TODO: required if discovery is disabled and the openid scope isn't requested + userinfo_endpoint: Optional[StrictStr] + + # URI where to fetch the JWKS. Required if discovery is disabled and the + # "openid" scope is used. + # TODO: required if discovery is disabled and the openid scope IS requested + jwks_uri: Optional[StrictStr] + + # Whether to skip metadata verification + skip_verification: StrictBool = False + + # Whether to fetch the user profile from the userinfo endpoint. Valid + # values are: "auto" or "userinfo_endpoint". + # TODO enum + user_profile_method: StrictStr = "auto" + + # whether to allow a user logging in via OIDC to match a pre-existing account + # instead of failing + allow_existing_users: StrictBool = False + + # the class of the user mapping provider + # TODO + user_mapping_provider_class: Any # TODO: Type + + # the config of the user mapping provider + # TODO + user_mapping_provider_config: Any + + # required attributes to require in userinfo to allow login/registration + attribute_requirements: Tuple[ + Any, ... + ] = tuple() # TODO SsoAttributeRequirement] = tuple() diff --git a/synapse/config/validators.py b/synapse/config/validators.py new file mode 100644 index 0000000000..2faae0cf4a --- /dev/null +++ b/synapse/config/validators.py @@ -0,0 +1,31 @@ +import re +from typing import Type + +from pydantic import BaseModel +from pydantic.fields import ModelField + + +def string_length_between(lower: int, upper: int): + def validator(cls: Type[BaseModel], value: str, field: ModelField) -> str: + print(f"validate {lower=} {upper=} {value=}") + if lower <= len(value) <= upper: + print("ok") + return value + print("bad") + raise ValueError( + f"{field.name} must be between {lower} and {upper} characters long" + ) + + return validator + + +def string_contains_characters(charset: str): + def validator(cls: Type[BaseModel], value: str, field: ModelField) -> str: + pattern = f"^[{charset}]*$" + if re.match(pattern, value): + return value + raise ValueError( + f"{field.name} must be only contain the characters {charset}" + ) + + return validator diff --git a/tests/config/test_oidc2.py b/tests/config/test_oidc2.py new file mode 100644 index 0000000000..021eba9511 --- /dev/null +++ b/tests/config/test_oidc2.py @@ -0,0 +1,73 @@ +from copy import deepcopy +from typing import Any, Mapping, Dict + +import yaml +from pydantic import ValidationError + +from synapse.config.oidc2 import OIDCProviderModel +from tests.unittest import TestCase + + +SAMPLE_CONFIG: Mapping[str, Any] = yaml.safe_load( + """ +idp_id: apple +idp_name: Apple +idp_icon: "mxc://matrix.org/blahblahblah" +idp_brand: "apple" +issuer: "https://appleid.apple.com" +client_id: "org.matrix.synapse.sso.service" +client_secret_jwt_key: + key: DUMMY_PRIVATE_KEY + jwt_header: + alg: ES256 + kid: potato123 + jwt_payload: + iss: issuer456 +client_auth_method: "client_secret_post" +scopes: ["name", "email", "openid"] +authorization_endpoint: https://appleid.apple.com/auth/authorize?response_mode=form_post +user_mapping_provider: + config: + email_template: "{{ user.email }}" + localpart_template: "{{ user.email|localpart_from_email }}" + confirm_localpart: true +""" +) + + +class PydanticOIDCTestCase(TestCase): + # Each test gets a dummy config it can change as it sees fit + config: Dict[str, Any] + + def setUp(self) -> None: + self.config = deepcopy(SAMPLE_CONFIG) + + def test_idp_id(self) -> None: + """Demonstrate that Pydantic validates idp_id correctly.""" + # OIDCProviderModel.parse_obj(self.config) + # + # # Enforce that idp_id is required. + # with self.assertRaises(ValidationError): + # del self.config["idp_id"] + # OIDCProviderModel.parse_obj(self.config) + # + # # Enforce that idp_id is a string. + # with self.assertRaises(ValidationError): + # self.config["idp_id"] = 123 + # OIDCProviderModel.parse_obj(self.config) + # with self.assertRaises(ValidationError): + # self.config["idp_id"] = None + # OIDCProviderModel.parse_obj(self.config) + + # Enforce a length between 1 and 250. + with self.assertRaises(ValidationError): + self.config["idp_id"] = "" + OIDCProviderModel.parse_obj(self.config) + with self.assertRaises(ValidationError): + self.config["idp_id"] = "a" * 251 + OIDCProviderModel.parse_obj(self.config) + + # Enforce the character set + with self.assertRaises(ValidationError): + self.config["idp_id"] = "$" + OIDCProviderModel.parse_obj(self.config) diff --git a/tests/config/test_validators.py b/tests/config/test_validators.py new file mode 100644 index 0000000000..7b57a13d5e --- /dev/null +++ b/tests/config/test_validators.py @@ -0,0 +1,52 @@ +from unittest import TestCase + +from pydantic import BaseModel, ValidationError, validator, StrictStr + +from synapse.config.validators import string_length_between, string_contains_characters + + +class TestValidators(TestCase): + def test_string_length_between(self) -> None: + class TestModel(BaseModel): + x: StrictStr + _x_length = validator("x")(string_length_between(5, 10)) + + with self.assertRaises(ValidationError): + TestModel(x="") + with self.assertRaises(ValidationError): + TestModel(x="a" * 4) + + # Should not raise: + TestModel(x="a" * 5) + TestModel(x="a" * 10) + + with self.assertRaises(ValidationError): + TestModel(x="a" * 11) + with self.assertRaises(ValidationError): + TestModel(x="a" * 1000) + + def test_string_contains_characters(self) -> None: + class TestModel(BaseModel): + x: StrictStr + _x_characters = validator("x")(string_contains_characters("A-Z0-9")) + + # Should not raise + TestModel(x="") + TestModel(x="A") + TestModel(x="B") + TestModel(x="Z") + TestModel(x="123456789") + + with self.assertRaises(ValidationError): + TestModel(x="---") + with self.assertRaises(ValidationError): + TestModel(x="$") + with self.assertRaises(ValidationError): + TestModel(x="A$") + with self.assertRaises(ValidationError): + TestModel(x="a") + with self.assertRaises(ValidationError): + TestModel(x="\u0000") + with self.assertRaises(ValidationError): + TestModel(x="☃") + |