diff options
-rw-r--r-- | synapse/http/servlet.py | 22 | ||||
-rw-r--r-- | synapse/rest/client/v1/profile.py | 16 |
2 files changed, 29 insertions, 9 deletions
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index fd90ba7828..622c47131e 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -17,6 +17,9 @@ import logging +import jsonschema +from jsonschema.exceptions import best_match + from synapse.api.errors import Codes, SynapseError from synapse.util import json_decoder @@ -222,7 +225,7 @@ def parse_json_value_from_request(request, allow_empty_body=False): return content -def parse_json_object_from_request(request, allow_empty_body=False): +def parse_json_object_from_request(request, validator=None, allow_empty_body=False): """Parse a JSON object from the body of a twisted HTTP request. Args: @@ -237,9 +240,14 @@ def parse_json_object_from_request(request, allow_empty_body=False): content = parse_json_value_from_request(request, allow_empty_body=allow_empty_body) if allow_empty_body and content is None: - return {} + content = {} + + if validator: + error = best_match(validator.iter_errors(content)) + if error: + raise SynapseError(400, error.message, errcode=Codes.BAD_JSON) - if type(content) != dict: + elif type(content) != dict: message = "Content must be a JSON object." raise SynapseError(400, message, errcode=Codes.BAD_JSON) @@ -291,5 +299,13 @@ class RestServlet: method, patterns, method_handler, servlet_classname ) + if hasattr(self, "%s_SCHEMA" % (method,)): + schema = getattr(self, "%s_SCHEMA" % (method,)) + setattr( + self, + "%s_VALIDATOR" % (method,), + jsonschema.Draft7Validator(schema), + ) + else: raise NotImplementedError("RestServlet must register something.") diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py index e7fe50ed72..78a3bebc7f 100644 --- a/synapse/rest/client/v1/profile.py +++ b/synapse/rest/client/v1/profile.py @@ -24,6 +24,14 @@ from synapse.types import UserID class ProfileDisplaynameRestServlet(RestServlet): PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/displayname", v1=True) + PUT_SCHEMA = { + "type": "object", + "properties": { + "displayname": {"oneOf": [{"type": "string"}, {"type": "null"}]}, + "required": ["displayname"], + }, + } + def __init__(self, hs): super(ProfileDisplaynameRestServlet, self).__init__() self.hs = hs @@ -54,12 +62,8 @@ class ProfileDisplaynameRestServlet(RestServlet): user = UserID.from_string(user_id) is_admin = await self.auth.is_server_admin(requester.user) - content = parse_json_object_from_request(request) - - try: - new_name = content["displayname"] - except Exception: - return 400, "Unable to parse name" + content = parse_json_object_from_request(request, validator=self.PUT_VALIDATOR) + new_name = content["displayname"] await self.profile_handler.set_displayname(user, requester, new_name, is_admin) |