summary refs log tree commit diff
diff options
context:
space:
mode:
authorDaniel Wagner-Hall <dawagner@gmail.com>2015-08-20 17:44:46 +0100
committerDaniel Wagner-Hall <dawagner@gmail.com>2015-08-20 17:44:46 +0100
commitb1e35eabf240c6231e5dcc6e8022f537e236829d (patch)
treef01b04d22e4d487bd61c96768154e5e4b6498b10
parentMerge pull request #229 from matrix-org/auth (diff)
parentFix bad merge (diff)
downloadsynapse-b1e35eabf240c6231e5dcc6e8022f537e236829d.tar.xz
Merge pull request #240 from matrix-org/refresh
/tokenrefresh POST endpoint
-rw-r--r--synapse/api/auth.py6
-rw-r--r--synapse/handlers/auth.py52
-rw-r--r--synapse/handlers/register.py26
-rw-r--r--synapse/rest/client/v1/login.py8
-rw-r--r--synapse/rest/client/v2_alpha/__init__.py2
-rw-r--r--synapse/rest/client/v2_alpha/tokenrefresh.py56
-rw-r--r--synapse/storage/__init__.py2
-rw-r--r--synapse/storage/_base.py1
-rw-r--r--synapse/storage/registration.py68
-rw-r--r--synapse/storage/schema/delta/23/refresh_tokens.sql21
-rw-r--r--tests/api/test_auth.py16
-rw-r--r--tests/handlers/test_auth.py (renamed from tests/handlers/test_register.py)14
-rw-r--r--tests/rest/client/v1/test_presence.py8
-rw-r--r--tests/rest/client/v1/test_rooms.py28
-rw-r--r--tests/rest/client/v1/test_typing.py4
-rw-r--r--tests/rest/client/v1/utils.py2
-rw-r--r--tests/rest/client/v2_alpha/__init__.py4
-rw-r--r--tests/storage/test_registration.py59
-rw-r--r--tests/utils.py2
19 files changed, 303 insertions, 76 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 1e3b0fbfb7..3d9237ccc3 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -361,7 +361,7 @@ class Auth(object):
             except KeyError:
                 pass  # normal users won't have the user_id query parameter set.
 
-            user_info = yield self.get_user_by_token(access_token)
+            user_info = yield self.get_user_by_access_token(access_token)
             user = user_info["user"]
             device_id = user_info["device_id"]
             token_id = user_info["token_id"]
@@ -390,7 +390,7 @@ class Auth(object):
             )
 
     @defer.inlineCallbacks
-    def get_user_by_token(self, token):
+    def get_user_by_access_token(self, token):
         """ Get a registered user's ID.
 
         Args:
@@ -401,7 +401,7 @@ class Auth(object):
         Raises:
             AuthError if no user by that token exists or the token is invalid.
         """
