summary refs log tree commit diff
diff options
context:
space:
mode:
authorRichard van der Hoff <github@rvanderhoff.org.uk>2017-12-05 09:40:38 +0000
committerGitHub <noreply@github.com>2017-12-05 09:40:38 +0000
commitaa6ecf09846951bd822970ba2b6d7d38abeec0ff (patch)
treef675cc0f589fb974b529024ca3403be38f03a628
parentMerge pull request #2716 from matrix-org/rav/federation_client_post (diff)
parentRefactor UI auth implementation (diff)
downloadsynapse-aa6ecf09846951bd822970ba2b6d7d38abeec0ff.tar.xz
Merge pull request #2727 from matrix-org/rav/refactor_ui_auth_return
Refactor UI auth implementation
-rw-r--r--synapse/api/errors.py16
-rw-r--r--synapse/handlers/auth.py46
-rw-r--r--synapse/rest/client/v2_alpha/_base.py41
-rw-r--r--synapse/rest/client/v2_alpha/account.py14
-rw-r--r--synapse/rest/client/v2_alpha/devices.py14
-rw-r--r--synapse/rest/client/v2_alpha/register.py9
-rw-r--r--tests/rest/client/v2_alpha/test_register.py11
7 files changed, 103 insertions, 48 deletions
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index d0dfa959dc..79b35b3e7c 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -140,6 +140,22 @@ class RegistrationError(SynapseError):
     pass
 
 
+class InteractiveAuthIncompleteError(Exception):
+    """An error raised when UI auth is not yet complete
+
+    (This indicates we should return a 401 with 'result' as the body)
+
+    Attributes:
+        result (dict): the server response to the request, which should be
+            passed back to the client
+    """
+    def __init__(self, result):
+        super(InteractiveAuthIncompleteError, self).__init__(
+            "Interactive auth not yet complete",
+        )
+        self.result = result
+
+
 class UnrecognizedRequestError(SynapseError):
     """An error indicating we don't understand the request you're trying to make"""
     def __init__(self, *args, **kwargs):
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 2f30f183ce..28c80608a7 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -17,7 +17,10 @@ from twisted.internet import defer
 
 from ._base import BaseHandler
 from synapse.api.constants import LoginType
-from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError
+from synapse.api.errors import (
+    AuthError, Codes, InteractiveAuthIncompleteError, LoginError, StoreError,
+    SynapseError,
+)
 from synapse.module_api import ModuleApi
 from synapse.types import UserID
 from synapse.util.async import run_on_reactor
@@ -95,26 +98,36 @@ class AuthHandler(BaseHandler):
         session with a map, which maps each auth-type (str) to the relevant
         identity authenticated by that auth-type (mostly str, but for captcha, bool).
 
+        If no auth flows have been completed successfully, raises an
+        InteractiveAuthIncompleteError. To handle this, you can use
+        synapse.rest.client.v2_alpha._base.interactive_auth_handler as a
+        decorator.
+
         Args:
             flows (list): A list of login flows. Each flow is an ordered list of
                           strings representing auth-types. At least one full
                           flow must be completed in order for auth to be successful.
+
             clientdict: The dictionary from the client root level, not the
                         'auth' key: this method prompts for auth if none is sent.
+
             clientip (str): The IP address of the client.
+
         Returns:
-            A tuple of (authed, dict, dict, session_id) where authed is true if
-            the client has successfully completed an auth flow. If it is true
-            the first dict contains the authenticated credentials of each stage.
+            defer.Deferred[dict, dict, str]: a deferred tuple of
+                (creds, params, session_id).
 
-            If authed is false, the first dictionary is the server response to
-            the login request and should be passed back to the client.
+                'creds' contains the authenticated credentials of each stage.
 
-            In either case, the second dict contains the parameters for this
-            request (which may have been given only in a previous call).
+                'params' contains the parameters for this request (which may
+                have been given only in a previous call).
 
