summary refs log tree commit diff
path: root/synapse/rest/client
diff options
context:
space:
mode:
authorMark Haines <mark.haines@matrix.org>2015-08-12 17:07:22 +0100
committerMark Haines <mark.haines@matrix.org>2015-08-12 17:21:14 +0100
commit998a72d4d9ec6e73000888dcdf51437ec427fbee (patch)
treef8fa9d5deb820b49eb3a216e194ee83e14fb9eda /synapse/rest/client
parentBump the version of twisted needed for setup_requires to 15.2.1 (diff)
parentMerge pull request #220 from matrix-org/markjh/generate_keys (diff)
downloadsynapse-998a72d4d9ec6e73000888dcdf51437ec427fbee.tar.xz
Merge branch 'develop' into markjh/twisted-15
Conflicts:
	synapse/http/matrixfederationclient.py
Diffstat (limited to 'synapse/rest/client')
-rw-r--r--synapse/rest/client/v1/login.py75
-rw-r--r--synapse/rest/client/v1/room.py2
-rw-r--r--synapse/rest/client/v1/transactions.py6
-rw-r--r--synapse/rest/client/v2_alpha/__init__.py6
-rw-r--r--synapse/rest/client/v2_alpha/keys.py276
-rw-r--r--synapse/rest/client/v2_alpha/receipts.py55
-rw-r--r--synapse/rest/client/v2_alpha/register.py132
7 files changed, 495 insertions, 57 deletions
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index b2257b749d..998d4d44c6 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -20,14 +20,32 @@ from synapse.types import UserID
 from base import ClientV1RestServlet, client_path_pattern
 
 import simplejson as json
+import urllib
+
+import logging
+from saml2 import BINDING_HTTP_POST
+from saml2 import config
+from saml2.client import Saml2Client
+
+
+logger = logging.getLogger(__name__)
 
 
 class LoginRestServlet(ClientV1RestServlet):
     PATTERN = client_path_pattern("/login$")
     PASS_TYPE = "m.login.password"
+    SAML2_TYPE = "m.login.saml2"
+
+    def __init__(self, hs):
+        super(LoginRestServlet, self).__init__(hs)
+        self.idp_redirect_url = hs.config.saml2_idp_redirect_url
+        self.saml2_enabled = hs.config.saml2_enabled
 
     def on_GET(self, request):
-        return (200, {"flows": [{"type": LoginRestServlet.PASS_TYPE}]})
+        flows = [{"type": LoginRestServlet.PASS_TYPE}]
+        if self.saml2_enabled:
+            flows.append({"type": LoginRestServlet.SAML2_TYPE})
+        return (200, {"flows": flows})
 
     def on_OPTIONS(self, request):
         return (200, {})
@@ -39,6 +57,16 @@ class LoginRestServlet(ClientV1RestServlet):
             if login_submission["type"] == LoginRestServlet.PASS_TYPE:
                 result = yield self.do_password_login(login_submission)
                 defer.returnValue(result)
+            elif self.saml2_enabled and (login_submission["type"] ==
+                                         LoginRestServlet.SAML2_TYPE):
+                relay_state = ""
+                if "relay_state" in login_submission:
+                    relay_state = "&RelayState="+urllib.quote(
+                                  login_submission["relay_state"])
+                result = {
+                    "uri": "%s%s" % (self.idp_redirect_url, relay_state)
+                }
+                defer.returnValue((200, result))
             else:
                 raise SynapseError(400, "Bad login type.")
         except KeyError:
@@ -94,6 +122,49 @@ class PasswordResetRestServlet(ClientV1RestServlet):
             )
 
 
