diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py
index 27cb856b0a..4a68d66573 100644
--- a/tests/rest/client/test_auth.py
+++ b/tests/rest/client/test_auth.py
@@ -13,15 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from http import HTTPStatus
-from typing import Optional, Union
+from typing import Optional, Tuple, Union
from twisted.internet.defer import succeed
import synapse.rest.admin
from synapse.api.constants import LoginType
from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
-from synapse.rest.client import account, auth, devices, login, register
+from synapse.rest.client import account, auth, devices, login, logout, register
from synapse.rest.synapse.client import build_synapse_client_resource_tree
+from synapse.storage.database import LoggingTransaction
from synapse.types import JsonDict, UserID
from tests import unittest
@@ -527,6 +528,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
auth.register_servlets,
account.register_servlets,
login.register_servlets,
+ logout.register_servlets,
synapse.rest.admin.register_servlets_for_client_rest_resource,
register.register_servlets,
]
@@ -984,3 +986,90 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
self.assertEqual(
fifth_refresh_response.code, HTTPStatus.OK, fifth_refresh_response.result
)
+
+ def test_many_token_refresh(self):
+ """
+ If a refresh is performed many times during a session, there shouldn't be
+ extra 'cruft' built up over time.
+
+ This test was written specifically to troubleshoot a case where logout
+ was very slow if a lot of refreshes had been performed for the session.
+ """
+
+ def _refresh(refresh_token: str) -> Tuple[str, str]:
+ """
+ Performs one refresh, returning the next refresh token and access token.
+ """
+ refresh_response = self.use_refresh_token(refresh_token)
+ self.assertEqual(
+ refresh_response.code, HTTPStatus.OK, refresh_response.result
+ )
+ return (
+ refresh_response.json_body["refresh_token"],
+ refresh_response.json_body["access_token"],
+ )
+
+ def _table_length(table_name: str) -> int:
+ """
+ Helper to get the size of a table, in rows.
+ For testing only; trivially vulnerable to SQL injection.
+ """
+
+ def _txn(txn: LoggingTransaction) -> int:
+ txn.execute(f"SELECT COUNT(1) FROM {table_name}")
+ row = txn.fetchone()
+ # Query is infallible
+ assert row is not None
+ return row[0]
+
+ return self.get_success(
+ self.hs.get_datastores().main.db_pool.runInteraction(
+ "_table_length", _txn
+ )
+ )
+
+ # Before we log in, there are no access tokens.
+ self.assertEqual(_table_length("access_tokens"), 0)
+ self.assertEqual(_table_length("refresh_tokens"), 0)
+
+ body = {
+ "type": "m.login.password",
+ "user": "test",
+ "password": self.user_pass,
+ "refresh_token": True,
+ }
+ login_response = self.make_request(
+ "POST",
+ "/_matrix/client/v3/login",
+ body,
+ )
+ self.assertEqual(login_response.code, HTTPStatus.OK, login_response.result)
+
+ access_token = login_response.json_body["access_token"]
+ refresh_token = login_response.json_body["refresh_token"]
+
+ # Now that we have logged in, there should be one access token and one
+ # refresh token
+ self.assertEqual(_table_length("access_tokens"), 1)
+ self.assertEqual(_table_length("refresh_tokens"), 1)
+
+ for _ in range(5):
+ refresh_token, access_token = _refresh(refresh_token)
+
+ # After 5 sequential refreshes, there should only be the latest two
+ # refresh/access token pairs.
+ # (The last one is preserved because it's in use!
+ # The one before that is preserved because it can still be used to
+ # replace the last token pair, in case of e.g. a network interruption.)
+ self.assertEqual(_table_length("access_tokens"), 2)
+ self.assertEqual(_table_length("refresh_tokens"), 2)
+
+ logout_response = self.make_request(
+ "POST", "/_matrix/client/v3/logout", {}, access_token=access_token
+ )
+ self.assertEqual(logout_response.code, HTTPStatus.OK, logout_response.result)
+
+ # Now that we have logged in, there should be no access token
+ # and no refresh token
+ self.assertEqual(_table_length("access_tokens"), 0)
+ self.assertEqual(_table_length("refresh_tokens"), 0)
|