summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorDavid Robertson <davidr@element.io>2022-08-15 20:05:57 +0100
committerGitHub <noreply@github.com>2022-08-15 19:05:57 +0000
commitd642ce4b3258012da6c024b0b5d1396d2a3e69dd (patch)
treedaf0fc320f58f1364a2ee7e91f7eb0ddcffda58b /synapse
parentAdd a warning to retention documentation regarding the possibility of databas... (diff)
downloadsynapse-d642ce4b3258012da6c024b0b5d1396d2a3e69dd.tar.xz
Use Pydantic to systematically validate a first batch of endpoints in `synapse.rest.client.account`. (#13188)
Diffstat (limited to 'synapse')
-rw-r--r--synapse/http/servlet.py25
-rw-r--r--synapse/rest/client/account.py148
-rw-r--r--synapse/rest/client/models.py69
-rw-r--r--synapse/rest/models.py23
4 files changed, 180 insertions, 85 deletions
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 4ff840ca0e..26aaabfb34 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -23,9 +23,12 @@ from typing import (
     Optional,
     Sequence,
     Tuple,
+    Type,
+    TypeVar,
     overload,
 )
 
+from pydantic import BaseModel, ValidationError
 from typing_extensions import Literal
 
 from twisted.web.server import Request
@@ -694,6 +697,28 @@ def parse_json_object_from_request(
     return content
 
 
+Model = TypeVar("Model", bound=BaseModel)
+
+
+def parse_and_validate_json_object_from_request(
+    request: Request, model_type: Type[Model]
+) -> Model:
+    """Parse a JSON object from the body of a twisted HTTP request, then deserialise and
+    validate using the given pydantic model.
+
+    Raises:
+        SynapseError if the request body couldn't be decoded as JSON or
+            if it wasn't a JSON object.
+    """
+    content = parse_json_object_from_request(request, allow_empty_body=False)
+    try:
+        instance = model_type.parse_obj(content)
+    except ValidationError as e:
+        raise SynapseError(HTTPStatus.BAD_REQUEST, str(e), errcode=Codes.BAD_JSON)
+
+    return instance
+
+
 def assert_params_in_dict(body: JsonDict, required: Iterable[str]) -> None:
     absent = []
     for k in required:
diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py
index 50edc6b7d3..e5ee63133b 100644
--- a/synapse/rest/client/account.py
+++ b/synapse/rest/client/account.py
@@ -15,10 +15,11 @@
 # limitations under the License.
 import logging
 import random
-from http import HTTPStatus
 from typing import TYPE_CHECKING, Optional, Tuple
 from urllib.parse import urlparse
 
+from pydantic import StrictBool, StrictStr, constr
+
 from twisted.web.server import Request
 
 from synapse.api.constants import LoginType
@@ -34,12 +35,15 @@ from synapse.http.server import HttpServer, finish_request, respond_with_html
 from synapse.http.servlet import (
     RestServlet,
     assert_params_in_dict,
+    parse_and_validate_json_object_from_request,
     parse_json_object_from_request,
     parse_string,
 )
 from synapse.http.site import SynapseRequest
 from synapse.metrics import threepid_send_requests
 from synapse.push.mailer import Mailer
+from synapse.rest.client.models import AuthenticationData, EmailRequestTokenBody
+from synapse.rest.models import RequestBodyModel
 from synapse.types import JsonDict
 from synapse.util.msisdn import phone_number_to_msisdn
 from synapse.util.stringutils import assert_valid_client_secret, random_string
@@ -82,32 +86,16 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
                 400, "Email-based password resets have been disabled on this server"
             )
 
-        body = parse_json_object_from_request(request)
-
-        assert_params_in_dict(body, ["client_secret", "email", "send_attempt"])
-
-        # Extract params from body
-        client_secret = body["client_secret"]
-        assert_valid_client_secret(client_secret)
-
-        # Canonicalise the email address. The addresses are all stored canonicalised
-        # in the database. This allows the user to reset his password without having to
-        # know the exact spelling (eg. upper and lower case) of address in the database.
-        # Stored in the database "foo@bar.com"
-        # User requests with "FOO@bar.com" would raise a Not Found error
-        try:
-            email = validate_email(body["email"])
-        except ValueError as e:
-            raise SynapseError(400, str(e))
-        send_attempt = body["send_attempt"]
-        next_link = body.get("next_link")  # Optional param
+        body = parse_and_validate_json_object_from_request(
+            request, EmailRequestTokenBody
+        )
 
-        if next_link:
+        if body.next_link:
             # Raise if the provided next_link value isn't valid
-            assert_valid_next_link(self.hs, next_link)
+            assert_valid_next_link(self.hs, body.next_link)
 
         await self.identity_handler.ratelimit_request_token_requests(
-            request, "email", email
+            request, "email", body.email
         )
 
         # The email will be sent to the stored address.
@@ -115,7 +103,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
         # an email address which is controlled by the attacker but which, after
         # canonicalisation, matches the one in our database.
         existing_user_id = await self.hs.get_datastores().main.get_user_id_by_threepid(
-            "email", email
+            "email", body.email
         )
 
         if existing_user_id is None:
@@ -135,26 +123,26 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
             # Have the configured identity server handle the request
             ret = await self.identity_handler.request_email_token(
                 self.hs.config.registration.account_threepid_delegate_email,
-                email,
-                client_secret,
-                send_attempt,
-                next_link,
+                body.email,
+                body.client_secret,
+                body.send_attempt,
+                body.next_link,
             )
         else:
             # Send password reset emails from Synapse
             sid = await self.identity_handler.send_threepid_validation(
-                email,
-                client_secret,
-                send_attempt,
+                body.email,
+                body.client_secret,
+                body.send_attempt,
                 self.mailer.send_password_reset_mail,
-                next_link,
+                body.next_link,
             )
 
             # Wrap the session id in a JSON object
             ret = {"sid": sid}
 
         threepid_send_requests.labels(type="email", reason="password_reset").observe(
-            send_attempt
+            body.send_attempt
         )
 
         return 200, ret
@@ -172,16 +160,23 @@ class PasswordRestServlet(RestServlet):
         self.password_policy_handler = hs.get_password_policy_handler()
         self._set_password_handler = hs.get_set_password_handler()
 
+    class PostBody(RequestBodyModel):
+        auth: Optional[AuthenticationData] = None
+        logout_devices: StrictBool = True
+        if TYPE_CHECKING:
+            # workaround for https://github.com/samuelcolvin/pydantic/issues/156
+            new_password: Optional[StrictStr] = None
+        else:
+            new_password: Optional[constr(max_length=512, strict=True)] = None
+
     @interactive_auth_handler
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
-        body = parse_json_object_from_request(request)
+        body = parse_and_validate_json_object_from_request(request, self.PostBody)
 
         # we do basic sanity checks here because the auth layer will store these
         # in sessions. Pull out the new password provided to us.
-        new_password = body.pop("new_password", None)
+        new_password = body.new_password
         if new_password is not None:
-            if not isinstance(new_password, str) or len(new_password) > 512:
-                raise SynapseError(400, "Invalid password")
             self.password_policy_handler.validate_password(new_password)
 
         # there are two possibilities here. Either the user does not have an
@@ -201,7 +196,7 @@ class PasswordRestServlet(RestServlet):
                 params, session_id = await self.auth_handler.validate_user_via_ui_auth(
                     requester,
                     request,
-                    body,
+                    body.dict(),
                     "modify your account password",
                 )
             except InteractiveAuthIncompleteError as e:
@@ -224,7 +219,7 @@ class PasswordRestServlet(RestServlet):
                 result, params, session_id = await self.auth_handler.check_ui_auth(
                     [[LoginType.EMAIL_IDENTITY]],
                     request,
-                    body,
+                    body.dict(),
                     "modify your account password",
                 )
             except InteractiveAuthIncompleteError as e:
@@ -299,37 +294,33 @@ class DeactivateAccountRestServlet(RestServlet):
         self.auth_handler = hs.get_auth_handler()
         self._deactivate_account_handler = hs.get_deactivate_account_handler()
 
+    class PostBody(RequestBodyModel):
+        auth: Optional[AuthenticationData] = None
+        id_server: Optional[StrictStr] = None
+        # Not specced, see https://github.com/matrix-org/matrix-spec/issues/297
+        erase: StrictBool = False
+
     @interactive_auth_handler
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
-        body = parse_json_object_from_request(request)
-        erase = body.get("erase", False)
-        if not isinstance(erase, bool):
-            raise SynapseError(
-                HTTPStatus.BAD_REQUEST,
-                "Param 'erase' must be a boolean, if given",
-                Codes.BAD_JSON,
-            )
+        body = parse_and_validate_json_object_from_request(request, self.PostBody)
 
         requester = await self.auth.get_user_by_req(request)
 
         # allow ASes to deactivate their own users
         if requester.app_service:
             await self._deactivate_account_handler.deactivate_account(
-                requester.user.to_string(), erase, requester
+                requester.user.to_string(), body.erase, requester
             )
             return 200, {}
 
         await self.auth_handler.validate_user_via_ui_auth(
             requester,
             request,
-            body,
+            body.dict(),
             "deactivate your account",
         )
         result = await self._deactivate_account_handler.deactivate_account(
-            requester.user.to_string(),
-            erase,
-            requester,
-            id_server=body.get("id_server"),
+            requester.user.to_string(), body.erase, requester, id_server=body.id_server
         )
         if result:
             id_server_unbind_result = "success"
@@ -364,28 +355,15 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
                     "Adding emails have been disabled due to lack of an email config"
                 )
             raise SynapseError(
-                400, "Adding an email to your account is disabled on this server"
+                400,
+                "Adding an email to your account is disabled on this server",
             )
 
