summary refs log tree commit diff
diff options
context:
space:
mode:
authorDavid Robertson <davidr@element.io>2022-05-20 00:56:02 +0100
committerDavid Robertson <davidr@element.io>2022-05-20 00:56:02 +0100
commit2179c6376aa9fbdf650518db35de17e1fe8b6377 (patch)
treeed0b05152b111f01ede04d9d7d324296bd816cd3
parentmove TYPE_CHECKING workaround outside (diff)
downloadsynapse-2179c6376aa9fbdf650518db35de17e1fe8b6377.tar.xz
Extra fields and tests
Pleasantly: no pain here
-rw-r--r--synapse/config/oidc2.py40
-rw-r--r--tests/config/test_oidc2.py159
2 files changed, 180 insertions, 19 deletions
diff --git a/synapse/config/oidc2.py b/synapse/config/oidc2.py
index 12059f5bef..5224a255bd 100644
--- a/synapse/config/oidc2.py
+++ b/synapse/config/oidc2.py
@@ -1,3 +1,5 @@
+from enum import Enum
+
 from typing import TYPE_CHECKING, Any, Optional, Tuple
 
 from pydantic import BaseModel, StrictBool, StrictStr, constr
@@ -7,6 +9,7 @@ from pydantic import BaseModel, StrictBool, StrictStr, constr
 
 if TYPE_CHECKING:
     IDP_ID_TYPE = str
+    IDP_BRAND_TYPE = str
 else:
     IDP_ID_TYPE = constr(
         strict=True,
@@ -14,15 +17,39 @@ else:
         max_length=250,
         regex="^[A-Za-z0-9._~-]+$",  # noqa: F722
     )
+    IDP_BRAND_TYPE = constr(
+        strict=True,
+        min_length=1,
+        max_length=255,
+        regex="^[a-z][a-z0-9_.-]*$",  # noqa: F722
+    )
+
+# the following list of enum members is the same as the keys of
+# authlib.oauth2.auth.ClientAuth.DEFAULT_AUTH_METHODS. We inline it
+# to avoid importing authlib here.
+class ClientAuthMethods(str, Enum):
+    # The duplication is unfortunate. 3.11 should have StrEnum though,
+    # and there is a backport available for 3.8.6.
+    client_secret_basic = "client_secret_basic"
+    client_secret_post = "client_secret_post"
+    none = "none"
+
+
+class UserProfileMethod(str, Enum):
+    # The duplication is unfortunate. 3.11 should have StrEnum though,
+    # and there is a backport available for 3.8.6.
+    auto = "auto"
+    userinfo_endpoint = "userinfo_endpoint"
 
 
 class OIDCProviderModel(BaseModel):
     """
     Notes on Pydantic:
-    - I've used StrictStr because a plain `str` accepts integers and calls str() on them
-    - I've factored out the validators here to demonstrate that we can avoid some duplication
-      if there are common patterns. Otherwise one could use @validator("field_name") and
-      define the validator function inline.
+    - I've used StrictStr because a plain `str` e.g. accepts integers and calls str()
+      on them
+    - pulling out constr() into IDP_ID_TYPE is a little awkward, but necessary to keep
+      mypy happy
+    -
     """
 
     # a unique identifier for this identity provider. Used in the 'user_external_ids'
@@ -63,7 +90,7 @@ class OIDCProviderModel(BaseModel):
     # auth method to use when exchanging the token.
     # Valid values are 'client_secret_basic', 'client_secret_post' and
     # 'none'.
-    client_auth_method: StrictStr = "client_secret_basic"
+    client_auth_method: ClientAuthMethods = ClientAuthMethods.client_secret_basic
 
     # list of scopes to request
     scopes: Tuple[StrictStr, ...] = ("openid",)
@@ -91,8 +118,7 @@ class OIDCProviderModel(BaseModel):
 
     # Whether to fetch the user profile from the userinfo endpoint. Valid
     # values are: "auto" or "userinfo_endpoint".
