diff options
Diffstat (limited to 'tests/storage')
-rw-r--r-- | tests/storage/test_registration.py | 55 |
1 files changed, 55 insertions, 0 deletions
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 7a24cf898a..a4f929796a 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -17,7 +17,9 @@ from tests import unittest from twisted.internet import defer +from synapse.api.errors import StoreError from synapse.storage.registration import RegistrationStore +from synapse.util import stringutils from tests.utils import setup_test_homeserver @@ -27,6 +29,7 @@ class RegistrationStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): hs = yield setup_test_homeserver() + self.db_pool = hs.get_db_pool() self.store = RegistrationStore(hs) @@ -77,3 +80,55 @@ class RegistrationStoreTestCase(unittest.TestCase): self.assertTrue("token_id" in result) + @defer.inlineCallbacks + def test_exchange_refresh_token_valid(self): + uid = stringutils.random_string(32) + generator = TokenGenerator() + last_token = generator.generate(uid) + + self.db_pool.runQuery( + "INSERT INTO refresh_tokens(user_id, token) VALUES(?,?)", + (uid, last_token,)) + + (found_user_id, refresh_token) = yield self.store.exchange_refresh_token( + last_token, generator.generate) + self.assertEqual(uid, found_user_id) + + rows = yield self.db_pool.runQuery( + "SELECT token FROM refresh_tokens WHERE user_id = ?", (uid, )) + self.assertEqual([(refresh_token,)], rows) + # We issued token 1, then exchanged it for token 2 + expected_refresh_token = u"%s-%d" % (uid, 2,) + self.assertEqual(expected_refresh_token, refresh_token) + + @defer.inlineCallbacks + def test_exchange_refresh_token_none(self): + uid = stringutils.random_string(32) + generator = TokenGenerator() + last_token = generator.generate(uid) + + with self.assertRaises(StoreError): + yield self.store.exchange_refresh_token(last_token, generator.generate) + + @defer.inlineCallbacks + def test_exchange_refresh_token_invalid(self): + uid = stringutils.random_string(32) + generator = TokenGenerator() + last_token = generator.generate(uid) + wrong_token = "%s-wrong" % (last_token,) + + self.db_pool.runQuery( + "INSERT INTO refresh_tokens(user_id, token) VALUES(?,?)", + (uid, wrong_token,)) + + with self.assertRaises(StoreError): + yield self.store.exchange_refresh_token(last_token, generator.generate) + + +class TokenGenerator: + def __init__(self): + self._last_issued_token = 0 + + def generate(self, user_id): + self._last_issued_token += 1 + return u"%s-%d" % (user_id, self._last_issued_token,) |