+class SAML2RestServlet(ClientV1RestServlet):
+    PATTERN = client_path_pattern("/login/saml2")
+
+    def __init__(self, hs):
+        super(SAML2RestServlet, self).__init__(hs)
+        self.sp_config = hs.config.saml2_config_path
+
+    @defer.inlineCallbacks
+    def on_POST(self, request):
+        saml2_auth = None
+        try:
+            conf = config.SPConfig()
+            conf.load_file(self.sp_config)
+            SP = Saml2Client(conf)
+            saml2_auth = SP.parse_authn_request_response(
+                request.args['SAMLResponse'][0], BINDING_HTTP_POST)
+        except Exception, e:        # Not authenticated
+            logger.exception(e)
+        if saml2_auth and saml2_auth.status_ok() and not saml2_auth.not_signed:
+            username = saml2_auth.name_id.text
+            handler = self.handlers.registration_handler
+            (user_id, token) = yield handler.register_saml2(username)
+            # Forward to the RelayState callback along with ava
+            if 'RelayState' in request.args:
+                request.redirect(urllib.unquote(
+                                 request.args['RelayState'][0]) +
+                                 '?status=authenticated&access_token=' +
+                                 token + '&user_id=' + user_id + '&ava=' +
+                                 urllib.quote(json.dumps(saml2_auth.ava)))
+                request.finish()
+                defer.returnValue(None)
+            defer.returnValue((200, {"status": "authenticated",
+                                     "user_id": user_id, "token": token,
+                                     "ava": saml2_auth.ava}))
+        elif 'RelayState' in request.args:
+            request.redirect(urllib.unquote(
+                             request.args['RelayState'][0]) +
+                             '?status=not_authenticated')
+            request.finish()
+            defer.returnValue(None)
+        defer.returnValue((200, {"status": "not_authenticated"}))
+
+
 def _parse_json(request):
     try:
         content = json.loads(request.content.read())
@@ -106,4 +177,6 @@ def _parse_json(request):
 
 def register_servlets(hs, http_server):
     LoginRestServlet(hs).register(http_server)
+    if hs.config.saml2_enabled:
+        SAML2RestServlet(hs).register(http_server)
     # TODO PasswordResetRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 0346afb1b4..b4a70cba99 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -412,6 +412,8 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
             if "user_id" not in content:
                 raise SynapseError(400, "Missing user_id key.")
             state_key = content["user_id"]
+            # make sure it looks like a user ID; it'll throw if it's invalid.
+            UserID.from_string(state_key)
 
             if membership_action == "kick":
                 membership_action = "leave"
diff --git a/synapse/rest/client/v1/transactions.py b/synapse/rest/client/v1/transactions.py
index d933fea18a..b861069b89 100644
--- a/synapse/rest/client/v1/transactions.py
+++ b/synapse/rest/client/v1/transactions.py
@@ -39,10 +39,10 @@ class HttpTransactionStore(object):
             A tuple of (HTTP response code, response content) or None.
         """
         try:
-            logger.debug("get_response Key: %s TxnId: %s", key, txn_id)
+            logger.debug("get_response TxnId: %s", txn_id)
             (last_txn_id, response) = self.transactions[key]
             if txn_id == last_txn_id:
-                logger.info("get_response: Returning a response for %s", key)
+                logger.info("get_response: Returning a response for %s", txn_id)
                 return response
         except KeyError:
             pass
@@ -58,7 +58,7 @@ class HttpTransactionStore(object):
             txn_id (str): The transaction ID for this request.
             response (tuple): A tuple of (HTTP response code, response content)
         """
-        logger.debug("store_response Key: %s TxnId: %s", key, txn_id)
+        logger.debug("store_response TxnId: %s", txn_id)
         self.transactions[key] = (txn_id, response)
 
     def store_client_transaction(self, request, txn_id, response):
diff --git a/synapse/rest/client/v2_alpha/__init__.py b/synapse/rest/client/v2_alpha/__init__.py
index 7d1aff4307..33f961e898 100644
--- a/synapse/rest/client/v2_alpha/__init__.py
+++ b/synapse/rest/client/v2_alpha/__init__.py
@@ -18,7 +18,9 @@ from . import (
     filter,
     account,
     register,
-    auth
+    auth,
+    receipts,
+    keys,
 )
 
 from synapse.http.server import JsonResource
@@ -38,3 +40,5 @@ class ClientV2AlphaRestResource(JsonResource):
         account.register_servlets(hs, client_resource)
         register.register_servlets(hs, client_resource)
         auth.register_servlets(hs, client_resource)
