diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 38bfdf5dad..4d6bbc94c7 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -15,6 +15,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import attr
+from synapse.api.constants import LoginType
from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import LoggingTransaction
@@ -329,6 +330,48 @@ class UIAuthWorkerStore(SQLBaseStore):
keyvalues={},
)
+ # If a registration token was used, decrement the pending counter
+ # before deleting the session.
+ rows = self.db_pool.simple_select_many_txn(
+ txn,
+ table="ui_auth_sessions_credentials",
+ column="session_id",
+ iterable=session_ids,
+ keyvalues={"stage_type": LoginType.REGISTRATION_TOKEN},
+ retcols=["result"],
+ )
+
+ # Get the tokens used and how much pending needs to be decremented by.
+ token_counts: Dict[str, int] = {}
+ for r in rows:
+ # If registration was successfully completed, the result of the
+ # registration token stage for that session will be True.
+ # If a token was used to authenticate, but registration was
+ # never completed, the result will be the token used.
+ token = db_to_json(r["result"])
+ if isinstance(token, str):
+ token_counts[token] = token_counts.get(token, 0) + 1
+
+ # Update the `pending` counters.
+ if len(token_counts) > 0:
+ token_rows = self.db_pool.simple_select_many_txn(
+ txn,
+ table="registration_tokens",
+ column="token",
+ iterable=list(token_counts.keys()),
+ keyvalues={},
+ retcols=["token", "pending"],
+ )
+ for token_row in token_rows:
+ token = token_row["token"]
+ new_pending = token_row["pending"] - token_counts[token]
+ self.db_pool.simple_update_one_txn(
+ txn,
+ table="registration_tokens",
+ keyvalues={"token": token},
+ updatevalues={"pending": new_pending},
+ )
+
# Delete the corresponding completed credentials.
self.db_pool.simple_delete_many_txn(
txn,
|