summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/http/servlet.py70
-rw-r--r--synapse/rest/client/v1/directory.py16
-rw-r--r--synapse/rest/client/v1/login.py15
-rw-r--r--synapse/rest/client/v1/push_rule.py16
-rw-r--r--synapse/rest/client/v1/pusher.py17
-rw-r--r--synapse/rest/client/v1/register.py14
-rw-r--r--synapse/rest/client/v1/room.py26
-rw-r--r--synapse/rest/client/v2_alpha/_base.py22
-rw-r--r--synapse/rest/client/v2_alpha/account.py8
-rw-r--r--synapse/rest/client/v2_alpha/register.py8
-rw-r--r--synapse/rest/client/v2_alpha/tokenrefresh.py6
11 files changed, 97 insertions, 121 deletions
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 7bd87940b4..41c519c000 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -15,14 +15,27 @@
 
 """ This module contains base REST classes for constructing REST servlets. """
 
-from synapse.api.errors import SynapseError
+from synapse.api.errors import SynapseError, Codes
 
 import logging
+import simplejson
 
 logger = logging.getLogger(__name__)
 
 
 def parse_integer(request, name, default=None, required=False):
+    """Parse an integer parameter from the request string
+
+    :param request: the twisted HTTP request.
+    :param name (str): the name of the query parameter.
+    :param default: value to use if the parameter is absent, defaults to None.
+    :param required (bool): whether to raise a 400 SynapseError if the
+        parameter is absent, defaults to False.
+    :return: An int value or the default.
+    :raises
+        SynapseError if the parameter is absent and required, or if the
+            parameter is present and not an integer.
+    """
     if name in request.args:
         try:
             return int(request.args[name][0])
@@ -32,12 +45,25 @@ def parse_integer(request, name, default=None, required=False):
     else:
         if required:
             message = "Missing integer query parameter %r" % (name,)
-            raise SynapseError(400, message)
+            raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
         else:
             return default
 
 
 def parse_boolean(request, name, default=None, required=False):