+        receipts.register_servlets(hs, client_resource)
+        keys.register_servlets(hs, client_resource)
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
new file mode 100644
index 0000000000..5f3a6207b5
--- /dev/null
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -0,0 +1,276 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from synapse.api.errors import SynapseError
+from synapse.http.servlet import RestServlet
+from syutil.jsonutil import encode_canonical_json
+
+from ._base import client_v2_pattern
+
+import simplejson as json
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class KeyUploadServlet(RestServlet):
+    """
+    POST /keys/upload/<device_id> HTTP/1.1
+    Content-Type: application/json
+
+    {
+      "device_keys": {
+        "user_id": "<user_id>",
+        "device_id": "<device_id>",
+        "valid_until_ts": <millisecond_timestamp>,
+        "algorithms": [
+          "m.olm.curve25519-aes-sha256",
+        ]
+        "keys": {
+          "<algorithm>:<device_id>": "<key_base64>",
+        },
+        "signatures:" {
+          "<user_id>" {
+            "<algorithm>:<device_id>": "<signature_base64>"
+      } } },
+      "one_time_keys": {
+        "<algorithm>:<key_id>": "<key_base64>"
+      },
+    }
+    """
+    PATTERN = client_v2_pattern("/keys/upload/(?P<device_id>[^/]*)")
+
+    def __init__(self, hs):
+        super(KeyUploadServlet, self).__init__()
+        self.store = hs.get_datastore()
+        self.clock = hs.get_clock()
+        self.auth = hs.get_auth()
+
+    @defer.inlineCallbacks
+    def on_POST(self, request, device_id):
+        auth_user, client_info = yield self.auth.get_user_by_req(request)
+        user_id = auth_user.to_string()
+        # TODO: Check that the device_id matches that in the authentication
+        # or derive the device_id from the authentication instead.
+        try:
+            body = json.loads(request.content.read())
+        except:
+            raise SynapseError(400, "Invalid key JSON")
+        time_now = self.clock.time_msec()
+
+        # TODO: Validate the JSON to make sure it has the right keys.
+        device_keys = body.get("device_keys", None)
+        if device_keys:
+            logger.info(
+                "Updating device_keys for device %r for user %r at %d",
+                device_id, auth_user, time_now
+            )
+            # TODO: Sign the JSON with the server key
+            yield self.store.set_e2e_device_keys(
+                user_id, device_id, time_now,
+                encode_canonical_json(device_keys)
+            )
+
+        one_time_keys = body.get("one_time_keys", None)
+        if one_time_keys:
+            logger.info(
+                "Adding %d one_time_keys for device %r for user %r at %d",
+                len(one_time_keys), device_id, user_id, time_now
+            )
+            key_list = []
+            for key_id, key_json in one_time_keys.items():
+                algorithm, key_id = key_id.split(":")
+                key_list.append((
+                    algorithm, key_id, encode_canonical_json(key_json)
+                ))
+
+            yield self.store.add_e2e_one_time_keys(
+                user_id, device_id, time_now, key_list
+            )
+
+        result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
+        defer.returnValue((200, {"one_time_key_counts": result}))
+
+    @defer.inlineCallbacks
+    def on_GET(self, request, device_id):
+        auth_user, client_info = yield self.auth.get_user_by_req(request)
+        user_id = auth_user.to_string()
+
+        result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
+        defer.returnValue((200, {"one_time_key_counts": result}))
+
+
+class KeyQueryServlet(RestServlet):
+    """
+    GET /keys/query/<user_id> HTTP/1.1
+
+    GET /keys/query/<user_id>/<device_id> HTTP/1.1
+
+    POST /keys/query HTTP/1.1
+    Content-Type: application/json
+    {
+      "device_keys": {
+        "<user_id>": ["<device_id>"]
+    } }
+
+    HTTP/1.1 200 OK
+    {
+      "device_keys": {
+        "<user_id>": {
+          "<device_id>": {
+            "user_id": "<user_id>", // Duplicated to be signed
+            "device_id": "<device_id>", // Duplicated to be signed
+            "valid_until_ts": <millisecond_timestamp>,
+            "algorithms": [ // List of supported algorithms
+              "m.olm.curve25519-aes-sha256",
+            ],
+            "keys": { // Must include a ed25519 signing key
+              "<algorithm>:<key_id>": "<key_base64>",
+            },
+            "signatures:" {
+              // Must be signed with device's ed25519 key
+              "<user_id>/<device_id>": {
+                "<algorithm>:<key_id>": "<signature_base64>"
+              }
+              // Must be signed by this server.
+              "<server_name>": {
+                "<algorithm>:<key_id>": "<signature_base64>"
+    } } } } } }
+    """
+
+    PATTERN = client_v2_pattern(
+        "/keys/query(?:"
+        "/(?P<user_id>[^/]*)(?:"
+        "/(?P<device_id>[^/]*)"
+        ")?"
+        ")?"
+    )
+
+    def __init__(self, hs):
+        super(KeyQueryServlet, self).__init__()
+        self.store = hs.get_datastore()
+        self.auth = hs.get_auth()
+
+    @defer.inlineCallbacks
+    def on_POST(self, request, user_id, device_id):
+        logger.debug("onPOST")
+        yield self.auth.get_user_by_req(request)
+        try:
+            body = json.loads(request.content.read())
+        except:
+            raise SynapseError(400, "Invalid key JSON")
+        query = []
+        for user_id, device_ids in body.get("device_keys", {}).items():
+            if not device_ids:
+                query.append((user_id, None))
+            else:
+                for device_id in device_ids:
+                    query.append((user_id, device_id))
+        results = yield self.store.get_e2e_device_keys(query)
+        defer.returnValue(self.json_result(request, results))
+
+    @defer.inlineCallbacks
+    def on_GET(self, request, user_id, device_id):
+        auth_user, client_info = yield self.auth.get_user_by_req(request)
+        auth_user_id = auth_user.to_string()
+        if not user_id:
+            user_id = auth_user_id
+        if not device_id:
+            device_id = None
+        # Returns a map of user_id->device_id->json_bytes.
+        results = yield self.store.get_e2e_device_keys([(user_id, device_id)])
+        defer.returnValue(self.json_result(request, results))
+
+    def json_result(self, request, results):
+        json_result = {}
+        for user_id, device_keys in results.items():
+            for device_id, json_bytes in device_keys.items():
+                json_result.setdefault(user_id, {})[device_id] = json.loads(
+                    json_bytes
+                )
+        return (200, {"device_keys": json_result})
+
+
+class OneTimeKeyServlet(RestServlet):
+    """
+    GET /keys/claim/<user-id>/<device-id>/<algorithm> HTTP/1.1
+
+    POST /keys/claim HTTP/1.1
+    {
+      "one_time_keys": {
+        "<user_id>": {
+          "<device_id>": "<algorithm>"
+    } } }
+
+    HTTP/1.1 200 OK
+    {
+      "one_time_keys": {
+        "<user_id>": {
+          "<device_id>": {
+            "<algorithm>:<key_id>": "<key_base64>"
+    } } } }
+
+    """
+    PATTERN = client_v2_pattern(
+        "/keys/claim(?:/?|(?:/"
+        "(?P<user_id>[^/]*)/(?P<device_id>[^/]*)/(?P<algorithm>[^/]*)"
+        ")?)"
+    )
+
+    def __init__(self, hs):
+        super(OneTimeKeyServlet, self).__init__()
+        self.store = hs.get_datastore()
+        self.auth = hs.get_auth()
+        self.clock = hs.get_clock()
+
+    @defer.inlineCallbacks
+    def on_GET(self, request, user_id, device_id, algorithm):
+        yield self.auth.get_user_by_req(request)
+        results = yield self.store.claim_e2e_one_time_keys(
+            [(user_id, device_id, algorithm)]
+        )
+        defer.returnValue(self.json_result(request, results))
+
+    @defer.inlineCallbacks
+    def on_POST(self, request, user_id, device_id, algorithm):
+        yield self.auth.get_user_by_req(request)
+        try:
+            body = json.loads(request.content.read())
+        except:
+            raise SynapseError(400, "Invalid key JSON")
+        query = []
+        for user_id, device_keys in body.get("one_time_keys", {}).items():
+            for device_id, algorithm in device_keys.items():
+                query.append((user_id, device_id, algorithm))
+        results = yield self.store.claim_e2e_one_time_keys(query)
+        defer.returnValue(self.json_result(request, results))
+
+    def json_result(self, request, results):
+        json_result = {}
+        for user_id, device_keys in results.items():
+            for device_id, keys in device_keys.items():
+                for key_id, json_bytes in keys.items():
+                    json_result.setdefault(user_id, {})[device_id] = {
+                        key_id: json.loads(json_bytes)
+                    }
+        return (200, {"one_time_keys": json_result})
+
+
+def register_servlets(hs, http_server):
+    KeyUploadServlet(hs).register(http_server)
+    KeyQueryServlet(hs).register(http_server)
+    OneTimeKeyServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py
new file mode 100644
index 0000000000..40406e2ede
--- /dev/null
+++ b/synapse/rest/client/v2_alpha/receipts.py
@@ -0,0 +1,55 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from synapse.http.servlet import RestServlet
+from ._base import client_v2_pattern
+
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+class ReceiptRestServlet(RestServlet):
+    PATTERN = client_v2_pattern(
+        "/rooms/(?P<room_id>[^/]*)"
+        "/receipt/(?P<receipt_type>[^/]*)"
+        "/(?P<event_id>[^/]*)$"
+    )
+
+    def __init__(self, hs):
+        super(ReceiptRestServlet, self).__init__()
+        self.hs = hs
+        self.auth = hs.get_auth()
+        self.receipts_handler = hs.get_handlers().receipts_handler
+
+    @defer.inlineCallbacks
+    def on_POST(self, request, room_id, receipt_type, event_id):
+        user, client = yield self.auth.get_user_by_req(request)
+
+        yield self.receipts_handler.received_client_receipt(
+            room_id,
+            receipt_type,
+            user_id=user.to_string(),
+            event_id=event_id
+        )
+
+        defer.returnValue((200, {}))
+
+
+def register_servlets(hs, http_server):
+    ReceiptRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 72dfb876c5..b5926f9ca6 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -19,7 +19,7 @@ from synapse.api.constants import LoginType
 from synapse.api.errors import SynapseError, Codes
 from synapse.http.servlet import RestServlet
 