-        body = parse_json_object_from_request(request)
-        assert_params_in_dict(body, ["client_secret", "email", "send_attempt"])
-        client_secret = body["client_secret"]
-        assert_valid_client_secret(client_secret)
-
-        # Canonicalise the email address. The addresses are all stored canonicalised
-        # in the database.
-        # This ensures that the validation email is sent to the canonicalised address
-        # as it will later be entered into the database.
-        # Otherwise the email will be sent to "FOO@bar.com" and stored as
-        # "foo@bar.com" in database.
-        try:
-            email = validate_email(body["email"])
-        except ValueError as e:
-            raise SynapseError(400, str(e))
-        send_attempt = body["send_attempt"]
-        next_link = body.get("next_link")  # Optional param
+        body = parse_and_validate_json_object_from_request(
+            request, EmailRequestTokenBody
+        )
 
-        if not await check_3pid_allowed(self.hs, "email", email):
+        if not await check_3pid_allowed(self.hs, "email", body.email):
             raise SynapseError(
                 403,
                 "Your email domain is not authorized on this server",
@@ -393,14 +371,14 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
             )
 
         await self.identity_handler.ratelimit_request_token_requests(
-            request, "email", email
+            request, "email", body.email
         )
 
-        if next_link:
+        if body.next_link:
             # Raise if the provided next_link value isn't valid