-        ret = yield self.store.get_user_by_token(token)
+        ret = yield self.store.get_user_by_access_token(token)
         if not ret:
             raise AuthError(
                 self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.",
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index ff2c66f442..c983d444e8 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -26,6 +26,7 @@ from twisted.web.client import PartialDownloadError
 
 import logging
 import bcrypt
+import pymacaroons
 import simplejson
 
 import synapse.util.stringutils as stringutils
@@ -278,18 +279,18 @@ class AuthHandler(BaseHandler):
             user_id (str): User ID
             password (str): Password
         Returns:
-            The access token for the user's session.
+            A tuple of:
+              The access token for the user's session.
+              The refresh token for the user's session.
         Raises:
             StoreError if there was a problem storing the token.
             LoginError if there was an authentication problem.
         """
         yield self._check_password(user_id, password)
-
-        reg_handler = self.hs.get_handlers().registration_handler
-        access_token = reg_handler.generate_token(user_id)
         logger.info("Logging in user %s", user_id)
-        yield self.store.add_access_token_to_user(user_id, access_token)
-        defer.returnValue(access_token)
+        access_token = yield self.issue_access_token(user_id)
+        refresh_token = yield self.issue_refresh_token(user_id)
+        defer.returnValue((access_token, refresh_token))
 
     @defer.inlineCallbacks
     def _check_password(self, user_id, password):
@@ -305,6 +306,45 @@ class AuthHandler(BaseHandler):
             raise LoginError(403, "", errcode=Codes.FORBIDDEN)
 
     @defer.inlineCallbacks
+    def issue_access_token(self, user_id):
+        access_token = self.generate_access_token(user_id)
+        yield self.store.add_access_token_to_user(user_id, access_token)
+        defer.returnValue(access_token)
+
+    @defer.inlineCallbacks
+    def issue_refresh_token(self, user_id):
+        refresh_token = self.generate_refresh_token(user_id)
+        yield self.store.add_refresh_token_to_user(user_id, refresh_token)
+        defer.returnValue(refresh_token)
+
+    def generate_access_token(self, user_id):
+        macaroon = self._generate_base_macaroon(user_id)
+        macaroon.add_first_party_caveat("type = access")
+        now = self.hs.get_clock().time_msec()
+        expiry = now + (60 * 60 * 1000)
+        macaroon.add_first_party_caveat("time < %d" % (expiry,))
+        return macaroon.serialize()
+
+    def generate_refresh_token(self, user_id):
+        m = self._generate_base_macaroon(user_id)
+        m.add_first_party_caveat("type = refresh")
+        # Important to add a nonce, because otherwise every refresh token for a
+        # user will be the same.
+        m.add_first_party_caveat("nonce = %s" % (
+            stringutils.random_string_with_symbols(16),
+        ))
+        return m.serialize()
+
+    def _generate_base_macaroon(self, user_id):
+        macaroon = pymacaroons.Macaroon(
+            location=self.hs.config.server_name,
+            identifier="key",
+            key=self.hs.config.macaroon_secret_key)
+        macaroon.add_first_party_caveat("gen = 1")
+        macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
+        return macaroon
+
+    @defer.inlineCallbacks
     def set_password(self, user_id, newpassword):
         password_hash = bcrypt.hashpw(newpassword, bcrypt.gensalt())
 
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 1adc3eebbb..3d1b6531c2 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -27,7 +27,6 @@ from synapse.http.client import CaptchaServerHttpClient
 
 import bcrypt
 import logging
-import pymacaroons
 import urllib
 
 logger = logging.getLogger(__name__)
@@ -91,7 +90,7 @@ class RegistrationHandler(BaseHandler):
             user = UserID(localpart, self.hs.hostname)
             user_id = user.to_string()
 
-            token = self.generate_token(user_id)
+            token = self.auth_handler().generate_access_token(user_id)
             yield self.store.register(
                 user_id=user_id,
                 token=token,
@@ -111,7 +110,7 @@ class RegistrationHandler(BaseHandler):
                     user_id = user.to_string()
                     yield self.check_user_id_is_valid(user_id)
 
-                    token = self.generate_token(user_id)
+                    token = self.auth_handler().generate_access_token(user_id)
                     yield self.store.register(
                         user_id=user_id,
                         token=token,
@@ -161,7 +160,7 @@ class RegistrationHandler(BaseHandler):
                 400, "Invalid user localpart for this application service.",
                 errcode=Codes.EXCLUSIVE
             )
-        token = self.generate_token(user_id)
+        token = self.auth_handler().generate_access_token(user_id)
         yield self.store.register(
             user_id=user_id,
             token=token,
@@ -208,7 +207,7 @@ class RegistrationHandler(BaseHandler):
         user_id = user.to_string()
 
         yield self.check_user_id_is_valid(user_id)
-        token = self.generate_token(user_id)
+        token = self.auth_handler().generate_access_token(user_id)
         try:
             yield self.store.register(
                 user_id=user_id,
@@ -273,20 +272,6 @@ class RegistrationHandler(BaseHandler):
                     errcode=Codes.EXCLUSIVE
                 )
 
-    def generate_token(self, user_id):
-        macaroon = pymacaroons.Macaroon(
-            location=self.hs.config.server_name,
-            identifier="key",
-            key=self.hs.config.macaroon_secret_key)
-        macaroon.add_first_party_caveat("gen = 1")
-        macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
-        macaroon.add_first_party_caveat("type = access")
-        now = self.hs.get_clock().time_msec()
-        expiry = now + (60 * 60 * 1000)
-        macaroon.add_first_party_caveat("time < %d" % (expiry,))
-
-        return macaroon.serialize()
-
     def _generate_user_id(self):
         return "-" + stringutils.random_string(18)
 
@@ -329,3 +314,6 @@ class RegistrationHandler(BaseHandler):
             }
         )
         defer.returnValue(data)
+
+    def auth_handler(self):
+        return self.hs.get_handlers().auth_handler
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 0d5eafd0fa..67323a16bb 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -85,13 +85,15 @@ class LoginRestServlet(ClientV1RestServlet):
             user_id = UserID.create(
                 user_id, self.hs.hostname).to_string()
 
-        token = yield self.handlers.auth_handler.login_with_password(
+        auth_handler = self.handlers.auth_handler
+        access_token, refresh_token = yield auth_handler.login_with_password(
             user_id=user_id,
             password=login_submission["password"])
 
         result = {
-            "user_id": user_id,  # may have changed
-            "access_token": token,
+            "user_id": login_submission["user"],  # may have changed
+            "access_token": access_token,
+            "refresh_token": refresh_token,
             "home_server": self.hs.hostname,
         }
 
diff --git a/synapse/rest/client/v2_alpha/__init__.py b/synapse/rest/client/v2_alpha/__init__.py
index 33f961e898..5831ff0e62 100644
--- a/synapse/rest/client/v2_alpha/__init__.py
+++ b/synapse/rest/client/v2_alpha/__init__.py
@@ -21,6 +21,7 @@ from . import (
     auth,
     receipts,
     keys,
+    tokenrefresh,
 )
 
 from synapse.http.server import JsonResource
@@ -42,3 +43,4 @@ class ClientV2AlphaRestResource(JsonResource):
         auth.register_servlets(hs, client_resource)
         receipts.register_servlets(hs, client_resource)
         keys.register_servlets(hs, client_resource)
+        tokenrefresh.register_servlets(hs, client_resource)
diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/v2_alpha/tokenrefresh.py
new file mode 100644
index 0000000000..901e777983
--- /dev/null
+++ b/synapse/rest/client/v2_alpha/tokenrefresh.py
@@ -0,0 +1,56 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from synapse.api.errors import AuthError, StoreError, SynapseError
+from synapse.http.servlet import RestServlet
+
+from ._base import client_v2_pattern, parse_json_dict_from_request
+
+
+class TokenRefreshRestServlet(RestServlet):
+    """
+    Exchanges refresh tokens for a pair of an access token and a new refresh
+    token.
+    """
+    PATTERN = client_v2_pattern("/tokenrefresh")
+
+    def __init__(self, hs):
+        super(TokenRefreshRestServlet, self).__init__()
+        self.hs = hs
+        self.store = hs.get_datastore()
+
+    @defer.inlineCallbacks
+    def on_POST(self, request):
+        body = parse_json_dict_from_request(request)
+        try:
+            old_refresh_token = body["refresh_token"]
+            auth_handler = self.hs.get_handlers().auth_handler
+            (user_id, new_refresh_token) = yield self.store.exchange_refresh_token(
+                old_refresh_token, auth_handler.generate_refresh_token)
+            new_access_token = yield auth_handler.issue_access_token(user_id)
+            defer.returnValue((200, {
+                "access_token": new_access_token,
+                "refresh_token": new_refresh_token,
+            }))
+        except KeyError:
+            raise SynapseError(400, "Missing required key 'refresh_token'.")
+        except StoreError:
+            raise AuthError(403, "Did not recognize refresh token")
+
+
+def register_servlets(hs, http_server):
+    TokenRefreshRestServlet(hs).register(http_server)
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index f154b1c8ae..53673b3bf5 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -54,7 +54,7 @@ logger = logging.getLogger(__name__)
 
 # Remember to update this number every time a change is made to database
 # schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 22
+SCHEMA_VERSION = 23
 
 dir_path = os.path.abspath(os.path.dirname(__file__))
 
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 1444767a52..ce71389f02 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -181,6 +181,7 @@ class SQLBaseStore(object):
         self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
         self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
         self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
+        self._refresh_tokens_id_gen = IdGenerator("refresh_tokens", "id", self)
         self._pushers_id_gen = IdGenerator("pushers", "id", self)
         self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
         self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index bf803f2c6e..f632306688 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -51,6 +51,28 @@ class RegistrationStore(SQLBaseStore):
         )
 
     @defer.inlineCallbacks
+    def add_refresh_token_to_user(self, user_id, token):
+        """Adds a refresh token for the given user.
+
+        Args:
+            user_id (str): The user ID.
+            token (str): The new refresh token to add.
+        Raises:
+            StoreError if there was a problem adding this.
+        """
+        next_id = yield self._refresh_tokens_id_gen.get_next()
+
+        yield self._simple_insert(
+            "refresh_tokens",
+            {
+                "id": next_id,
+                "user_id": user_id,
+                "token": token
+            },
+            desc="add_refresh_token_to_user",
+        )
+
+    @defer.inlineCallbacks
     def register(self, user_id, token, password_hash):
         """Attempts to register an account.
 
@@ -132,10 +154,10 @@ class RegistrationStore(SQLBaseStore):
             user_id
         )
         for r in rows:
