summary refs log tree commit diff
diff options
context:
space:
mode:
-rwxr-xr-xscripts-dev/federation_client.py42
-rw-r--r--synapse/api/errors.py16
-rw-r--r--synapse/handlers/auth.py164
-rw-r--r--synapse/rest/client/v2_alpha/_base.py41
-rw-r--r--synapse/rest/client/v2_alpha/account.py117
-rw-r--r--synapse/rest/client/v2_alpha/devices.py36
-rw-r--r--synapse/rest/client/v2_alpha/register.py9
-rw-r--r--tests/rest/client/v2_alpha/test_register.py11
8 files changed, 289 insertions, 147 deletions
diff --git a/scripts-dev/federation_client.py b/scripts-dev/federation_client.py
index 82a90ef6fa..3b28417376 100755
--- a/scripts-dev/federation_client.py
+++ b/scripts-dev/federation_client.py
@@ -123,15 +123,25 @@ def lookup(destination, path):
         except:
             return "https://%s:%d%s" % (destination, 8448, path)
 
-def get_json(origin_name, origin_key, destination, path):
-    request_json = {
-        "method": "GET",
+
+def request_json(method, origin_name, origin_key, destination, path, content):
+    if method is None:
+        if content is None:
+            method = "GET"
+        else:
+            method = "POST"
+
+    json_to_sign = {
+        "method": method,
         "uri": path,
         "origin": origin_name,
         "destination": destination,
     }
 
-    signed_json = sign_json(request_json, origin_key, origin_name)
+    if content is not None:
+        json_to_sign["content"] = json.loads(content)
+
+    signed_json = sign_json(json_to_sign, origin_key, origin_name)
 
     authorization_headers = []
 
@@ -145,10 +155,12 @@ def get_json(origin_name, origin_key, destination, path):
     dest = lookup(destination, path)
     print ("Requesting %s" % dest, file=sys.stderr)
 
-    result = requests.get(
-        dest,
+    result = requests.request(
+        method=method,
+        url=dest,
         headers={"Authorization": authorization_headers[0]},
         verify=False,
+        data=content,
     )
     sys.stderr.write("Status Code: %d\n" % (result.status_code,))
     return result.json()
@@ -187,6 +199,17 @@ def main():
     )
 
     parser.add_argument(
+        "-X", "--method",
+        help="HTTP method to use for the request. Defaults to GET if --data is"
+             "unspecified, POST if it is."
+    )
+
+    parser.add_argument(
+        "--body",
+        help="Data to send as the body of the HTTP request"
+    )
+
+    parser.add_argument(
         "path",
         help="request path. We will add '/_matrix/federation/v1/' to this."
     )
@@ -199,8 +222,11 @@ def main():
     with open(args.signing_key_path) as f:
         key = read_signing_keys(f)[0]
 
-    result = get_json(
-        args.server_name, key, args.destination, "/_matrix/federation/v1/" + args.path
+    result = request_json(
+        args.method,
+        args.server_name, key, args.destination,
+        "/_matrix/federation/v1/" + args.path,
+        content=args.body,
     )
 
     json.dump(result, sys.stdout)
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index d0dfa959dc..79b35b3e7c 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -140,6 +140,22 @@ class RegistrationError(SynapseError):
     pass
 
 
+class InteractiveAuthIncompleteError(Exception):
+    """An error raised when UI auth is not yet complete
+
+    (This indicates we should return a 401 with 'result' as the body)
+
+    Attributes:
+        result (dict): the server response to the request, which should be
+            passed back to the client
+    """
+    def __init__(self, result):
+        super(InteractiveAuthIncompleteError, self).__init__(
+            "Interactive auth not yet complete",
+        )
+        self.result = result
+
+
 class UnrecognizedRequestError(SynapseError):
     """An error indicating we don't understand the request you're trying to make"""
     def __init__(self, *args, **kwargs):
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 2f30f183ce..573c9db8a1 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -17,7 +17,10 @@ from twisted.internet import defer
 
 from ._base import BaseHandler
 from synapse.api.constants import LoginType
-from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError
+from synapse.api.errors import (
+    AuthError, Codes, InteractiveAuthIncompleteError, LoginError, StoreError,
+    SynapseError,
+)
 from synapse.module_api import ModuleApi
 from synapse.types import UserID
 from synapse.util.async import run_on_reactor
@@ -46,7 +49,6 @@ class AuthHandler(BaseHandler):
         """
         super(AuthHandler, self).__init__(hs)
         self.checkers = {
-            LoginType.PASSWORD: self._check_password_auth,
             LoginType.RECAPTCHA: self._check_recaptcha,
             LoginType.EMAIL_IDENTITY: self._check_email_identity,
             LoginType.MSISDN: self._check_msisdn,
@@ -75,15 +77,76 @@ class AuthHandler(BaseHandler):
         self.macaroon_gen = hs.get_macaroon_generator()
         self._password_enabled = hs.config.password_enabled
 
-        login_types = set()
+        # we keep this as a list despite the O(N^2) implication so that we can
+        # keep PASSWORD first and avoid confusing clients which pick the first
+        # type in the list. (NB that the spec doesn't require us to do so and
+        # clients which favour types that they don't understand over those that
+        # they do are technically broken)
+        login_types = []
         if self._password_enabled:
-            login_types.add(LoginType.PASSWORD)
+            login_types.append(LoginType.PASSWORD)
         for provider in self.password_providers:
             if hasattr(provider, "get_supported_login_types"):
-                login_types.update(
-                    provider.get_supported_login_types().keys()
-                )
-        self._supported_login_types = frozenset(login_types)
+                for t in provider.get_supported_login_types().keys():
+                    if t not in login_types:
+                        login_types.append(t)
+        self._supported_login_types = login_types
+
+    @defer.inlineCallbacks
+    def validate_user_via_ui_auth(self, requester, request_body, clientip):
+        """
+        Checks that the user is who they claim to be, via a UI auth.
+
+        This is used for things like device deletion and password reset where
+        the user already has a valid access token, but we want to double-check
+        that it isn't stolen by re-authenticating them.
+
+        Args:
+            requester (Requester): The user, as given by the access token
+
+            request_body (dict): The body of the request sent by the client
+
+            clientip (str): The IP address of the client.
+
+        Returns:
+            defer.Deferred[dict]: the parameters for this request (which may
+                have been given only in a previous call).
+
+        Raises:
+            InteractiveAuthIncompleteError if the client has not yet completed
+                any of the permitted login flows
+
+            AuthError if the client has completed a login flow, and it gives
+                a different user to `requester`
+        """
+
+        # build a list of supported flows
+        flows = [
+            [login_type] for login_type in self._supported_login_types
+        ]
+
+        result, params, _ = yield self.check_auth(
+            flows, request_body, clientip,
+        )
+
+        # find the completed login type
+        for login_type in self._supported_login_types:
+            if login_type not in result:
+                continue
+
+            user_id = result[login_type]
+            break
+        else:
+            # this can't happen
+            raise Exception(
+                "check_auth returned True but no successful login type",
+            )
+
+        # check that the UI auth matched the access token
+        if user_id != requester.user.to_string():
+            raise AuthError(403, "Invalid auth")
+
+        defer.returnValue(params)
 
     @defer.inlineCallbacks
     def check_auth(self, flows, clientdict, clientip):
@@ -95,26 +158,36 @@ class AuthHandler(BaseHandler):
         session with a map, which maps each auth-type (str) to the relevant
         identity authenticated by that auth-type (mostly str, but for captcha, bool).
 
+        If no auth flows have been completed successfully, raises an
+        InteractiveAuthIncompleteError. To handle this, you can use
+        synapse.rest.client.v2_alpha._base.interactive_auth_handler as a
+        decorator.
+
         Args:
             flows (list): A list of login flows. Each flow is an ordered list of
                           strings representing auth-types. At least one full
                           flow must be completed in order for auth to be successful.
+
             clientdict: The dictionary from the client root level, not the
                         'auth' key: this method prompts for auth if none is sent.
+
             clientip (str): The IP address of the client.
+
         Returns:
-            A tuple of (authed, dict, dict, session_id) where authed is true if
-            the client has successfully completed an auth flow. If it is true
-            the first dict contains the authenticated credentials of each stage.
+            defer.Deferred[dict, dict, str]: a deferred tuple of
+                (creds, params, session_id).
+
+                'creds' contains the authenticated credentials of each stage.
 
-            If authed is false, the first dictionary is the server response to
-            the login request and should be passed back to the client.
+                'params' contains the parameters for this request (which may
+                have been given only in a previous call).
 
-            In either case, the second dict contains the parameters for this
-            request (which may have been given only in a previous call).
+                'session_id' is the ID of this session, either passed in by the
+                client or assigned by this call
 
-            session_id is the ID of this session, either passed in by the client
-            or assigned by the call to check_auth
+        Raises:
+            InteractiveAuthIncompleteError if the client has not yet completed
+                all the stages in any of the permitted flows.
         """
 
         authdict = None
@@ -142,11 +215,8 @@ class AuthHandler(BaseHandler):
             clientdict = session['clientdict']
 
         if not authdict:
-            defer.returnValue(
-                (
-                    False, self._auth_dict_for_flows(flows, session),
-                    clientdict, session['id']
-                )
+            raise InteractiveAuthIncompleteError(
+                self._auth_dict_for_flows(flows, session),
             )
 
         if 'creds' not in session:
@@ -157,14 +227,12 @@ class AuthHandler(BaseHandler):
         errordict = {}
         if 'type' in authdict:
             login_type = authdict['type']
-            if login_type not in self.checkers:
-                raise LoginError(400, "", Codes.UNRECOGNIZED)
             try:
-                result = yield self.checkers[login_type](authdict, clientip)
+                result = yield self._check_auth_dict(authdict, clientip)
                 if result:
                     creds[login_type] = result
                     self._save_session(session)
-            except LoginError, e:
+            except LoginError as e:
                 if login_type == LoginType.EMAIL_IDENTITY:
                     # riot used to have a bug where it would request a new
                     # validation token (thus sending a new email) each time it
@@ -173,7 +241,7 @@ class AuthHandler(BaseHandler):
                     #
                     # Grandfather in the old behaviour for now to avoid
                     # breaking old riot deployments.
-                    raise e
+                    raise
 
                 # this step failed. Merge the error dict into the response
                 # so that the client can have another go.
@@ -190,12 +258,14 @@ class AuthHandler(BaseHandler):
                     "Auth completed with creds: %r. Client dict has keys: %r",
                     creds, clientdict.keys()
                 )
-                defer.returnValue((True, creds, clientdict, session['id']))
+                defer.returnValue((creds, clientdict, session['id']))
 
         ret = self._auth_dict_for_flows(flows, session)
         ret['completed'] = creds.keys()
         ret.update(errordict)
-        defer.returnValue((False, ret, clientdict, session['id']))
+        raise InteractiveAuthIncompleteError(
+            ret,
+        )
 
     @defer.inlineCallbacks
     def add_oob_auth(self, stagetype, authdict, clientip):
@@ -268,17 +338,35 @@ class AuthHandler(BaseHandler):
         return sess.setdefault('serverdict', {}).get(key, default)
 
     @defer.inlineCallbacks
-    def _check_password_auth(self, authdict, _):
-        if "user" not in authdict or "password" not in authdict:
-            raise LoginError(400, "", Codes.MISSING_PARAM)
+    def _check_auth_dict(self, authdict, clientip):
+        """Attempt to validate the auth dict provided by a client
+
+        Args:
+            authdict (object): auth dict provided by the client
+            clientip (str): IP address of the client
+
+        Returns:
+            Deferred: result of the stage verification.
+
+        Raises:
+            StoreError if there was a problem accessing the database
+            SynapseError if there was a problem with the request
+            LoginError if there was an authentication problem.
+        """
+        login_type = authdict['type']
+        checker = self.checkers.get(login_type)
+        if checker is not None:
+            res = yield checker(authdict, clientip)
+            defer.returnValue(res)
+
+        # build a v1-login-style dict out of the authdict and fall back to the
+        # v1 code
+        user_id = authdict.get("user")
 
-        user_id = authdict["user"]
-        password = authdict["password"]
+        if user_id is None:
+            raise SynapseError(400, "", Codes.MISSING_PARAM)
 
-        (canonical_id, callback) = yield self.validate_login(user_id, {
-            "type": LoginType.PASSWORD,
-            "password": password,
-        })
+        (canonical_id, callback) = yield self.validate_login(user_id, authdict)
         defer.returnValue(canonical_id)
 
     @defer.inlineCallbacks
diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py
index 1f5bc24cc3..77434937ff 100644
--- a/synapse/rest/client/v2_alpha/_base.py
+++ b/synapse/rest/client/v2_alpha/_base.py
@@ -15,12 +15,13 @@
 
 """This module contains base REST classes for constructing client v1 servlets.
 """
-
-from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
+import logging
 import re
 
-import logging
+from twisted.internet import defer
 
+from synapse.api.errors import InteractiveAuthIncompleteError
+from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
 
 logger = logging.getLogger(__name__)
 
@@ -57,3 +58,37 @@ def set_timeline_upper_limit(filter_json, filter_timeline_limit):
         filter_json['room']['timeline']["limit"] = min(
             filter_json['room']['timeline']['limit'],
             filter_timeline_limit)
+
+
+def interactive_auth_handler(orig):
+    """Wraps an on_POST method to handle InteractiveAuthIncompleteErrors
+
+    Takes a on_POST method which returns a deferred (errcode, body) response
+    and adds exception handling to turn a InteractiveAuthIncompleteError into
+    a 401 response.
+
+    Normal usage is:
+
+    @interactive_auth_handler
+    @defer.inlineCallbacks
+    def on_POST(self, request):
+        # ...
+        yield self.auth_handler.check_auth
+            """
+    def wrapped(*args, **kwargs):
+        res = defer.maybeDeferred(orig, *args, **kwargs)
+        res.addErrback(_catch_incomplete_interactive_auth)
+        return res
+    return wrapped
+
+
+def _catch_incomplete_interactive_auth(f):
+    """helper for interactive_auth_handler
+
+    Catches InteractiveAuthIncompleteErrors and turns them into 401 responses
+
+    Args:
+        f (failure.Failure):
+    """
+    f.trap(InteractiveAuthIncompleteError)
+    return 401, f.value.result
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index c26ce63bcf..385a3ad2ec 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -19,14 +19,14 @@ from twisted.internet import defer
 
 from synapse.api.auth import has_access_token
 from synapse.api.constants import LoginType
-from synapse.api.errors import Codes, LoginError, SynapseError
+from synapse.api.errors import Codes, SynapseError
 from synapse.http.servlet import (
     RestServlet, assert_params_in_request,
     parse_json_object_from_request,
 )
 from synapse.util.async import run_on_reactor
 from synapse.util.msisdn import phone_number_to_msisdn
-from ._base import client_v2_patterns
+from ._base import client_v2_patterns, interactive_auth_handler
 
 logger = logging.getLogger(__name__)
 
@@ -100,49 +100,53 @@ class PasswordRestServlet(RestServlet):
         self.datastore = self.hs.get_datastore()
         self._set_password_handler = hs.get_set_password_handler()
 
+    @interactive_auth_handler
     @defer.inlineCallbacks
     def on_POST(self, request):
-        yield run_on_reactor()
-
         body = parse_json_object_from_request(request)
 
-        authed, result, params, _ = yield self.auth_handler.check_auth([
-            [LoginType.PASSWORD],
-            [LoginType.EMAIL_IDENTITY],
-            [LoginType.MSISDN],
-        ], body, self.hs.get_ip_from_request(request))
-
-        if not authed:
-            defer.returnValue((401, result))
+        # there are two possibilities here. Either the user does not have an
+        # access token, and needs to do a password reset; or they have one and
+        # need to validate their identity.
+        #
+        # In the first case, we offer a couple of means of identifying
+        # themselves (email and msisdn, though it's unclear if msisdn actually
+        # works).
+        #
+        # In the second case, we require a password to confirm their identity.
 
-        user_id = None
-        requester = None
-
-        if LoginType.PASSWORD in result:
-            # if using password, they should also be logged in
+        if has_access_token(request):
             requester = yield self.auth.get_user_by_req(request)
-            user_id = requester.user.to_string()
-            if user_id != result[LoginType.PASSWORD]:
-                raise LoginError(400, "", Codes.UNKNOWN)
-        elif LoginType.EMAIL_IDENTITY in result:
-            threepid = result[LoginType.EMAIL_IDENTITY]
-            if 'medium' not in threepid or 'address' not in threepid:
-                raise SynapseError(500, "Malformed threepid")
-            if threepid['medium'] == 'email':
-                # For emails, transform the address to lowercase.
-                # We store all email addreses as lowercase in the DB.
-                # (See add_threepid in synapse/handlers/auth.py)
-                threepid['address'] = threepid['address'].lower()
-            # if using email, we must know about the email they're authing with!
-            threepid_user_id = yield self.datastore.get_user_id_by_threepid(
-                threepid['medium'], threepid['address']
+            params = yield self.auth_handler.validate_user_via_ui_auth(
+                requester, body, self.hs.get_ip_from_request(request),
             )
-            if not threepid_user_id:
-                raise SynapseError(404, "Email address not found", Codes.NOT_FOUND)
-            user_id = threepid_user_id
+            user_id = requester.user.to_string()
         else:
-            logger.error("Auth succeeded but no known type!", result.keys())
-            raise SynapseError(500, "", Codes.UNKNOWN)
+            requester = None
+            result, params, _ = yield self.auth_handler.check_auth(
+                [[LoginType.EMAIL_IDENTITY], [LoginType.MSISDN]],
+                body, self.hs.get_ip_from_request(request),
+            )
+
+            if LoginType.EMAIL_IDENTITY in result:
+                threepid = result[LoginType.EMAIL_IDENTITY]
+                if 'medium' not in threepid or 'address' not in threepid:
+                    raise SynapseError(500, "Malformed threepid")
+                if threepid['medium'] == 'email':
+                    # For emails, transform the address to lowercase.
+                    # We store all email addreses as lowercase in the DB.
+                    # (See add_threepid in synapse/handlers/auth.py)
+                    threepid['address'] = threepid['address'].lower()
+                # if using email, we must know about the email they're authing with!
+                threepid_user_id = yield self.datastore.get_user_id_by_threepid(
+                    threepid['medium'], threepid['address']
+                )
+                if not threepid_user_id:
+                    raise SynapseError(404, "Email address not found", Codes.NOT_FOUND)
+                user_id = threepid_user_id
+            else:
+                logger.error("Auth succeeded but no known type!", result.keys())
+                raise SynapseError(500, "", Codes.UNKNOWN)
 
         if 'new_password' not in params:
             raise SynapseError(400, "", Codes.MISSING_PARAM)
@@ -168,47 +172,26 @@ class DeactivateAccountRestServlet(RestServlet):
         self.auth_handler = hs.get_auth_handler()
         self._deactivate_account_handler = hs.get_deactivate_account_handler()
 
+    @interactive_auth_handler
     @defer.inlineCallbacks
     def on_POST(self, request):
         body = parse_json_object_from_request(request)
 
-        # if the caller provides an access token, it ought to be valid.
-        requester = None
-        if has_access_token(request):
-            requester = yield self.auth.get_user_by_req(
-                request,
-            )  # type: synapse.types.Requester
+        requester = yield self.auth.get_user_by_req(request)
 
         # allow ASes to dectivate their own users
-        if requester and requester.app_service:
+        if requester.app_service:
             yield self._deactivate_account_handler.deactivate_account(
                 requester.user.to_string()
             )
             defer.returnValue((200, {}))
 
-        authed, result, params, _ = yield self.auth_handler.check_auth([
-            [LoginType.PASSWORD],
-        ], body, self.hs.get_ip_from_request(request))
-
-        if not authed:
-            defer.returnValue((401, result))
-
-        if LoginType.PASSWORD in result:
-            user_id = result[LoginType.PASSWORD]
-            # if using password, they should also be logged in
-            if requester is None:
-                raise SynapseError(
-                    400,
-                    "Deactivate account requires an access_token",
-                    errcode=Codes.MISSING_TOKEN
-                )
-            if requester.user.to_string() != user_id:
-                raise LoginError(400, "", Codes.UNKNOWN)
-        else:
-            logger.error("Auth succeeded but no known type!", result.keys())
-            raise SynapseError(500, "", Codes.UNKNOWN)
-
-        yield self._deactivate_account_handler.deactivate_account(user_id)
+        yield self.auth_handler.validate_user_via_ui_auth(
+            requester, body, self.hs.get_ip_from_request(request),
+        )
+        yield self._deactivate_account_handler.deactivate_account(
+            requester.user.to_string(),
+        )
         defer.returnValue((200, {}))
 
 
diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py
index 5321e5abbb..35d58b367a 100644
--- a/synapse/rest/client/v2_alpha/devices.py
+++ b/synapse/rest/client/v2_alpha/devices.py
@@ -17,9 +17,9 @@ import logging
 
 from twisted.internet import defer
 
-from synapse.api import constants, errors
+from synapse.api import errors
 from synapse.http import servlet
-from ._base import client_v2_patterns
+from ._base import client_v2_patterns, interactive_auth_handler
 
 logger = logging.getLogger(__name__)
 
@@ -60,8 +60,11 @@ class DeleteDevicesRestServlet(servlet.RestServlet):
         self.device_handler = hs.get_device_handler()
         self.auth_handler = hs.get_auth_handler()
 
+    @interactive_auth_handler
     @defer.inlineCallbacks
     def on_POST(self, request):
+        requester = yield self.auth.get_user_by_req(request)
+
         try:
             body = servlet.parse_json_object_from_request(request)
         except errors.SynapseError as e:
@@ -77,14 +80,10 @@ class DeleteDevicesRestServlet(servlet.RestServlet):
                 400, "No devices supplied", errcode=errors.Codes.MISSING_PARAM
             )
 
-        authed, result, params, _ = yield self.auth_handler.check_auth([
-            [constants.LoginType.PASSWORD],
-        ], body, self.hs.get_ip_from_request(request))
-
-        if not authed:
-            defer.returnValue((401, result))
+        yield self.auth_handler.validate_user_via_ui_auth(
+            requester, body, self.hs.get_ip_from_request(request),
+        )
 
-        requester = yield self.auth.get_user_by_req(request)
         yield self.device_handler.delete_devices(
             requester.user.to_string(),
             body['devices'],
@@ -115,6 +114,7 @@ class DeviceRestServlet(servlet.RestServlet):
         )
         defer.returnValue((200, device))
 
+    @interactive_auth_handler
     @defer.inlineCallbacks
     def on_DELETE(self, request, device_id):
         requester = yield self.auth.get_user_by_req(request)
@@ -130,19 +130,13 @@ class DeviceRestServlet(servlet.RestServlet):
             else:
                 raise
 
-        authed, result, params, _ = yield self.auth_handler.check_auth([
-            [constants.LoginType.PASSWORD],
-        ], body, self.hs.get_ip_from_request(request))
-
-        if not authed:
-            defer.returnValue((401, result))
-
-        # check that the UI auth matched the access token
-        user_id = result[constants.LoginType.PASSWORD]
-        if user_id != requester.user.to_string():
-            raise errors.AuthError(403, "Invalid auth")
+        yield self.auth_handler.validate_user_via_ui_auth(
+            requester, body, self.hs.get_ip_from_request(request),
+        )
 
-        yield self.device_handler.delete_device(user_id, device_id)
+        yield self.device_handler.delete_device(
+            requester.user.to_string(), device_id,
+        )
         defer.returnValue((200, {}))
 
     @defer.inlineCallbacks
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 9e2f7308ce..e9d88a8895 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -27,7 +27,7 @@ from synapse.http.servlet import (
 )
 from synapse.util.msisdn import phone_number_to_msisdn
 
-from ._base import client_v2_patterns
+from ._base import client_v2_patterns, interactive_auth_handler
 
 import logging
 import hmac
@@ -176,6 +176,7 @@ class RegisterRestServlet(RestServlet):
         self.device_handler = hs.get_device_handler()
         self.macaroon_gen = hs.get_macaroon_generator()
 
+    @interactive_auth_handler
     @defer.inlineCallbacks
     def on_POST(self, request):
         yield run_on_reactor()
@@ -325,14 +326,10 @@ class RegisterRestServlet(RestServlet):
                     [LoginType.MSISDN, LoginType.EMAIL_IDENTITY],
                 ])
 
