summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/handlers/device.py6
-rw-r--r--synapse/storage/registration.py58
-rw-r--r--tests/storage/test_registration.py34
3 files changed, 83 insertions, 15 deletions
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 9e65d85e6d..eaead50800 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -138,8 +138,10 @@ class DeviceHandler(BaseHandler):
             else:
                 raise
 
-        yield self.store.user_delete_access_tokens(user_id,
-                                                   device_id=device_id)
+        yield self.store.user_delete_access_tokens(
+            user_id, device_id=device_id,
+            delete_refresh_tokens=True,
+        )
 
     @defer.inlineCallbacks
     def update_device(self, user_id, device_id, content):
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 935e82bf7a..d9555e073a 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -252,20 +252,36 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
 
     @defer.inlineCallbacks
     def user_delete_access_tokens(self, user_id, except_token_ids=[],
-                                  device_id=None):
-        def f(txn):
-            sql = "SELECT token FROM access_tokens WHERE user_id = ?"
+                                  device_id=None,
+                                  delete_refresh_tokens=False):
+        """
+        Invalidate access/refresh tokens belonging to a user
+
+        Args:
+            user_id (str):  ID of user the tokens belong to
+            except_token_ids (list[str]): list of access_tokens which should
+                *not* be deleted
+            device_id (str|None):  ID of device the tokens are associated with.
+                If None, tokens associated with any device (or no device) will
+                be deleted
+            delete_refresh_tokens (bool):  True to delete refresh tokens as
+                well as access tokens.
+        Returns:
+            defer.Deferred:
+        """
+        def f(txn, table, except_tokens, call_after_delete):
+            sql = "SELECT token FROM %s WHERE user_id = ?" % table
             clauses = [user_id]
 
             if device_id is not None:
                 sql += " AND device_id = ?"
                 clauses.append(device_id)
 
-            if except_token_ids:
+            if except_tokens:
                 sql += " AND id NOT IN (%s)" % (
-                    ",".join(["?" for _ in except_token_ids]),
+                    ",".join(["?" for _ in except_tokens]),
                 )
-                clauses += except_token_ids
+                clauses += except_tokens
 
             txn.execute(sql, clauses)
 
@@ -274,16 +290,33 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
             n = 100
             chunks = [rows[i:i + n] for i in xrange(0, len(rows), n)]
             for chunk in chunks:
-                for row in chunk:
-                    txn.call_after(self.get_user_by_access_token.invalidate, (row[0],))
+                if call_after_delete:
+                    for row in chunk:
+                        txn.call_after(call_after_delete, (row[0],))
 
                 txn.execute(
-                    "DELETE FROM access_tokens WHERE token in (%s)" % (
+                    "DELETE FROM %s WHERE token in (%s)" % (
+                        table,
                         ",".join(["?" for _ in chunk]),
                     ), [r[0] for r in chunk]
                 )
 
-        yield self.runInteraction("user_delete_access_tokens", f)
+        # delete refresh tokens first, to stop new access tokens being
+        # allocated while our backs are turned
+        if delete_refresh_tokens:
+            yield self.runInteraction(
+                "user_delete_access_tokens", f,
+                table="refresh_tokens",
+                except_tokens=[],
+                call_after_delete=None,
+            )
+
+        yield self.runInteraction(
+            "user_delete_access_tokens", f,
+            table="access_tokens",
+            except_tokens=except_token_ids,
+            call_after_delete=self.get_user_by_access_token.invalidate,
+        )
 
     def delete_access_token(self, access_token):
         def f(txn):
@@ -306,9 +339,8 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
         Args:
             token (str): The access token of a user.
         Returns:
-            dict: Including the name (user_id) and the ID of their access token.
-        Raises:
-            StoreError if no user was found.
+            defer.Deferred: None, if the token did not match, ootherwise dict
+                including the keys `name`, `is_guest`, `device_id`, `token_id`.
         """
         return self.runInteraction(
             "get_user_by_access_token",
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index b03ca303a2..f7d74dea8e 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -128,6 +128,40 @@ class RegistrationStoreTestCase(unittest.TestCase):
         with self.assertRaises(StoreError):
             yield self.store.exchange_refresh_token(last_token, generator.generate)
 
+    @defer.inlineCallbacks
+    def test_user_delete_access_tokens(self):
+        # add some tokens
+        generator = TokenGenerator()
+        refresh_token = generator.generate(self.user_id)
+        yield self.store.register(self.user_id, self.tokens[0], self.pwhash)
+        yield self.store.add_access_token_to_user(self.user_id, self.tokens[1],
+                                                  self.device_id)
+        yield self.store.add_refresh_token_to_user(self.user_id, refresh_token,
+                                                   self.device_id)
+
+        # now delete some
+        yield self.store.user_delete_access_tokens(
+            self.user_id, device_id=self.device_id, delete_refresh_tokens=True)
+
+        # check they were deleted
+        user = yield self.store.get_user_by_access_token(self.tokens[1])
+        self.assertIsNone(user, "access token was not deleted by device_id")
+        with self.assertRaises(StoreError):
+            yield self.store.exchange_refresh_token(refresh_token,
+                                                    generator.generate)
+
+        # check the one not associated with the device was not deleted
+        user = yield self.store.get_user_by_access_token(self.tokens[0])
+        self.assertEqual(self.user_id, user["name"])
+
+        # now delete the rest
+        yield self.store.user_delete_access_tokens(
+            self.user_id, delete_refresh_tokens=True)
+
+        user = yield self.store.get_user_by_access_token(self.tokens[0])
+        self.assertIsNone(user,
+                          "access token was not deleted without device_id")
+
 
 class TokenGenerator:
     def __init__(self):