diff options
Diffstat (limited to 'synapse/rest/client/v1/register.py')
-rw-r--r-- | synapse/rest/client/v1/register.py | 43 |
1 files changed, 15 insertions, 28 deletions
diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index c10320dedf..25a143af8d 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -14,21 +14,19 @@ # limitations under the License. """This module contains REST servlets to do with registration: /register""" +import hmac +import logging +from hashlib import sha1 + from twisted.internet import defer -from synapse.api.errors import SynapseError, Codes -from synapse.api.constants import LoginType -from synapse.api.auth import get_access_token_from_request -from .base import ClientV1RestServlet, client_path_patterns import synapse.util.stringutils as stringutils -from synapse.http.servlet import parse_json_object_from_request +from synapse.api.constants import LoginType +from synapse.api.errors import Codes, SynapseError +from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request from synapse.types import create_requester -from hashlib import sha1 -import hmac -import logging - -from six import string_types +from .base import ClientV1RestServlet, client_path_patterns logger = logging.getLogger(__name__) @@ -66,6 +64,7 @@ class RegisterRestServlet(ClientV1RestServlet): # TODO: persistent storage self.sessions = {} self.enable_registration = hs.config.enable_registration + self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() self.handlers = hs.get_handlers() @@ -123,8 +122,7 @@ class RegisterRestServlet(ClientV1RestServlet): session = (register_json["session"] if "session" in register_json else None) login_type = None - if "type" not in register_json: - raise SynapseError(400, "Missing 'type' key.") + assert_params_in_dict(register_json, ["type"]) try: login_type = register_json["type"] @@ -309,11 +307,9 @@ class RegisterRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def _do_app_service(self, request, register_json, session): - as_token = get_access_token_from_request(request) - - if "user" not in register_json: - raise SynapseError(400, "Expected 'user' key.") + as_token = self.auth.get_access_token_from_request(request) + assert_params_in_dict(register_json, ["user"]) user_localpart = register_json["user"].encode("utf-8") handler = self.handlers.registration_handler @@ -330,12 +326,7 @@ class RegisterRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def _do_shared_secret(self, request, register_json, session): - if not isinstance(register_json.get("mac", None), string_types): - raise SynapseError(400, "Expected mac.") - if not isinstance(register_json.get("user", None), string_types): - raise SynapseError(400, "Expected 'user' key.") - if not isinstance(register_json.get("password", None), string_types): - raise SynapseError(400, "Expected 'password' key.") + assert_params_in_dict(register_json, ["mac", "user", "password"]) if not self.hs.config.registration_shared_secret: raise SynapseError(400, "Shared secret registration is not enabled") @@ -399,7 +390,7 @@ class CreateUserRestServlet(ClientV1RestServlet): def on_POST(self, request): user_json = parse_json_object_from_request(request) - access_token = get_access_token_from_request(request) + access_token = self.auth.get_access_token_from_request(request) app_service = self.store.get_app_service_by_token( access_token ) @@ -418,11 +409,7 @@ class CreateUserRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def _do_create(self, requester, user_json): - if "localpart" not in user_json: - raise SynapseError(400, "Expected 'localpart' key.") - - if "displayname" not in user_json: - raise SynapseError(400, "Expected 'displayname' key.") + assert_params_in_dict(user_json, ["localpart", "displayname"]) localpart = user_json["localpart"].encode("utf-8") displayname = user_json["displayname"].encode("utf-8") |