summary refs log tree commit diff
path: root/tests/storage
diff options
context:
space:
mode:
Diffstat (limited to 'tests/storage')
-rw-r--r--tests/storage/test_registration.py55
1 files changed, 55 insertions, 0 deletions
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 7a24cf898a..a4f929796a 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -17,7 +17,9 @@
 from tests import unittest
 from twisted.internet import defer
 
+from synapse.api.errors import StoreError
 from synapse.storage.registration import RegistrationStore
+from synapse.util import stringutils
 
 from tests.utils import setup_test_homeserver
 
@@ -27,6 +29,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def setUp(self):
         hs = yield setup_test_homeserver()
+        self.db_pool = hs.get_db_pool()
 
         self.store = RegistrationStore(hs)
 
@@ -77,3 +80,55 @@ class RegistrationStoreTestCase(unittest.TestCase):
 
         self.assertTrue("token_id" in result)
 
+    @defer.inlineCallbacks
+    def test_exchange_refresh_token_valid(self):
+        uid = stringutils.random_string(32)
+        generator = TokenGenerator()
+        last_token = generator.generate(uid)
+
+        self.db_pool.runQuery(
+            "INSERT INTO refresh_tokens(user_id, token) VALUES(?,?)",
+            (uid, last_token,))
+
+        (found_user_id, refresh_token) = yield self.store.exchange_refresh_token(
+            last_token, generator.generate)
+        self.assertEqual(uid, found_user_id)
+
+        rows = yield self.db_pool.runQuery(
+            "SELECT token FROM refresh_tokens WHERE user_id = ?", (uid, ))
+        self.assertEqual([(refresh_token,)], rows)
+        # We issued token 1, then exchanged it for token 2
+        expected_refresh_token = u"%s-%d" % (uid, 2,)
+        self.assertEqual(expected_refresh_token, refresh_token)
+
+    @defer.inlineCallbacks
+    def test_exchange_refresh_token_none(self):
+        uid = stringutils.random_string(32)
+        generator = TokenGenerator()
+        last_token = generator.generate(uid)
+
+        with self.assertRaises(StoreError):
+            yield self.store.exchange_refresh_token(last_token, generator.generate)
+
+    @defer.inlineCallbacks
+    def test_exchange_refresh_token_invalid(self):
+        uid = stringutils.random_string(32)
+        generator = TokenGenerator()
+        last_token = generator.generate(uid)
+        wrong_token = "%s-wrong" % (last_token,)
+
+        self.db_pool.runQuery(
+            "INSERT INTO refresh_tokens(user_id, token) VALUES(?,?)",
+            (uid, wrong_token,))
+
+        with self.assertRaises(StoreError):
+            yield self.store.exchange_refresh_token(last_token, generator.generate)
+
+
+class TokenGenerator:
+    def __init__(self):
+        self._last_issued_token = 0
+
+    def generate(self, user_id):
+        self._last_issued_token += 1
+        return u"%s-%d" % (user_id, self._last_issued_token,)