summary refs log tree commit diff
diff options
context:
space:
mode:
authorDavid Robertson <davidr@element.io>2022-05-22 18:00:10 +0100
committerDavid Robertson <davidr@element.io>2022-05-22 18:00:10 +0100
commit26cb343d40beb186eebbe7796596b079c0e24971 (patch)
treeb3b0536820204b5d6737a4339bb0c77be27f0caa
parentA few examples (diff)
downloadsynapse-26cb343d40beb186eebbe7796596b079c0e24971.tar.xz
SSOAttributeRequirement
-rw-r--r--synapse/config/oidc2.py34
-rw-r--r--tests/config/test_oidc2.py66
2 files changed, 90 insertions, 10 deletions
diff --git a/synapse/config/oidc2.py b/synapse/config/oidc2.py
index eb00a622fc..ab1a79c272 100644
--- a/synapse/config/oidc2.py
+++ b/synapse/config/oidc2.py
@@ -44,6 +44,19 @@ class UserProfileMethod(str, Enum):
     userinfo_endpoint = "userinfo_endpoint"
 
 
+class SSOAttributeRequirement(BaseModel):
+    class Config:
+        # Complain if someone provides a field that's not one of those listed here.
+        # Pydantic suggests making your own BaseModel subclass if you want to do this,
+        # see https://pydantic-docs.helpmanual.io/usage/model_config/#change-behaviour-globally
+        extra = "forbid"
+
+    attribute: StrictStr
+    # Note: a comment in config/oidc.py suggests that `value` may be optional. But
+    # The JSON schema seems to forbid this.
+    value: StrictStr
+
+
 class OIDCProviderModel(BaseModel):
     """
     Notes on Pydantic:
@@ -166,14 +179,8 @@ class OIDCProviderModel(BaseModel):
         cls, jwks_uri: Optional[str], values: Mapping[str, object]
     ) -> Optional[str]:
         discovery_disabled = "discover" in values and not values["discover"]
-        openid_scope_requested = (
-            "scopes" in values and "openid" in values["scopes"]
-        )
-        if (
-            discovery_disabled
-            and openid_scope_requested
-            and jwks_uri is None
-        ):
+        openid_scope_requested = "scopes" in values and "openid" in values["scopes"]
+        if discovery_disabled and openid_scope_requested and jwks_uri is None:
             raise ValueError(
                 "jwks_uri is required if discovery is disabled and"
                 "the 'openid' scope is not requested"
@@ -192,7 +199,7 @@ class OIDCProviderModel(BaseModel):
     allow_existing_users: StrictBool = False
 
     # the class of the user mapping provider
-    # TODO
+    # TODO there was logic for this
     user_mapping_provider_class: Any  # TODO: Type
 
     # the config of the user mapping provider
@@ -200,9 +207,16 @@ class OIDCProviderModel(BaseModel):
     user_mapping_provider_config: Any
 
     # required attributes to require in userinfo to allow login/registration
-    attribute_requirements: Tuple[Any, ...] = ()  # TODO SsoAttributeRequirement] = ()
+    # TODO: wouldn't this be better expressed as a Mapping[str, str]?
+    attribute_requirements: Tuple[SSOAttributeRequirement, ...] = ()
 
 
 class LegacyOIDCProviderModel(OIDCProviderModel):
+    # These fields could be omitted in the old scheme.
     idp_id: IDP_ID_TYPE = "oidc"
     idp_name: StrictStr = "OIDC"
+
+
+# TODO
+# top-level config: check we don't have any duplicate idp_ids now
+# compute callback url
diff --git a/tests/config/test_oidc2.py b/tests/config/test_oidc2.py
index d5a6825503..c2d450e3c8 100644
--- a/tests/config/test_oidc2.py
+++ b/tests/config/test_oidc2.py
@@ -336,3 +336,69 @@ class PydanticOIDCTestCase(TestCase):
         del self.config["scopes"]
         self.config["userinfo_endpoint"] = None
         OIDCProviderModel.parse_obj(self.config)
+
+    def test_attribute_requirements(self):
+        # Example of a field involving a nested model
+        model = OIDCProviderModel.parse_obj(self.config)
+        self.assertIsInstance(model.attribute_requirements, tuple)
+        self.assertEqual(
+            len(model.attribute_requirements), 1, model.attribute_requirements
+        )
+
+        # Bad tGypes should be rejected
+        for bad_value in 123, 456.0, False, None, {}, ["hello"]:
+            with self.assertRaises(ValidationError):
+                self.config["attribute_requirements"] = bad_value
+                OIDCProviderModel.parse_obj(self.config)
+
+        # An empty list of requirements is okay, ...
+        self.config["attribute_requirements"] = []
+        OIDCProviderModel.parse_obj(self.config)
+
+        # ...as is an omitted list of requirements...
+        del self.config["attribute_requirements"]
+        OIDCProviderModel.parse_obj(self.config)
+
+        # ...but not an explicit None.
+        with self.assertRaises(ValidationError):
+            self.config["attribute_requirements"] = None
+            OIDCProviderModel.parse_obj(self.config)
+
+        # Multiple requirements are fine.
+        self.config["attribute_requirements"] = [{"attribute": "k", "value": "v"}] * 3
+        model = OIDCProviderModel.parse_obj(self.config)
+        self.assertEqual(
+            len(model.attribute_requirements), 3, model.attribute_requirements
+        )
+
+        # The submodel's field types should be enforced too.
+        with self.assertRaises(ValidationError):
+            self.config["attribute_requirements"] = [{"attribute": "key", "value": 123}]
+            OIDCProviderModel.parse_obj(self.config)
+        with self.assertRaises(ValidationError):
+            self.config["attribute_requirements"] = [{"attribute": 123, "value": "val"}]
+            OIDCProviderModel.parse_obj(self.config)
+        with self.assertRaises(ValidationError):
+            self.config["attribute_requirements"] = [{"attribute": "a", "value": ["b"]}]
+            OIDCProviderModel.parse_obj(self.config)
+        with self.assertRaises(ValidationError):
+            self.config["attribute_requirements"] = [{"attribute": "a", "value": None}]
+            OIDCProviderModel.parse_obj(self.config)
+
+        # Missing fields in the submodel are an error.
+        with self.assertRaises(ValidationError):
+            self.config["attribute_requirements"] = [{"attribute": "a"}]
+            OIDCProviderModel.parse_obj(self.config)
+        with self.assertRaises(ValidationError):
+            self.config["attribute_requirements"] = [{"value": "v"}]
+            OIDCProviderModel.parse_obj(self.config)
+        with self.assertRaises(ValidationError):
+            self.config["attribute_requirements"] = [{}]
+            OIDCProviderModel.parse_obj(self.config)
+
+        # Extra fields in the submodel are an error.
+        with self.assertRaises(ValidationError):
+            self.config["attribute_requirements"] = [
+                {"attribute": "a", "value": "v", "answer": "forty-two"}
+            ]
+            OIDCProviderModel.parse_obj(self.config)