summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorDavid Robertson <davidr@element.io>2022-09-07 12:16:10 +0100
committerGitHub <noreply@github.com>2022-09-07 12:16:10 +0100
commitb58386e37e30e920332e4b04011b528a66a39fad (patch)
treee1fc845196b649e8819aee5fe002a4331c4f6c2f /synapse
parentCancel the processing of key query requests when they time out. (#13680) (diff)
downloadsynapse-b58386e37e30e920332e4b04011b528a66a39fad.tar.xz
A second batch of Pydantic models for rest/client/account.py (#13687)
Diffstat (limited to 'synapse')
-rw-r--r--synapse/http/servlet.py19
-rw-r--r--synapse/rest/client/account.py54
-rw-r--r--synapse/rest/client/models.py24
3 files changed, 63 insertions, 34 deletions
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 26aaabfb34..80acbdcf3c 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -28,7 +28,8 @@ from typing import (
     overload,
 )
 
-from pydantic import BaseModel, ValidationError
+from pydantic import BaseModel, MissingError, PydanticValueError, ValidationError
+from pydantic.error_wrappers import ErrorWrapper
 from typing_extensions import Literal
 
 from twisted.web.server import Request
@@ -714,7 +715,21 @@ def parse_and_validate_json_object_from_request(
     try:
         instance = model_type.parse_obj(content)
     except ValidationError as e:
-        raise SynapseError(HTTPStatus.BAD_REQUEST, str(e), errcode=Codes.BAD_JSON)
+        # Choose a matrix error code. The catch-all is BAD_JSON, but we try to find a
+        # more specific error if possible (which occasionally helps us to be spec-
+        # compliant) This is a bit awkward because the spec's error codes aren't very
+        # clear-cut: BAD_JSON arguably overlaps with MISSING_PARAM and INVALID_PARAM.
+        errcode = Codes.BAD_JSON
+
+        raw_errors = e.raw_errors
+        if len(raw_errors) == 1 and isinstance(raw_errors[0], ErrorWrapper):
+            raw_error = raw_errors[0].exc
+            if isinstance(raw_error, MissingError):
+                errcode = Codes.MISSING_PARAM
+            elif isinstance(raw_error, PydanticValueError):
+                errcode = Codes.INVALID_PARAM
+
+        raise SynapseError(HTTPStatus.BAD_REQUEST, str(e), errcode=errcode)
 
     return instance
 
diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py
index 1f9a8ccc23..a09aaf3448 100644
--- a/synapse/rest/client/account.py
+++ b/synapse/rest/client/account.py
@@ -15,7 +15,7 @@
 # limitations under the License.
 import logging
 import random
-from typing import TYPE_CHECKING, Optional, Tuple
+from typing import TYPE_CHECKING, List, Optional, Tuple
 from urllib.parse import urlparse
 
 from pydantic import StrictBool, StrictStr, constr
@@ -41,7 +41,11 @@ from synapse.http.servlet import (
 from synapse.http.site import SynapseRequest
 from synapse.metrics import threepid_send_requests
 from synapse.push.mailer import Mailer
-from synapse.rest.client.models import AuthenticationData, EmailRequestTokenBody
+from synapse.rest.client.models import (
+    AuthenticationData,
+    EmailRequestTokenBody,
+    MsisdnRequestTokenBody,
+)
 from synapse.rest.models import RequestBodyModel
 from synapse.types import JsonDict
 from synapse.util.msisdn import phone_number_to_msisdn
@@ -400,23 +404,16 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
         self.identity_handler = hs.get_identity_handler()
 
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
-        body = parse_json_object_from_request(request)
-        assert_params_in_dict(
-            body, ["client_secret", "country", "phone_number", "send_attempt"]
+        body = parse_and_validate_json_object_from_request(
+            request, MsisdnRequestTokenBody
         )
-        client_secret = body["client_secret"]
-        assert_valid_client_secret(client_secret)
-
-        country = body["country"]
-        phone_number = body["phone_number"]
-        send_attempt = body["send_attempt"]
-        next_link = body.get("next_link")  # Optional param
-
-        msisdn = phone_number_to_msisdn(country, phone_number)
+        msisdn = phone_number_to_msisdn(body.country, body.phone_number)
 
         if not await check_3pid_allowed(self.hs, "msisdn", msisdn):
             raise SynapseError(
                 403,
+                # TODO: is this error message accurate? Looks like we've only rejected
+                #       this phone number, not necessarily all phone numbers
                 "Account phone numbers are not authorized on this server",
                 Codes.THREEPID_DENIED,
             )
@@ -425,9 +422,9 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
             request, "msisdn", msisdn
         )
 
-        if next_link:
+        if body.next_link:
             # Raise if the provided next_link value isn't valid
-            assert_valid_next_link(self.hs, next_link)
+            assert_valid_next_link(self.hs, body.next_link)
 
         existing_user_id = await self.store.get_user_id_by_threepid("msisdn", msisdn)
 
@@ -454,15 +451,15 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
 
         ret = await self.identity_handler.requestMsisdnToken(
             self.hs.config.registration.account_threepid_delegate_msisdn,
-            country,
-            phone_number,
-            client_secret,
-            send_attempt,
-            next_link,
+            body.country,
+            body.phone_number,
+            body.client_secret,
+            body.send_attempt,
+            body.next_link,
         )
 
         threepid_send_requests.labels(type="msisdn", reason="add_threepid").observe(
-            send_attempt
+            body.send_attempt
         )
 
         return 200, ret
@@ -845,17 +842,18 @@ class AccountStatusRestServlet(RestServlet):
         self._auth = hs.get_auth()
         self._account_handler = hs.get_account_handler()
 
+    class PostBody(RequestBodyModel):
+        # TODO: we could validate that each user id is an mxid here, and/or parse it
+        #       as a UserID
+        user_ids: List[StrictStr]
+
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         await self._auth.get_user_by_req(request)
 
-        body = parse_json_object_from_request(request)
-        if "user_ids" not in body:
-            raise SynapseError(
-                400, "Required parameter 'user_ids' is missing", Codes.MISSING_PARAM
-            )
+        body = parse_and_validate_json_object_from_request(request, self.PostBody)
 
         statuses, failures = await self._account_handler.get_account_statuses(
-            body["user_ids"],
+            body.user_ids,
             allow_remote=True,
         )
 
diff --git a/synapse/rest/client/models.py b/synapse/rest/client/models.py
index 3150602997..6278450c70 100644
--- a/synapse/rest/client/models.py
+++ b/synapse/rest/client/models.py
@@ -25,8 +25,8 @@ class AuthenticationData(RequestBodyModel):
 
     (The name "Authentication Data" is taken directly from the spec.)
 
-    Additional keys will be present, depending on the `type` field. Use `.dict()` to
-    access them.
+    Additional keys will be present, depending on the `type` field. Use
+    `.dict(exclude_unset=True)` to access them.
     """
 
     class Config:
@@ -36,7 +36,7 @@ class AuthenticationData(RequestBodyModel):
     type: Optional[StrictStr] = None
 
 
-class EmailRequestTokenBody(RequestBodyModel):
+class ThreePidRequestTokenBody(RequestBodyModel):
     if TYPE_CHECKING:
         client_secret: StrictStr
     else:
@@ -47,7 +47,7 @@ class EmailRequestTokenBody(RequestBodyModel):
             max_length=255,
             strict=True,
         )
-    email: StrictStr
+
     id_server: Optional[StrictStr]
     id_access_token: Optional[StrictStr]
     next_link: Optional[StrictStr]
@@ -61,9 +61,25 @@ class EmailRequestTokenBody(RequestBodyModel):
             raise ValueError("id_access_token is required if an id_server is supplied.")
         return token
 
+
+class EmailRequestTokenBody(ThreePidRequestTokenBody):
+    email: StrictStr
+
     # Canonicalise the email address. The addresses are all stored canonicalised
     # in the database. This allows the user to reset his password without having to
     # know the exact spelling (eg. upper and lower case) of address in the database.
     # Without this, an email stored in the database as "foo@bar.com" would cause
     # user requests for "FOO@bar.com" to raise a Not Found error.
     _email_validator = validator("email", allow_reuse=True)(validate_email)
+
+
+if TYPE_CHECKING:
+    ISO3116_1_Alpha_2 = StrictStr
+else:
+    # Per spec: two-letter uppercase ISO-3166-1-alpha-2
+    ISO3116_1_Alpha_2 = constr(regex="[A-Z]{2}", strict=True)
+
+
+class MsisdnRequestTokenBody(ThreePidRequestTokenBody):
+    country: ISO3116_1_Alpha_2
+    phone_number: StrictStr