summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/12056.bugfix1
-rw-r--r--synapse/storage/databases/main/registration.py18
-rw-r--r--synapse/storage/schema/main/delta/68/04_refresh_tokens_index_next_token_id.sql28
-rw-r--r--tests/rest/client/test_auth.py93
4 files changed, 136 insertions, 4 deletions
diff --git a/changelog.d/12056.bugfix b/changelog.d/12056.bugfix
new file mode 100644
index 0000000000..210e30c63f
--- /dev/null
+++ b/changelog.d/12056.bugfix
@@ -0,0 +1 @@
+Fix slow performance of `/logout` in some cases where refresh tokens are in use. The slowness existed since the initial implementation of refresh tokens.
\ No newline at end of file
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 17110bb033..dc6665237a 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -1681,7 +1681,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
                 user_id=row[1],
                 device_id=row[2],
                 next_token_id=row[3],
-                has_next_refresh_token_been_refreshed=row[4],
+                # SQLite returns 0 or 1 for false/true, so convert to a bool.
+                has_next_refresh_token_been_refreshed=bool(row[4]),
                 # This column is nullable, ensure it's a boolean
                 has_next_access_token_been_used=(row[5] or False),
                 expiry_ts=row[6],
@@ -1697,12 +1698,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
         Set the successor of a refresh token, removing the existing successor
         if any.
 