-            session_id is the ID of this session, either passed in by the client
-            or assigned by the call to check_auth
+                'session_id' is the ID of this session, either passed in by the
+                client or assigned by this call
+
+        Raises:
+            InteractiveAuthIncompleteError if the client has not yet completed
+                all the stages in any of the permitted flows.
         """
 
         authdict = None
@@ -142,11 +155,8 @@ class AuthHandler(BaseHandler):
             clientdict = session['clientdict']
 
         if not authdict:
-            defer.returnValue(
-                (
-                    False, self._auth_dict_for_flows(flows, session),
-                    clientdict, session['id']
-                )
+            raise InteractiveAuthIncompleteError(
+                self._auth_dict_for_flows(flows, session),
             )
 
         if 'creds' not in session:
@@ -190,12 +200,14 @@ class AuthHandler(BaseHandler):
                     "Auth completed with creds: %r. Client dict has keys: %r",
                     creds, clientdict.keys()
                 )
-                defer.returnValue((True, creds, clientdict, session['id']))
+                defer.returnValue((creds, clientdict, session['id']))
 
         ret = self._auth_dict_for_flows(flows, session)
         ret['completed'] = creds.keys()
         ret.update(errordict)
-        defer.returnValue((False, ret, clientdict, session['id']))
+        raise InteractiveAuthIncompleteError(
+            ret,
+        )
 
     @defer.inlineCallbacks
     def add_oob_auth(self, stagetype, authdict, clientip):
diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py
index 1f5bc24cc3..77434937ff 100644
--- a/synapse/rest/client/v2_alpha/_base.py
+++ b/synapse/rest/client/v2_alpha/_base.py
@@ -15,12 +15,13 @@
 
 """This module contains base REST classes for constructing client v1 servlets.
 """
-
-from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
+import logging
 import re
 
-import logging
+from twisted.internet import defer
 
+from synapse.api.errors import InteractiveAuthIncompleteError
+from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
 
 logger = logging.getLogger(__name__)
 
@@ -57,3 +58,37 @@ def set_timeline_upper_limit(filter_json, filter_timeline_limit):
         filter_json['room']['timeline']["limit"] = min(
             filter_json['room']['timeline']['limit'],
             filter_timeline_limit)
+
+
+def interactive_auth_handler(orig):
+    """Wraps an on_POST method to handle InteractiveAuthIncompleteErrors
+
+    Takes a on_POST method which returns a deferred (errcode, body) response
+    and adds exception handling to turn a InteractiveAuthIncompleteError into
+    a 401 response.
+
+    Normal usage is:
+
+    @interactive_auth_handler
+    @defer.inlineCallbacks
+    def on_POST(self, request):
+        # ...
+        yield self.auth_handler.check_auth
+            """
+    def wrapped(*args, **kwargs):
+        res = defer.maybeDeferred(orig, *args, **kwargs)
+        res.addErrback(_catch_incomplete_interactive_auth)
+        return res
+    return wrapped
+
+
+def _catch_incomplete_interactive_auth(f):
+    """helper for interactive_auth_handler
+
+    Catches InteractiveAuthIncompleteErrors and turns them into 401 responses
+
+    Args:
+        f (failure.Failure):
+    """
+    f.trap(InteractiveAuthIncompleteError)
+    return 401, f.value.result
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index c26ce63bcf..0d59a93222 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -26,7 +26,7 @@ from synapse.http.servlet import (
 )
 from synapse.util.async import run_on_reactor
 from synapse.util.msisdn import phone_number_to_msisdn
-from ._base import client_v2_patterns
+from ._base import client_v2_patterns, interactive_auth_handler
 
 logger = logging.getLogger(__name__)
 
@@ -100,21 +100,19 @@ class PasswordRestServlet(RestServlet):
         self.datastore = self.hs.get_datastore()
         self._set_password_handler = hs.get_set_password_handler()
 
+    @interactive_auth_handler
     @defer.inlineCallbacks
     def on_POST(self, request):
         yield run_on_reactor()
 
         body = parse_json_object_from_request(request)
 
-        authed, result, params, _ = yield self.auth_handler.check_auth([
+        result, params, _ = yield self.auth_handler.check_auth([
             [LoginType.PASSWORD],
             [LoginType.EMAIL_IDENTITY],
             [LoginType.MSISDN],
         ], body, self.hs.get_ip_from_request(request))
 
-        if not authed:
-            defer.returnValue((401, result))
-
         user_id = None
         requester = None
 
@@ -168,6 +166,7 @@ class DeactivateAccountRestServlet(RestServlet):
         self.auth_handler = hs.get_auth_handler()
         self._deactivate_account_handler = hs.get_deactivate_account_handler()
 
+    @interactive_auth_handler
     @defer.inlineCallbacks
     def on_POST(self, request):
         body = parse_json_object_from_request(request)
@@ -186,13 +185,10 @@ class DeactivateAccountRestServlet(RestServlet):
             )
             defer.returnValue((200, {}))
 
-        authed, result, params, _ = yield self.auth_handler.check_auth([
+        result, params, _ = yield self.auth_handler.check_auth([
             [LoginType.PASSWORD],
         ], body, self.hs.get_ip_from_request(request))
 
-        if not authed:
-            defer.returnValue((401, result))
-
         if LoginType.PASSWORD in result:
             user_id = result[LoginType.PASSWORD]
             # if using password, they should also be logged in
diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py
index 5321e5abbb..909f9c087b 100644
--- a/synapse/rest/client/v2_alpha/devices.py
+++ b/synapse/rest/client/v2_alpha/devices.py
@@ -19,7 +19,7 @@ from twisted.internet import defer
 
 from synapse.api import constants, errors
 from synapse.http import servlet
-from ._base import client_v2_patterns
+from ._base import client_v2_patterns, interactive_auth_handler
 
 logger = logging.getLogger(__name__)
 
@@ -60,6 +60,7 @@ class DeleteDevicesRestServlet(servlet.RestServlet):
         self.device_handler = hs.get_device_handler()
         self.auth_handler = hs.get_auth_handler()
 
+    @interactive_auth_handler
     @defer.inlineCallbacks
     def on_POST(self, request):
         try:
@@ -77,13 +78,10 @@ class DeleteDevicesRestServlet(servlet.RestServlet):
                 400, "No devices supplied", errcode=errors.Codes.MISSING_PARAM
             )
 
-        authed, result, params, _ = yield self.auth_handler.check_auth([
+        result, params, _ = yield self.auth_handler.check_auth([
             [constants.LoginType.PASSWORD],
         ], body, self.hs.get_ip_from_request(request))
 
-        if not authed:
-            defer.returnValue((401, result))
-
         requester = yield self.auth.get_user_by_req(request)
         yield self.device_handler.delete_devices(
             requester.user.to_string(),
@@ -115,6 +113,7 @@ class DeviceRestServlet(servlet.RestServlet):
         )
         defer.returnValue((200, device))
 
+    @interactive_auth_handler
     @defer.inlineCallbacks
     def on_DELETE(self, request, device_id):
         requester = yield self.auth.get_user_by_req(request)
@@ -130,13 +129,10 @@ class DeviceRestServlet(servlet.RestServlet):
             else:
                 raise
 
-        authed, result, params, _ = yield self.auth_handler.check_auth([
+        result, params, _ = yield self.auth_handler.check_auth([
             [constants.LoginType.PASSWORD],
         ], body, self.hs.get_ip_from_request(request))
 
-        if not authed:
-            defer.returnValue((401, result))
-
         # check that the UI auth matched the access token
         user_id = result[constants.LoginType.PASSWORD]
         if user_id != requester.user.to_string():
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 9e2f7308ce..e9d88a8895 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -27,7 +27,7 @@ from synapse.http.servlet import (
 )
 from synapse.util.msisdn import phone_number_to_msisdn
 
-from ._base import client_v2_patterns
+from ._base import client_v2_patterns, interactive_auth_handler
 
 import logging
 import hmac
@@ -176,6 +176,7 @@ class RegisterRestServlet(RestServlet):
         self.device_handler = hs.get_device_handler()
         self.macaroon_gen = hs.get_macaroon_generator()
 
+    @interactive_auth_handler
     @defer.inlineCallbacks
     def on_POST(self, request):
         yield run_on_reactor()
@@ -325,14 +326,10 @@ class RegisterRestServlet(RestServlet):
                     [LoginType.MSISDN, LoginType.EMAIL_IDENTITY],
                 ])
 
-        authed, auth_result, params, session_id = yield self.auth_handler.check_auth(
+        auth_result, params, session_id = yield self.auth_handler.check_auth(
             flows, body, self.hs.get_ip_from_request(request)
         )
 
-        if not authed:
-            defer.returnValue((401, auth_result))
-            return
-
         if registered_user_id is not None:
             logger.info(
                 "Already registered user ID %r for this session",
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 821c735528..096f771bea 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -1,5 +1,7 @@
+from twisted.python import failure
+
 from synapse.rest.client.v2_alpha.register import RegisterRestServlet
-from synapse.api.errors import SynapseError
+from synapse.api.errors import SynapseError, InteractiveAuthIncompleteError
 from twisted.internet import defer
 from mock import Mock
 from tests import unittest
@@ -24,7 +26,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
             side_effect=lambda x: self.appservice)
         )
 
-        self.auth_result = (False, None, None, None)
+        self.auth_result = failure.Failure(InteractiveAuthIncompleteError(None))
         self.auth_handler = Mock(
             check_auth=Mock(side_effect=lambda x, y, z: self.auth_result),
             get_session_data=Mock(return_value=None)
@@ -86,6 +88,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
         self.request.args = {
             "access_token": "i_am_an_app_service"
         }
+
         self.request_data = json.dumps({
             "username": "kermit"
         })
@@ -120,7 +123,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
             "device_id": device_id,
         })
         self.registration_handler.check_username = Mock(return_value=True)
-        self.auth_result = (True, None, {
+        self.auth_result = (None, {
             "username": "kermit",
             "password": "monkey"
         }, None)
@@ -150,7 +153,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
             "password": "monkey"
         })
         self.registration_handler.check_username = Mock(return_value=True)
-        self.auth_result = (True, None, {
+        self.auth_result = (None, {
             "username": "kermit",
             "password": "monkey"
         }, None)