+    """Parse a boolean parameter from the request query string
+
+    :param request: the twisted HTTP request.
+    :param name (str): the name of the query parameter.
+    :param default: value to use if the parameter is absent, defaults to None.
+    :param required (bool): whether to raise a 400 SynapseError if the
+        parameter is absent, defaults to False.
+    :return: A bool value or the default.
+    :raises
+        SynapseError if the parameter is absent and required, or if the
+            parameter is present and not one of "true" or "false".
+    """
+
     if name in request.args:
         try:
             return {
@@ -53,30 +79,64 @@ def parse_boolean(request, name, default=None, required=False):
     else:
         if required:
             message = "Missing boolean query parameter %r" % (name,)
-            raise SynapseError(400, message)
+            raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
         else:
             return default
 
 
 def parse_string(request, name, default=None, required=False,
                  allowed_values=None, param_type="string"):
+    """Parse a string parameter from the request query string.
+
+    :param request: the twisted HTTP request.
+    :param name (str): the name of the query parameter.
+    :param default: value to use if the parameter is absent, defaults to None.
+    :param required (bool): whether to raise a 400 SynapseError if the
+        parameter is absent, defaults to False.
+    :param allowed_values (list): List of allowed values for the string,
+        or None if any value is allowed, defaults to None
+    :return: A string value or the default.
+    :raises
+        SynapseError if the parameter is absent and required, or if the
+            parameter is present, must be one of a list of allowed values and
+            is not one of those allowed values.
+    """
+
     if name in request.args:
         value = request.args[name][0]
         if allowed_values is not None and value not in allowed_values:
             message = "Query parameter %r must be one of [%s]" % (
                 name, ", ".join(repr(v) for v in allowed_values)
             )
-            raise SynapseError(message)
+            raise SynapseError(400, message)
         else:
             return value
     else:
         if required:
             message = "Missing %s query parameter %r" % (param_type, name)
-            raise SynapseError(400, message)
+            raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
         else:
             return default
 
 
+def parse_json_object_from_request(request):
+    """Parse a JSON object from the body of a twisted HTTP request.
+
+    :param request: the twisted HTTP request.
+    :raises
+        SynapseError if the request body couldn't be decoded as JSON or
+            if it wasn't a JSON object.
+    """
+    try:
+        content = simplejson.loads(request.content.read())
+        if type(content) != dict:
+            message = "Content must be a JSON object."
+            raise SynapseError(400, message, errcode=Codes.BAD_JSON)
+        return content
+    except simplejson.JSONDecodeError:
+        raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
+
+
 class RestServlet(object):
 
     """ A Synapse REST Servlet.
diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py
index 8bfe9fdea8..60c5ec77aa 100644
--- a/synapse/rest/client/v1/directory.py
+++ b/synapse/rest/client/v1/directory.py
@@ -18,9 +18,10 @@ from twisted.internet import defer
 
 from synapse.api.errors import AuthError, SynapseError, Codes
 from synapse.types import RoomAlias
+from synapse.http.servlet import parse_json_object_from_request
+
 from .base import ClientV1RestServlet, client_path_patterns
 
-import simplejson as json
 import logging
 
 
@@ -45,7 +46,7 @@ class ClientDirectoryServer(ClientV1RestServlet):
 
     @defer.inlineCallbacks
     def on_PUT(self, request, room_alias):
-        content = _parse_json(request)
+        content = parse_json_object_from_request(request)
         if "room_id" not in content:
             raise SynapseError(400, "Missing room_id key",
                                errcode=Codes.BAD_JSON)
@@ -135,14 +136,3 @@ class ClientDirectoryServer(ClientV1RestServlet):
         )
 
         defer.returnValue((200, {}))
-
-
-def _parse_json(request):
-    try:
-        content = json.loads(request.content.read())
-        if type(content) != dict:
-            raise SynapseError(400, "Content must be a JSON object.",
-                               errcode=Codes.NOT_JSON)
-        return content
-    except ValueError:
-        raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index f6902a60a8..fe593d07ce 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -18,6 +18,7 @@ from twisted.internet import defer
 from synapse.api.errors import SynapseError, LoginError, Codes
 from synapse.types import UserID
 from synapse.http.server import finish_request
+from synapse.http.servlet import parse_json_object_from_request
 
 from .base import ClientV1RestServlet, client_path_patterns
 
@@ -79,7 +80,7 @@ class LoginRestServlet(ClientV1RestServlet):
 
     @defer.inlineCallbacks
     def on_POST(self, request):
-        login_submission = _parse_json(request)
+        login_submission = parse_json_object_from_request(request)
         try:
             if login_submission["type"] == LoginRestServlet.PASS_TYPE:
                 if not self.password_enabled:
@@ -400,18 +401,6 @@ class CasTicketServlet(ClientV1RestServlet):
         return (user, attributes)
 
 
-def _parse_json(request):
-    try:
-        content = json.loads(request.content.read())
-        if type(content) != dict:
-            raise SynapseError(
-                400, "Content must be a JSON object.", errcode=Codes.BAD_JSON
-            )
-        return content
-    except ValueError:
-        raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
-
-
 def register_servlets(hs, http_server):
     LoginRestServlet(hs).register(http_server)
     if hs.config.saml2_enabled:
diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py
index 981d7708db..b5695d427a 100644
--- a/synapse/rest/client/v1/push_rule.py
+++ b/synapse/rest/client/v1/push_rule.py
@@ -16,7 +16,7 @@
 from twisted.internet import defer
 
 from synapse.api.errors import (
-    SynapseError, Codes, UnrecognizedRequestError, NotFoundError, StoreError
+    SynapseError, UnrecognizedRequestError, NotFoundError, StoreError
 )
 from .base import ClientV1RestServlet, client_path_patterns
 from synapse.storage.push_rule import (
@@ -25,8 +25,7 @@ from synapse.storage.push_rule import (
 from synapse.push.clientformat import format_push_rules_for_user
 from synapse.push.baserules import BASE_RULE_IDS
 from synapse.push.rulekinds import PRIORITY_CLASS_MAP
-
-import simplejson as json
+from synapse.http.servlet import parse_json_object_from_request
 
 
 class PushRuleRestServlet(ClientV1RestServlet):
@@ -52,7 +51,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
         if '/' in spec['rule_id'] or '\\' in spec['rule_id']:
             raise SynapseError(400, "rule_id may not contain slashes")
 
-        content = _parse_json(request)
+        content = parse_json_object_from_request(request)
 
         user_id = requester.user.to_string()
 
@@ -341,14 +340,5 @@ class InvalidRuleException(Exception):
     pass
 
 
-# XXX: C+ped from rest/room.py - surely this should be common?
-def _parse_json(request):
-    try:
-        content = json.loads(request.content.read())
-        return content
-    except ValueError:
-        raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
-
-
 def register_servlets(hs, http_server):
     PushRuleRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py
index 4c662e6e3c..ee029b4f77 100644
--- a/synapse/rest/client/v1/pusher.py
+++ b/synapse/rest/client/v1/pusher.py
@@ -17,9 +17,10 @@ from twisted.internet import defer
 
 from synapse.api.errors import SynapseError, Codes
 from synapse.push import PusherConfigException
+from synapse.http.servlet import parse_json_object_from_request
+
 from .base import ClientV1RestServlet, client_path_patterns
 
-import simplejson as json
 import logging
 
 logger = logging.getLogger(__name__)
@@ -33,7 +34,7 @@ class PusherRestServlet(ClientV1RestServlet):
         requester = yield self.auth.get_user_by_req(request)
         user = requester.user
 
-        content = _parse_json(request)
+        content = parse_json_object_from_request(request)
 
         pusher_pool = self.hs.get_pusherpool()
 
@@ -92,17 +93,5 @@ class PusherRestServlet(ClientV1RestServlet):
         return 200, {}
 
 
-# XXX: C+ped from rest/room.py - surely this should be common?
-def _parse_json(request):
-    try:
-        content = json.loads(request.content.read())
-        if type(content) != dict:
-            raise SynapseError(400, "Content must be a JSON object.",
-                               errcode=Codes.NOT_JSON)
-        return content
-    except ValueError:
-        raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
-
-
 def register_servlets(hs, http_server):
     PusherRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py
index 040a7a7ffa..c6a2ef2ccc 100644
--- a/synapse/rest/client/v1/register.py
+++ b/synapse/rest/client/v1/register.py
@@ -20,12 +20,12 @@ from synapse.api.errors import SynapseError, Codes
 from synapse.api.constants import LoginType
 from .base import ClientV1RestServlet, client_path_patterns
 import synapse.util.stringutils as stringutils
+from synapse.http.servlet import parse_json_object_from_request
 
 from synapse.util.async import run_on_reactor
 
 from hashlib import sha1
 import hmac
-import simplejson as json
 import logging
 
 logger = logging.getLogger(__name__)
@@ -98,7 +98,7 @@ class RegisterRestServlet(ClientV1RestServlet):
 
     @defer.inlineCallbacks
     def on_POST(self, request):
-        register_json = _parse_json(request)
+        register_json = parse_json_object_from_request(request)
 
         session = (register_json["session"]
                    if "session" in register_json else None)
@@ -355,15 +355,5 @@ class RegisterRestServlet(ClientV1RestServlet):
             )
 
 
-def _parse_json(request):
-    try:
-        content = json.loads(request.content.read())
-        if type(content) != dict:
-            raise SynapseError(400, "Content must be a JSON object.")
-        return content
-    except ValueError:
-        raise SynapseError(400, "Content not JSON.")
-
-
 def register_servlets(hs, http_server):
     RegisterRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 4b7d198c52..7a9f3d11b9 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -22,6 +22,7 @@ from synapse.streams.config import PaginationConfig
 from synapse.api.constants import EventTypes, Membership
 from synapse.types import UserID, RoomID, RoomAlias
 from synapse.events.utils import serialize_event
+from synapse.http.servlet import parse_json_object_from_request
 
 import simplejson as json
 import logging
@@ -137,7 +138,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
     def on_PUT(self, request, room_id, event_type, state_key, txn_id=None):
         requester = yield self.auth.get_user_by_req(request)
 
-        content = _parse_json(request)
+        content = parse_json_object_from_request(request)
 
         event_dict = {
             "type": event_type,
@@ -179,7 +180,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
     @defer.inlineCallbacks
     def on_POST(self, request, room_id, event_type, txn_id=None):
         requester = yield self.auth.get_user_by_req(request, allow_guest=True)
-        content = _parse_json(request)
+        content = parse_json_object_from_request(request)
 
         msg_handler = self.handlers.message_handler
         event = yield msg_handler.create_and_send_nonmember_event(
@@ -229,7 +230,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
         )
 
         try:
-            content = _parse_json(request)
+            content = parse_json_object_from_request(request)
         except:
             # Turns out we used to ignore the body entirely, and some clients
             # cheekily send invalid bodies.
@@ -433,7 +434,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
             raise AuthError(403, "Guest access not allowed")
 
         try:
-            content = _parse_json(request)
+            content = parse_json_object_from_request(request)
         except:
             # Turns out we used to ignore the body entirely, and some clients
             # cheekily send invalid bodies.
@@ -500,7 +501,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
     @defer.inlineCallbacks
     def on_POST(self, request, room_id, event_id, txn_id=None):
         requester = yield self.auth.get_user_by_req(request)
-        content = _parse_json(request)
+        content = parse_json_object_from_request(request)
 
         msg_handler = self.handlers.message_handler
         event = yield msg_handler.create_and_send_nonmember_event(
@@ -548,7 +549,7 @@ class RoomTypingRestServlet(ClientV1RestServlet):
         room_id = urllib.unquote(room_id)
         target_user = UserID.from_string(urllib.unquote(user_id))
 
-        content = _parse_json(request)
+        content = parse_json_object_from_request(request)
 
         typing_handler = self.handlers.typing_notification_handler
 
@@ -580,7 +581,7 @@ class SearchRestServlet(ClientV1RestServlet):
     def on_POST(self, request):
         requester = yield self.auth.get_user_by_req(request)
 
-        content = _parse_json(request)
+        content = parse_json_object_from_request(request)
 
         batch = request.args.get("next_batch", [None])[0]
         results = yield self.handlers.search_handler.search(
@@ -592,17 +593,6 @@ class SearchRestServlet(ClientV1RestServlet):
         defer.returnValue((200, results))
 
 
-def _parse_json(request):
-    try:
-        content = json.loads(request.content.read())
-        if type(content) != dict:
-            raise SynapseError(400, "Content must be a JSON object.",
-                               errcode=Codes.NOT_JSON)
-        return content
-    except ValueError:
-        raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
-
-
 def register_txn_path(servlet, regex_string, http_server, with_get=False):
     """Registers a transaction-based path.
 
diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py
index 24af322126..b6faa2b0e6 100644
--- a/synapse/rest/client/v2_alpha/_base.py
+++ b/synapse/rest/client/v2_alpha/_base.py
@@ -17,11 +17,9 @@
 """
 
 from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
-from synapse.api.errors import SynapseError
 import re
 
 import logging
-import simplejson
 
 
 logger = logging.getLogger(__name__)
@@ -44,23 +42,3 @@ def client_v2_patterns(path_regex, releases=(0,)):
         new_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/r%d" % release)
         patterns.append(re.compile("^" + new_prefix + path_regex))
     return patterns
-
-
-def parse_request_allow_empty(request):
-    content = request.content.read()
-    if content is None or content == '':
-        return None
-    try:
-        return simplejson.loads(content)
-    except simplejson.JSONDecodeError:
-        raise SynapseError(400, "Content not JSON.")
-
-
-def parse_json_dict_from_request(request):
-    try:
-        content = simplejson.loads(request.content.read())
-        if type(content) != dict:
-            raise SynapseError(400, "Content must be a JSON object.")
-        return content
-    except simplejson.JSONDecodeError:
-        raise SynapseError(400, "Content not JSON.")
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index a614b79d45..688b051580 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -17,10 +17,10 @@ from twisted.internet import defer
 
 from synapse.api.constants import LoginType
 from synapse.api.errors import LoginError, SynapseError, Codes
-from synapse.http.servlet import RestServlet
+from synapse.http.servlet import RestServlet, parse_json_object_from_request
 from synapse.util.async import run_on_reactor
 
-from ._base import client_v2_patterns, parse_json_dict_from_request
+from ._base import client_v2_patterns
 
 import logging
 
@@ -41,7 +41,7 @@ class PasswordRestServlet(RestServlet):
     def on_POST(self, request):
         yield run_on_reactor()
 
-        body = parse_json_dict_from_request(request)
+        body = parse_json_object_from_request(request)
 
         authed, result, params = yield self.auth_handler.check_auth([
             [LoginType.PASSWORD],
@@ -114,7 +114,7 @@ class ThreepidRestServlet(RestServlet):
     def on_POST(self, request):
         yield run_on_reactor()
 
-        body = parse_json_dict_from_request(request)
+        body = parse_json_object_from_request(request)
 
         threePidCreds = body.get('threePidCreds')
         threePidCreds = body.get('three_pid_creds', threePidCreds)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index ec5c21fa1f..b090e66bcf 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -17,9 +17,9 @@ from twisted.internet import defer
 
 from synapse.api.constants import LoginType
 from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError
-from synapse.http.servlet import RestServlet
+from synapse.http.servlet import RestServlet, parse_json_object_from_request
 
-from ._base import client_v2_patterns, parse_json_dict_from_request
+from ._base import client_v2_patterns
 
 import logging
 import hmac
@@ -73,7 +73,7 @@ class RegisterRestServlet(RestServlet):
             ret = yield self.onEmailTokenRequest(request)
             defer.returnValue(ret)
 
-        body = parse_json_dict_from_request(request)
+        body = parse_json_object_from_request(request)
 
         # we do basic sanity checks here because the auth layer will store these
         # in sessions. Pull out the username/password provided to us.
@@ -236,7 +236,7 @@ class RegisterRestServlet(RestServlet):
 
     @defer.inlineCallbacks
     def onEmailTokenRequest(self, request):
-        body = parse_json_dict_from_request(request)
+        body = parse_json_object_from_request(request)
 
         required = ['id_server', 'client_secret', 'email', 'send_attempt']
         absent = []
diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/v2_alpha/tokenrefresh.py
index 3553f6b040..a158c2209a 100644
--- a/synapse/rest/client/v2_alpha/tokenrefresh.py
+++ b/synapse/rest/client/v2_alpha/tokenrefresh.py
@@ -16,9 +16,9 @@
 from twisted.internet import defer
 
 from synapse.api.errors import AuthError, StoreError, SynapseError
-from synapse.http.servlet import RestServlet
+from synapse.http.servlet import RestServlet, parse_json_object_from_request
 
-from ._base import client_v2_patterns, parse_json_dict_from_request
+from ._base import client_v2_patterns
 
 
 class TokenRefreshRestServlet(RestServlet):
@@ -35,7 +35,7 @@ class TokenRefreshRestServlet(RestServlet):
 
     @defer.inlineCallbacks
     def on_POST(self, request):
-        body = parse_json_dict_from_request(request)
+        body = parse_json_object_from_request(request)
         try:
             old_refresh_token = body["refresh_token"]
             auth_handler = self.hs.get_handlers().auth_handler