diff options
author | David Robertson <davidr@element.io> | 2022-05-22 18:00:10 +0100 |
---|---|---|
committer | David Robertson <davidr@element.io> | 2022-05-22 18:00:10 +0100 |
commit | 26cb343d40beb186eebbe7796596b079c0e24971 (patch) | |
tree | b3b0536820204b5d6737a4339bb0c77be27f0caa | |
parent | A few examples (diff) | |
download | synapse-26cb343d40beb186eebbe7796596b079c0e24971.tar.xz |
SSOAttributeRequirement
-rw-r--r-- | synapse/config/oidc2.py | 34 | ||||
-rw-r--r-- | tests/config/test_oidc2.py | 66 |
2 files changed, 90 insertions, 10 deletions
diff --git a/synapse/config/oidc2.py b/synapse/config/oidc2.py index eb00a622fc..ab1a79c272 100644 --- a/synapse/config/oidc2.py +++ b/synapse/config/oidc2.py @@ -44,6 +44,19 @@ class UserProfileMethod(str, Enum): userinfo_endpoint = "userinfo_endpoint" +class SSOAttributeRequirement(BaseModel): + class Config: + # Complain if someone provides a field that's not one of those listed here. + # Pydantic suggests making your own BaseModel subclass if you want to do this, + # see https://pydantic-docs.helpmanual.io/usage/model_config/#change-behaviour-globally + extra = "forbid" + + attribute: StrictStr + # Note: a comment in config/oidc.py suggests that `value` may be optional. But + # The JSON schema seems to forbid this. + value: StrictStr + + class OIDCProviderModel(BaseModel): """ Notes on Pydantic: @@ -166,14 +179,8 @@ class OIDCProviderModel(BaseModel): cls, jwks_uri: Optional[str], values: Mapping[str, object] ) -> Optional[str]: discovery_disabled = "discover" in values and not values["discover"] - openid_scope_requested = ( - "scopes" in values and "openid" in values["scopes"] - ) - if ( - discovery_disabled - and openid_scope_requested - and jwks_uri is None - ): + openid_scope_requested = "scopes" in values and "openid" in values["scopes"] + if discovery_disabled and openid_scope_requested and jwks_uri is None: raise ValueError( "jwks_uri is required if discovery is disabled and" "the 'openid' scope is not requested" @@ -192,7 +199,7 @@ class OIDCProviderModel(BaseModel): allow_existing_users: StrictBool = False # the class of the user mapping provider - # TODO + # TODO there was logic for this user_mapping_provider_class: Any # TODO: Type # the config of the user mapping provider @@ -200,9 +207,16 @@ class OIDCProviderModel(BaseModel): user_mapping_provider_config: Any # required attributes to require in userinfo to allow login/registration - attribute_requirements: Tuple[Any, ...] = () # TODO SsoAttributeRequirement] = () + # TODO: wouldn't this be better expressed as a Mapping[str, str]? + attribute_requirements: Tuple[SSOAttributeRequirement, ...] = () class LegacyOIDCProviderModel(OIDCProviderModel): + # These fields could be omitted in the old scheme. idp_id: IDP_ID_TYPE = "oidc" idp_name: StrictStr = "OIDC" + + +# TODO +# top-level config: check we don't have any duplicate idp_ids now +# compute callback url diff --git a/tests/config/test_oidc2.py b/tests/config/test_oidc2.py index d5a6825503..c2d450e3c8 100644 --- a/tests/config/test_oidc2.py +++ b/tests/config/test_oidc2.py @@ -336,3 +336,69 @@ class PydanticOIDCTestCase(TestCase): del self.config["scopes"] self.config["userinfo_endpoint"] = None OIDCProviderModel.parse_obj(self.config) + + def test_attribute_requirements(self): + # Example of a field involving a nested model + model = OIDCProviderModel.parse_obj(self.config) + self.assertIsInstance(model.attribute_requirements, tuple) + self.assertEqual( + len(model.attribute_requirements), 1, model.attribute_requirements + ) + + # Bad tGypes should be rejected + for bad_value in 123, 456.0, False, None, {}, ["hello"]: + with self.assertRaises(ValidationError): + self.config["attribute_requirements"] = bad_value + OIDCProviderModel.parse_obj(self.config) + + # An empty list of requirements is okay, ... + self.config["attribute_requirements"] = [] + OIDCProviderModel.parse_obj(self.config) + + # ...as is an omitted list of requirements... + del self.config["attribute_requirements"] + OIDCProviderModel.parse_obj(self.config) + + # ...but not an explicit None. + with self.assertRaises(ValidationError): + self.config["attribute_requirements"] = None + OIDCProviderModel.parse_obj(self.config) + + # Multiple requirements are fine. + self.config["attribute_requirements"] = [{"attribute": "k", "value": "v"}] * 3 + model = OIDCProviderModel.parse_obj(self.config) + self.assertEqual( + len(model.attribute_requirements), 3, model.attribute_requirements + ) + + # The submodel's field types should be enforced too. + with self.assertRaises(ValidationError): + self.config["attribute_requirements"] = [{"attribute": "key", "value": 123}] + OIDCProviderModel.parse_obj(self.config) + with self.assertRaises(ValidationError): + self.config["attribute_requirements"] = [{"attribute": 123, "value": "val"}] + OIDCProviderModel.parse_obj(self.config) + with self.assertRaises(ValidationError): + self.config["attribute_requirements"] = [{"attribute": "a", "value": ["b"]}] + OIDCProviderModel.parse_obj(self.config) + with self.assertRaises(ValidationError): + self.config["attribute_requirements"] = [{"attribute": "a", "value": None}] + OIDCProviderModel.parse_obj(self.config) + + # Missing fields in the submodel are an error. + with self.assertRaises(ValidationError): + self.config["attribute_requirements"] = [{"attribute": "a"}] + OIDCProviderModel.parse_obj(self.config) + with self.assertRaises(ValidationError): + self.config["attribute_requirements"] = [{"value": "v"}] + OIDCProviderModel.parse_obj(self.config) + with self.assertRaises(ValidationError): + self.config["attribute_requirements"] = [{}] + OIDCProviderModel.parse_obj(self.config) + + # Extra fields in the submodel are an error. + with self.assertRaises(ValidationError): + self.config["attribute_requirements"] = [ + {"attribute": "a", "value": "v", "answer": "forty-two"} + ] + OIDCProviderModel.parse_obj(self.config) |