diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 01b918e12e..00a644e8f7 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -63,6 +63,7 @@ from .relations import RelationsStore
from .room import RoomStore
from .roommember import RoomMemberStore
from .search import SearchStore
+from .session import SessionStore
from .signatures import SignatureStore
from .state import StateStore
from .stats import StatsStore
@@ -121,6 +122,7 @@ class DataStore(
ServerMetricsStore,
EventForwardExtremitiesStore,
LockStore,
+ SessionStore,
):
def __init__(self, database: DatabasePool, db_conn, hs):
self.hs = hs
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 375463e4e9..9501f00f3b 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -520,16 +520,26 @@ class EventsWorkerStore(SQLBaseStore):
# We now look up if we're already fetching some of the events in the DB,
# if so we wait for those lookups to finish instead of pulling the same
# events out of the DB multiple times.
- already_fetching: Dict[str, defer.Deferred] = {}
+ #
+ # Note: we might get the same `ObservableDeferred` back for multiple
+ # events we're already fetching, so we deduplicate the deferreds to
+ # avoid extraneous work (if we don't do this we can end up in a n^2 mode
+ # when we wait on the same Deferred N times, then try and merge the
+ # same dict into itself N times).
+ already_fetching_ids: Set[str] = set()
+ already_fetching_deferreds: Set[
+ ObservableDeferred[Dict[str, _EventCacheEntry]]
+ ] = set()
for event_id in missing_events_ids:
deferred = self._current_event_fetches.get(event_id)
if deferred is not None:
# We're already pulling the event out of the DB. Add the deferred
# to the collection of deferreds to wait on.
- already_fetching[event_id] = deferred.observe()
+ already_fetching_ids.add(event_id)
+ already_fetching_deferreds.add(deferred)
- missing_events_ids.difference_update(already_fetching)
+ missing_events_ids.difference_update(already_fetching_ids)
if missing_events_ids:
log_ctx = current_context()
@@ -569,18 +579,25 @@ class EventsWorkerStore(SQLBaseStore):
with PreserveLoggingContext():
fetching_deferred.callback(missing_events)
- if already_fetching:
+ if already_fetching_deferreds:
# Wait for the other event requests to finish and add their results
# to ours.
results = await make_deferred_yieldable(
defer.gatherResults(
- already_fetching.values(),
+ (d.observe() for d in already_fetching_deferreds),
consumeErrors=True,
)
).addErrback(unwrapFirstError)
for result in results:
- event_entry_map.update(result)
+ # We filter out events that we haven't asked for as we might get
+ # a *lot* of superfluous events back, and there is no point
+ # going through and inserting them all (which can take time).
+ event_entry_map.update(
+ (event_id, entry)
+ for event_id, entry in result.items()
+ if event_id in already_fetching_ids
+ )
if not allow_rejected:
event_entry_map = {
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 664c65dac5..bccff5e5b9 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -295,6 +295,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
self._invalidate_cache_and_stream(
txn, self.have_seen_event, (room_id, event_id)
)
+ self._invalidate_get_event_cache(event_id)
logger.info("[purge] done")
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index b48fe086d4..e47caa2125 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -48,6 +48,11 @@ class PusherWorkerStore(SQLBaseStore):
self._remove_stale_pushers,
)
+ self.db_pool.updates.register_background_update_handler(
+ "remove_deleted_email_pushers",
+ self._remove_deleted_email_pushers,
+ )
+
def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[PusherConfig]:
"""JSON-decode the data in the rows returned from the `pushers` table
@@ -388,6 +393,73 @@ class PusherWorkerStore(SQLBaseStore):
return number_deleted
+ async def _remove_deleted_email_pushers(
+ self, progress: dict, batch_size: int
+ ) -> int:
+ """A background update that deletes all pushers for deleted email addresses.
+
+ In previous versions of synapse, when users deleted their email address, it didn't
+ also delete all the pushers for that email address. This background update removes
+ those to prevent unwanted emails. This should only need to be run once (when users
+ upgrade to v1.42.0
+
+ Args:
+ progress: dict used to store progress of this background update
+ batch_size: the maximum number of rows to retrieve in a single select query
+
+ Returns:
+ The number of deleted rows
+ """
+
+ last_pusher = progress.get("last_pusher", 0)
+
+ def _delete_pushers(txn) -> int:
+
+ sql = """
+ SELECT p.id, p.user_name, p.app_id, p.pushkey
+ FROM pushers AS p
+ LEFT JOIN user_threepids AS t
+ ON t.user_id = p.user_name
+ AND t.medium = 'email'
+ AND t.address = p.pushkey
+ WHERE t.user_id is NULL
+ AND p.app_id = 'm.email'
+ AND p.id > ?
+ ORDER BY p.id ASC
+ LIMIT ?
+ """
+
+ txn.execute(sql, (last_pusher, batch_size))
+
+ last = None
+ num_deleted = 0
+ for row in txn:
+ last = row[0]
+ num_deleted += 1
+ self.db_pool.simple_delete_txn(
+ txn,
+ "pushers",
+ {"user_name": row[1], "app_id": row[2], "pushkey": row[3]},
+ )
+
+ if last is not None:
+ self.db_pool.updates._background_update_progress_txn(
+ txn, "remove_deleted_email_pushers", {"last_pusher": last}
+ )
+
+ return num_deleted
+
+ number_deleted = await self.db_pool.runInteraction(
+ "_remove_deleted_email_pushers", _delete_pushers
+ )
+
+ if number_deleted < batch_size:
+ await self.db_pool.updates._end_background_update(
+ "remove_deleted_email_pushers"
+ )
+
+ return number_deleted
+
class PusherStore(PusherWorkerStore):
def get_pushers_stream_token(self) -> int:
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index c67bea81c6..a6517962f6 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -754,16 +754,18 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
)
return user_id
- def get_user_id_by_threepid_txn(self, txn, medium, address):
+ def get_user_id_by_threepid_txn(
+ self, txn, medium: str, address: str
+ ) -> Optional[str]:
"""Returns user id from threepid
Args:
txn (cursor):
- medium (str): threepid medium e.g. email
- address (str): threepid address e.g. me@example.com
+ medium: threepid medium e.g. email
+ address: threepid address e.g. me@example.com
Returns:
- str|None: user id or None if no user id/threepid mapping exists
+ user id, or None if no user id/threepid mapping exists
"""
ret = self.db_pool.simple_select_one_txn(
txn,
@@ -776,14 +778,21 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
return ret["user_id"]
return None
- async def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
+ async def user_add_threepid(
+ self,
+ user_id: str,
+ medium: str,
+ address: str,
+ validated_at: int,
+ added_at: int,
+ ) -> None:
await self.db_pool.simple_upsert(
"user_threepids",
{"medium": medium, "address": address},
{"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
)
- async def user_get_threepids(self, user_id):
+ async def user_get_threepids(self, user_id) -> List[Dict[str, Any]]:
return await self.db_pool.simple_select_list(
"user_threepids",
{"user_id": user_id},
@@ -791,7 +800,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"user_get_threepids",
)
- async def user_delete_threepid(self, user_id, medium, address) -> None:
+ async def user_delete_threepid(
+ self, user_id: str, medium: str, address: str
+ ) -> None:
await self.db_pool.simple_delete(
"user_threepids",
keyvalues={"user_id": user_id, "medium": medium, "address": address},
@@ -1157,6 +1168,322 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
desc="update_access_token_last_validated",
)
+ async def registration_token_is_valid(self, token: str) -> bool:
+ """Checks if a token can be used to authenticate a registration.
+
+ Args:
+ token: The registration token to be checked
+ Returns:
+ True if the token is valid, False otherwise.
+ """
+ res = await self.db_pool.simple_select_one(
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcols=["uses_allowed", "pending", "completed", "expiry_time"],
+ allow_none=True,
+ )
+
+ # Check if the token exists
+ if res is None:
+ return False
+
+ # Check if the token has expired
+ now = self._clock.time_msec()
+ if res["expiry_time"] and res["expiry_time"] < now:
+ return False
+
+ # Check if the token has been used up
+ if (
+ res["uses_allowed"]
+ and res["pending"] + res["completed"] >= res["uses_allowed"]
+ ):
+ return False
+
+ # Otherwise, the token is valid
+ return True
+
+ async def set_registration_token_pending(self, token: str) -> None:
+ """Increment the pending registrations counter for a token.
+
+ Args:
+ token: The registration token pending use
+ """
+
+ def _set_registration_token_pending_txn(txn):
+ pending = self.db_pool.simple_select_one_onecol_txn(
+ txn,
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcol="pending",
+ )
+ self.db_pool.simple_update_one_txn(
+ txn,
+ "registration_tokens",
+ keyvalues={"token": token},
+ updatevalues={"pending": pending + 1},
+ )
+
+ return await self.db_pool.runInteraction(
+ "set_registration_token_pending", _set_registration_token_pending_txn
+ )
+
+ async def use_registration_token(self, token: str) -> None:
+ """Complete a use of the given registration token.
+
+ The `pending` counter will be decremented, and the `completed`
+ counter will be incremented.
+
+ Args:
+ token: The registration token to be 'used'
+ """
+
+ def _use_registration_token_txn(txn):
+ # Normally, res is Optional[Dict[str, Any]].
+ # Override type because the return type is only optional if
+ # allow_none is True, and we don't want mypy throwing errors
+ # about None not being indexable.
+ res: Dict[str, Any] = self.db_pool.simple_select_one_txn(
+ txn,
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcols=["pending", "completed"],
+ ) # type: ignore
+
+ # Decrement pending and increment completed
+ self.db_pool.simple_update_one_txn(
+ txn,
+ "registration_tokens",
+ keyvalues={"token": token},
+ updatevalues={
+ "completed": res["completed"] + 1,
+ "pending": res["pending"] - 1,
+ },
+ )
+
+ return await self.db_pool.runInteraction(
+ "use_registration_token", _use_registration_token_txn
+ )
+
+ async def get_registration_tokens(
+ self, valid: Optional[bool] = None
+ ) -> List[Dict[str, Any]]:
+ """List all registration tokens. Used by the admin API.
+
+ Args:
+ valid: If True, only valid tokens are returned.
+ If False, only invalid tokens are returned.
+ Default is None: return all tokens regardless of validity.
+
+ Returns:
+ A list of dicts, each containing details of a token.
+ """
+
+ def select_registration_tokens_txn(txn, now: int, valid: Optional[bool]):
+ if valid is None:
+ # Return all tokens regardless of validity
+ txn.execute("SELECT * FROM registration_tokens")
+
+ elif valid:
+ # Select valid tokens only
+ sql = (
+ "SELECT * FROM registration_tokens WHERE "
+ "(uses_allowed > pending + completed OR uses_allowed IS NULL) "
+ "AND (expiry_time > ? OR expiry_time IS NULL)"
+ )
+ txn.execute(sql, [now])
+
+ else:
+ # Select invalid tokens only
+ sql = (
+ "SELECT * FROM registration_tokens WHERE "
+ "uses_allowed <= pending + completed OR expiry_time <= ?"
+ )
+ txn.execute(sql, [now])
+
+ return self.db_pool.cursor_to_dict(txn)
+
+ return await self.db_pool.runInteraction(
+ "select_registration_tokens",
+ select_registration_tokens_txn,
+ self._clock.time_msec(),
+ valid,
+ )
+
+ async def get_one_registration_token(self, token: str) -> Optional[Dict[str, Any]]:
+ """Get info about the given registration token. Used by the admin API.
+
+ Args:
+ token: The token to retrieve information about.
+
+ Returns:
+ A dict, or None if token doesn't exist.
+ """
+ return await self.db_pool.simple_select_one(
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcols=["token", "uses_allowed", "pending", "completed", "expiry_time"],
+ allow_none=True,
+ desc="get_one_registration_token",
+ )
+
+ async def generate_registration_token(
+ self, length: int, chars: str
+ ) -> Optional[str]:
+ """Generate a random registration token. Used by the admin API.
+
+ Args:
+ length: The length of the token to generate.
+ chars: A string of the characters allowed in the generated token.
+
+ Returns:
+ The generated token.
+
+ Raises:
+ SynapseError if a unique registration token could still not be
+ generated after a few tries.
+ """
+ # Make a few attempts at generating a unique token of the required
+ # length before failing.
+ for _i in range(3):
+ # Generate token
+ token = "".join(random.choices(chars, k=length))
+
+ # Check if the token already exists
+ existing_token = await self.db_pool.simple_select_one_onecol(
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcol="token",
+ allow_none=True,
+ desc="check_if_registration_token_exists",
+ )
+
+ if existing_token is None:
+ # The generated token doesn't exist yet, return it
+ return token
+
+ raise SynapseError(
+ 500,
+ "Unable to generate a unique registration token. Try again with a greater length",
+ Codes.UNKNOWN,
+ )
+
+ async def create_registration_token(
+ self, token: str, uses_allowed: Optional[int], expiry_time: Optional[int]
+ ) -> bool:
+ """Create a new registration token. Used by the admin API.
+
+ Args:
+ token: The token to create.
+ uses_allowed: The number of times the token can be used to complete
+ a registration before it becomes invalid. A value of None indicates
+ unlimited uses.
+ expiry_time: The latest time the token is valid. Given as the
+ number of milliseconds since 1970-01-01 00:00:00 UTC. A value of
+ None indicates that the token does not expire.
+
+ Returns:
+ Whether the row was inserted or not.
+ """
+
+ def _create_registration_token_txn(txn):
+ row = self.db_pool.simple_select_one_txn(
+ txn,
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcols=["token"],
+ allow_none=True,
+ )
+
+ if row is not None:
+ # Token already exists
+ return False
+
+ self.db_pool.simple_insert_txn(
+ txn,
+ "registration_tokens",
+ values={
+ "token": token,
+ "uses_allowed": uses_allowed,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": expiry_time,
+ },
+ )
+
+ return True
+
+ return await self.db_pool.runInteraction(
+ "create_registration_token", _create_registration_token_txn
+ )
+
+ async def update_registration_token(
+ self, token: str, updatevalues: Dict[str, Optional[int]]
+ ) -> Optional[Dict[str, Any]]:
+ """Update a registration token. Used by the admin API.
+
+ Args:
+ token: The token to update.
+ updatevalues: A dict with the fields to update. E.g.:
+ `{"uses_allowed": 3}` to update just uses_allowed, or
+ `{"uses_allowed": 3, "expiry_time": None}` to update both.
+ This is passed straight to simple_update_one.
+
+ Returns:
+ A dict with all info about the token, or None if token doesn't exist.
+ """
+
+ def _update_registration_token_txn(txn):
+ try:
+ self.db_pool.simple_update_one_txn(
+ txn,
+ "registration_tokens",
+ keyvalues={"token": token},
+ updatevalues=updatevalues,
+ )
+ except StoreError:
+ # Update failed because token does not exist
+ return None
+
+ # Get all info about the token so it can be sent in the response
+ return self.db_pool.simple_select_one_txn(
+ txn,
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcols=[
+ "token",
+ "uses_allowed",
+ "pending",
+ "completed",
+ "expiry_time",
+ ],
+ allow_none=True,
+ )
+
+ return await self.db_pool.runInteraction(
+ "update_registration_token", _update_registration_token_txn
+ )
+
+ async def delete_registration_token(self, token: str) -> bool:
+ """Delete a registration token. Used by the admin API.
+
+ Args:
+ token: The token to delete.
+
+ Returns:
+ Whether the token was successfully deleted or not.
+ """
+ try:
+ await self.db_pool.simple_delete_one(
+ "registration_tokens",
+ keyvalues={"token": token},
+ desc="delete_registration_token",
+ )
+ except StoreError:
+ # Deletion failed because token does not exist
+ return False
+
+ return True
+
@cached()
async def mark_access_token_as_used(self, token_id: int) -> None:
"""
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index e8157ba3d4..c58a4b8690 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -307,7 +307,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
@cached()
- async def get_invited_rooms_for_local_user(self, user_id: str) -> RoomsForUser:
+ async def get_invited_rooms_for_local_user(
+ self, user_id: str
+ ) -> List[RoomsForUser]:
"""Get all the rooms the *local* user is invited to.
Args:
@@ -384,9 +386,10 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
sql = """
- SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering
+ SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering, r.room_version
FROM local_current_membership AS c
INNER JOIN events AS e USING (room_id, event_id)
+ INNER JOIN rooms AS r USING (room_id)
WHERE
user_id = ?
AND %s
@@ -395,7 +398,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
txn.execute(sql, (user_id, *args))
- results = [RoomsForUser(**r) for r in self.db_pool.cursor_to_dict(txn)]
+ results = [RoomsForUser(*r) for r in txn]
return results
@@ -445,7 +448,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
Returns:
Returns the rooms the user is in currently, along with the stream
- ordering of the most recent join for that user and room.
+ ordering of the most recent join for that user and room, along with
+ the room version of the room.
"""
return await self.db_pool.runInteraction(
"get_rooms_for_user_with_stream_ordering",
@@ -522,7 +526,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
_get_users_server_still_shares_room_with_txn,
)
- async def get_rooms_for_user(self, user_id: str, on_invalidate=None):
+ async def get_rooms_for_user(
+ self, user_id: str, on_invalidate=None
+ ) -> FrozenSet[str]:
"""Returns a set of room_ids the user is currently joined to.
If a remote user only returns rooms this server is currently
diff --git a/synapse/storage/databases/main/session.py b/synapse/storage/databases/main/session.py
new file mode 100644
index 0000000000..172f27d109
--- /dev/null
+++ b/synapse/storage/databases/main/session.py
@@ -0,0 +1,145 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+import synapse.util.stringutils as stringutils
+from synapse.api.errors import StoreError
+from synapse.metrics.background_process_metrics import wrap_as_background_process
+from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
+from synapse.types import JsonDict
+from synapse.util import json_encoder
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+
+class SessionStore(SQLBaseStore):
+ """
+ A store for generic session data.
+
+ Each type of session should provide a unique type (to separate sessions).
+
+ Sessions are automatically removed when they expire.
+ """
+
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
+ super().__init__(database, db_conn, hs)
+
+ # Create a background job for culling expired sessions.
+ if hs.config.run_background_tasks:
+ self._clock.looping_call(self._delete_expired_sessions, 30 * 60 * 1000)
+
+ async def create_session(
+ self, session_type: str, value: JsonDict, expiry_ms: int
+ ) -> str:
+ """
+ Creates a new pagination session for the room hierarchy endpoint.
+
+ Args:
+ session_type: The type for this session.
+ value: The value to store.
+ expiry_ms: How long before an item is evicted from the cache
+ in milliseconds. Default is 0, indicating items never get
+ evicted based on time.
+
+ Returns:
+ The newly created session ID.
+
+ Raises:
+ StoreError if a unique session ID cannot be generated.
+ """
+ # autogen a session ID and try to create it. We may clash, so just
+ # try a few times till one goes through, giving up eventually.
+ attempts = 0
+ while attempts < 5:
+ session_id = stringutils.random_string(24)
+
+ try:
+ await self.db_pool.simple_insert(
+ table="sessions",
+ values={
+ "session_id": session_id,
+ "session_type": session_type,
+ "value": json_encoder.encode(value),
+ "expiry_time_ms": self.hs.get_clock().time_msec() + expiry_ms,
+ },
+ desc="create_session",
+ )
+
+ return session_id
+ except self.db_pool.engine.module.IntegrityError:
+ attempts += 1
+ raise StoreError(500, "Couldn't generate a session ID.")
+
+ async def get_session(self, session_type: str, session_id: str) -> JsonDict:
+ """
+ Retrieve data stored with create_session
+
+ Args:
+ session_type: The type for this session.
+ session_id: The session ID returned from create_session.
+
+ Raises:
+ StoreError if the session cannot be found.
+ """
+
+ def _get_session(
+ txn: LoggingTransaction, session_type: str, session_id: str, ts: int
+ ) -> JsonDict:
+ # This includes the expiry time since items are only periodically
+ # deleted, not upon expiry.
+ select_sql = """
+ SELECT value FROM sessions WHERE
+ session_type = ? AND session_id = ? AND expiry_time_ms > ?
+ """
+ txn.execute(select_sql, [session_type, session_id, ts])
+ row = txn.fetchone()
+
+ if not row:
+ raise StoreError(404, "No session")
+
+ return db_to_json(row[0])
+
+ return await self.db_pool.runInteraction(
+ "get_session",
+ _get_session,
+ session_type,
+ session_id,
+ self._clock.time_msec(),
+ )
+
+ @wrap_as_background_process("delete_expired_sessions")
+ async def _delete_expired_sessions(self) -> None:
+ """Remove sessions with expiry dates that have passed."""
+
+ def _delete_expired_sessions_txn(txn: LoggingTransaction, ts: int) -> None:
+ sql = "DELETE FROM sessions WHERE expiry_time_ms <= ?"
+ txn.execute(sql, (ts,))
+
+ await self.db_pool.runInteraction(
+ "delete_expired_sessions",
+ _delete_expired_sessions_txn,
+ self._clock.time_msec(),
+ )
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 38bfdf5dad..4d6bbc94c7 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -15,6 +15,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import attr
+from synapse.api.constants import LoginType
from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import LoggingTransaction
@@ -329,6 +330,48 @@ class UIAuthWorkerStore(SQLBaseStore):
keyvalues={},
)
+ # If a registration token was used, decrement the pending counter
+ # before deleting the session.
+ rows = self.db_pool.simple_select_many_txn(
+ txn,
+ table="ui_auth_sessions_credentials",
+ column="session_id",
+ iterable=session_ids,
+ keyvalues={"stage_type": LoginType.REGISTRATION_TOKEN},
+ retcols=["result"],
+ )
+
+ # Get the tokens used and how much pending needs to be decremented by.
+ token_counts: Dict[str, int] = {}
+ for r in rows:
+ # If registration was successfully completed, the result of the
+ # registration token stage for that session will be True.
+ # If a token was used to authenticate, but registration was
+ # never completed, the result will be the token used.
+ token = db_to_json(r["result"])
+ if isinstance(token, str):
+ token_counts[token] = token_counts.get(token, 0) + 1
+
+ # Update the `pending` counters.
+ if len(token_counts) > 0:
+ token_rows = self.db_pool.simple_select_many_txn(
+ txn,
+ table="registration_tokens",
+ column="token",
+ iterable=list(token_counts.keys()),
+ keyvalues={},
+ retcols=["token", "pending"],
+ )
+ for token_row in token_rows:
+ token = token_row["token"]
+ new_pending = token_row["pending"] - token_counts[token]
+ self.db_pool.simple_update_one_txn(
+ txn,
+ table="registration_tokens",
+ keyvalues={"token": token},
+ updatevalues={"pending": new_pending},
+ )
+
# Delete the corresponding completed credentials.
self.db_pool.simple_delete_many_txn(
txn,
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 9d28d69ac7..65dde67ae9 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -365,7 +365,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
return False
async def update_profile_in_user_dir(
- self, user_id: str, display_name: str, avatar_url: str
+ self, user_id: str, display_name: Optional[str], avatar_url: Optional[str]
) -> None:
"""
Update or add a user's profile in the user directory.
|