-    # TODO enum
-    user_profile_method: StrictStr = "auto"
+    user_profile_method: UserProfileMethod = UserProfileMethod.auto
 
     # whether to allow a user logging in via OIDC to match a pre-existing account
     # instead of failing
diff --git a/tests/config/test_oidc2.py b/tests/config/test_oidc2.py
index 102b35f9cc..e340a7d43b 100644
--- a/tests/config/test_oidc2.py
+++ b/tests/config/test_oidc2.py
@@ -4,7 +4,7 @@ from typing import Any, Dict
 import yaml
 from pydantic import ValidationError
 
-from synapse.config.oidc2 import OIDCProviderModel
+from synapse.config.oidc2 import OIDCProviderModel, ClientAuthMethods
 
 from tests.unittest import TestCase
 
@@ -36,29 +36,31 @@ user_mapping_provider:
 
 
 class PydanticOIDCTestCase(TestCase):
+    """Examples to build confidence that pydantic is doing the validation we think
+    it's doing"""
+
     # Each test gets a dummy config it can change as it sees fit
     config: Dict[str, Any]
 
     def setUp(self) -> None:
         self.config = deepcopy(SAMPLE_CONFIG)
 
-    def test_idp_id(self) -> None:
-        """Demonstrate that Pydantic validates idp_id correctly."""
+    def test_example_config(self):
+        # Check that parsing the sample config doesn't raise an error.
         OIDCProviderModel.parse_obj(self.config)
 
+    def test_idp_id(self) -> None:
+        """Example of using a Pydantic constr() field without a default."""
         # Enforce that idp_id is required.
         with self.assertRaises(ValidationError):
             del self.config["idp_id"]
             OIDCProviderModel.parse_obj(self.config)
 
         # Enforce that idp_id is a string.
-        with self.assertRaises(ValidationError) as e:
-            self.config["idp_id"] = 123
-            OIDCProviderModel.parse_obj(self.config)
-        print(e.exception)
-        with self.assertRaises(ValidationError):
-            self.config["idp_id"] = None
-            OIDCProviderModel.parse_obj(self.config)
+        for bad_vlaue in 123, None, ["a"], {"a": "b"}:
+            with self.assertRaises(ValidationError) as e:
+                self.config["idp_id"] = bad_vlaue
+                OIDCProviderModel.parse_obj(self.config)
 
         # Enforce a length between 1 and 250.
         with self.assertRaises(ValidationError):
@@ -68,7 +70,7 @@ class PydanticOIDCTestCase(TestCase):
             self.config["idp_id"] = "a" * 251
             OIDCProviderModel.parse_obj(self.config)
 
-        # Enforce the character set
+        # Enforce the regex
         with self.assertRaises(ValidationError):
             self.config["idp_id"] = "$"
             OIDCProviderModel.parse_obj(self.config)
@@ -77,4 +79,137 @@ class PydanticOIDCTestCase(TestCase):
         with self.assertRaises(ValidationError) as e:
             self.config["idp_id"] = "$" * 500
             OIDCProviderModel.parse_obj(self.config)
