summary refs log tree commit diff
path: root/tests/storage/test_registration.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/storage/test_registration.py')
-rw-r--r--tests/storage/test_registration.py55
1 files changed, 48 insertions, 7 deletions
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index b8384c98d8..f7d74dea8e 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -38,6 +38,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
             "BcDeFgHiJkLmNoPqRsTuVwXyZa"
         ]
         self.pwhash = "{xx1}123456789"
+        self.device_id = "akgjhdjklgshg"
 
     @defer.inlineCallbacks
     def test_register(self):
@@ -64,13 +65,15 @@ class RegistrationStoreTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_add_tokens(self):
         yield self.store.register(self.user_id, self.tokens[0], self.pwhash)
-        yield self.store.add_access_token_to_user(self.user_id, self.tokens[1])
+        yield self.store.add_access_token_to_user(self.user_id, self.tokens[1],
+                                                  self.device_id)
 
         result = yield self.store.get_user_by_access_token(self.tokens[1])
 
         self.assertDictContainsSubset(
             {
                 "name": self.user_id,
+                "device_id": self.device_id,
             },
             result
         )
@@ -80,20 +83,24 @@ class RegistrationStoreTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_exchange_refresh_token_valid(self):
         uid = stringutils.random_string(32)
+        device_id = stringutils.random_string(16)
         generator = TokenGenerator()
         last_token = generator.generate(uid)
 
         self.db_pool.runQuery(
-            "INSERT INTO refresh_tokens(user_id, token) VALUES(?,?)",
-            (uid, last_token,))
+            "INSERT INTO refresh_tokens(user_id, token, device_id) "
+            "VALUES(?,?,?)",
+            (uid, last_token, device_id))
 
-        (found_user_id, refresh_token) = yield self.store.exchange_refresh_token(
-            last_token, generator.generate)
+        (found_user_id, refresh_token, device_id) = \
+            yield self.store.exchange_refresh_token(last_token,
+                                                    generator.generate)
         self.assertEqual(uid, found_user_id)
 
         rows = yield self.db_pool.runQuery(
-            "SELECT token FROM refresh_tokens WHERE user_id = ?", (uid, ))
-        self.assertEqual([(refresh_token,)], rows)
+            "SELECT token, device_id FROM refresh_tokens WHERE user_id = ?",
+            (uid, ))
+        self.assertEqual([(refresh_token, device_id)], rows)
         # We issued token 1, then exchanged it for token 2
         expected_refresh_token = u"%s-%d" % (uid, 2,)
         self.assertEqual(expected_refresh_token, refresh_token)
@@ -121,6 +128,40 @@ class RegistrationStoreTestCase(unittest.TestCase):
         with self.assertRaises(StoreError):
             yield self.store.exchange_refresh_token(last_token, generator.generate)
 
+    @defer.inlineCallbacks
+    def test_user_delete_access_tokens(self):
+        # add some tokens
+        generator = TokenGenerator()
+        refresh_token = generator.generate(self.user_id)
+        yield self.store.register(self.user_id, self.tokens[0], self.pwhash)
+        yield self.store.add_access_token_to_user(self.user_id, self.tokens[1],
+                                                  self.device_id)
+        yield self.store.add_refresh_token_to_user(self.user_id, refresh_token,
+                                                   self.device_id)
+
+        # now delete some
+        yield self.store.user_delete_access_tokens(
+            self.user_id, device_id=self.device_id, delete_refresh_tokens=True)
+
+        # check they were deleted
+        user = yield self.store.get_user_by_access_token(self.tokens[1])
+        self.assertIsNone(user, "access token was not deleted by device_id")
+        with self.assertRaises(StoreError):
+            yield self.store.exchange_refresh_token(refresh_token,
+                                                    generator.generate)
+
+        # check the one not associated with the device was not deleted
+        user = yield self.store.get_user_by_access_token(self.tokens[0])
+        self.assertEqual(self.user_id, user["name"])
+
+        # now delete the rest
+        yield self.store.user_delete_access_tokens(
+            self.user_id, delete_refresh_tokens=True)
+
+        user = yield self.store.get_user_by_access_token(self.tokens[0])
+        self.assertIsNone(user,
+                          "access token was not deleted without device_id")
+
 
 class TokenGenerator:
     def __init__(self):