From 617501dd2a0562f4bf7edf8bc7a4e8aeb16b2254 Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Thu, 20 Aug 2015 11:35:56 +0100 Subject: Move token generation to auth handler I prefer the auth handler to worry about all auth, and register to call into it as needed, than to smatter auth logic between the two. --- synapse/handlers/auth.py | 29 ++++++++++++++--- synapse/handlers/register.py | 26 +++++---------- tests/handlers/test_auth.py | 70 +++++++++++++++++++++++++++++++++++++++++ tests/handlers/test_register.py | 70 ----------------------------------------- 4 files changed, 101 insertions(+), 94 deletions(-) create mode 100644 tests/handlers/test_auth.py delete mode 100644 tests/handlers/test_register.py diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index be2baeaece..0bf917efdd 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -26,6 +26,7 @@ from twisted.web.client import PartialDownloadError import logging import bcrypt +import pymacaroons import simplejson import synapse.util.stringutils as stringutils @@ -284,12 +285,9 @@ class AuthHandler(BaseHandler): LoginError if there was an authentication problem. """ yield self._check_password(user_id, password) - - reg_handler = self.hs.get_handlers().registration_handler - access_token = reg_handler.generate_token(user_id) logger.info("Logging in user %s", user_id) - yield self.store.add_access_token_to_user(user_id, access_token) - defer.returnValue(access_token) + token = yield self.issue_access_token(user_id) + defer.returnValue(token) @defer.inlineCallbacks def _check_password(self, user_id, password): @@ -304,6 +302,27 @@ class AuthHandler(BaseHandler): logger.warn("Failed password login for user %s", user_id) raise LoginError(403, "", errcode=Codes.FORBIDDEN) + @defer.inlineCallbacks + def issue_access_token(self, user_id): + reg_handler = self.hs.get_handlers().registration_handler + access_token = reg_handler.generate_access_token(user_id) + yield self.store.add_access_token_to_user(user_id, access_token) + defer.returnValue(access_token) + + def generate_access_token(self, user_id): + macaroon = pymacaroons.Macaroon( + location = self.hs.config.server_name, + identifier = "key", + key = self.hs.config.macaroon_secret_key) + macaroon.add_first_party_caveat("gen = 1") + macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) + macaroon.add_first_party_caveat("type = access") + now = self.hs.get_clock().time_msec() + expiry = now + (60 * 60 * 1000) + macaroon.add_first_party_caveat("time < %d" % (expiry,)) + + return macaroon.serialize() + @defer.inlineCallbacks def set_password(self, user_id, newpassword): password_hash = bcrypt.hashpw(newpassword, bcrypt.gensalt()) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index c391c1bdf5..3d1b6531c2 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -27,7 +27,6 @@ from synapse.http.client import CaptchaServerHttpClient import bcrypt import logging -import pymacaroons import urllib logger = logging.getLogger(__name__) @@ -91,7 +90,7 @@ class RegistrationHandler(BaseHandler): user = UserID(localpart, self.hs.hostname) user_id = user.to_string() - token = self.generate_token(user_id) + token = self.auth_handler().generate_access_token(user_id) yield self.store.register( user_id=user_id, token=token, @@ -111,7 +110,7 @@ class RegistrationHandler(BaseHandler): user_id = user.to_string() yield self.check_user_id_is_valid(user_id) - token = self.generate_token(user_id) + token = self.auth_handler().generate_access_token(user_id) yield self.store.register( user_id=user_id, token=token, @@ -161,7 +160,7 @@ class RegistrationHandler(BaseHandler): 400, "Invalid user localpart for this application service.", errcode=Codes.EXCLUSIVE ) - token = self.generate_token(user_id) + token = self.auth_handler().generate_access_token(user_id) yield self.store.register( user_id=user_id, token=token, @@ -208,7 +207,7 @@ class RegistrationHandler(BaseHandler): user_id = user.to_string() yield self.check_user_id_is_valid(user_id) - token = self.generate_token(user_id) + token = self.auth_handler().generate_access_token(user_id) try: yield self.store.register( user_id=user_id, @@ -273,20 +272,6 @@ class RegistrationHandler(BaseHandler): errcode=Codes.EXCLUSIVE ) - def generate_token(self, user_id): - macaroon = pymacaroons.Macaroon( - location = self.hs.config.server_name, - identifier = "key", - key = self.hs.config.macaroon_secret_key) - macaroon.add_first_party_caveat("gen = 1") - macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) - macaroon.add_first_party_caveat("type = access") - now = self.hs.get_clock().time_msec() - expiry = now + (60 * 60 * 1000) - macaroon.add_first_party_caveat("time < %d" % (expiry,)) - - return macaroon.serialize() - def _generate_user_id(self): return "-" + stringutils.random_string(18) @@ -329,3 +314,6 @@ class RegistrationHandler(BaseHandler): } ) defer.returnValue(data) + + def auth_handler(self): + return self.hs.get_handlers().auth_handler diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py new file mode 100644 index 0000000000..978e4d0d2e --- /dev/null +++ b/tests/handlers/test_auth.py @@ -0,0 +1,70 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pymacaroons + +from mock import Mock, NonCallableMock +from synapse.handlers.auth import AuthHandler +from tests import unittest +from tests.utils import setup_test_homeserver +from twisted.internet import defer + + +class AuthHandlers(object): + def __init__(self, hs): + self.auth_handler = AuthHandler(hs) + + +class AuthTestCase(unittest.TestCase): + @defer.inlineCallbacks + def setUp(self): + self.hs = yield setup_test_homeserver(handlers=None) + self.hs.handlers = AuthHandlers(self.hs) + + def test_token_is_a_macaroon(self): + self.hs.config.macaroon_secret_key = "this key is a huge secret" + + token = self.hs.handlers.auth_handler.generate_access_token("some_user") + # Check that we can parse the thing with pymacaroons + macaroon = pymacaroons.Macaroon.deserialize(token) + # The most basic of sanity checks + if "some_user" not in macaroon.inspect(): + self.fail("some_user was not in %s" % macaroon.inspect()) + + def test_macaroon_caveats(self): + self.hs.config.macaroon_secret_key = "this key is a massive secret" + self.hs.clock.now = 5000 + + token = self.hs.handlers.auth_handler.generate_access_token("a_user") + macaroon = pymacaroons.Macaroon.deserialize(token) + + def verify_gen(caveat): + return caveat == "gen = 1" + + def verify_user(caveat): + return caveat == "user_id = a_user" + + def verify_type(caveat): + return caveat == "type = access" + + def verify_expiry(caveat): + return caveat == "time < 8600000" + + v = pymacaroons.Verifier() + v.satisfy_general(verify_gen) + v.satisfy_general(verify_user) + v.satisfy_general(verify_type) + v.satisfy_general(verify_expiry) + v.verify(macaroon, self.hs.config.macaroon_secret_key) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py deleted file mode 100644 index 91cc90242f..0000000000 --- a/tests/handlers/test_register.py +++ /dev/null @@ -1,70 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pymacaroons - -from mock import Mock, NonCallableMock -from synapse.handlers.register import RegistrationHandler -from tests import unittest -from tests.utils import setup_test_homeserver -from twisted.internet import defer - - -class RegisterHandlers(object): - def __init__(self, hs): - self.registration_handler = RegistrationHandler(hs) - - -class RegisterTestCase(unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - self.hs = yield setup_test_homeserver(handlers=None) - self.hs.handlers = RegisterHandlers(self.hs) - - def test_token_is_a_macaroon(self): - self.hs.config.macaroon_secret_key = "this key is a huge secret" - - token = self.hs.handlers.registration_handler.generate_token("some_user") - # Check that we can parse the thing with pymacaroons - macaroon = pymacaroons.Macaroon.deserialize(token) - # The most basic of sanity checks - if "some_user" not in macaroon.inspect(): - self.fail("some_user was not in %s" % macaroon.inspect()) - - def test_macaroon_caveats(self): - self.hs.config.macaroon_secret_key = "this key is a massive secret" - self.hs.clock.now = 5000 - - token = self.hs.handlers.registration_handler.generate_token("a_user") - macaroon = pymacaroons.Macaroon.deserialize(token) - - def verify_gen(caveat): - return caveat == "gen = 1" - - def verify_user(caveat): - return caveat == "user_id = a_user" - - def verify_type(caveat): - return caveat == "type = access" - - def verify_expiry(caveat): - return caveat == "time < 8600000" - - v = pymacaroons.Verifier() - v.satisfy_general(verify_gen) - v.satisfy_general(verify_user) - v.satisfy_general(verify_type) - v.satisfy_general(verify_expiry) - v.verify(macaroon, self.hs.config.macaroon_secret_key) -- cgit 1.4.1 From 13a6517d89c0619a938321640f331571eba0edc9 Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Thu, 20 Aug 2015 16:01:29 +0100 Subject: s/by_token/by_access_token/g We're about to have two kinds of token, access and refresh --- synapse/api/auth.py | 6 +++--- synapse/storage/registration.py | 6 +++--- tests/api/test_auth.py | 16 ++++++++-------- tests/rest/client/v1/test_presence.py | 8 ++++---- tests/rest/client/v1/test_rooms.py | 28 ++++++++++++++-------------- tests/rest/client/v1/test_typing.py | 4 ++-- tests/rest/client/v1/utils.py | 2 +- tests/rest/client/v2_alpha/__init__.py | 4 ++-- tests/storage/test_registration.py | 4 ++-- tests/utils.py | 2 +- 10 files changed, 40 insertions(+), 40 deletions(-) diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 1e3b0fbfb7..3d9237ccc3 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -361,7 +361,7 @@ class Auth(object): except KeyError: pass # normal users won't have the user_id query parameter set. - user_info = yield self.get_user_by_token(access_token) + user_info = yield self.get_user_by_access_token(access_token) user = user_info["user"] device_id = user_info["device_id"] token_id = user_info["token_id"] @@ -390,7 +390,7 @@ class Auth(object): ) @defer.inlineCallbacks - def get_user_by_token(self, token): + def get_user_by_access_token(self, token): """ Get a registered user's ID. Args: @@ -401,7 +401,7 @@ class Auth(object): Raises: AuthError if no user by that token exists or the token is invalid. """ - ret = yield self.store.get_user_by_token(token) + ret = yield self.store.get_user_by_access_token(token) if not ret: raise AuthError( self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.", diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index bf803f2c6e..0e404afb7c 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -132,10 +132,10 @@ class RegistrationStore(SQLBaseStore): user_id ) for r in rows: - self.get_user_by_token.invalidate((r,)) + self.get_user_by_access_token.invalidate((r,)) @cached() - def get_user_by_token(self, token): + def get_user_by_access_token(self, token): """Get a user from the given access token. Args: @@ -147,7 +147,7 @@ class RegistrationStore(SQLBaseStore): StoreError if no user was found. """ return self.runInteraction( - "get_user_by_token", + "get_user_by_access_token", self._query_for_auth, token ) diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 4f83db5e84..3343c635cc 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -44,7 +44,7 @@ class AuthTestCase(unittest.TestCase): "token_id": "ditto", "admin": False } - self.store.get_user_by_token = Mock(return_value=user_info) + self.store.get_user_by_access_token = Mock(return_value=user_info) request = Mock(args={}) request.args["access_token"] = [self.test_token] @@ -54,7 +54,7 @@ class AuthTestCase(unittest.TestCase): def test_get_user_by_req_user_bad_token(self): self.store.get_app_service_by_token = Mock(return_value=None) - self.store.get_user_by_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=None) request = Mock(args={}) request.args["access_token"] = [self.test_token] @@ -70,7 +70,7 @@ class AuthTestCase(unittest.TestCase): "token_id": "ditto", "admin": False } - self.store.get_user_by_token = Mock(return_value=user_info) + self.store.get_user_by_access_token = Mock(return_value=user_info) request = Mock(args={}) request.requestHeaders.getRawHeaders = Mock(return_value=[""]) @@ -81,7 +81,7 @@ class AuthTestCase(unittest.TestCase): def test_get_user_by_req_appservice_valid_token(self): app_service = Mock(token="foobar", url="a_url", sender=self.test_user) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=None) request = Mock(args={}) request.args["access_token"] = [self.test_token] @@ -91,7 +91,7 @@ class AuthTestCase(unittest.TestCase): def test_get_user_by_req_appservice_bad_token(self): self.store.get_app_service_by_token = Mock(return_value=None) - self.store.get_user_by_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=None) request = Mock(args={}) request.args["access_token"] = [self.test_token] @@ -102,7 +102,7 @@ class AuthTestCase(unittest.TestCase): def test_get_user_by_req_appservice_missing_token(self): app_service = Mock(token="foobar", url="a_url", sender=self.test_user) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=None) request = Mock(args={}) request.requestHeaders.getRawHeaders = Mock(return_value=[""]) @@ -115,7 +115,7 @@ class AuthTestCase(unittest.TestCase): app_service = Mock(token="foobar", url="a_url", sender=self.test_user) app_service.is_interested_in_user = Mock(return_value=True) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=None) request = Mock(args={}) request.args["access_token"] = [self.test_token] @@ -129,7 +129,7 @@ class AuthTestCase(unittest.TestCase): app_service = Mock(token="foobar", url="a_url", sender=self.test_user) app_service.is_interested_in_user = Mock(return_value=False) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_token = Mock(return_value=None) + self.store.get_user_by_access_token = Mock(return_value=None) request = Mock(args={}) request.args["access_token"] = [self.test_token] diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py index 089a71568c..0b78a82a66 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py @@ -70,7 +70,7 @@ class PresenceStateTestCase(unittest.TestCase): return defer.succeed([]) self.datastore.get_presence_list = get_presence_list - def _get_user_by_token(token=None): + def _get_user_by_access_token(token=None): return { "user": UserID.from_string(myid), "admin": False, @@ -78,7 +78,7 @@ class PresenceStateTestCase(unittest.TestCase): "token_id": 1, } - hs.get_v1auth().get_user_by_token = _get_user_by_token + hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token room_member_handler = hs.handlers.room_member_handler = Mock( spec=[ @@ -159,7 +159,7 @@ class PresenceListTestCase(unittest.TestCase): ) self.datastore.has_presence_state = has_presence_state - def _get_user_by_token(token=None): + def _get_user_by_access_token(token=None): return { "user": UserID.from_string(myid), "admin": False, @@ -173,7 +173,7 @@ class PresenceListTestCase(unittest.TestCase): ] ) - hs.get_v1auth().get_user_by_token = _get_user_by_token + hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token presence.register_servlets(hs, self.mock_resource) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index c83348acf9..2e55cc08a1 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -54,14 +54,14 @@ class RoomPermissionsTestCase(RestTestCase): hs.get_handlers().federation_handler = Mock() - def _get_user_by_token(token=None): + def _get_user_by_access_token(token=None): return { "user": UserID.from_string(self.auth_user_id), "admin": False, "device_id": None, "token_id": 1, } - hs.get_v1auth().get_user_by_token = _get_user_by_token + hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -441,14 +441,14 @@ class RoomsMemberListTestCase(RestTestCase): self.auth_user_id = self.user_id - def _get_user_by_token(token=None): + def _get_user_by_access_token(token=None): return { "user": UserID.from_string(self.auth_user_id), "admin": False, "device_id": None, "token_id": 1, } - hs.get_v1auth().get_user_by_token = _get_user_by_token + hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -521,14 +521,14 @@ class RoomsCreateTestCase(RestTestCase): hs.get_handlers().federation_handler = Mock() - def _get_user_by_token(token=None): + def _get_user_by_access_token(token=None): return { "user": UserID.from_string(self.auth_user_id), "admin": False, "device_id": None, "token_id": 1, } - hs.get_v1auth().get_user_by_token = _get_user_by_token + hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -614,7 +614,7 @@ class RoomTopicTestCase(RestTestCase): hs.get_handlers().federation_handler = Mock() - def _get_user_by_token(token=None): + def _get_user_by_access_token(token=None): return { "user": UserID.from_string(self.auth_user_id), "admin": False, @@ -622,7 +622,7 @@ class RoomTopicTestCase(RestTestCase): "token_id": 1, } - hs.get_v1auth().get_user_by_token = _get_user_by_token + hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -721,14 +721,14 @@ class RoomMemberStateTestCase(RestTestCase): hs.get_handlers().federation_handler = Mock() - def _get_user_by_token(token=None): + def _get_user_by_access_token(token=None): return { "user": UserID.from_string(self.auth_user_id), "admin": False, "device_id": None, "token_id": 1, } - hs.get_v1auth().get_user_by_token = _get_user_by_token + hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -848,14 +848,14 @@ class RoomMessagesTestCase(RestTestCase): hs.get_handlers().federation_handler = Mock() - def _get_user_by_token(token=None): + def _get_user_by_access_token(token=None): return { "user": UserID.from_string(self.auth_user_id), "admin": False, "device_id": None, "token_id": 1, } - hs.get_v1auth().get_user_by_token = _get_user_by_token + hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -945,14 +945,14 @@ class RoomInitialSyncTestCase(RestTestCase): hs.get_handlers().federation_handler = Mock() - def _get_user_by_token(token=None): + def _get_user_by_access_token(token=None): return { "user": UserID.from_string(self.auth_user_id), "admin": False, "device_id": None, "token_id": 1, } - hs.get_v1auth().get_user_by_token = _get_user_by_token + hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index 7d8b1c2683..dc8bbaaf0e 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -61,7 +61,7 @@ class RoomTypingTestCase(RestTestCase): hs.get_handlers().federation_handler = Mock() - def _get_user_by_token(token=None): + def _get_user_by_access_token(token=None): return { "user": UserID.from_string(self.auth_user_id), "admin": False, @@ -69,7 +69,7 @@ class RoomTypingTestCase(RestTestCase): "token_id": 1, } - hs.get_v1auth().get_user_by_token = _get_user_by_token + hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 579441fb4a..c472d53043 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -37,7 +37,7 @@ class RestTestCase(unittest.TestCase): self.mock_resource = None self.auth_user_id = None - def mock_get_user_by_token(self, token=None): + def mock_get_user_by_access_token(self, token=None): return self.auth_user_id @defer.inlineCallbacks diff --git a/tests/rest/client/v2_alpha/__init__.py b/tests/rest/client/v2_alpha/__init__.py index de5a917e6a..15568b36cd 100644 --- a/tests/rest/client/v2_alpha/__init__.py +++ b/tests/rest/client/v2_alpha/__init__.py @@ -43,14 +43,14 @@ class V2AlphaRestTestCase(unittest.TestCase): resource_for_federation=self.mock_resource, ) - def _get_user_by_token(token=None): + def _get_user_by_access_token(token=None): return { "user": UserID.from_string(self.USER_ID), "admin": False, "device_id": None, "token_id": 1, } - hs.get_auth().get_user_by_token = _get_user_by_token + hs.get_auth().get_user_by_access_token = _get_user_by_access_token for r in self.TO_REGISTER: r.register_servlets(hs, self.mock_resource) diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 2702291178..7a24cf898a 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -46,7 +46,7 @@ class RegistrationStoreTestCase(unittest.TestCase): (yield self.store.get_user_by_id(self.user_id)) ) - result = yield self.store.get_user_by_token(self.tokens[0]) + result = yield self.store.get_user_by_access_token(self.tokens[0]) self.assertDictContainsSubset( { @@ -64,7 +64,7 @@ class RegistrationStoreTestCase(unittest.TestCase): yield self.store.register(self.user_id, self.tokens[0], self.pwhash) yield self.store.add_access_token_to_user(self.user_id, self.tokens[1]) - result = yield self.store.get_user_by_token(self.tokens[1]) + result = yield self.store.get_user_by_access_token(self.tokens[1]) self.assertDictContainsSubset( { diff --git a/tests/utils.py b/tests/utils.py index 80be70b74f..d0fba2252d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -277,7 +277,7 @@ class MemoryDataStore(object): raise StoreError(400, "User in use.") self.tokens_to_users[token] = user_id - def get_user_by_token(self, token): + def get_user_by_access_token(self, token): try: return { "name": self.tokens_to_users[token], -- cgit 1.4.1 From cecbd636e94f4e900ef6d246b62698ff1c8ee352 Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Thu, 20 Aug 2015 16:21:35 +0100 Subject: /tokenrefresh POST endpoint This allows refresh tokens to be exchanged for (access_token, refresh_token). It also starts issuing them on login, though no clients currently interpret them. --- synapse/handlers/auth.py | 35 ++++++++++-- synapse/rest/client/v1/login.py | 6 ++- synapse/rest/client/v2_alpha/__init__.py | 2 + synapse/rest/client/v2_alpha/tokenrefresh.py | 56 +++++++++++++++++++ synapse/storage/__init__.py | 2 +- synapse/storage/_base.py | 1 + synapse/storage/registration.py | 62 ++++++++++++++++++++++ synapse/storage/schema/delta/23/refresh_tokens.sql | 21 ++++++++ tests/storage/test_registration.py | 55 +++++++++++++++++++ 9 files changed, 232 insertions(+), 8 deletions(-) create mode 100644 synapse/rest/client/v2_alpha/tokenrefresh.py create mode 100644 synapse/storage/schema/delta/23/refresh_tokens.sql diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 0bf917efdd..65bd8189dc 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -279,15 +279,18 @@ class AuthHandler(BaseHandler): user_id (str): User ID password (str): Password Returns: - The access token for the user's session. + A tuple of: + The access token for the user's session. + The refresh token for the user's session. Raises: StoreError if there was a problem storing the token. LoginError if there was an authentication problem. """ yield self._check_password(user_id, password) logger.info("Logging in user %s", user_id) - token = yield self.issue_access_token(user_id) - defer.returnValue(token) + access_token = yield self.issue_access_token(user_id) + refresh_token = yield self.issue_refresh_token(user_id) + defer.returnValue((access_token, refresh_token)) @defer.inlineCallbacks def _check_password(self, user_id, password): @@ -304,11 +307,16 @@ class AuthHandler(BaseHandler): @defer.inlineCallbacks def issue_access_token(self, user_id): - reg_handler = self.hs.get_handlers().registration_handler - access_token = reg_handler.generate_access_token(user_id) + access_token = self.generate_access_token(user_id) yield self.store.add_access_token_to_user(user_id, access_token) defer.returnValue(access_token) + @defer.inlineCallbacks + def issue_refresh_token(self, user_id): + refresh_token = self.generate_refresh_token(user_id) + yield self.store.add_refresh_token_to_user(user_id, refresh_token) + defer.returnValue(refresh_token) + def generate_access_token(self, user_id): macaroon = pymacaroons.Macaroon( location = self.hs.config.server_name, @@ -323,6 +331,23 @@ class AuthHandler(BaseHandler): return macaroon.serialize() + def generate_refresh_token(self, user_id): + m = self._generate_base_macaroon(user_id) + m.add_first_party_caveat("type = refresh") + # Important to add a nonce, because otherwise every refresh token for a + # user will be the same. + m.add_first_party_caveat("nonce = %s" % stringutils.random_string_with_symbols(16)) + return m.serialize() + + def _generate_base_macaroon(self, user_id): + macaroon = pymacaroons.Macaroon( + location = self.hs.config.server_name, + identifier = "key", + key = self.hs.config.macaroon_secret_key) + macaroon.add_first_party_caveat("gen = 1") + macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) + return macaroon + @defer.inlineCallbacks def set_password(self, user_id, newpassword): password_hash = bcrypt.hashpw(newpassword, bcrypt.gensalt()) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 694072693d..b963a38618 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -78,13 +78,15 @@ class LoginRestServlet(ClientV1RestServlet): login_submission["user"] = UserID.create( login_submission["user"], self.hs.hostname).to_string() - token = yield self.handlers.auth_handler.login_with_password( + auth_handler = self.handlers.auth_handler + access_token, refresh_token = yield auth_handler.login_with_password( user_id=login_submission["user"], password=login_submission["password"]) result = { "user_id": login_submission["user"], # may have changed - "access_token": token, + "access_token": access_token, + "refresh_token": refresh_token, "home_server": self.hs.hostname, } diff --git a/synapse/rest/client/v2_alpha/__init__.py b/synapse/rest/client/v2_alpha/__init__.py index 33f961e898..5831ff0e62 100644 --- a/synapse/rest/client/v2_alpha/__init__.py +++ b/synapse/rest/client/v2_alpha/__init__.py @@ -21,6 +21,7 @@ from . import ( auth, receipts, keys, + tokenrefresh, ) from synapse.http.server import JsonResource @@ -42,3 +43,4 @@ class ClientV2AlphaRestResource(JsonResource): auth.register_servlets(hs, client_resource) receipts.register_servlets(hs, client_resource) keys.register_servlets(hs, client_resource) + tokenrefresh.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/v2_alpha/tokenrefresh.py new file mode 100644 index 0000000000..901e777983 --- /dev/null +++ b/synapse/rest/client/v2_alpha/tokenrefresh.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.api.errors import AuthError, StoreError, SynapseError +from synapse.http.servlet import RestServlet + +from ._base import client_v2_pattern, parse_json_dict_from_request + + +class TokenRefreshRestServlet(RestServlet): + """ + Exchanges refresh tokens for a pair of an access token and a new refresh + token. + """ + PATTERN = client_v2_pattern("/tokenrefresh") + + def __init__(self, hs): + super(TokenRefreshRestServlet, self).__init__() + self.hs = hs + self.store = hs.get_datastore() + + @defer.inlineCallbacks + def on_POST(self, request): + body = parse_json_dict_from_request(request) + try: + old_refresh_token = body["refresh_token"] + auth_handler = self.hs.get_handlers().auth_handler + (user_id, new_refresh_token) = yield self.store.exchange_refresh_token( + old_refresh_token, auth_handler.generate_refresh_token) + new_access_token = yield auth_handler.issue_access_token(user_id) + defer.returnValue((200, { + "access_token": new_access_token, + "refresh_token": new_refresh_token, + })) + except KeyError: + raise SynapseError(400, "Missing required key 'refresh_token'.") + except StoreError: + raise AuthError(403, "Did not recognize refresh token") + + +def register_servlets(hs, http_server): + TokenRefreshRestServlet(hs).register(http_server) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index f154b1c8ae..53673b3bf5 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -54,7 +54,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 22 +SCHEMA_VERSION = 23 dir_path = os.path.abspath(os.path.dirname(__file__)) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 1444767a52..ce71389f02 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -181,6 +181,7 @@ class SQLBaseStore(object): self._transaction_id_gen = IdGenerator("sent_transactions", "id", self) self._state_groups_id_gen = IdGenerator("state_groups", "id", self) self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self) + self._refresh_tokens_id_gen = IdGenerator("refresh_tokens", "id", self) self._pushers_id_gen = IdGenerator("pushers", "id", self) self._push_rule_id_gen = IdGenerator("push_rules", "id", self) self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 0e404afb7c..f632306688 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -50,6 +50,28 @@ class RegistrationStore(SQLBaseStore): desc="add_access_token_to_user", ) + @defer.inlineCallbacks + def add_refresh_token_to_user(self, user_id, token): + """Adds a refresh token for the given user. + + Args: + user_id (str): The user ID. + token (str): The new refresh token to add. + Raises: + StoreError if there was a problem adding this. + """ + next_id = yield self._refresh_tokens_id_gen.get_next() + + yield self._simple_insert( + "refresh_tokens", + { + "id": next_id, + "user_id": user_id, + "token": token + }, + desc="add_refresh_token_to_user", + ) + @defer.inlineCallbacks def register(self, user_id, token, password_hash): """Attempts to register an account. @@ -152,6 +174,46 @@ class RegistrationStore(SQLBaseStore): token ) + def exchange_refresh_token(self, refresh_token, token_generator): + """Exchange a refresh token for a new access token and refresh token. + + Doing so invalidates the old refresh token - refresh tokens are single + use. + + Args: + token (str): The refresh token of a user. + token_generator (fn: str -> str): Function which, when given a + user ID, returns a unique refresh token for that user. This + function must never return the same value twice. + Returns: + tuple of (user_id, refresh_token) + Raises: + StoreError if no user was found with that refresh token. + """ + return self.runInteraction( + "exchange_refresh_token", + self._exchange_refresh_token, + refresh_token, + token_generator + ) + + def _exchange_refresh_token(self, txn, old_token, token_generator): + sql = "SELECT user_id FROM refresh_tokens WHERE token = ?" + txn.execute(sql, (old_token,)) + rows = self.cursor_to_dict(txn) + if not rows: + raise StoreError(403, "Did not recognize refresh token") + user_id = rows[0]["user_id"] + + # TODO(danielwh): Maybe perform a validation on the macaroon that + # macaroon.user_id == user_id. + + new_token = token_generator(user_id) + sql = "UPDATE refresh_tokens SET token = ? WHERE token = ?" + txn.execute(sql, (new_token, old_token,)) + + return user_id, new_token + @defer.inlineCallbacks def is_server_admin(self, user): res = yield self._simple_select_one_onecol( diff --git a/synapse/storage/schema/delta/23/refresh_tokens.sql b/synapse/storage/schema/delta/23/refresh_tokens.sql new file mode 100644 index 0000000000..46839e016c --- /dev/null +++ b/synapse/storage/schema/delta/23/refresh_tokens.sql @@ -0,0 +1,21 @@ +/* Copyright 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS refresh_tokens( + id INTEGER PRIMARY KEY AUTOINCREMENT, + token TEXT NOT NULL, + user_id TEXT NOT NULL, + UNIQUE (token) +); diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 7a24cf898a..a4f929796a 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -17,7 +17,9 @@ from tests import unittest from twisted.internet import defer +from synapse.api.errors import StoreError from synapse.storage.registration import RegistrationStore +from synapse.util import stringutils from tests.utils import setup_test_homeserver @@ -27,6 +29,7 @@ class RegistrationStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): hs = yield setup_test_homeserver() + self.db_pool = hs.get_db_pool() self.store = RegistrationStore(hs) @@ -77,3 +80,55 @@ class RegistrationStoreTestCase(unittest.TestCase): self.assertTrue("token_id" in result) + @defer.inlineCallbacks + def test_exchange_refresh_token_valid(self): + uid = stringutils.random_string(32) + generator = TokenGenerator() + last_token = generator.generate(uid) + + self.db_pool.runQuery( + "INSERT INTO refresh_tokens(user_id, token) VALUES(?,?)", + (uid, last_token,)) + + (found_user_id, refresh_token) = yield self.store.exchange_refresh_token( + last_token, generator.generate) + self.assertEqual(uid, found_user_id) + + rows = yield self.db_pool.runQuery( + "SELECT token FROM refresh_tokens WHERE user_id = ?", (uid, )) + self.assertEqual([(refresh_token,)], rows) + # We issued token 1, then exchanged it for token 2 + expected_refresh_token = u"%s-%d" % (uid, 2,) + self.assertEqual(expected_refresh_token, refresh_token) + + @defer.inlineCallbacks + def test_exchange_refresh_token_none(self): + uid = stringutils.random_string(32) + generator = TokenGenerator() + last_token = generator.generate(uid) + + with self.assertRaises(StoreError): + yield self.store.exchange_refresh_token(last_token, generator.generate) + + @defer.inlineCallbacks + def test_exchange_refresh_token_invalid(self): + uid = stringutils.random_string(32) + generator = TokenGenerator() + last_token = generator.generate(uid) + wrong_token = "%s-wrong" % (last_token,) + + self.db_pool.runQuery( + "INSERT INTO refresh_tokens(user_id, token) VALUES(?,?)", + (uid, wrong_token,)) + + with self.assertRaises(StoreError): + yield self.store.exchange_refresh_token(last_token, generator.generate) + + +class TokenGenerator: + def __init__(self): + self._last_issued_token = 0 + + def generate(self, user_id): + self._last_issued_token += 1 + return u"%s-%d" % (user_id, self._last_issued_token,) -- cgit 1.4.1 From ea570ffaebf59219c06d5d3d65400b1f1b1384bd Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Thu, 20 Aug 2015 17:22:41 +0100 Subject: Fix flake8 warnings --- synapse/handlers/auth.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 17465d2af6..1b0971e13d 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -336,14 +336,16 @@ class AuthHandler(BaseHandler): m.add_first_party_caveat("type = refresh") # Important to add a nonce, because otherwise every refresh token for a # user will be the same. - m.add_first_party_caveat("nonce = %s" % stringutils.random_string_with_symbols(16)) + m.add_first_party_caveat("nonce = %s" % ( + stringutils.random_string_with_symbols(16), + )) return m.serialize() def _generate_base_macaroon(self, user_id): macaroon = pymacaroons.Macaroon( - location = self.hs.config.server_name, - identifier = "key", - key = self.hs.config.macaroon_secret_key) + location=self.hs.config.server_name, + identifier="key", + key=self.hs.config.macaroon_secret_key) macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) return macaroon -- cgit 1.4.1 From 8c74bd896010c6011a63bc7147e39e0078df7dcb Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Thu, 20 Aug 2015 17:26:52 +0100 Subject: Fix indentation --- synapse/handlers/auth.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 1b0971e13d..e043363176 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -321,7 +321,8 @@ class AuthHandler(BaseHandler): macaroon = pymacaroons.Macaroon( location=self.hs.config.server_name, identifier="key", - key=self.hs.config.macaroon_secret_key) + key=self.hs.config.macaroon_secret_key + ) macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) macaroon.add_first_party_caveat("type = access") -- cgit 1.4.1 From c7788685b061dc1fbbecc07e472570f99f36dca3 Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Thu, 20 Aug 2015 17:43:12 +0100 Subject: Fix bad merge --- synapse/handlers/auth.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index e043363176..c983d444e8 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -318,18 +318,11 @@ class AuthHandler(BaseHandler): defer.returnValue(refresh_token) def generate_access_token(self, user_id): - macaroon = pymacaroons.Macaroon( - location=self.hs.config.server_name, - identifier="key", - key=self.hs.config.macaroon_secret_key - ) - macaroon.add_first_party_caveat("gen = 1") - macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) + macaroon = self._generate_base_macaroon(user_id) macaroon.add_first_party_caveat("type = access") now = self.hs.get_clock().time_msec() expiry = now + (60 * 60 * 1000) macaroon.add_first_party_caveat("time < %d" % (expiry,)) - return macaroon.serialize() def generate_refresh_token(self, user_id): -- cgit 1.4.1