-        print(e.exception)
+
+    def test_issuer(self) -> None:
+        """Example of a StrictStr field without a default."""
+
+        # Empty and nonempty strings should be accepted.
+        for good_value in "", "hello", "hello" * 1000, "☃":
+            self.config["issuer"] = good_value
+            OIDCProviderModel.parse_obj(self.config)
+
+        # Invalid types should be rejected.
+        for bad_value in 123, None, ["h", "e", "l", "l", "o"], {"hello": "there"}:
+            with self.assertRaises(ValidationError):
+                self.config["issuer"] = bad_value
+                OIDCProviderModel.parse_obj(self.config)
+
+        # A missing issuer should be rejected.
+        with self.assertRaises(ValidationError):
+            del self.config["issuer"]
+            OIDCProviderModel.parse_obj(self.config)
+
+    def test_idp_brand(self) -> None:
+        """Example of an Optional[StrictStr] field."""
+        # Empty and nonempty strings should be accepted.
+        for good_value in "", "hello", "hello" * 1000, "☃":
+            self.config["idp_brand"] = good_value
+            OIDCProviderModel.parse_obj(self.config)
+
+        # Invalid types should be rejected.
+        for bad_value in 123, ["h", "e", "l", "l", "o"], {"hello": "there"}:
+            with self.assertRaises(ValidationError):
+                self.config["idp_brand"] = bad_value
+                OIDCProviderModel.parse_obj(self.config)
+
+        # A lack of an idp_brand is fine...
+        del self.config["idp_brand"]
+        model = OIDCProviderModel.parse_obj(self.config)
+        self.assertIsNone(model.idp_brand)
+
+        # ... and interpreted the same as an explicit `None`.
+        self.config["idp_brand"] = None
+        model = OIDCProviderModel.parse_obj(self.config)
+        self.assertIsNone(model.idp_brand)
+
+    def test_discover(self) -> None:
+        """Example of a StrictBool field with a default."""
+        # Booleans are permitted.
+        for value in True, False:
+            self.config["discover"] = value
+            model = OIDCProviderModel.parse_obj(self.config)
+            self.assertEqual(model.discover, value)
+
+        # Invalid types should be rejected.
+        for bad_value in (
+            -1.0,
+            0,
+            1,
+            float("nan"),
+            "yes",
+            "NO",
+            "True",
+            "true",
+            None,
+            "None",
+            "null",
+            ["a"],
+            {"a": "b"},
+        ):
+            self.config["discover"] = bad_value
+            with self.assertRaises(ValidationError):
+                OIDCProviderModel.parse_obj(self.config)
+
+        # A missing value is okay, because this field has a default.
+        del self.config["discover"]
+        model = OIDCProviderModel.parse_obj(self.config)
+        self.assertIs(model.discover, True)
+
+    def test_client_auth_method(self) -> None:
+        """This is an example of using a Pydantic string enum field."""
+        # check the allowed values are permitted and deserialise to an enum member
+        for method in "client_secret_basic", "client_secret_post", "none":
+            self.config["client_auth_method"] = method
+            model = OIDCProviderModel.parse_obj(self.config)
+            self.assertIs(model.client_auth_method, ClientAuthMethods[method])
+
+        # check the default applies if no auth method is provided.
+        del self.config["client_auth_method"]
+        model = OIDCProviderModel.parse_obj(self.config)
+        self.assertIs(model.client_auth_method, ClientAuthMethods.client_secret_basic)
+
+        # Check invalid types are rejected
+        for bad_value in 123, ["client_secret_basic"], {"a": 1}, None:
+            with self.assertRaises(ValidationError):
+                self.config["client_auth_method"] = bad_value
+                OIDCProviderModel.parse_obj(self.config)
+
+        # Check that disallowed strings are rejected
+        with self.assertRaises(ValidationError):
+            self.config["client_auth_method"] = "No, Luke, _I_ am your father!"
+            OIDCProviderModel.parse_obj(self.config)
+
+    def test_scopes(self) -> None:
+        """Example of a Tuple[StrictStr] with a default."""
+        # Check that the parsed object holds a tuple
+        self.config["scopes"] = []
+        model = OIDCProviderModel.parse_obj(self.config)
+        self.assertEqual(model.scopes, ())
+
+        # Check a variety of list lengths are accepted.
+        for good_value in ["aa"], ["hello", "world"], ["a"] * 4, [""] * 20:
+            self.config["scopes"] = good_value
+            model = OIDCProviderModel.parse_obj(self.config)
+            self.assertEqual(model.scopes, tuple(good_value))
+
+        # Check invalid types are rejected.
+        for bad_value in (
+            "",
+            "abc",
+            123,
+            {},
+            {"a": 1},
+            None,
+            [None],
+            [["a"]],
+            [{}],
+            [456],
+        ):
+            with self.assertRaises(ValidationError):
+                self.config["scopes"] = bad_value
+                OIDCProviderModel.parse_obj(self.config)
+
+        # Check that "scopes" may be omitted.
+        del self.config["scopes"]
+        model = OIDCProviderModel.parse_obj(self.config)
+        self.assertEqual(model.scopes, ("openid",))