diff options
Diffstat (limited to 'synapse/storage/registration.py')
-rw-r--r-- | synapse/storage/registration.py | 109 |
1 files changed, 48 insertions, 61 deletions
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index d36917e4d6..983ce13291 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -116,8 +116,9 @@ class RegistrationWorkerStore(SQLBaseStore): defer.returnValue(res) @defer.inlineCallbacks - def set_account_validity_for_user(self, user_id, expiration_ts, email_sent, - renewal_token=None): + def set_account_validity_for_user( + self, user_id, expiration_ts, email_sent, renewal_token=None + ): """Updates the account validity properties of the given account, with the given values. @@ -131,6 +132,7 @@ class RegistrationWorkerStore(SQLBaseStore): renewal_token (str): Renewal token the user can use to extend the validity of their account. Defaults to no token. """ + def set_account_validity_for_user_txn(txn): self._simple_update_txn( txn=txn, @@ -143,12 +145,11 @@ class RegistrationWorkerStore(SQLBaseStore): }, ) self._invalidate_cache_and_stream( - txn, self.get_expiration_ts_for_user, (user_id,), + txn, self.get_expiration_ts_for_user, (user_id,) ) yield self.runInteraction( - "set_account_validity_for_user", - set_account_validity_for_user_txn, + "set_account_validity_for_user", set_account_validity_for_user_txn ) @defer.inlineCallbacks @@ -217,6 +218,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: Deferred: Resolves to a list[dict[user_id (str), expiration_ts_ms (int)]] """ + def select_users_txn(txn, now_ms, renew_at): sql = ( "SELECT user_id, expiration_ts_ms FROM account_validity" @@ -229,7 +231,8 @@ class RegistrationWorkerStore(SQLBaseStore): res = yield self.runInteraction( "get_users_expiring_soon", select_users_txn, - self.clock.time_msec(), self.config.account_validity.renew_at, + self.clock.time_msec(), + self.config.account_validity.renew_at, ) defer.returnValue(res) @@ -369,7 +372,7 @@ class RegistrationWorkerStore(SQLBaseStore): WHERE creation_ts > ? ) AS t GROUP BY user_type """ - results = {'native': 0, 'guest': 0, 'bridged': 0} + results = {"native": 0, "guest": 0, "bridged": 0} txn.execute(sql, (yesterday,)) for row in txn: results[row[0]] = row[1] @@ -435,7 +438,7 @@ class RegistrationWorkerStore(SQLBaseStore): {"medium": medium, "address": address}, ["guest_access_token"], True, - 'get_3pid_guest_access_token', + "get_3pid_guest_access_token", ) if ret: defer.returnValue(ret["guest_access_token"]) @@ -472,11 +475,11 @@ class RegistrationWorkerStore(SQLBaseStore): txn, "user_threepids", {"medium": medium, "address": address}, - ['user_id'], + ["user_id"], True, ) if ret: - return ret['user_id'] + return ret["user_id"] return None @defer.inlineCallbacks @@ -492,8 +495,8 @@ class RegistrationWorkerStore(SQLBaseStore): ret = yield self._simple_select_list( "user_threepids", {"user_id": user_id}, - ['medium', 'address', 'validated_at', 'added_at'], - 'user_get_threepids', + ["medium", "address", "validated_at", "added_at"], + "user_get_threepids", ) defer.returnValue(ret) @@ -572,11 +575,7 @@ class RegistrationWorkerStore(SQLBaseStore): """ return self._simple_select_onecol( table="user_threepid_id_server", - keyvalues={ - "user_id": user_id, - "medium": medium, - "address": address, - }, + keyvalues={"user_id": user_id, "medium": medium, "address": address}, retcol="id_server", desc="get_id_servers_user_bound", ) @@ -612,16 +611,16 @@ class RegistrationStore( self.register_noop_background_update("refresh_tokens_device_index") self.register_background_update_handler( - "user_threepids_grandfather", self._bg_user_threepids_grandfather, + "user_threepids_grandfather", self._bg_user_threepids_grandfather ) self.register_background_update_handler( - "users_set_deactivated_flag", self._backgroud_update_set_deactivated_flag, + "users_set_deactivated_flag", self._backgroud_update_set_deactivated_flag ) # Create a background job for culling expired 3PID validity tokens hs.get_clock().looping_call( - self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS, + self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS ) @defer.inlineCallbacks @@ -662,7 +661,7 @@ class RegistrationStore( for user in rows: if not user["count_tokens"] and not user["count_threepids"]: - self.set_user_deactivated_status_txn(txn, user["user_id"], True) + self.set_user_deactivated_status_txn(txn, user["name"], True) rows_processed_nb += 1 logger.info("Marked %d rows as deactivated", rows_processed_nb) @@ -677,8 +676,7 @@ class RegistrationStore( return False end = yield self.runInteraction( - "users_set_deactivated_flag", - _backgroud_update_set_deactivated_flag_txn, + "users_set_deactivated_flag", _backgroud_update_set_deactivated_flag_txn ) if end: @@ -851,7 +849,7 @@ class RegistrationStore( def user_set_password_hash_txn(txn): self._simple_update_one_txn( - txn, 'users', {'name': user_id}, {'password_hash': password_hash} + txn, "users", {"name": user_id}, {"password_hash": password_hash} ) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) @@ -872,9 +870,9 @@ class RegistrationStore( def f(txn): self._simple_update_one_txn( txn, - table='users', - keyvalues={'name': user_id}, - updatevalues={'consent_version': consent_version}, + table="users", + keyvalues={"name": user_id}, + updatevalues={"consent_version": consent_version}, ) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) @@ -896,9 +894,9 @@ class RegistrationStore( def f(txn): self._simple_update_one_txn( txn, - table='users', - keyvalues={'name': user_id}, - updatevalues={'consent_server_notice_sent': consent_version}, + table="users", + keyvalues={"name": user_id}, + updatevalues={"consent_server_notice_sent": consent_version}, ) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) @@ -1068,7 +1066,7 @@ class RegistrationStore( if id_servers: yield self.runInteraction( - "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn, + "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn ) yield self._end_background_update("user_threepids_grandfather") @@ -1076,12 +1074,7 @@ class RegistrationStore( defer.returnValue(1) def get_threepid_validation_session( - self, - medium, - client_secret, - address=None, - sid=None, - validated=True, + self, medium, client_secret, address=None, sid=None, validated=True ): """Gets a session_id and last_send_attempt (if available) for a client_secret/medium/(address|session_id) combo @@ -1101,23 +1094,22 @@ class RegistrationStore( latest session_id and send_attempt count for this 3PID. Otherwise None if there hasn't been a previous attempt """ - keyvalues = { - "medium": medium, - "client_secret": client_secret, - } + keyvalues = {"medium": medium, "client_secret": client_secret} if address: keyvalues["address"] = address if sid: keyvalues["session_id"] = sid - assert(address or sid) + assert address or sid def get_threepid_validation_session_txn(txn): sql = """ SELECT address, session_id, medium, client_secret, last_send_attempt, validated_at FROM threepid_validation_session WHERE %s - """ % (" AND ".join("%s = ?" % k for k in iterkeys(keyvalues)),) + """ % ( + " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)), + ) if validated is not None: sql += " AND validated_at IS " + ("NOT NULL" if validated else "NULL") @@ -1132,17 +1124,10 @@ class RegistrationStore( return rows[0] return self.runInteraction( - "get_threepid_validation_session", - get_threepid_validation_session_txn, + "get_threepid_validation_session", get_threepid_validation_session_txn ) - def validate_threepid_session( - self, - session_id, - client_secret, - token, - current_ts, - ): + def validate_threepid_session(self, session_id, client_secret, token, current_ts): """Attempt to validate a threepid session using a token Args: @@ -1174,7 +1159,7 @@ class RegistrationStore( if retrieved_client_secret != client_secret: raise ThreepidValidationError( - 400, "This client_secret does not match the provided session_id", + 400, "This client_secret does not match the provided session_id" ) row = self._simple_select_one_txn( @@ -1187,7 +1172,7 @@ class RegistrationStore( if not row: raise ThreepidValidationError( - 400, "Validation token not found or has expired", + 400, "Validation token not found or has expired" ) expires = row["expires"] next_link = row["next_link"] @@ -1198,7 +1183,7 @@ class RegistrationStore( if expires <= current_ts: raise ThreepidValidationError( - 400, "This token has expired. Please request a new one", + 400, "This token has expired. Please request a new one" ) # Looks good. Validate the session @@ -1213,8 +1198,7 @@ class RegistrationStore( # Return next_link if it exists return self.runInteraction( - "validate_threepid_session_txn", - validate_threepid_session_txn, + "validate_threepid_session_txn", validate_threepid_session_txn ) def upsert_threepid_validation_session( @@ -1281,6 +1265,7 @@ class RegistrationStore( token_expires (int): The timestamp for which after the token will no longer be valid """ + def start_or_continue_validation_session_txn(txn): # Create or update a validation session self._simple_upsert_txn( @@ -1314,6 +1299,7 @@ class RegistrationStore( def cull_expired_threepid_validation_tokens(self): """Remove threepid validation tokens with expiry dates that have passed""" + def cull_expired_threepid_validation_tokens_txn(txn, ts): sql = """ DELETE FROM threepid_validation_token WHERE @@ -1335,6 +1321,7 @@ class RegistrationStore( Args: session_id (str): The ID of the session to delete """ + def delete_threepid_session_txn(txn): self._simple_delete_txn( txn, @@ -1348,8 +1335,7 @@ class RegistrationStore( ) return self.runInteraction( - "delete_threepid_session", - delete_threepid_session_txn, + "delete_threepid_session", delete_threepid_session_txn ) def set_user_deactivated_status_txn(self, txn, user_id, deactivated): @@ -1360,7 +1346,7 @@ class RegistrationStore( updatevalues={"deactivated": 1 if deactivated else 0}, ) self._invalidate_cache_and_stream( - txn, self.get_user_deactivated_status, (user_id,), + txn, self.get_user_deactivated_status, (user_id,) ) @defer.inlineCallbacks @@ -1375,7 +1361,8 @@ class RegistrationStore( yield self.runInteraction( "set_user_deactivated_status", self.set_user_deactivated_status_txn, - user_id, deactivated, + user_id, + deactivated, ) @cachedInlineCallbacks() |