-            assert_valid_next_link(self.hs, next_link)
+            assert_valid_next_link(self.hs, body.next_link)
 
-        existing_user_id = await self.store.get_user_id_by_threepid("email", email)
+        existing_user_id = await self.store.get_user_id_by_threepid("email", body.email)
 
         if existing_user_id is not None:
             if self.config.server.request_token_inhibit_3pid_errors:
@@ -419,26 +397,26 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
             # Have the configured identity server handle the request
             ret = await self.identity_handler.request_email_token(
                 self.hs.config.registration.account_threepid_delegate_email,
-                email,
-                client_secret,
-                send_attempt,
-                next_link,
+                body.email,
+                body.client_secret,
+                body.send_attempt,
+                body.next_link,
             )
         else:
             # Send threepid validation emails from Synapse
             sid = await self.identity_handler.send_threepid_validation(
-                email,
-                client_secret,
-                send_attempt,
+                body.email,
+                body.client_secret,
+                body.send_attempt,
                 self.mailer.send_add_threepid_mail,
-                next_link,
+                body.next_link,
             )
 
             # Wrap the session id in a JSON object
             ret = {"sid": sid}
 
         threepid_send_requests.labels(type="email", reason="add_threepid").observe(
-            send_attempt
+            body.send_attempt
         )
 
         return 200, ret
