summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/13736.feature1
-rw-r--r--synapse/rest/client/account.py65
-rw-r--r--synapse/rest/client/models.py28
-rw-r--r--tests/rest/client/test_models.py29
4 files changed, 78 insertions, 45 deletions
diff --git a/changelog.d/13736.feature b/changelog.d/13736.feature
new file mode 100644
index 0000000000..60a63c1009
--- /dev/null
+++ b/changelog.d/13736.feature
@@ -0,0 +1 @@
+Improve validation of request bodies for the following client-server API endpoints: [`/account/3pid/add`](https://spec.matrix.org/v1.3/client-server-api/#post_matrixclientv3account3pidadd), [`/account/3pid/bind`](https://spec.matrix.org/v1.3/client-server-api/#post_matrixclientv3account3pidbind), [`/account/3pid/delete`](https://spec.matrix.org/v1.3/client-server-api/#post_matrixclientv3account3piddelete) and [`/account/3pid/unbind`](https://spec.matrix.org/v1.3/client-server-api/#post_matrixclientv3account3pidunbind).
diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py
index a09aaf3448..2db2a04f95 100644
--- a/synapse/rest/client/account.py
+++ b/synapse/rest/client/account.py
@@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple
 from urllib.parse import urlparse
 
 from pydantic import StrictBool, StrictStr, constr
+from typing_extensions import Literal
 
 from twisted.web.server import Request
 
@@ -43,6 +44,7 @@ from synapse.metrics import threepid_send_requests
 from synapse.push.mailer import Mailer
 from synapse.rest.client.models import (
     AuthenticationData,
+    ClientSecretStr,
     EmailRequestTokenBody,
     MsisdnRequestTokenBody,
 )
@@ -627,6 +629,11 @@ class ThreepidAddRestServlet(RestServlet):
         self.auth = hs.get_auth()
         self.auth_handler = hs.get_auth_handler()
 
+    class PostBody(RequestBodyModel):
+        auth: Optional[AuthenticationData] = None
+        client_secret: ClientSecretStr
+        sid: StrictStr
+
     @interactive_auth_handler
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         if not self.hs.config.registration.enable_3pid_changes:
@@ -636,22 +643,17 @@ class ThreepidAddRestServlet(RestServlet):
 
         requester = await self.auth.get_user_by_req(request)
         user_id = requester.user.to_string()
-        body = parse_json_object_from_request(request)
-
-        assert_params_in_dict(body, ["client_secret", "sid"])
-        sid = body["sid"]
-        client_secret = body["client_secret"]
-        assert_valid_client_secret(client_secret)
+        body = parse_and_validate_json_object_from_request(request, self.PostBody)
 
         await self.auth_handler.validate_user_via_ui_auth(
             requester,
             request,
-            body,
+            body.dict(exclude_unset=True),
             "add a third-party identifier to your account",
         )
 
         validation_session = await self.identity_handler.validate_threepid_session(
-            client_secret, sid
+            body.client_secret, body.sid
         )
         if validation_session:
             await self.auth_handler.add_threepid(
@@ -676,23 +678,20 @@ class ThreepidBindRestServlet(RestServlet):
         self.identity_handler = hs.get_identity_handler()
         self.auth = hs.get_auth()
 
-    async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
-        body = parse_json_object_from_request(request)
+    class PostBody(RequestBodyModel):
+        client_secret: ClientSecretStr
+        id_access_token: StrictStr
+        id_server: StrictStr
+        sid: StrictStr
 
-        assert_params_in_dict(
-            body, ["id_server", "sid", "id_access_token", "client_secret"]
-        )
-        id_server = body["id_server"]
-        sid = body["sid"]
-        id_access_token = body["id_access_token"]
-        client_secret = body["client_secret"]
-        assert_valid_client_secret(client_secret)
+    async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+        body = parse_and_validate_json_object_from_request(request, self.PostBody)
 
         requester = await self.auth.get_user_by_req(request)
         user_id = requester.user.to_string()
 
         await self.identity_handler.bind_threepid(
-            client_secret, sid, user_id, id_server, id_access_token
+            body.client_secret, body.sid, user_id, body.id_server, body.id_access_token
         )
 
         return 200, {}
@@ -708,23 +707,27 @@ class ThreepidUnbindRestServlet(RestServlet):
         self.auth = hs.get_auth()
         self.datastore = self.hs.get_datastores().main
 
+    class PostBody(RequestBodyModel):
+        address: StrictStr
+        id_server: Optional[StrictStr] = None
+        medium: Literal["email", "msisdn"]
+
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         """Unbind the given 3pid from a specific identity server, or identity servers that are
         known to have this 3pid bound
         """
         requester = await self.auth.get_user_by_req(request)
-        body = parse_json_object_from_request(request)
-        assert_params_in_dict(body, ["medium", "address"])
-
-        medium = body.get("medium")
-        address = body.get("address")
-        id_server = body.get("id_server")
+        body = parse_and_validate_json_object_from_request(request, self.PostBody)
 
         # Attempt to unbind the threepid from an identity server. If id_server is None, try to
         # unbind from all identity servers this threepid has been added to in the past
         result = await self.identity_handler.try_unbind_threepid(
             requester.user.to_string(),
-            {"address": address, "medium": medium, "id_server": id_server},
+            {
+                "address": body.address,
+                "medium": body.medium,
+                "id_server": body.id_server,
+            },
         )
         return 200, {"id_server_unbind_result": "success" if result else "no-support"}
 
@@ -738,21 +741,25 @@ class ThreepidDeleteRestServlet(RestServlet):
         self.auth = hs.get_auth()
         self.auth_handler = hs.get_auth_handler()
 
+    class PostBody(RequestBodyModel):
+        address: StrictStr
+        id_server: Optional[StrictStr] = None
+        medium: Literal["email", "msisdn"]
+
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         if not self.hs.config.registration.enable_3pid_changes:
             raise SynapseError(
                 400, "3PID changes are disabled on this server", Codes.FORBIDDEN
             )
 
-        body = parse_json_object_from_request(request)
-        assert_params_in_dict(body, ["medium", "address"])
+        body = parse_and_validate_json_object_from_request(request, self.PostBody)
 
         requester = await self.auth.get_user_by_req(request)
         user_id = requester.user.to_string()
 
         try:
             ret = await self.auth_handler.delete_threepid(
-                user_id, body["medium"], body["address"], body.get("id_server")
+                user_id, body.medium, body.address, body.id_server
             )
         except Exception:
             # NB. This endpoint should succeed if there is nothing to
diff --git a/synapse/rest/client/models.py b/synapse/rest/client/models.py
index 6278450c70..3d7940b0fc 100644
--- a/synapse/rest/client/models.py
+++ b/synapse/rest/client/models.py
@@ -36,18 +36,20 @@ class AuthenticationData(RequestBodyModel):
     type: Optional[StrictStr] = None
 
 
-class ThreePidRequestTokenBody(RequestBodyModel):
-    if TYPE_CHECKING:
-        client_secret: StrictStr
-    else:
-        # See also assert_valid_client_secret()
-        client_secret: constr(
-            regex="[0-9a-zA-Z.=_-]",  # noqa: F722
-            min_length=0,
-            max_length=255,
-            strict=True,
-        )
+if TYPE_CHECKING:
+    ClientSecretStr = StrictStr
+else:
+    # See also assert_valid_client_secret()
+    ClientSecretStr = constr(
+        regex="[0-9a-zA-Z.=_-]",  # noqa: F722
+        min_length=1,
+        max_length=255,
+        strict=True,
+    )
+
 
+class ThreepidRequestTokenBody(RequestBodyModel):
+    client_secret: ClientSecretStr
     id_server: Optional[StrictStr]
     id_access_token: Optional[StrictStr]
     next_link: Optional[StrictStr]
@@ -62,7 +64,7 @@ class ThreePidRequestTokenBody(RequestBodyModel):
         return token
 
 
-class EmailRequestTokenBody(ThreePidRequestTokenBody):
+class EmailRequestTokenBody(ThreepidRequestTokenBody):
     email: StrictStr
 
     # Canonicalise the email address. The addresses are all stored canonicalised
@@ -80,6 +82,6 @@ else:
     ISO3116_1_Alpha_2 = constr(regex="[A-Z]{2}", strict=True)
 
 
-class MsisdnRequestTokenBody(ThreePidRequestTokenBody):
+class MsisdnRequestTokenBody(ThreepidRequestTokenBody):
     country: ISO3116_1_Alpha_2
     phone_number: StrictStr
diff --git a/tests/rest/client/test_models.py b/tests/rest/client/test_models.py
index a9da00665e..0b8fcb0c47 100644
--- a/tests/rest/client/test_models.py
+++ b/tests/rest/client/test_models.py
@@ -11,14 +11,37 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import unittest
+import unittest as stdlib_unittest
 
-from pydantic import ValidationError
+from pydantic import BaseModel, ValidationError
+from typing_extensions import Literal
 
 from synapse.rest.client.models import EmailRequestTokenBody
 
 
-class EmailRequestTokenBodyTestCase(unittest.TestCase):
+class ThreepidMediumEnumTestCase(stdlib_unittest.TestCase):
+    class Model(BaseModel):
+        medium: Literal["email", "msisdn"]
+
+    def test_accepts_valid_medium_string(self) -> None:
+        """Sanity check that Pydantic behaves sensibly with an enum-of-str
+
+        This is arguably more of a test of a class that inherits from str and Enum
+        simultaneously.
+        """
+        model = self.Model.parse_obj({"medium": "email"})
+        self.assertEqual(model.medium, "email")
+
+    def test_rejects_invalid_medium_value(self) -> None:
+        with self.assertRaises(ValidationError):
+            self.Model.parse_obj({"medium": "interpretive_dance"})
+
+    def test_rejects_invalid_medium_type(self) -> None:
+        with self.assertRaises(ValidationError):
+            self.Model.parse_obj({"medium": 123})
+
+
+class EmailRequestTokenBodyTestCase(stdlib_unittest.TestCase):
     base_request = {
         "client_secret": "hunter2",
         "email": "alice@wonderland.com",