-from ._base import client_v2_pattern, parse_request_allow_empty
+from ._base import client_v2_pattern, parse_json_dict_from_request
 
 import logging
 import hmac
@@ -55,21 +55,55 @@ class RegisterRestServlet(RestServlet):
     @defer.inlineCallbacks
     def on_POST(self, request):
         yield run_on_reactor()
-
-        body = parse_request_allow_empty(request)
-        if 'password' not in body:
-            raise SynapseError(400, "", Codes.MISSING_PARAM)
-
+        body = parse_json_dict_from_request(request)
+
+        # we do basic sanity checks here because the auth layer will store these
+        # in sessions. Pull out the username/password provided to us.
+        desired_password = None
+        if 'password' in body:
+            if (not isinstance(body['password'], basestring) or
+                    len(body['password']) > 512):
+                raise SynapseError(400, "Invalid password")
+            desired_password = body["password"]
+
+        desired_username = None
         if 'username' in body:
+            if (not isinstance(body['username'], basestring) or
+                    len(body['username']) > 512):
+                raise SynapseError(400, "Invalid username")
             desired_username = body['username']
-            yield self.registration_handler.check_username(desired_username)
 
-        is_using_shared_secret = False
-        is_application_server = False
-
-        service = None
+        appservice = None
         if 'access_token' in request.args:
