diff options
-rw-r--r-- | synapse/handlers/auth.py | 35 | ||||
-rw-r--r-- | synapse/rest/client/v1/login.py | 6 | ||||
-rw-r--r-- | synapse/rest/client/v2_alpha/__init__.py | 2 | ||||
-rw-r--r-- | synapse/rest/client/v2_alpha/tokenrefresh.py | 56 | ||||
-rw-r--r-- | synapse/storage/__init__.py | 2 | ||||
-rw-r--r-- | synapse/storage/_base.py | 1 | ||||
-rw-r--r-- | synapse/storage/registration.py | 62 | ||||
-rw-r--r-- | synapse/storage/schema/delta/23/refresh_tokens.sql | 21 | ||||
-rw-r--r-- | tests/storage/test_registration.py | 55 |
9 files changed, 232 insertions, 8 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 0bf917efdd..65bd8189dc 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -279,15 +279,18 @@ class AuthHandler(BaseHandler): user_id (str): User ID password (str): Password Returns: - The access token for the user's session. + A tuple of: + The access token for the user's session. + The refresh token for the user's session. Raises: StoreError if there was a problem storing the token. LoginError if there was an authentication problem. """ yield self._check_password(user_id, password) logger.info("Logging in user %s", user_id) - token = yield self.issue_access_token(user_id) - defer.returnValue(token) + access_token = yield self.issue_access_token(user_id) + refresh_token = yield self.issue_refresh_token(user_id) + defer.returnValue((access_token, refresh_token)) @defer.inlineCallbacks def _check_password(self, user_id, password): @@ -304,11 +307,16 @@ class AuthHandler(BaseHandler): @defer.inlineCallbacks def issue_access_token(self, user_id): - reg_handler = self.hs.get_handlers().registration_handler - access_token = reg_handler.generate_access_token(user_id) + access_token = self.generate_access_token(user_id) yield self.store.add_access_token_to_user(user_id, access_token) defer.returnValue(access_token) + @defer.inlineCallbacks + def issue_refresh_token(self, user_id): + refresh_token = self.generate_refresh_token(user_id) + yield self.store.add_refresh_token_to_user(user_id, refresh_token) + defer.returnValue(refresh_token) + def generate_access_token(self, user_id): macaroon = pymacaroons.Macaroon( location = self.hs.config.server_name, @@ -323,6 +331,23 @@ class AuthHandler(BaseHandler): return macaroon.serialize() + def generate_refresh_token(self, user_id): + m = self._generate_base_macaroon(user_id) + m.add_first_party_caveat("type = refresh") + # Important to add a nonce, because otherwise every refresh token for a + # user will be the same. + m.add_first_party_caveat("nonce = %s" % stringutils.random_string_with_symbols(16)) + return m.serialize() + + def _generate_base_macaroon(self, user_id): + macaroon = pymacaroons.Macaroon( + location = self.hs.config.server_name, + identifier = "key", + key = self.hs.config.macaroon_secret_key) + macaroon.add_first_party_caveat("gen = 1") + macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) + return macaroon + @defer.inlineCallbacks def set_password(self, user_id, newpassword): password_hash = bcrypt.hashpw(newpassword, bcrypt.gensalt()) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 694072693d..b963a38618 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -78,13 +78,15 @@ class LoginRestServlet(ClientV1RestServlet): login_submission["user"] = UserID.create( login_submission["user"], self.hs.hostname).to_string() - token = yield self.handlers.auth_handler.login_with_password( + auth_handler = self.handlers.auth_handler + access_token, refresh_token = yield auth_handler.login_with_password( user_id=login_submission["user"], password=login_submission["password"]) result = { "user_id": login_submission["user"], # may have changed - "access_token": token, + "access_token": access_token, + "refresh_token": refresh_token, "home_server": self.hs.hostname, } diff --git a/synapse/rest/client/v2_alpha/__init__.py b/synapse/rest/client/v2_alpha/__init__.py index 33f961e898..5831ff0e62 100644 --- a/synapse/rest/client/v2_alpha/__init__.py +++ b/synapse/rest/client/v2_alpha/__init__.py @@ -21,6 +21,7 @@ from . import ( auth, receipts, keys, + tokenrefresh, ) from synapse.http.server import JsonResource @@ -42,3 +43,4 @@ class ClientV2AlphaRestResource(JsonResource): auth.register_servlets(hs, client_resource) receipts.register_servlets(hs, client_resource) keys.register_servlets(hs, client_resource) + tokenrefresh.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/v2_alpha/tokenrefresh.py new file mode 100644 index 0000000000..901e777983 --- /dev/null +++ b/synapse/rest/client/v2_alpha/tokenrefresh.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.api.errors import AuthError, StoreError, SynapseError +from synapse.http.servlet import RestServlet + +from ._base import client_v2_pattern, parse_json_dict_from_request + + +class TokenRefreshRestServlet(RestServlet): + """ + Exchanges refresh tokens for a pair of an access token and a new refresh + token. + """ + PATTERN = client_v2_pattern("/tokenrefresh") + + def __init__(self, hs): + super(TokenRefreshRestServlet, self).__init__() + self.hs = hs + self.store = hs.get_datastore() + + @defer.inlineCallbacks + def on_POST(self, request): + body = parse_json_dict_from_request(request) + try: + old_refresh_token = body["refresh_token"] + auth_handler = self.hs.get_handlers().auth_handler + (user_id, new_refresh_token) = yield self.store.exchange_refresh_token( + old_refresh_token, auth_handler.generate_refresh_token) + new_access_token = yield auth_handler.issue_access_token(user_id) + defer.returnValue((200, { + "access_token": new_access_token, + "refresh_token": new_refresh_token, + })) + except KeyError: + raise SynapseError(400, "Missing required key 'refresh_token'.") + except StoreError: + raise AuthError(403, "Did not recognize refresh token") + + +def register_servlets(hs, http_server): + TokenRefreshRestServlet(hs).register(http_server) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index f154b1c8ae..53673b3bf5 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -54,7 +54,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 22 +SCHEMA_VERSION = 23 dir_path = os.path.abspath(os.path.dirname(__file__)) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 1444767a52..ce71389f02 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -181,6 +181,7 @@ class SQLBaseStore(object): self._transaction_id_gen = IdGenerator("sent_transactions", "id", self) self._state_groups_id_gen = IdGenerator("state_groups", "id", self) self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self) + self._refresh_tokens_id_gen = IdGenerator("refresh_tokens", "id", self) self._pushers_id_gen = IdGenerator("pushers", "id", self) self._push_rule_id_gen = IdGenerator("push_rules", "id", self) self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 0e404afb7c..f632306688 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -51,6 +51,28 @@ class RegistrationStore(SQLBaseStore): ) @defer.inlineCallbacks + def add_refresh_token_to_user(self, user_id, token): + """Adds a refresh token for the given user. + + Args: + user_id (str): The user ID. + token (str): The new refresh token to add. + Raises: + StoreError if there was a problem adding this. + """ + next_id = yield self._refresh_tokens_id_gen.get_next() + + yield self._simple_insert( + "refresh_tokens", + { + "id": next_id, + "user_id": user_id, + "token": token + }, + desc="add_refresh_token_to_user", + ) + + @defer.inlineCallbacks def register(self, user_id, token, password_hash): """Attempts to register an account. @@ -152,6 +174,46 @@ class RegistrationStore(SQLBaseStore): token ) + def exchange_refresh_token(self, refresh_token, token_generator): + """Exchange a refresh token for a new access token and refresh token. + + Doing so invalidates the old refresh token - refresh tokens are single + use. + + Args: + token (str): The refresh token of a user. + token_generator (fn: str -> str): Function which, when given a + user ID, returns a unique refresh token for that user. This + function must never return the same value twice. + Returns: + tuple of (user_id, refresh_token) + Raises: + StoreError if no user was found with that refresh token. + """ + return self.runInteraction( + "exchange_refresh_token", + self._exchange_refresh_token, + refresh_token, + token_generator + ) + + def _exchange_refresh_token(self, txn, old_token, token_generator): + sql = "SELECT user_id FROM refresh_tokens WHERE token = ?" + txn.execute(sql, (old_token,)) + rows = self.cursor_to_dict(txn) + if not rows: + raise StoreError(403, "Did not recognize refresh token") + user_id = rows[0]["user_id"] + + # TODO(danielwh): Maybe perform a validation on the macaroon that + # macaroon.user_id == user_id. + + new_token = token_generator(user_id) + sql = "UPDATE refresh_tokens SET token = ? WHERE token = ?" + txn.execute(sql, (new_token, old_token,)) + + return user_id, new_token + @defer.inlineCallbacks def is_server_admin(self, user): res = yield self._simple_select_one_onecol( diff --git a/synapse/storage/schema/delta/23/refresh_tokens.sql b/synapse/storage/schema/delta/23/refresh_tokens.sql new file mode 100644 index 0000000000..46839e016c --- /dev/null +++ b/synapse/storage/schema/delta/23/refresh_tokens.sql @@ -0,0 +1,21 @@ +/* Copyright 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS refresh_tokens( + id INTEGER PRIMARY KEY AUTOINCREMENT, + token TEXT NOT NULL, + user_id TEXT NOT NULL, + UNIQUE (token) +); diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 7a24cf898a..a4f929796a 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -17,7 +17,9 @@ from tests import unittest from twisted.internet import defer +from synapse.api.errors import StoreError from synapse.storage.registration import RegistrationStore +from synapse.util import stringutils from tests.utils import setup_test_homeserver @@ -27,6 +29,7 @@ class RegistrationStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): hs = yield setup_test_homeserver() + self.db_pool = hs.get_db_pool() self.store = RegistrationStore(hs) @@ -77,3 +80,55 @@ class RegistrationStoreTestCase(unittest.TestCase): self.assertTrue("token_id" in result) + @defer.inlineCallbacks + def test_exchange_refresh_token_valid(self): + uid = stringutils.random_string(32) + generator = TokenGenerator() + last_token = generator.generate(uid) + + self.db_pool.runQuery( + "INSERT INTO refresh_tokens(user_id, token) VALUES(?,?)", + (uid, last_token,)) + + (found_user_id, refresh_token) = yield self.store.exchange_refresh_token( + last_token, generator.generate) + self.assertEqual(uid, found_user_id) + + rows = yield self.db_pool.runQuery( + "SELECT token FROM refresh_tokens WHERE user_id = ?", (uid, )) + self.assertEqual([(refresh_token,)], rows) + # We issued token 1, then exchanged it for token 2 + expected_refresh_token = u"%s-%d" % (uid, 2,) + self.assertEqual(expected_refresh_token, refresh_token) + + @defer.inlineCallbacks + def test_exchange_refresh_token_none(self): + uid = stringutils.random_string(32) + generator = TokenGenerator() + last_token = generator.generate(uid) + + with self.assertRaises(StoreError): + yield self.store.exchange_refresh_token(last_token, generator.generate) + + @defer.inlineCallbacks + def test_exchange_refresh_token_invalid(self): + uid = stringutils.random_string(32) + generator = TokenGenerator() + last_token = generator.generate(uid) + wrong_token = "%s-wrong" % (last_token,) + + self.db_pool.runQuery( + "INSERT INTO refresh_tokens(user_id, token) VALUES(?,?)", + (uid, wrong_token,)) + + with self.assertRaises(StoreError): + yield self.store.exchange_refresh_token(last_token, generator.generate) + + +class TokenGenerator: + def __init__(self): + self._last_issued_token = 0 + + def generate(self, user_id): + self._last_issued_token += 1 + return u"%s-%d" % (user_id, self._last_issued_token,) |