diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 80d9606783..d856b76952 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -21,9 +21,11 @@ from typing import Any, Dict, List, Optional, Tuple
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
from synapse.metrics.background_process_metrics import wrap_as_background_process
-from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
+from synapse.storage.databases.main.stats import StatsStore
from synapse.storage.types import Cursor
+from synapse.storage.util.id_generators import IdGenerator
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import UserID
from synapse.util.caches.descriptors import cached
@@ -33,7 +35,7 @@ THIRTY_MINUTES_IN_MS = 30 * 60 * 1000
logger = logging.getLogger(__name__)
-class RegistrationWorkerStore(SQLBaseStore):
+class RegistrationWorkerStore(CacheInvalidationWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
@@ -1020,13 +1022,56 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
return 1
+ async def set_user_deactivated_status(
+ self, user_id: str, deactivated: bool
+ ) -> None:
+ """Set the `deactivated` property for the provided user to the provided value.
+
+ Args:
+ user_id: The ID of the user to set the status for.
+ deactivated: The value to set for `deactivated`.
+ """
+
+ await self.db_pool.runInteraction(
+ "set_user_deactivated_status",
+ self.set_user_deactivated_status_txn,
+ user_id,
+ deactivated,
+ )
+
+ def set_user_deactivated_status_txn(self, txn, user_id, deactivated):
+ self.db_pool.simple_update_one_txn(
+ txn=txn,
+ table="users",
+ keyvalues={"name": user_id},
+ updatevalues={"deactivated": 1 if deactivated else 0},
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_user_deactivated_status, (user_id,)
+ )
+ txn.call_after(self.is_guest.invalidate, (user_id,))
-class RegistrationStore(RegistrationBackgroundUpdateStore):
+ @cached()
+ async def is_guest(self, user_id: str) -> bool:
+ res = await self.db_pool.simple_select_one_onecol(
+ table="users",
+ keyvalues={"name": user_id},
+ retcol="is_guest",
+ allow_none=True,
+ desc="is_guest",
+ )
+
+ return res if res else False
+
+
+class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors
+ self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
+
async def add_access_token_to_user(
self,
user_id: str,
@@ -1378,18 +1423,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
await self.db_pool.runInteraction("delete_access_token", f)
- @cached()
- async def is_guest(self, user_id: str) -> bool:
- res = await self.db_pool.simple_select_one_onecol(
- table="users",
- keyvalues={"name": user_id},
- retcol="is_guest",
- allow_none=True,
- desc="is_guest",
- )
-
- return res if res else False
-
async def add_user_pending_deactivation(self, user_id: str) -> None:
"""
Adds a user to the table of users who need to be parted from all the rooms they're
@@ -1551,35 +1584,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
start_or_continue_validation_session_txn,
)
- async def set_user_deactivated_status(
- self, user_id: str, deactivated: bool
- ) -> None:
- """Set the `deactivated` property for the provided user to the provided value.
-
- Args:
- user_id: The ID of the user to set the status for.
- deactivated: The value to set for `deactivated`.
- """
-
- await self.db_pool.runInteraction(
- "set_user_deactivated_status",
- self.set_user_deactivated_status_txn,
- user_id,
- deactivated,
- )
-
- def set_user_deactivated_status_txn(self, txn, user_id, deactivated):
- self.db_pool.simple_update_one_txn(
- txn=txn,
- table="users",
- keyvalues={"name": user_id},
- updatevalues={"deactivated": 1 if deactivated else 0},
- )
- self._invalidate_cache_and_stream(
- txn, self.get_user_deactivated_status, (user_id,)
- )
- txn.call_after(self.is_guest.invalidate, (user_id,))
-
def find_max_generated_user_id_localpart(cur: Cursor) -> int:
"""
|