-        authed, auth_result, params, session_id = yield self.auth_handler.check_auth(
+        auth_result, params, session_id = yield self.auth_handler.check_auth(
             flows, body, self.hs.get_ip_from_request(request)
         )
 
-        if not authed:
-            defer.returnValue((401, auth_result))
-            return
-
         if registered_user_id is not None:
             logger.info(
                 "Already registered user ID %r for this session",
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 821c735528..096f771bea 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -1,5 +1,7 @@
+from twisted.python import failure
+
 from synapse.rest.client.v2_alpha.register import RegisterRestServlet
-from synapse.api.errors import SynapseError
+from synapse.api.errors import SynapseError, InteractiveAuthIncompleteError
 from twisted.internet import defer
 from mock import Mock
 from tests import unittest
@@ -24,7 +26,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
             side_effect=lambda x: self.appservice)
         )
 
-        self.auth_result = (False, None, None, None)
+        self.auth_result = failure.Failure(InteractiveAuthIncompleteError(None))
         self.auth_handler = Mock(
             check_auth=Mock(side_effect=lambda x, y, z: self.auth_result),
             get_session_data=Mock(return_value=None)
@@ -86,6 +88,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
         self.request.args = {
             "access_token": "i_am_an_app_service"
         }
+
         self.request_data = json.dumps({
             "username": "kermit"
         })
@@ -120,7 +123,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
             "device_id": device_id,
         })
         self.registration_handler.check_username = Mock(return_value=True)
-        self.auth_result = (True, None, {
+        self.auth_result = (None, {
             "username": "kermit",
             "password": "monkey"
         }, None)
@@ -150,7 +153,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
             "password": "monkey"
         })
         self.registration_handler.check_username = Mock(return_value=True)
-        self.auth_result = (True, None, {
+        self.auth_result = (None, {
             "username": "kermit",
             "password": "monkey"
         }, None)