diff options
author | David Robertson <davidr@element.io> | 2022-05-20 00:56:02 +0100 |
---|---|---|
committer | David Robertson <davidr@element.io> | 2022-05-20 00:56:02 +0100 |
commit | 2179c6376aa9fbdf650518db35de17e1fe8b6377 (patch) | |
tree | ed0b05152b111f01ede04d9d7d324296bd816cd3 | |
parent | move TYPE_CHECKING workaround outside (diff) | |
download | synapse-2179c6376aa9fbdf650518db35de17e1fe8b6377.tar.xz |
Extra fields and tests
Pleasantly: no pain here
-rw-r--r-- | synapse/config/oidc2.py | 40 | ||||
-rw-r--r-- | tests/config/test_oidc2.py | 159 |
2 files changed, 180 insertions, 19 deletions
diff --git a/synapse/config/oidc2.py b/synapse/config/oidc2.py index 12059f5bef..5224a255bd 100644 --- a/synapse/config/oidc2.py +++ b/synapse/config/oidc2.py @@ -1,3 +1,5 @@ +from enum import Enum + from typing import TYPE_CHECKING, Any, Optional, Tuple from pydantic import BaseModel, StrictBool, StrictStr, constr @@ -7,6 +9,7 @@ from pydantic import BaseModel, StrictBool, StrictStr, constr if TYPE_CHECKING: IDP_ID_TYPE = str + IDP_BRAND_TYPE = str else: IDP_ID_TYPE = constr( strict=True, @@ -14,15 +17,39 @@ else: max_length=250, regex="^[A-Za-z0-9._~-]+$", # noqa: F722 ) + IDP_BRAND_TYPE = constr( + strict=True, + min_length=1, + max_length=255, + regex="^[a-z][a-z0-9_.-]*$", # noqa: F722 + ) + +# the following list of enum members is the same as the keys of +# authlib.oauth2.auth.ClientAuth.DEFAULT_AUTH_METHODS. We inline it +# to avoid importing authlib here. +class ClientAuthMethods(str, Enum): + # The duplication is unfortunate. 3.11 should have StrEnum though, + # and there is a backport available for 3.8.6. + client_secret_basic = "client_secret_basic" + client_secret_post = "client_secret_post" + none = "none" + + +class UserProfileMethod(str, Enum): + # The duplication is unfortunate. 3.11 should have StrEnum though, + # and there is a backport available for 3.8.6. + auto = "auto" + userinfo_endpoint = "userinfo_endpoint" class OIDCProviderModel(BaseModel): """ Notes on Pydantic: - - I've used StrictStr because a plain `str` accepts integers and calls str() on them - - I've factored out the validators here to demonstrate that we can avoid some duplication - if there are common patterns. Otherwise one could use @validator("field_name") and - define the validator function inline. + - I've used StrictStr because a plain `str` e.g. accepts integers and calls str() + on them + - pulling out constr() into IDP_ID_TYPE is a little awkward, but necessary to keep + mypy happy + - """ # a unique identifier for this identity provider. Used in the 'user_external_ids' @@ -63,7 +90,7 @@ class OIDCProviderModel(BaseModel): # auth method to use when exchanging the token. # Valid values are 'client_secret_basic', 'client_secret_post' and # 'none'. - client_auth_method: StrictStr = "client_secret_basic" + client_auth_method: ClientAuthMethods = ClientAuthMethods.client_secret_basic # list of scopes to request scopes: Tuple[StrictStr, ...] = ("openid",) @@ -91,8 +118,7 @@ class OIDCProviderModel(BaseModel): # Whether to fetch the user profile from the userinfo endpoint. Valid # values are: "auto" or "userinfo_endpoint". - # TODO enum - user_profile_method: StrictStr = "auto" + user_profile_method: UserProfileMethod = UserProfileMethod.auto # whether to allow a user logging in via OIDC to match a pre-existing account # instead of failing diff --git a/tests/config/test_oidc2.py b/tests/config/test_oidc2.py index 102b35f9cc..e340a7d43b 100644 --- a/tests/config/test_oidc2.py +++ b/tests/config/test_oidc2.py @@ -4,7 +4,7 @@ from typing import Any, Dict import yaml from pydantic import ValidationError -from synapse.config.oidc2 import OIDCProviderModel +from synapse.config.oidc2 import OIDCProviderModel, ClientAuthMethods from tests.unittest import TestCase @@ -36,29 +36,31 @@ user_mapping_provider: class PydanticOIDCTestCase(TestCase): + """Examples to build confidence that pydantic is doing the validation we think + it's doing""" + # Each test gets a dummy config it can change as it sees fit config: Dict[str, Any] def setUp(self) -> None: self.config = deepcopy(SAMPLE_CONFIG) - def test_idp_id(self) -> None: - """Demonstrate that Pydantic validates idp_id correctly.""" + def test_example_config(self): + # Check that parsing the sample config doesn't raise an error. OIDCProviderModel.parse_obj(self.config) + def test_idp_id(self) -> None: + """Example of using a Pydantic constr() field without a default.""" # Enforce that idp_id is required. with self.assertRaises(ValidationError): del self.config["idp_id"] OIDCProviderModel.parse_obj(self.config) # Enforce that idp_id is a string. - with self.assertRaises(ValidationError) as e: - self.config["idp_id"] = 123 - OIDCProviderModel.parse_obj(self.config) - print(e.exception) - with self.assertRaises(ValidationError): - self.config["idp_id"] = None - OIDCProviderModel.parse_obj(self.config) + for bad_vlaue in 123, None, ["a"], {"a": "b"}: + with self.assertRaises(ValidationError) as e: + self.config["idp_id"] = bad_vlaue + OIDCProviderModel.parse_obj(self.config) # Enforce a length between 1 and 250. with self.assertRaises(ValidationError): @@ -68,7 +70,7 @@ class PydanticOIDCTestCase(TestCase): self.config["idp_id"] = "a" * 251 OIDCProviderModel.parse_obj(self.config) - # Enforce the character set + # Enforce the regex with self.assertRaises(ValidationError): self.config["idp_id"] = "$" OIDCProviderModel.parse_obj(self.config) @@ -77,4 +79,137 @@ class PydanticOIDCTestCase(TestCase): with self.assertRaises(ValidationError) as e: self.config["idp_id"] = "$" * 500 OIDCProviderModel.parse_obj(self.config) - print(e.exception) + + def test_issuer(self) -> None: + """Example of a StrictStr field without a default.""" + + # Empty and nonempty strings should be accepted. + for good_value in "", "hello", "hello" * 1000, "☃": + self.config["issuer"] = good_value + OIDCProviderModel.parse_obj(self.config) + + # Invalid types should be rejected. + for bad_value in 123, None, ["h", "e", "l", "l", "o"], {"hello": "there"}: + with self.assertRaises(ValidationError): + self.config["issuer"] = bad_value + OIDCProviderModel.parse_obj(self.config) + + # A missing issuer should be rejected. + with self.assertRaises(ValidationError): + del self.config["issuer"] + OIDCProviderModel.parse_obj(self.config) + + def test_idp_brand(self) -> None: + """Example of an Optional[StrictStr] field.""" + # Empty and nonempty strings should be accepted. + for good_value in "", "hello", "hello" * 1000, "☃": + self.config["idp_brand"] = good_value + OIDCProviderModel.parse_obj(self.config) + + # Invalid types should be rejected. + for bad_value in 123, ["h", "e", "l", "l", "o"], {"hello": "there"}: + with self.assertRaises(ValidationError): + self.config["idp_brand"] = bad_value + OIDCProviderModel.parse_obj(self.config) + + # A lack of an idp_brand is fine... + del self.config["idp_brand"] + model = OIDCProviderModel.parse_obj(self.config) + self.assertIsNone(model.idp_brand) + + # ... and interpreted the same as an explicit `None`. + self.config["idp_brand"] = None + model = OIDCProviderModel.parse_obj(self.config) + self.assertIsNone(model.idp_brand) + + def test_discover(self) -> None: + """Example of a StrictBool field with a default.""" + # Booleans are permitted. + for value in True, False: + self.config["discover"] = value + model = OIDCProviderModel.parse_obj(self.config) + self.assertEqual(model.discover, value) + + # Invalid types should be rejected. + for bad_value in ( + -1.0, + 0, + 1, + float("nan"), + "yes", + "NO", + "True", + "true", + None, + "None", + "null", + ["a"], + {"a": "b"}, + ): + self.config["discover"] = bad_value + with self.assertRaises(ValidationError): + OIDCProviderModel.parse_obj(self.config) + + # A missing value is okay, because this field has a default. + del self.config["discover"] + model = OIDCProviderModel.parse_obj(self.config) + self.assertIs(model.discover, True) + + def test_client_auth_method(self) -> None: + """This is an example of using a Pydantic string enum field.""" + # check the allowed values are permitted and deserialise to an enum member + for method in "client_secret_basic", "client_secret_post", "none": + self.config["client_auth_method"] = method + model = OIDCProviderModel.parse_obj(self.config) + self.assertIs(model.client_auth_method, ClientAuthMethods[method]) + + # check the default applies if no auth method is provided. + del self.config["client_auth_method"] + model = OIDCProviderModel.parse_obj(self.config) + self.assertIs(model.client_auth_method, ClientAuthMethods.client_secret_basic) + + # Check invalid types are rejected + for bad_value in 123, ["client_secret_basic"], {"a": 1}, None: + with self.assertRaises(ValidationError): + self.config["client_auth_method"] = bad_value + OIDCProviderModel.parse_obj(self.config) + + # Check that disallowed strings are rejected + with self.assertRaises(ValidationError): + self.config["client_auth_method"] = "No, Luke, _I_ am your father!" + OIDCProviderModel.parse_obj(self.config) + + def test_scopes(self) -> None: + """Example of a Tuple[StrictStr] with a default.""" + # Check that the parsed object holds a tuple + self.config["scopes"] = [] + model = OIDCProviderModel.parse_obj(self.config) + self.assertEqual(model.scopes, ()) + + # Check a variety of list lengths are accepted. + for good_value in ["aa"], ["hello", "world"], ["a"] * 4, [""] * 20: + self.config["scopes"] = good_value + model = OIDCProviderModel.parse_obj(self.config) + self.assertEqual(model.scopes, tuple(good_value)) + + # Check invalid types are rejected. + for bad_value in ( + "", + "abc", + 123, + {}, + {"a": 1}, + None, + [None], + [["a"]], + [{}], + [456], + ): + with self.assertRaises(ValidationError): + self.config["scopes"] = bad_value + OIDCProviderModel.parse_obj(self.config) + + # Check that "scopes" may be omitted. + del self.config["scopes"] + model = OIDCProviderModel.parse_obj(self.config) + self.assertEqual(model.scopes, ("openid",)) |