diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index d0dfa959dc..79b35b3e7c 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -140,6 +140,22 @@ class RegistrationError(SynapseError):
pass
+class InteractiveAuthIncompleteError(Exception):
+ """An error raised when UI auth is not yet complete
+
+ (This indicates we should return a 401 with 'result' as the body)
+
+ Attributes:
+ result (dict): the server response to the request, which should be
+ passed back to the client
+ """
+ def __init__(self, result):
+ super(InteractiveAuthIncompleteError, self).__init__(
+ "Interactive auth not yet complete",
+ )
+ self.result = result
+
+
class UnrecognizedRequestError(SynapseError):
"""An error indicating we don't understand the request you're trying to make"""
def __init__(self, *args, **kwargs):
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 2f30f183ce..28c80608a7 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -17,7 +17,10 @@ from twisted.internet import defer
from ._base import BaseHandler
from synapse.api.constants import LoginType
-from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError
+from synapse.api.errors import (
+ AuthError, Codes, InteractiveAuthIncompleteError, LoginError, StoreError,
+ SynapseError,
+)
from synapse.module_api import ModuleApi
from synapse.types import UserID
from synapse.util.async import run_on_reactor
@@ -95,26 +98,36 @@ class AuthHandler(BaseHandler):
session with a map, which maps each auth-type (str) to the relevant
identity authenticated by that auth-type (mostly str, but for captcha, bool).
+ If no auth flows have been completed successfully, raises an
+ InteractiveAuthIncompleteError. To handle this, you can use
+ synapse.rest.client.v2_alpha._base.interactive_auth_handler as a
+ decorator.
+
Args:
flows (list): A list of login flows. Each flow is an ordered list of
strings representing auth-types. At least one full
flow must be completed in order for auth to be successful.
+
clientdict: The dictionary from the client root level, not the
'auth' key: this method prompts for auth if none is sent.
+
clientip (str): The IP address of the client.
+
Returns:
- A tuple of (authed, dict, dict, session_id) where authed is true if
- the client has successfully completed an auth flow. If it is true
- the first dict contains the authenticated credentials of each stage.
+ defer.Deferred[dict, dict, str]: a deferred tuple of
+ (creds, params, session_id).
- If authed is false, the first dictionary is the server response to
- the login request and should be passed back to the client.
+ 'creds' contains the authenticated credentials of each stage.
- In either case, the second dict contains the parameters for this
- request (which may have been given only in a previous call).
+ 'params' contains the parameters for this request (which may
+ have been given only in a previous call).
- session_id is the ID of this session, either passed in by the client
- or assigned by the call to check_auth
+ 'session_id' is the ID of this session, either passed in by the
+ client or assigned by this call
+
+ Raises:
+ InteractiveAuthIncompleteError if the client has not yet completed
+ all the stages in any of the permitted flows.
"""
authdict = None
@@ -142,11 +155,8 @@ class AuthHandler(BaseHandler):
clientdict = session['clientdict']
if not authdict:
- defer.returnValue(
- (
- False, self._auth_dict_for_flows(flows, session),
- clientdict, session['id']
- )
+ raise InteractiveAuthIncompleteError(
+ self._auth_dict_for_flows(flows, session),
)
if 'creds' not in session:
@@ -190,12 +200,14 @@ class AuthHandler(BaseHandler):
"Auth completed with creds: %r. Client dict has keys: %r",
creds, clientdict.keys()
)
- defer.returnValue((True, creds, clientdict, session['id']))
+ defer.returnValue((creds, clientdict, session['id']))
ret = self._auth_dict_for_flows(flows, session)
ret['completed'] = creds.keys()
ret.update(errordict)
- defer.returnValue((False, ret, clientdict, session['id']))
+ raise InteractiveAuthIncompleteError(
+ ret,
+ )
@defer.inlineCallbacks
def add_oob_auth(self, stagetype, authdict, clientip):
diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py
index 1f5bc24cc3..77434937ff 100644
--- a/synapse/rest/client/v2_alpha/_base.py
+++ b/synapse/rest/client/v2_alpha/_base.py
@@ -15,12 +15,13 @@
"""This module contains base REST classes for constructing client v1 servlets.
"""
-
-from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
+import logging
import re
-import logging
+from twisted.internet import defer
+from synapse.api.errors import InteractiveAuthIncompleteError
+from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
logger = logging.getLogger(__name__)
@@ -57,3 +58,37 @@ def set_timeline_upper_limit(filter_json, filter_timeline_limit):
filter_json['room']['timeline']["limit"] = min(
filter_json['room']['timeline']['limit'],
filter_timeline_limit)
+
+
+def interactive_auth_handler(orig):
+ """Wraps an on_POST method to handle InteractiveAuthIncompleteErrors
+
+ Takes a on_POST method which returns a deferred (errcode, body) response
+ and adds exception handling to turn a InteractiveAuthIncompleteError into
+ a 401 response.
+
+ Normal usage is:
+
+ @interactive_auth_handler
+ @defer.inlineCallbacks
+ def on_POST(self, request):
+ # ...
+ yield self.auth_handler.check_auth
+ """
+ def wrapped(*args, **kwargs):
+ res = defer.maybeDeferred(orig, *args, **kwargs)
+ res.addErrback(_catch_incomplete_interactive_auth)
+ return res
+ return wrapped
+
+
+def _catch_incomplete_interactive_auth(f):
+ """helper for interactive_auth_handler
+
+ Catches InteractiveAuthIncompleteErrors and turns them into 401 responses
+
+ Args:
+ f (failure.Failure):
+ """
+ f.trap(InteractiveAuthIncompleteError)
+ return 401, f.value.result
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index c26ce63bcf..0d59a93222 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -26,7 +26,7 @@ from synapse.http.servlet import (
)
from synapse.util.async import run_on_reactor
from synapse.util.msisdn import phone_number_to_msisdn
-from ._base import client_v2_patterns
+from ._base import client_v2_patterns, interactive_auth_handler
logger = logging.getLogger(__name__)
@@ -100,21 +100,19 @@ class PasswordRestServlet(RestServlet):
self.datastore = self.hs.get_datastore()
self._set_password_handler = hs.get_set_password_handler()
+ @interactive_auth_handler
@defer.inlineCallbacks
def on_POST(self, request):
yield run_on_reactor()
body = parse_json_object_from_request(request)
- authed, result, params, _ = yield self.auth_handler.check_auth([
+ result, params, _ = yield self.auth_handler.check_auth([
[LoginType.PASSWORD],
[LoginType.EMAIL_IDENTITY],
[LoginType.MSISDN],
], body, self.hs.get_ip_from_request(request))
- if not authed:
- defer.returnValue((401, result))
-
user_id = None
requester = None
@@ -168,6 +166,7 @@ class DeactivateAccountRestServlet(RestServlet):
self.auth_handler = hs.get_auth_handler()
self._deactivate_account_handler = hs.get_deactivate_account_handler()
+ @interactive_auth_handler
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
@@ -186,13 +185,10 @@ class DeactivateAccountRestServlet(RestServlet):
)
defer.returnValue((200, {}))
- authed, result, params, _ = yield self.auth_handler.check_auth([
+ result, params, _ = yield self.auth_handler.check_auth([
[LoginType.PASSWORD],
], body, self.hs.get_ip_from_request(request))
- if not authed:
- defer.returnValue((401, result))
-
if LoginType.PASSWORD in result:
user_id = result[LoginType.PASSWORD]
# if using password, they should also be logged in
diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py
index 5321e5abbb..909f9c087b 100644
--- a/synapse/rest/client/v2_alpha/devices.py
+++ b/synapse/rest/client/v2_alpha/devices.py
@@ -19,7 +19,7 @@ from twisted.internet import defer
from synapse.api import constants, errors
from synapse.http import servlet
-from ._base import client_v2_patterns
+from ._base import client_v2_patterns, interactive_auth_handler
logger = logging.getLogger(__name__)
@@ -60,6 +60,7 @@ class DeleteDevicesRestServlet(servlet.RestServlet):
self.device_handler = hs.get_device_handler()
self.auth_handler = hs.get_auth_handler()
+ @interactive_auth_handler
@defer.inlineCallbacks
def on_POST(self, request):
try:
@@ -77,13 +78,10 @@ class DeleteDevicesRestServlet(servlet.RestServlet):
400, "No devices supplied", errcode=errors.Codes.MISSING_PARAM
)
- authed, result, params, _ = yield self.auth_handler.check_auth([
+ result, params, _ = yield self.auth_handler.check_auth([
[constants.LoginType.PASSWORD],
], body, self.hs.get_ip_from_request(request))
- if not authed:
- defer.returnValue((401, result))
-
requester = yield self.auth.get_user_by_req(request)
yield self.device_handler.delete_devices(
requester.user.to_string(),
@@ -115,6 +113,7 @@ class DeviceRestServlet(servlet.RestServlet):
)
defer.returnValue((200, device))
+ @interactive_auth_handler
@defer.inlineCallbacks
def on_DELETE(self, request, device_id):
requester = yield self.auth.get_user_by_req(request)
@@ -130,13 +129,10 @@ class DeviceRestServlet(servlet.RestServlet):
else:
raise
- authed, result, params, _ = yield self.auth_handler.check_auth([
+ result, params, _ = yield self.auth_handler.check_auth([
[constants.LoginType.PASSWORD],
], body, self.hs.get_ip_from_request(request))
- if not authed:
- defer.returnValue((401, result))
-
# check that the UI auth matched the access token
user_id = result[constants.LoginType.PASSWORD]
if user_id != requester.user.to_string():
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 9e2f7308ce..e9d88a8895 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -27,7 +27,7 @@ from synapse.http.servlet import (
)
from synapse.util.msisdn import phone_number_to_msisdn
-from ._base import client_v2_patterns
+from ._base import client_v2_patterns, interactive_auth_handler
import logging
import hmac
@@ -176,6 +176,7 @@ class RegisterRestServlet(RestServlet):
self.device_handler = hs.get_device_handler()
self.macaroon_gen = hs.get_macaroon_generator()
+ @interactive_auth_handler
@defer.inlineCallbacks
def on_POST(self, request):
yield run_on_reactor()
@@ -325,14 +326,10 @@ class RegisterRestServlet(RestServlet):
[LoginType.MSISDN, LoginType.EMAIL_IDENTITY],
])
- authed, auth_result, params, session_id = yield self.auth_handler.check_auth(
+ auth_result, params, session_id = yield self.auth_handler.check_auth(
flows, body, self.hs.get_ip_from_request(request)
)
- if not authed:
- defer.returnValue((401, auth_result))
- return
-
if registered_user_id is not None:
logger.info(
"Already registered user ID %r for this session",
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 821c735528..096f771bea 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -1,5 +1,7 @@
+from twisted.python import failure
+
from synapse.rest.client.v2_alpha.register import RegisterRestServlet
-from synapse.api.errors import SynapseError
+from synapse.api.errors import SynapseError, InteractiveAuthIncompleteError
from twisted.internet import defer
from mock import Mock
from tests import unittest
@@ -24,7 +26,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
side_effect=lambda x: self.appservice)
)
- self.auth_result = (False, None, None, None)
+ self.auth_result = failure.Failure(InteractiveAuthIncompleteError(None))
self.auth_handler = Mock(
check_auth=Mock(side_effect=lambda x, y, z: self.auth_result),
get_session_data=Mock(return_value=None)
@@ -86,6 +88,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.request.args = {
"access_token": "i_am_an_app_service"
}
+
self.request_data = json.dumps({
"username": "kermit"
})
@@ -120,7 +123,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
"device_id": device_id,
})
self.registration_handler.check_username = Mock(return_value=True)
- self.auth_result = (True, None, {
+ self.auth_result = (None, {
"username": "kermit",
"password": "monkey"
}, None)
@@ -150,7 +153,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
"password": "monkey"
})
self.registration_handler.check_username = Mock(return_value=True)
- self.auth_result = (True, None, {
+ self.auth_result = (None, {
"username": "kermit",
"password": "monkey"
}, None)
|