-            service = yield self.auth.get_appservice_by_req(request)
+            appservice = yield self.auth.get_appservice_by_req(request)
+
+        # fork off as soon as possible for ASes and shared secret auth which
+        # have completely different registration flows to normal users
+
+        # == Application Service Registration ==
+        if appservice:
+            result = yield self._do_appservice_registration(
+                desired_username, request.args["access_token"][0]
+            )
+            defer.returnValue((200, result))  # we throw for non 200 responses
+            return
+
+        # == Shared Secret Registration == (e.g. create new user scripts)
+        if 'mac' in body:
+            # FIXME: Should we really be determining if this is shared secret
+            # auth based purely on the 'mac' key?
+            result = yield self._do_shared_secret_registration(
+                desired_username, desired_password, body["mac"]
+            )
+            defer.returnValue((200, result))  # we throw for non 200 responses
+            return
+
+        # == Normal User Registration == (everyone else)
+        if self.hs.config.disable_registration:
+            raise SynapseError(403, "Registration has been disabled")
+
+        if desired_username is not None:
+            yield self.registration_handler.check_username(desired_username)
 
         if self.hs.config.enable_registration_captcha:
             flows = [
@@ -82,39 +116,20 @@ class RegisterRestServlet(RestServlet):
                 [LoginType.EMAIL_IDENTITY]
             ]
 
