summary refs log tree commit diff
diff options
context:
space:
mode:
authorDavid Robertson <davidr@element.io>2022-05-18 23:12:29 +0100
committerDavid Robertson <davidr@element.io>2022-05-19 10:29:27 +0100
commit348b53fe9c16068f86b0afe37bd31e959de81ae6 (patch)
tree8b5f4dd3e6ec96c010d24c70d04114f6be6b31a0
parentRequire and lock `pydantic` (diff)
downloadsynapse-348b53fe9c16068f86b0afe37bd31e959de81ae6.tar.xz
WIP trying out validators
-rw-r--r--synapse/config/oidc2.py104
-rw-r--r--synapse/config/validators.py31
-rw-r--r--tests/config/test_oidc2.py73
-rw-r--r--tests/config/test_validators.py52
4 files changed, 260 insertions, 0 deletions
diff --git a/synapse/config/oidc2.py b/synapse/config/oidc2.py
new file mode 100644
index 0000000000..64005deae9
--- /dev/null
+++ b/synapse/config/oidc2.py
@@ -0,0 +1,104 @@
+from typing import Optional, Tuple, Any
+
+from pydantic import BaseModel, StrictStr, validator, StrictBool
+from synapse.config.validators import string_length_between, string_contains_characters
+
+
+class OIDCProviderModel(BaseModel):
+    """
+    Notes on Pydantic:
+    - I've used StrictStr because a plain `str` accepts integers and calls str() on them
+    - I've factored out the validators here to demonstrate that we can avoid some duplication
+      if there are common patterns. Otherwise one could use @validator("field_name") and
+      define the validator function inline.
+    """
+
+    # a unique identifier for this identity provider. Used in the 'user_external_ids'
+    # table, as well as the query/path parameter used in the login protocol.
+    # TODO: this is optional in the old-style config, defaulting to "oidc".
+    idp_id: StrictStr
+    _idp_id_length = validator("idp_id")(string_length_between(1, 250))
+    _idp_id_characters = validator("idp_id")(
+        string_contains_characters("A-Za-z0-9._~-")
+    )
+
+    # user-facing name for this identity provider.
+    # TODO: this is optional in the old-style config, defaulting to "OIDC".
+    idp_name: StrictStr
+
+    # Optional MXC URI for icon for this IdP.
+    # TODO: validate that this is an MXC URI.
+    idp_icon: Optional[StrictStr]
+
+    # Optional brand identifier for this IdP.
+    idp_brand: Optional[StrictStr]
+
+    # whether the OIDC discovery mechanism is used to discover endpoints
+    discover: StrictBool = True
+
+    # the OIDC issuer. Used to validate tokens and (if discovery is enabled) to
+    # discover the provider's endpoints.
+    issuer: StrictStr
+
+    # oauth2 client id to use
+    client_id: StrictStr
+
+    # oauth2 client secret to use. if `None`, use client_secret_jwt_key to generate
+    # a secret.
+    client_secret: Optional[StrictStr]
+
+    # key to use to construct a JWT to use as a client secret. May be `None` if
+    # `client_secret` is set.
+    # TODO
+    client_secret_jwt_key: Optional[Any]  # OidcProviderClientSecretJwtKey]
+
+    # auth method to use when exchanging the token.
+    # Valid values are 'client_secret_basic', 'client_secret_post' and
+    # 'none'.
+    client_auth_method: StrictStr = "client_secret_basic"
+
+    # list of scopes to request
+    scopes: Tuple[StrictStr, ...] = ("openid",)
+
+    # the oauth2 authorization endpoint. Required if discovery is disabled.
+    # TODO: required if discovery is disabled
+    authorization_endpoint: Optional[StrictStr]
+
+    # the oauth2 token endpoint. Required if discovery is disabled.
+    # TODO: required if discovery is disabled
+    token_endpoint: Optional[StrictStr]
+
+    # the OIDC userinfo endpoint. Required if discovery is disabled and the
+    # "openid" scope is not requested.
+    # TODO: required if discovery is disabled and the openid scope isn't requested
+    userinfo_endpoint: Optional[StrictStr]
+
+    # URI where to fetch the JWKS. Required if discovery is disabled and the
+    # "openid" scope is used.
+    # TODO: required if discovery is disabled and the openid scope IS requested
+    jwks_uri: Optional[StrictStr]
+
+    # Whether to skip metadata verification
+    skip_verification: StrictBool = False
+
+    # Whether to fetch the user profile from the userinfo endpoint. Valid
+    # values are: "auto" or "userinfo_endpoint".
+    # TODO enum
+    user_profile_method: StrictStr = "auto"
+
+    # whether to allow a user logging in via OIDC to match a pre-existing account
+    # instead of failing
+    allow_existing_users: StrictBool = False
+
+    # the class of the user mapping provider
+    # TODO
+    user_mapping_provider_class: Any  # TODO: Type
+
+    # the config of the user mapping provider
+    # TODO
+    user_mapping_provider_config: Any
+
+    # required attributes to require in userinfo to allow login/registration
+    attribute_requirements: Tuple[
+        Any, ...
+    ] = tuple()  # TODO SsoAttributeRequirement] = tuple()
diff --git a/synapse/config/validators.py b/synapse/config/validators.py
new file mode 100644
index 0000000000..2faae0cf4a
--- /dev/null
+++ b/synapse/config/validators.py
@@ -0,0 +1,31 @@
+import re
+from typing import Type
+
+from pydantic import BaseModel
+from pydantic.fields import ModelField
+
+
+def string_length_between(lower: int, upper: int):
+    def validator(cls: Type[BaseModel], value: str, field: ModelField) -> str:
+        print(f"validate {lower=} {upper=} {value=}")
+        if lower <= len(value) <= upper:
+            print("ok")
+            return value
+        print("bad")
+        raise ValueError(
+            f"{field.name} must be between {lower} and {upper} characters long"
+        )
+
+    return validator
+
+
+def string_contains_characters(charset: str):
+    def validator(cls: Type[BaseModel], value: str, field: ModelField) -> str:
+        pattern = f"^[{charset}]*$"
+        if re.match(pattern, value):
+            return value
+        raise ValueError(
+            f"{field.name} must be only contain the characters {charset}"
+        )
+
+    return validator
diff --git a/tests/config/test_oidc2.py b/tests/config/test_oidc2.py
new file mode 100644
index 0000000000..021eba9511
--- /dev/null
+++ b/tests/config/test_oidc2.py
@@ -0,0 +1,73 @@
+from copy import deepcopy
+from typing import Any, Mapping, Dict
+
+import yaml
+from pydantic import ValidationError
+
+from synapse.config.oidc2 import OIDCProviderModel
+from tests.unittest import TestCase
+
+
+SAMPLE_CONFIG: Mapping[str, Any] = yaml.safe_load(
+    """
+idp_id: apple
+idp_name: Apple
+idp_icon: "mxc://matrix.org/blahblahblah"
+idp_brand: "apple"
+issuer: "https://appleid.apple.com"
+client_id: "org.matrix.synapse.sso.service"
+client_secret_jwt_key:
+  key: DUMMY_PRIVATE_KEY
+  jwt_header:
+    alg: ES256
+    kid: potato123
+  jwt_payload:
+    iss: issuer456
+client_auth_method: "client_secret_post"
+scopes: ["name", "email", "openid"]
+authorization_endpoint: https://appleid.apple.com/auth/authorize?response_mode=form_post
+user_mapping_provider:
+  config:
+    email_template: "{{ user.email }}"
+    localpart_template: "{{ user.email|localpart_from_email }}"
+    confirm_localpart: true
+"""
+)
+
+
+class PydanticOIDCTestCase(TestCase):
+    # Each test gets a dummy config it can change as it sees fit
+    config: Dict[str, Any]
+
+    def setUp(self) -> None:
+        self.config = deepcopy(SAMPLE_CONFIG)
+
+    def test_idp_id(self) -> None:
+        """Demonstrate that Pydantic validates idp_id correctly."""
+        # OIDCProviderModel.parse_obj(self.config)
+        #
+        # # Enforce that idp_id is required.
+        # with self.assertRaises(ValidationError):
+        #     del self.config["idp_id"]
+        #     OIDCProviderModel.parse_obj(self.config)
+        #
+        # # Enforce that idp_id is a string.
+        # with self.assertRaises(ValidationError):
+        #     self.config["idp_id"] = 123
+        #     OIDCProviderModel.parse_obj(self.config)
+        # with self.assertRaises(ValidationError):
+        #     self.config["idp_id"] = None
+        #     OIDCProviderModel.parse_obj(self.config)
+
+        # Enforce a length between 1 and 250.
+        with self.assertRaises(ValidationError):
+            self.config["idp_id"] = ""
+            OIDCProviderModel.parse_obj(self.config)
+        with self.assertRaises(ValidationError):
+            self.config["idp_id"] = "a" * 251
+            OIDCProviderModel.parse_obj(self.config)
+
+        # Enforce the character set
+        with self.assertRaises(ValidationError):
+            self.config["idp_id"] = "$"
+            OIDCProviderModel.parse_obj(self.config)
diff --git a/tests/config/test_validators.py b/tests/config/test_validators.py
new file mode 100644
index 0000000000..7b57a13d5e
--- /dev/null
+++ b/tests/config/test_validators.py
@@ -0,0 +1,52 @@
+from unittest import TestCase
+
+from pydantic import BaseModel, ValidationError, validator, StrictStr
+
+from synapse.config.validators import string_length_between, string_contains_characters
+
+
+class TestValidators(TestCase):
+    def test_string_length_between(self) -> None:
+        class TestModel(BaseModel):
+            x: StrictStr
+            _x_length = validator("x")(string_length_between(5, 10))
+
+        with self.assertRaises(ValidationError):
+            TestModel(x="")
+        with self.assertRaises(ValidationError):
+            TestModel(x="a" * 4)
+
+        # Should not raise:
+        TestModel(x="a" * 5)
+        TestModel(x="a" * 10)
+
+        with self.assertRaises(ValidationError):
+            TestModel(x="a" * 11)
+        with self.assertRaises(ValidationError):
+            TestModel(x="a" * 1000)
+
+    def test_string_contains_characters(self) -> None:
+        class TestModel(BaseModel):
+            x: StrictStr
+            _x_characters = validator("x")(string_contains_characters("A-Z0-9"))
+
+        # Should not raise
+        TestModel(x="")
+        TestModel(x="A")
+        TestModel(x="B")
+        TestModel(x="Z")
+        TestModel(x="123456789")
+
+        with self.assertRaises(ValidationError):
+            TestModel(x="---")
+        with self.assertRaises(ValidationError):
+            TestModel(x="$")
+        with self.assertRaises(ValidationError):
+            TestModel(x="A$")
+        with self.assertRaises(ValidationError):
+            TestModel(x="a")
+        with self.assertRaises(ValidationError):
+            TestModel(x="\u0000")
+        with self.assertRaises(ValidationError):
+            TestModel(x="☃")
+