diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index fd90ba7828..622c47131e 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -17,6 +17,9 @@
import logging
+import jsonschema
+from jsonschema.exceptions import best_match
+
from synapse.api.errors import Codes, SynapseError
from synapse.util import json_decoder
@@ -222,7 +225,7 @@ def parse_json_value_from_request(request, allow_empty_body=False):
return content
-def parse_json_object_from_request(request, allow_empty_body=False):
+def parse_json_object_from_request(request, validator=None, allow_empty_body=False):
"""Parse a JSON object from the body of a twisted HTTP request.
Args:
@@ -237,9 +240,14 @@ def parse_json_object_from_request(request, allow_empty_body=False):
content = parse_json_value_from_request(request, allow_empty_body=allow_empty_body)
if allow_empty_body and content is None:
- return {}
+ content = {}
+
+ if validator:
+ error = best_match(validator.iter_errors(content))
+ if error:
+ raise SynapseError(400, error.message, errcode=Codes.BAD_JSON)
- if type(content) != dict:
+ elif type(content) != dict:
message = "Content must be a JSON object."
raise SynapseError(400, message, errcode=Codes.BAD_JSON)
@@ -291,5 +299,13 @@ class RestServlet:
method, patterns, method_handler, servlet_classname
)
+ if hasattr(self, "%s_SCHEMA" % (method,)):
+ schema = getattr(self, "%s_SCHEMA" % (method,))
+ setattr(
+ self,
+ "%s_VALIDATOR" % (method,),
+ jsonschema.Draft7Validator(schema),
+ )
+
else:
raise NotImplementedError("RestServlet must register something.")
diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py
index e7fe50ed72..78a3bebc7f 100644
--- a/synapse/rest/client/v1/profile.py
+++ b/synapse/rest/client/v1/profile.py
@@ -24,6 +24,14 @@ from synapse.types import UserID
class ProfileDisplaynameRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/displayname", v1=True)
+ PUT_SCHEMA = {
+ "type": "object",
+ "properties": {
+ "displayname": {"oneOf": [{"type": "string"}, {"type": "null"}]},
+ "required": ["displayname"],
+ },
+ }
+
def __init__(self, hs):
super(ProfileDisplaynameRestServlet, self).__init__()
self.hs = hs
@@ -54,12 +62,8 @@ class ProfileDisplaynameRestServlet(RestServlet):
user = UserID.from_string(user_id)
is_admin = await self.auth.is_server_admin(requester.user)
- content = parse_json_object_from_request(request)
-
- try:
- new_name = content["displayname"]
- except Exception:
- return 400, "Unable to parse name"
+ content = parse_json_object_from_request(request, validator=self.PUT_VALIDATOR)
+ new_name = content["displayname"]
await self.profile_handler.set_displayname(user, requester, new_name, is_admin)
|