-        result = None
-        if service:
-            is_application_server = True
-            params = body
-        elif 'mac' in body:
-            # Check registration-specific shared secret auth
-            if 'username' not in body:
-                raise SynapseError(400, "", Codes.MISSING_PARAM)
-            self._check_shared_secret_auth(
-                body['username'], body['mac']
-            )
-            is_using_shared_secret = True
-            params = body
-        else:
-            authed, result, params = yield self.auth_handler.check_auth(
-                flows, body, self.hs.get_ip_from_request(request)
-            )
-
-            if not authed:
-                defer.returnValue((401, result))
-
-        can_register = (
-            not self.hs.config.disable_registration
-            or is_application_server
-            or is_using_shared_secret
+        authed, result, params = yield self.auth_handler.check_auth(
+            flows, body, self.hs.get_ip_from_request(request)
         )
-        if not can_register:
-            raise SynapseError(403, "Registration has been disabled")
 
+        if not authed:
+            defer.returnValue((401, result))
+            return
+
+        # NB: This may be from the auth handler and NOT from the POST
         if 'password' not in params:
-            raise SynapseError(400, "", Codes.MISSING_PARAM)
-        desired_username = params['username'] if 'username' in params else None
-        new_password = params['password']
+            raise SynapseError(400, "Missing password.", Codes.MISSING_PARAM)
+
+        desired_username = params.get("username", None)
+        new_password = params.get("password", None)
 
         (user_id, token) = yield self.registration_handler.register(
             localpart=desired_username,
@@ -147,18 +162,21 @@ class RegisterRestServlet(RestServlet):
             else:
                 logger.info("bind_email not specified: not binding email")
 
-        result = {
-            "user_id": user_id,
-            "access_token": token,
-            "home_server": self.hs.hostname,
-        }
-
+        result = self._create_registration_details(user_id, token)
         defer.returnValue((200, result))
 
     def on_OPTIONS(self, _):
         return 200, {}
 
-    def _check_shared_secret_auth(self, username, mac):
+    @defer.inlineCallbacks
+    def _do_appservice_registration(self, username, as_token):
+        (user_id, token) = yield self.registration_handler.appservice_register(
+            username, as_token
+        )
+        defer.returnValue(self._create_registration_details(user_id, token))
+
+    @defer.inlineCallbacks
+    def _do_shared_secret_registration(self, username, password, mac):
         if not self.hs.config.registration_shared_secret:
             raise SynapseError(400, "Shared secret registration is not enabled")
 
@@ -174,13 +192,23 @@ class RegisterRestServlet(RestServlet):
             digestmod=sha1,
         ).hexdigest()
 
-        if compare_digest(want_mac, got_mac):
-            return True
-        else:
+        if not compare_digest(want_mac, got_mac):
             raise SynapseError(
                 403, "HMAC incorrect",
             )
 
+        (user_id, token) = yield self.registration_handler.register(
+            localpart=username, password=password
+        )
+        defer.returnValue(self._create_registration_details(user_id, token))
+
+    def _create_registration_details(self, user_id, token):
+        return {
+            "user_id": user_id,
+            "access_token": token,
+            "home_server": self.hs.hostname,
+        }
+
 
 def register_servlets(hs, http_server):
     RegisterRestServlet(hs).register(http_server)