diff --git a/synapse/config/oidc2.py b/synapse/config/oidc2.py
index eb00a622fc..ab1a79c272 100644
--- a/synapse/config/oidc2.py
+++ b/synapse/config/oidc2.py
@@ -44,6 +44,19 @@ class UserProfileMethod(str, Enum):
userinfo_endpoint = "userinfo_endpoint"
+class SSOAttributeRequirement(BaseModel):
+ class Config:
+ # Complain if someone provides a field that's not one of those listed here.
+ # Pydantic suggests making your own BaseModel subclass if you want to do this,
+ # see https://pydantic-docs.helpmanual.io/usage/model_config/#change-behaviour-globally
+ extra = "forbid"
+
+ attribute: StrictStr
+ # Note: a comment in config/oidc.py suggests that `value` may be optional. But
+ # The JSON schema seems to forbid this.
+ value: StrictStr
+
+
class OIDCProviderModel(BaseModel):
"""
Notes on Pydantic:
@@ -166,14 +179,8 @@ class OIDCProviderModel(BaseModel):
cls, jwks_uri: Optional[str], values: Mapping[str, object]
) -> Optional[str]:
discovery_disabled = "discover" in values and not values["discover"]
- openid_scope_requested = (
- "scopes" in values and "openid" in values["scopes"]
- )
- if (
- discovery_disabled
- and openid_scope_requested
- and jwks_uri is None
- ):
+ openid_scope_requested = "scopes" in values and "openid" in values["scopes"]
+ if discovery_disabled and openid_scope_requested and jwks_uri is None:
raise ValueError(
"jwks_uri is required if discovery is disabled and"
"the 'openid' scope is not requested"
@@ -192,7 +199,7 @@ class OIDCProviderModel(BaseModel):
allow_existing_users: StrictBool = False
# the class of the user mapping provider
- # TODO
+ # TODO there was logic for this
user_mapping_provider_class: Any # TODO: Type
# the config of the user mapping provider
@@ -200,9 +207,16 @@ class OIDCProviderModel(BaseModel):
user_mapping_provider_config: Any
# required attributes to require in userinfo to allow login/registration
- attribute_requirements: Tuple[Any, ...] = () # TODO SsoAttributeRequirement] = ()
+ # TODO: wouldn't this be better expressed as a Mapping[str, str]?
+ attribute_requirements: Tuple[SSOAttributeRequirement, ...] = ()
class LegacyOIDCProviderModel(OIDCProviderModel):
+ # These fields could be omitted in the old scheme.
idp_id: IDP_ID_TYPE = "oidc"
idp_name: StrictStr = "OIDC"
+
+
+# TODO
+# top-level config: check we don't have any duplicate idp_ids now
+# compute callback url
diff --git a/tests/config/test_oidc2.py b/tests/config/test_oidc2.py
index d5a6825503..c2d450e3c8 100644
--- a/tests/config/test_oidc2.py
+++ b/tests/config/test_oidc2.py
@@ -336,3 +336,69 @@ class PydanticOIDCTestCase(TestCase):
del self.config["scopes"]
self.config["userinfo_endpoint"] = None
OIDCProviderModel.parse_obj(self.config)
+
+ def test_attribute_requirements(self):
+ # Example of a field involving a nested model
+ model = OIDCProviderModel.parse_obj(self.config)
+ self.assertIsInstance(model.attribute_requirements, tuple)
+ self.assertEqual(
+ len(model.attribute_requirements), 1, model.attribute_requirements
+ )
+
+ # Bad tGypes should be rejected
+ for bad_value in 123, 456.0, False, None, {}, ["hello"]:
+ with self.assertRaises(ValidationError):
+ self.config["attribute_requirements"] = bad_value
+ OIDCProviderModel.parse_obj(self.config)
+
+ # An empty list of requirements is okay, ...
+ self.config["attribute_requirements"] = []
+ OIDCProviderModel.parse_obj(self.config)
+
+ # ...as is an omitted list of requirements...
+ del self.config["attribute_requirements"]
+ OIDCProviderModel.parse_obj(self.config)
+
+ # ...but not an explicit None.
+ with self.assertRaises(ValidationError):
+ self.config["attribute_requirements"] = None
+ OIDCProviderModel.parse_obj(self.config)
+
+ # Multiple requirements are fine.
+ self.config["attribute_requirements"] = [{"attribute": "k", "value": "v"}] * 3
+ model = OIDCProviderModel.parse_obj(self.config)
+ self.assertEqual(
+ len(model.attribute_requirements), 3, model.attribute_requirements
+ )
+
+ # The submodel's field types should be enforced too.
+ with self.assertRaises(ValidationError):
+ self.config["attribute_requirements"] = [{"attribute": "key", "value": 123}]
+ OIDCProviderModel.parse_obj(self.config)
+ with self.assertRaises(ValidationError):
+ self.config["attribute_requirements"] = [{"attribute": 123, "value": "val"}]
+ OIDCProviderModel.parse_obj(self.config)
+ with self.assertRaises(ValidationError):
+ self.config["attribute_requirements"] = [{"attribute": "a", "value": ["b"]}]
+ OIDCProviderModel.parse_obj(self.config)
+ with self.assertRaises(ValidationError):
+ self.config["attribute_requirements"] = [{"attribute": "a", "value": None}]
+ OIDCProviderModel.parse_obj(self.config)
+
+ # Missing fields in the submodel are an error.
+ with self.assertRaises(ValidationError):
+ self.config["attribute_requirements"] = [{"attribute": "a"}]
+ OIDCProviderModel.parse_obj(self.config)
+ with self.assertRaises(ValidationError):
+ self.config["attribute_requirements"] = [{"value": "v"}]
+ OIDCProviderModel.parse_obj(self.config)
+ with self.assertRaises(ValidationError):
+ self.config["attribute_requirements"] = [{}]
+ OIDCProviderModel.parse_obj(self.config)
+
+ # Extra fields in the submodel are an error.
+ with self.assertRaises(ValidationError):
+ self.config["attribute_requirements"] = [
+ {"attribute": "a", "value": "v", "answer": "forty-two"}
+ ]
+ OIDCProviderModel.parse_obj(self.config)
|