summary refs log tree commit diff
path: root/synapse/http/servlet.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/http/servlet.py')
-rw-r--r--synapse/http/servlet.py22
1 files changed, 19 insertions, 3 deletions
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index fd90ba7828..622c47131e 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -17,6 +17,9 @@
 
 import logging
 
+import jsonschema
+from jsonschema.exceptions import best_match
+
 from synapse.api.errors import Codes, SynapseError
 from synapse.util import json_decoder
 
@@ -222,7 +225,7 @@ def parse_json_value_from_request(request, allow_empty_body=False):
     return content
 
 
-def parse_json_object_from_request(request, allow_empty_body=False):
+def parse_json_object_from_request(request, validator=None, allow_empty_body=False):
     """Parse a JSON object from the body of a twisted HTTP request.
 
     Args:
@@ -237,9 +240,14 @@ def parse_json_object_from_request(request, allow_empty_body=False):
     content = parse_json_value_from_request(request, allow_empty_body=allow_empty_body)
 
     if allow_empty_body and content is None:
-        return {}
+        content = {}
+
+    if validator:
+        error = best_match(validator.iter_errors(content))
+        if error:
+            raise SynapseError(400, error.message, errcode=Codes.BAD_JSON)
 
-    if type(content) != dict:
+    elif type(content) != dict:
         message = "Content must be a JSON object."
         raise SynapseError(400, message, errcode=Codes.BAD_JSON)
 
@@ -291,5 +299,13 @@ class RestServlet:
                         method, patterns, method_handler, servlet_classname
                     )
 
+                if hasattr(self, "%s_SCHEMA" % (method,)):
+                    schema = getattr(self, "%s_SCHEMA" % (method,))
+                    setattr(
+                        self,
+                        "%s_VALIDATOR" % (method,),
+                        jsonschema.Draft7Validator(schema),
+                    )
+
         else:
             raise NotImplementedError("RestServlet must register something.")