diff --git a/synapse/rest/client/models.py b/synapse/rest/client/models.py
new file mode 100644
index 0000000000..3150602997
--- /dev/null
+++ b/synapse/rest/client/models.py
@@ -0,0 +1,69 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING, Dict, Optional
+
+from pydantic import Extra, StrictInt, StrictStr, constr, validator
+
+from synapse.rest.models import RequestBodyModel
+from synapse.util.threepids import validate_email
+
+
+class AuthenticationData(RequestBodyModel):
+    """
+    Data used during user-interactive authentication.
+
+    (The name "Authentication Data" is taken directly from the spec.)
+
+    Additional keys will be present, depending on the `type` field. Use `.dict()` to
+    access them.
+    """
+
+    class Config:
+        extra = Extra.allow
+
+    session: Optional[StrictStr] = None
+    type: Optional[StrictStr] = None
+
+
+class EmailRequestTokenBody(RequestBodyModel):
+    if TYPE_CHECKING:
+        client_secret: StrictStr
+    else:
+        # See also assert_valid_client_secret()
+        client_secret: constr(
+            regex="[0-9a-zA-Z.=_-]",  # noqa: F722
+            min_length=0,
+            max_length=255,
+            strict=True,
+        )
+    email: StrictStr
+    id_server: Optional[StrictStr]
+    id_access_token: Optional[StrictStr]
+    next_link: Optional[StrictStr]
+    send_attempt: StrictInt
+
+    @validator("id_access_token", always=True)
+    def token_required_for_identity_server(
+        cls, token: Optional[str], values: Dict[str, object]
+    ) -> Optional[str]:
+        if values.get("id_server") is not None and token is None:
+            raise ValueError("id_access_token is required if an id_server is supplied.")
+        return token
+
+    # Canonicalise the email address. The addresses are all stored canonicalised
+    # in the database. This allows the user to reset his password without having to
+    # know the exact spelling (eg. upper and lower case) of address in the database.
+    # Without this, an email stored in the database as "foo@bar.com" would cause
+    # user requests for "FOO@bar.com" to raise a Not Found error.
+    _email_validator = validator("email", allow_reuse=True)(validate_email)
diff --git a/synapse/rest/models.py b/synapse/rest/models.py
new file mode 100644
index 0000000000..ac39cda8e5
--- /dev/null
+++ b/synapse/rest/models.py
@@ -0,0 +1,23 @@
+from pydantic import BaseModel, Extra
+
+
+class RequestBodyModel(BaseModel):
+    """A custom version of Pydantic's BaseModel which
+
+     - ignores unknown fields and
+     - does not allow fields to be overwritten after construction,
+
+    but otherwise uses Pydantic's default behaviour.
+
+    Ignoring unknown fields is a useful default. It means that clients can provide
+    unstable field not known to the server without the request being refused outright.
+
+    Subclassing in this way is recommended by
+    https://pydantic-docs.helpmanual.io/usage/model_config/#change-behaviour-globally
+    """
+
+    class Config:
+        # By default, ignore fields that we don't recognise.
+        extra = Extra.ignore
+        # By default, don't allow fields to be reassigned after parsing.
+        allow_mutation = False