-            self.get_user_by_token.invalidate((r,))
+            self.get_user_by_access_token.invalidate((r,))
 
     @cached()
-    def get_user_by_token(self, token):
+    def get_user_by_access_token(self, token):
         """Get a user from the given access token.
 
         Args:
@@ -147,11 +169,51 @@ class RegistrationStore(SQLBaseStore):
             StoreError if no user was found.
         """
         return self.runInteraction(
-            "get_user_by_token",
+            "get_user_by_access_token",
             self._query_for_auth,
             token
         )
 
+    def exchange_refresh_token(self, refresh_token, token_generator):
+        """Exchange a refresh token for a new access token and refresh token.
+
+        Doing so invalidates the old refresh token - refresh tokens are single
+        use.
+
+        Args:
+            token (str): The refresh token of a user.
+            token_generator (fn: str -> str): Function which, when given a
+                user ID, returns a unique refresh token for that user. This
+                function must never return the same value twice.
+        Returns:
+            tuple of (user_id, refresh_token)
+        Raises:
+            StoreError if no user was found with that refresh token.
+        """
+        return self.runInteraction(
+            "exchange_refresh_token",
+            self._exchange_refresh_token,
+            refresh_token,
+            token_generator
+        )
+
+    def _exchange_refresh_token(self, txn, old_token, token_generator):
+        sql = "SELECT user_id FROM refresh_tokens WHERE token = ?"
+        txn.execute(sql, (old_token,))
+        rows = self.cursor_to_dict(txn)
+        if not rows:
+            raise StoreError(403, "Did not recognize refresh token")
+        user_id = rows[0]["user_id"]
+
+        # TODO(danielwh): Maybe perform a validation on the macaroon that
+        # macaroon.user_id == user_id.
+
+        new_token = token_generator(user_id)
+        sql = "UPDATE refresh_tokens SET token = ? WHERE token = ?"
+        txn.execute(sql, (new_token, old_token,))
+
+        return user_id, new_token
+
     @defer.inlineCallbacks
     def is_server_admin(self, user):
         res = yield self._simple_select_one_onecol(
diff --git a/synapse/storage/schema/delta/23/refresh_tokens.sql b/synapse/storage/schema/delta/23/refresh_tokens.sql
new file mode 100644
index 0000000000..46839e016c
--- /dev/null
+++ b/synapse/storage/schema/delta/23/refresh_tokens.sql
@@ -0,0 +1,21 @@
+/* Copyright 2015 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS refresh_tokens(
+    id INTEGER PRIMARY KEY AUTOINCREMENT,
+    token TEXT NOT NULL,
+    user_id TEXT NOT NULL,
+    UNIQUE (token)
+);
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 4f83db5e84..3343c635cc 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -44,7 +44,7 @@ class AuthTestCase(unittest.TestCase):
             "token_id": "ditto",
             "admin": False
         }
-        self.store.get_user_by_token = Mock(return_value=user_info)
+        self.store.get_user_by_access_token = Mock(return_value=user_info)
 
         request = Mock(args={})
         request.args["access_token"] = [self.test_token]
@@ -54,7 +54,7 @@ class AuthTestCase(unittest.TestCase):
 
     def test_get_user_by_req_user_bad_token(self):
         self.store.get_app_service_by_token = Mock(return_value=None)
-        self.store.get_user_by_token = Mock(return_value=None)
+        self.store.get_user_by_access_token = Mock(return_value=None)
 
         request = Mock(args={})
         request.args["access_token"] = [self.test_token]
@@ -70,7 +70,7 @@ class AuthTestCase(unittest.TestCase):
             "token_id": "ditto",
             "admin": False
         }
-        self.store.get_user_by_token = Mock(return_value=user_info)
+        self.store.get_user_by_access_token = Mock(return_value=user_info)
 
         request = Mock(args={})
         request.requestHeaders.getRawHeaders = Mock(return_value=[""])
@@ -81,7 +81,7 @@ class AuthTestCase(unittest.TestCase):
     def test_get_user_by_req_appservice_valid_token(self):
         app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
         self.store.get_app_service_by_token = Mock(return_value=app_service)
-        self.store.get_user_by_token = Mock(return_value=None)
+        self.store.get_user_by_access_token = Mock(return_value=None)
 
         request = Mock(args={})
         request.args["access_token"] = [self.test_token]
@@ -91,7 +91,7 @@ class AuthTestCase(unittest.TestCase):
 
     def test_get_user_by_req_appservice_bad_token(self):
         self.store.get_app_service_by_token = Mock(return_value=None)
-        self.store.get_user_by_token = Mock(return_value=None)
+        self.store.get_user_by_access_token = Mock(return_value=None)
 
         request = Mock(args={})
         request.args["access_token"] = [self.test_token]
@@ -102,7 +102,7 @@ class AuthTestCase(unittest.TestCase):
     def test_get_user_by_req_appservice_missing_token(self):
         app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
         self.store.get_app_service_by_token = Mock(return_value=app_service)
-        self.store.get_user_by_token = Mock(return_value=None)
+        self.store.get_user_by_access_token = Mock(return_value=None)
 
         request = Mock(args={})
         request.requestHeaders.getRawHeaders = Mock(return_value=[""])
@@ -115,7 +115,7 @@ class AuthTestCase(unittest.TestCase):
         app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
         app_service.is_interested_in_user = Mock(return_value=True)
         self.store.get_app_service_by_token = Mock(return_value=app_service)
-        self.store.get_user_by_token = Mock(return_value=None)
+        self.store.get_user_by_access_token = Mock(return_value=None)
 
         request = Mock(args={})
         request.args["access_token"] = [self.test_token]
@@ -129,7 +129,7 @@ class AuthTestCase(unittest.TestCase):
         app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
         app_service.is_interested_in_user = Mock(return_value=False)
         self.store.get_app_service_by_token = Mock(return_value=app_service)
-        self.store.get_user_by_token = Mock(return_value=None)
+        self.store.get_user_by_access_token = Mock(return_value=None)
 
         request = Mock(args={})
         request.args["access_token"] = [self.test_token]
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_auth.py
index 91cc90242f..978e4d0d2e 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_auth.py
@@ -16,27 +16,27 @@
 import pymacaroons
 
 from mock import Mock, NonCallableMock
-from synapse.handlers.register import RegistrationHandler
+from synapse.handlers.auth import AuthHandler
 from tests import unittest
 from tests.utils import setup_test_homeserver
 from twisted.internet import defer
 
 
-class RegisterHandlers(object):
+class AuthHandlers(object):
     def __init__(self, hs):
-        self.registration_handler = RegistrationHandler(hs)
+        self.auth_handler = AuthHandler(hs)
 
 
-class RegisterTestCase(unittest.TestCase):
+class AuthTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def setUp(self):
         self.hs = yield setup_test_homeserver(handlers=None)
-        self.hs.handlers = RegisterHandlers(self.hs)
+        self.hs.handlers = AuthHandlers(self.hs)
 
     def test_token_is_a_macaroon(self):
         self.hs.config.macaroon_secret_key = "this key is a huge secret"
 
-        token = self.hs.handlers.registration_handler.generate_token("some_user")
+        token = self.hs.handlers.auth_handler.generate_access_token("some_user")
         # Check that we can parse the thing with pymacaroons
         macaroon = pymacaroons.Macaroon.deserialize(token)
         # The most basic of sanity checks
@@ -47,7 +47,7 @@ class RegisterTestCase(unittest.TestCase):
         self.hs.config.macaroon_secret_key = "this key is a massive secret"
         self.hs.clock.now = 5000
 
-        token = self.hs.handlers.registration_handler.generate_token("a_user")
+        token = self.hs.handlers.auth_handler.generate_access_token("a_user")
         macaroon = pymacaroons.Macaroon.deserialize(token)
 
         def verify_gen(caveat):
diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py
index 089a71568c..0b78a82a66 100644
--- a/tests/rest/client/v1/test_presence.py
+++ b/tests/rest/client/v1/test_presence.py
@@ -70,7 +70,7 @@ class PresenceStateTestCase(unittest.TestCase):
             return defer.succeed([])
         self.datastore.get_presence_list = get_presence_list
 
-        def _get_user_by_token(token=None):
+        def _get_user_by_access_token(token=None):
             return {
                 "user": UserID.from_string(myid),
                 "admin": False,
@@ -78,7 +78,7 @@ class PresenceStateTestCase(unittest.TestCase):
                 "token_id": 1,
             }
 
-        hs.get_v1auth().get_user_by_token = _get_user_by_token
+        hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
 
         room_member_handler = hs.handlers.room_member_handler = Mock(
             spec=[
@@ -159,7 +159,7 @@ class PresenceListTestCase(unittest.TestCase):
             )
         self.datastore.has_presence_state = has_presence_state
 
-        def _get_user_by_token(token=None):
+        def _get_user_by_access_token(token=None):
             return {
                 "user": UserID.from_string(myid),
                 "admin": False,
@@ -173,7 +173,7 @@ class PresenceListTestCase(unittest.TestCase):
             ]
         )
 
-        hs.get_v1auth().get_user_by_token = _get_user_by_token
+        hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
 
         presence.register_servlets(hs, self.mock_resource)
 
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index c83348acf9..2e55cc08a1 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -54,14 +54,14 @@ class RoomPermissionsTestCase(RestTestCase):
 
         hs.get_handlers().federation_handler = Mock()
 
-        def _get_user_by_token(token=None):
+        def _get_user_by_access_token(token=None):
             return {
                 "user": UserID.from_string(self.auth_user_id),
                 "admin": False,
                 "device_id": None,
                 "token_id": 1,
             }
-        hs.get_v1auth().get_user_by_token = _get_user_by_token
+        hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
 
         def _insert_client_ip(*args, **kwargs):
             return defer.succeed(None)
@@ -441,14 +441,14 @@ class RoomsMemberListTestCase(RestTestCase):
 
         self.auth_user_id = self.user_id
 
-        def _get_user_by_token(token=None):
+        def _get_user_by_access_token(token=None):
             return {
                 "user": UserID.from_string(self.auth_user_id),
                 "admin": False,
                 "device_id": None,
                 "token_id": 1,
             }
-        hs.get_v1auth().get_user_by_token = _get_user_by_token
+        hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
 
         def _insert_client_ip(*args, **kwargs):
             return defer.succeed(None)
@@ -521,14 +521,14 @@ class RoomsCreateTestCase(RestTestCase):
 
         hs.get_handlers().federation_handler = Mock()
 
-        def _get_user_by_token(token=None):
+        def _get_user_by_access_token(token=None):
             return {
                 "user": UserID.from_string(self.auth_user_id),
                 "admin": False,
                 "device_id": None,
                 "token_id": 1,
             }
-        hs.get_v1auth().get_user_by_token = _get_user_by_token
+        hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
 
         def _insert_client_ip(*args, **kwargs):
             return defer.succeed(None)
@@ -614,7 +614,7 @@ class RoomTopicTestCase(RestTestCase):
 
         hs.get_handlers().federation_handler = Mock()
 
-        def _get_user_by_token(token=None):
+        def _get_user_by_access_token(token=None):
             return {
                 "user": UserID.from_string(self.auth_user_id),
                 "admin": False,
@@ -622,7 +622,7 @@ class RoomTopicTestCase(RestTestCase):
                 "token_id": 1,
             }
 
-        hs.get_v1auth().get_user_by_token = _get_user_by_token
+        hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
 
         def _insert_client_ip(*args, **kwargs):
             return defer.succeed(None)
@@ -721,14 +721,14 @@ class RoomMemberStateTestCase(RestTestCase):
 
         hs.get_handlers().federation_handler = Mock()
 
-        def _get_user_by_token(token=None):
+        def _get_user_by_access_token(token=None):
             return {
                 "user": UserID.from_string(self.auth_user_id),
                 "admin": False,
                 "device_id": None,
                 "token_id": 1,
             }
-        hs.get_v1auth().get_user_by_token = _get_user_by_token
+        hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
 
         def _insert_client_ip(*args, **kwargs):
             return defer.succeed(None)
@@ -848,14 +848,14 @@ class RoomMessagesTestCase(RestTestCase):
 
         hs.get_handlers().federation_handler = Mock()
 
-        def _get_user_by_token(token=None):
+        def _get_user_by_access_token(token=None):
             return {
                 "user": UserID.from_string(self.auth_user_id),
                 "admin": False,
                 "device_id": None,
                 "token_id": 1,
             }
-        hs.get_v1auth().get_user_by_token = _get_user_by_token
+        hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
 
         def _insert_client_ip(*args, **kwargs):
             return defer.succeed(None)
@@ -945,14 +945,14 @@ class RoomInitialSyncTestCase(RestTestCase):
 
         hs.get_handlers().federation_handler = Mock()
 
-        def _get_user_by_token(token=None):
+        def _get_user_by_access_token(token=None):
             return {
                 "user": UserID.from_string(self.auth_user_id),
                 "admin": False,
                 "device_id": None,
                 "token_id": 1,
             }
-        hs.get_v1auth().get_user_by_token = _get_user_by_token
+        hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
 
         def _insert_client_ip(*args, **kwargs):
             return defer.succeed(None)
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index 7d8b1c2683..dc8bbaaf0e 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -61,7 +61,7 @@ class RoomTypingTestCase(RestTestCase):
 
         hs.get_handlers().federation_handler = Mock()
 
-        def _get_user_by_token(token=None):
+        def _get_user_by_access_token(token=None):
             return {
                 "user": UserID.from_string(self.auth_user_id),
                 "admin": False,
@@ -69,7 +69,7 @@ class RoomTypingTestCase(RestTestCase):
                 "token_id": 1,
             }
 
-        hs.get_v1auth().get_user_by_token = _get_user_by_token
+        hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
 
         def _insert_client_ip(*args, **kwargs):
             return defer.succeed(None)
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 579441fb4a..c472d53043 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -37,7 +37,7 @@ class RestTestCase(unittest.TestCase):
         self.mock_resource = None
         self.auth_user_id = None
 
-    def mock_get_user_by_token(self, token=None):
+    def mock_get_user_by_access_token(self, token=None):
         return self.auth_user_id
 
     @defer.inlineCallbacks
diff --git a/tests/rest/client/v2_alpha/__init__.py b/tests/rest/client/v2_alpha/__init__.py
index de5a917e6a..15568b36cd 100644
--- a/tests/rest/client/v2_alpha/__init__.py
+++ b/tests/rest/client/v2_alpha/__init__.py
@@ -43,14 +43,14 @@ class V2AlphaRestTestCase(unittest.TestCase):
             resource_for_federation=self.mock_resource,
         )
 
-        def _get_user_by_token(token=None):
+        def _get_user_by_access_token(token=None):
             return {
                 "user": UserID.from_string(self.USER_ID),
                 "admin": False,
                 "device_id": None,
                 "token_id": 1,
             }
-        hs.get_auth().get_user_by_token = _get_user_by_token
+        hs.get_auth().get_user_by_access_token = _get_user_by_access_token
 
         for r in self.TO_REGISTER:
             r.register_servlets(hs, self.mock_resource)
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 2702291178..a4f929796a 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -17,7 +17,9 @@
 from tests import unittest
 from twisted.internet import defer
 
+from synapse.api.errors import StoreError
 from synapse.storage.registration import RegistrationStore
+from synapse.util import stringutils
 
 from tests.utils import setup_test_homeserver
 
@@ -27,6 +29,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def setUp(self):
         hs = yield setup_test_homeserver()
+        self.db_pool = hs.get_db_pool()
 
         self.store = RegistrationStore(hs)
 
@@ -46,7 +49,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
             (yield self.store.get_user_by_id(self.user_id))
         )
 
-        result = yield self.store.get_user_by_token(self.tokens[0])
+        result = yield self.store.get_user_by_access_token(self.tokens[0])
 
         self.assertDictContainsSubset(
             {
@@ -64,7 +67,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
         yield self.store.register(self.user_id, self.tokens[0], self.pwhash)
         yield self.store.add_access_token_to_user(self.user_id, self.tokens[1])
 
-        result = yield self.store.get_user_by_token(self.tokens[1])
+        result = yield self.store.get_user_by_access_token(self.tokens[1])
 
         self.assertDictContainsSubset(
             {
@@ -77,3 +80,55 @@ class RegistrationStoreTestCase(unittest.TestCase):
 
         self.assertTrue("token_id" in result)
 
+    @defer.inlineCallbacks
+    def test_exchange_refresh_token_valid(self):
+        uid = stringutils.random_string(32)
+        generator = TokenGenerator()
+        last_token = generator.generate(uid)
+
+        self.db_pool.runQuery(
+            "INSERT INTO refresh_tokens(user_id, token) VALUES(?,?)",
+            (uid, last_token,))
+
+        (found_user_id, refresh_token) = yield self.store.exchange_refresh_token(
+            last_token, generator.generate)
+        self.assertEqual(uid, found_user_id)
+
+        rows = yield self.db_pool.runQuery(
+            "SELECT token FROM refresh_tokens WHERE user_id = ?", (uid, ))
+        self.assertEqual([(refresh_token,)], rows)
+        # We issued token 1, then exchanged it for token 2
+        expected_refresh_token = u"%s-%d" % (uid, 2,)
+        self.assertEqual(expected_refresh_token, refresh_token)
+
+    @defer.inlineCallbacks
+    def test_exchange_refresh_token_none(self):
+        uid = stringutils.random_string(32)
+        generator = TokenGenerator()
+        last_token = generator.generate(uid)
+
+        with self.assertRaises(StoreError):
+            yield self.store.exchange_refresh_token(last_token, generator.generate)
+
+    @defer.inlineCallbacks
+    def test_exchange_refresh_token_invalid(self):
+        uid = stringutils.random_string(32)
+        generator = TokenGenerator()
+        last_token = generator.generate(uid)
+        wrong_token = "%s-wrong" % (last_token,)
+
+        self.db_pool.runQuery(
+            "INSERT INTO refresh_tokens(user_id, token) VALUES(?,?)",
+            (uid, wrong_token,))
+
+        with self.assertRaises(StoreError):
+            yield self.store.exchange_refresh_token(last_token, generator.generate)
+
+
+class TokenGenerator:
+    def __init__(self):
+        self._last_issued_token = 0
+
+    def generate(self, user_id):
+        self._last_issued_token += 1
+        return u"%s-%d" % (user_id, self._last_issued_token,)
diff --git a/tests/utils.py b/tests/utils.py
index 80be70b74f..d0fba2252d 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -277,7 +277,7 @@ class MemoryDataStore(object):
             raise StoreError(400, "User in use.")
         self.tokens_to_users[token] = user_id
 
-    def get_user_by_token(self, token):
+    def get_user_by_access_token(self, token):
         try:
             return {
                 "name": self.tokens_to_users[token],