diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 1a859352b6..1f6c93b73b 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -37,7 +37,57 @@ THIRTY_MINUTES_IN_MS = 30 * 60 * 1000
logger = logging.getLogger(__name__)
-class RegistrationWorkerStore(SQLBaseStore):
+class RegistrationDeactivationStore(SQLBaseStore):
+ @cachedInlineCallbacks()
+ def get_user_deactivated_status(self, user_id):
+ """Retrieve the value for the `deactivated` property for the provided user.
+
+ Args:
+ user_id (str): The ID of the user to retrieve the status for.
+
+ Returns:
+ defer.Deferred(bool): The requested value.
+ """
+
+ res = yield self._simple_select_one_onecol(
+ table="users",
+ keyvalues={"name": user_id},
+ retcol="deactivated",
+ desc="get_user_deactivated_status",
+ )
+
+ # Convert the integer into a boolean.
+ return res == 1
+
+ @defer.inlineCallbacks
+ def set_user_deactivated_status(self, user_id, deactivated):
+ """Set the `deactivated` property for the provided user to the provided value.
+
+ Args:
+ user_id (str): The ID of the user to set the status for.
+ deactivated (bool): The value to set for `deactivated`.
+ """
+
+ yield self.runInteraction(
+ "set_user_deactivated_status",
+ self.set_user_deactivated_status_txn,
+ user_id,
+ deactivated,
+ )
+
+ def set_user_deactivated_status_txn(self, txn, user_id, deactivated):
+ self._simple_update_one_txn(
+ txn=txn,
+ table="users",
+ keyvalues={"name": user_id},
+ updatevalues={"deactivated": 1 if deactivated else 0},
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_user_deactivated_status, (user_id,)
+ )
+
+
+class RegistrationWorkerStore(RegistrationDeactivationStore):
def __init__(self, db_conn, hs):
super(RegistrationWorkerStore, self).__init__(db_conn, hs)
@@ -673,27 +723,6 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="get_id_servers_user_bound",
)
- @cachedInlineCallbacks()
- def get_user_deactivated_status(self, user_id):
- """Retrieve the value for the `deactivated` property for the provided user.
-
- Args:
- user_id (str): The ID of the user to retrieve the status for.
-
- Returns:
- defer.Deferred(bool): The requested value.
- """
-
- res = yield self._simple_select_one_onecol(
- table="users",
- keyvalues={"name": user_id},
- retcol="deactivated",
- desc="get_user_deactivated_status",
- )
-
- # Convert the integer into a boolean.
- return res == 1
-
def get_threepid_validation_session(
self, medium, client_secret, address=None, sid=None, validated=True
):
@@ -787,13 +816,14 @@ class RegistrationWorkerStore(SQLBaseStore):
)
-class RegistrationStore(
- RegistrationWorkerStore, background_updates.BackgroundUpdateStore
+class RegistrationBackgroundUpdateStore(
+ RegistrationDeactivationStore, background_updates.BackgroundUpdateStore
):
def __init__(self, db_conn, hs):
- super(RegistrationStore, self).__init__(db_conn, hs)
+ super(RegistrationBackgroundUpdateStore, self).__init__(db_conn, hs)
self.clock = hs.get_clock()
+ self.config = hs.config
self.register_background_index_update(
"access_tokens_device_index",
@@ -809,8 +839,6 @@ class RegistrationStore(
columns=["creation_ts"],
)
- self._account_validity = hs.config.account_validity
-
# we no longer use refresh tokens, but it's possible that some people
# might have a background update queued to build this index. Just
# clear the background update.
@@ -824,17 +852,6 @@ class RegistrationStore(
"users_set_deactivated_flag", self._background_update_set_deactivated_flag
)
- # Create a background job for culling expired 3PID validity tokens
- def start_cull():
- # run as a background process to make sure that the database transactions
- # have a logcontext to report to
- return run_as_background_process(
- "cull_expired_threepid_validation_tokens",
- self.cull_expired_threepid_validation_tokens,
- )
-
- hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS)
-
@defer.inlineCallbacks
def _background_update_set_deactivated_flag(self, progress, batch_size):
"""Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
@@ -897,6 +914,54 @@ class RegistrationStore(
return nb_processed
@defer.inlineCallbacks
+ def _bg_user_threepids_grandfather(self, progress, batch_size):
+ """We now track which identity servers a user binds their 3PID to, so
+ we need to handle the case of existing bindings where we didn't track
+ this.
+
+ We do this by grandfathering in existing user threepids assuming that
+ they used one of the server configured trusted identity servers.
+ """
+ id_servers = set(self.config.trusted_third_party_id_servers)
+
+ def _bg_user_threepids_grandfather_txn(txn):
+ sql = """
+ INSERT INTO user_threepid_id_server
+ (user_id, medium, address, id_server)
+ SELECT user_id, medium, address, ?
+ FROM user_threepids
+ """
+
+ txn.executemany(sql, [(id_server,) for id_server in id_servers])
+
+ if id_servers:
+ yield self.runInteraction(
+ "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn
+ )
+
+ yield self._end_background_update("user_threepids_grandfather")
+
+ return 1
+
+
+class RegistrationStore(RegistrationWorkerStore, RegistrationBackgroundUpdateStore):
+ def __init__(self, db_conn, hs):
+ super(RegistrationStore, self).__init__(db_conn, hs)
+
+ self._account_validity = hs.config.account_validity
+
+ # Create a background job for culling expired 3PID validity tokens
+ def start_cull():
+ # run as a background process to make sure that the database transactions
+ # have a logcontext to report to
+ return run_as_background_process(
+ "cull_expired_threepid_validation_tokens",
+ self.cull_expired_threepid_validation_tokens,
+ )
+
+ hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS)
+
+ @defer.inlineCallbacks
def add_access_token_to_user(self, user_id, token, device_id, valid_until_ms):
"""Adds an access token for the given user.
@@ -1244,36 +1309,6 @@ class RegistrationStore(
desc="get_users_pending_deactivation",
)
- @defer.inlineCallbacks
- def _bg_user_threepids_grandfather(self, progress, batch_size):
- """We now track which identity servers a user binds their 3PID to, so
- we need to handle the case of existing bindings where we didn't track
- this.
-
- We do this by grandfathering in existing user threepids assuming that
- they used one of the server configured trusted identity servers.
- """
- id_servers = set(self.config.trusted_third_party_id_servers)
-
- def _bg_user_threepids_grandfather_txn(txn):
- sql = """
- INSERT INTO user_threepid_id_server
- (user_id, medium, address, id_server)
- SELECT user_id, medium, address, ?
- FROM user_threepids
- """
-
- txn.executemany(sql, [(id_server,) for id_server in id_servers])
-
- if id_servers:
- yield self.runInteraction(
- "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn
- )
-
- yield self._end_background_update("user_threepids_grandfather")
-
- return 1
-
def validate_threepid_session(self, session_id, client_secret, token, current_ts):
"""Attempt to validate a threepid session using a token
@@ -1464,30 +1499,3 @@ class RegistrationStore(
cull_expired_threepid_validation_tokens_txn,
self.clock.time_msec(),
)
-
- def set_user_deactivated_status_txn(self, txn, user_id, deactivated):
- self._simple_update_one_txn(
- txn=txn,
- table="users",
- keyvalues={"name": user_id},
- updatevalues={"deactivated": 1 if deactivated else 0},
- )
- self._invalidate_cache_and_stream(
- txn, self.get_user_deactivated_status, (user_id,)
- )
-
- @defer.inlineCallbacks
- def set_user_deactivated_status(self, user_id, deactivated):
- """Set the `deactivated` property for the provided user to the provided value.
-
- Args:
- user_id (str): The ID of the user to set the status for.
- deactivated (bool): The value to set for `deactivated`.
- """
-
- yield self.runInteraction(
- "set_user_deactivated_status",
- self.set_user_deactivated_status_txn,
- user_id,
- deactivated,
- )
|