+        This also deletes the predecessor refresh and access tokens,
+        since they cannot be valid anymore.
+
         Args:
             token_id: ID of the refresh token to update.
             next_token_id: ID of its successor.
         """
 
-        def _replace_refresh_token_txn(txn) -> None:
+        def _replace_refresh_token_txn(txn: LoggingTransaction) -> None:
             # First check if there was an existing refresh token
             old_next_token_id = self.db_pool.simple_select_one_onecol_txn(
                 txn,
@@ -1728,6 +1732,16 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
                     {"id": old_next_token_id},
                 )
 
+            # Delete the previous refresh token, since we only want to keep the
+            # last 2 refresh tokens in the database.
+            # (The predecessor of the latest refresh token is still useful in
+            # case the refresh was interrupted and the client re-uses the old
+            # one.)
+            # This cascades to delete the associated access token.
+            self.db_pool.simple_delete_txn(
+                txn, "refresh_tokens", {"next_token_id": token_id}
+            )
+
         await self.db_pool.runInteraction(
             "replace_refresh_token", _replace_refresh_token_txn
         )
diff --git a/synapse/storage/schema/main/delta/68/04_refresh_tokens_index_next_token_id.sql b/synapse/storage/schema/main/delta/68/04_refresh_tokens_index_next_token_id.sql
new file mode 100644
index 0000000000..09305638ea
--- /dev/null
+++ b/synapse/storage/schema/main/delta/68/04_refresh_tokens_index_next_token_id.sql
@@ -0,0 +1,28 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- next_token_id is a foreign key reference, so previously required a table scan
+-- when a row in the referenced table was deleted.
+-- As it was self-referential and cascaded deletes, this led to O(t*n) time to
+-- delete a row, where t: number of rows in the table and n: number of rows in
+-- the ancestral 'chain' of access tokens.
+--
+-- This index is partial since we only require it for rows which reference
+-- another.
+-- Performance was tested to be the same regardless of whether the index was
+-- full or partial, but a partial index can be smaller.
+CREATE INDEX refresh_tokens_next_token_id
+    ON refresh_tokens(next_token_id)
+    WHERE next_token_id IS NOT NULL;
diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py
index 27cb856b0a..4a68d66573 100644
--- a/tests/rest/client/test_auth.py
+++ b/tests/rest/client/test_auth.py
@@ -13,15 +13,16 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from http import HTTPStatus
-from typing import Optional, Union
+from typing import Optional, Tuple, Union
 
 from twisted.internet.defer import succeed
 
 import synapse.rest.admin
 from synapse.api.constants import LoginType
 from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
-from synapse.rest.client import account, auth, devices, login, register
+from synapse.rest.client import account, auth, devices, login, logout, register
 from synapse.rest.synapse.client import build_synapse_client_resource_tree
+from synapse.storage.database import LoggingTransaction
 from synapse.types import JsonDict, UserID
 
 from tests import unittest
@@ -527,6 +528,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
         auth.register_servlets,
         account.register_servlets,
         login.register_servlets,
+        logout.register_servlets,
         synapse.rest.admin.register_servlets_for_client_rest_resource,
         register.register_servlets,
     ]
@@ -984,3 +986,90 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
         self.assertEqual(
             fifth_refresh_response.code, HTTPStatus.OK, fifth_refresh_response.result
         )
+
+    def test_many_token_refresh(self):
+        """
+        If a refresh is performed many times during a session, there shouldn't be
+        extra 'cruft' built up over time.
+
+        This test was written specifically to troubleshoot a case where logout
+        was very slow if a lot of refreshes had been performed for the session.
+        """
+
+        def _refresh(refresh_token: str) -> Tuple[str, str]:
+            """
+            Performs one refresh, returning the next refresh token and access token.
+            """
+            refresh_response = self.use_refresh_token(refresh_token)
+            self.assertEqual(
+                refresh_response.code, HTTPStatus.OK, refresh_response.result
+            )
+            return (
+                refresh_response.json_body["refresh_token"],
+                refresh_response.json_body["access_token"],
+            )
+
+        def _table_length(table_name: str) -> int:
+            """
+            Helper to get the size of a table, in rows.
+            For testing only; trivially vulnerable to SQL injection.
+            """
+
+            def _txn(txn: LoggingTransaction) -> int:
+                txn.execute(f"SELECT COUNT(1) FROM {table_name}")
+                row = txn.fetchone()
+                # Query is infallible
+                assert row is not None
+                return row[0]
+
+            return self.get_success(
+                self.hs.get_datastores().main.db_pool.runInteraction(
+                    "_table_length", _txn
+                )
+            )
+
+        # Before we log in, there are no access tokens.
+        self.assertEqual(_table_length("access_tokens"), 0)
+        self.assertEqual(_table_length("refresh_tokens"), 0)
+
+        body = {
+            "type": "m.login.password",
+            "user": "test",
+            "password": self.user_pass,
+            "refresh_token": True,
+        }
+        login_response = self.make_request(
+            "POST",
+            "/_matrix/client/v3/login",
+            body,
+        )
+        self.assertEqual(login_response.code, HTTPStatus.OK, login_response.result)
+
+        access_token = login_response.json_body["access_token"]
+        refresh_token = login_response.json_body["refresh_token"]
+
+        # Now that we have logged in, there should be one access token and one
+        # refresh token
+        self.assertEqual(_table_length("access_tokens"), 1)
+        self.assertEqual(_table_length("refresh_tokens"), 1)
+
+        for _ in range(5):
+            refresh_token, access_token = _refresh(refresh_token)
+
+        # After 5 sequential refreshes, there should only be the latest two
+        # refresh/access token pairs.
+        # (The last one is preserved because it's in use!
+        # The one before that is preserved because it can still be used to
+        # replace the last token pair, in case of e.g. a network interruption.)
+        self.assertEqual(_table_length("access_tokens"), 2)
+        self.assertEqual(_table_length("refresh_tokens"), 2)
+
+        logout_response = self.make_request(
+            "POST", "/_matrix/client/v3/logout", {}, access_token=access_token
+        )
+        self.assertEqual(logout_response.code, HTTPStatus.OK, logout_response.result)
+
+        # Now that we have logged in, there should be no access token
+        # and no refresh token
+        self.assertEqual(_table_length("access_tokens"), 0)
+        self.assertEqual(_table_length("refresh_tokens"), 0)