summary refs log tree commit diff
path: root/synapse/storage/registration.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/registration.py')
-rw-r--r--synapse/storage/registration.py227
1 files changed, 172 insertions, 55 deletions
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 1dd1182e82..983ce13291 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -15,6 +15,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import logging
 import re
 
 from six import iterkeys
@@ -31,6 +32,8 @@ from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
 
 THIRTY_MINUTES_IN_MS = 30 * 60 * 1000
 
+logger = logging.getLogger(__name__)
+
 
 class RegistrationWorkerStore(SQLBaseStore):
     def __init__(self, db_conn, hs):
@@ -113,8 +116,9 @@ class RegistrationWorkerStore(SQLBaseStore):
         defer.returnValue(res)
 
     @defer.inlineCallbacks
-    def set_account_validity_for_user(self, user_id, expiration_ts, email_sent,
-                                      renewal_token=None):
+    def set_account_validity_for_user(
+        self, user_id, expiration_ts, email_sent, renewal_token=None
+    ):
         """Updates the account validity properties of the given account, with the
         given values.
 
@@ -128,6 +132,7 @@ class RegistrationWorkerStore(SQLBaseStore):
             renewal_token (str): Renewal token the user can use to extend the validity
                 of their account. Defaults to no token.
         """
+
         def set_account_validity_for_user_txn(txn):
             self._simple_update_txn(
                 txn=txn,
@@ -140,12 +145,11 @@ class RegistrationWorkerStore(SQLBaseStore):
                 },
             )
             self._invalidate_cache_and_stream(
-                txn, self.get_expiration_ts_for_user, (user_id,),
+                txn, self.get_expiration_ts_for_user, (user_id,)
             )
 
         yield self.runInteraction(
-            "set_account_validity_for_user",
-            set_account_validity_for_user_txn,
+            "set_account_validity_for_user", set_account_validity_for_user_txn
         )
 
     @defer.inlineCallbacks
@@ -214,6 +218,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         Returns:
             Deferred: Resolves to a list[dict[user_id (str), expiration_ts_ms (int)]]
         """
+
         def select_users_txn(txn, now_ms, renew_at):
             sql = (
                 "SELECT user_id, expiration_ts_ms FROM account_validity"
@@ -226,7 +231,8 @@ class RegistrationWorkerStore(SQLBaseStore):
         res = yield self.runInteraction(
             "get_users_expiring_soon",
             select_users_txn,
-            self.clock.time_msec(), self.config.account_validity.renew_at,
+            self.clock.time_msec(),
+            self.config.account_validity.renew_at,
         )
 
         defer.returnValue(res)
@@ -249,6 +255,20 @@ class RegistrationWorkerStore(SQLBaseStore):
         )
 
     @defer.inlineCallbacks
+    def delete_account_validity_for_user(self, user_id):
+        """Deletes the entry for the given user in the account validity table, removing
+        their expiration date and renewal token.
+
+        Args:
+            user_id (str): ID of the user to remove from the account validity table.
+        """
+        yield self._simple_delete_one(
+            table="account_validity",
+            keyvalues={"user_id": user_id},
+            desc="delete_account_validity_for_user",
+        )
+
+    @defer.inlineCallbacks
     def is_server_admin(self, user):
         res = yield self._simple_select_one_onecol(
             table="users",
@@ -352,7 +372,7 @@ class RegistrationWorkerStore(SQLBaseStore):
                     WHERE creation_ts > ?
                 ) AS t GROUP BY user_type
             """
-            results = {'native': 0, 'guest': 0, 'bridged': 0}
+            results = {"native": 0, "guest": 0, "bridged": 0}
             txn.execute(sql, (yesterday,))
             for row in txn:
                 results[row[0]] = row[1]
@@ -418,7 +438,7 @@ class RegistrationWorkerStore(SQLBaseStore):
             {"medium": medium, "address": address},
             ["guest_access_token"],
             True,
-            'get_3pid_guest_access_token',
+            "get_3pid_guest_access_token",
         )
         if ret:
             defer.returnValue(ret["guest_access_token"])
@@ -455,11 +475,11 @@ class RegistrationWorkerStore(SQLBaseStore):
             txn,
             "user_threepids",
             {"medium": medium, "address": address},
-            ['user_id'],
+            ["user_id"],
             True,
         )
         if ret:
-            return ret['user_id']
+            return ret["user_id"]
         return None
 
     @defer.inlineCallbacks
@@ -475,8 +495,8 @@ class RegistrationWorkerStore(SQLBaseStore):
         ret = yield self._simple_select_list(
             "user_threepids",
             {"user_id": user_id},
-            ['medium', 'address', 'validated_at', 'added_at'],
-            'user_get_threepids',
+            ["medium", "address", "validated_at", "added_at"],
+            "user_get_threepids",
         )
         defer.returnValue(ret)
 
@@ -555,11 +575,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         """
         return self._simple_select_onecol(
             table="user_threepid_id_server",
-            keyvalues={
-                "user_id": user_id,
-                "medium": medium,
-                "address": address,
-            },
+            keyvalues={"user_id": user_id, "medium": medium, "address": address},
             retcol="id_server",
             desc="get_id_servers_user_bound",
         )
@@ -595,15 +611,80 @@ class RegistrationStore(
         self.register_noop_background_update("refresh_tokens_device_index")
 
         self.register_background_update_handler(
-            "user_threepids_grandfather", self._bg_user_threepids_grandfather,
+            "user_threepids_grandfather", self._bg_user_threepids_grandfather
+        )
+
+        self.register_background_update_handler(
+            "users_set_deactivated_flag", self._backgroud_update_set_deactivated_flag
         )
 
         # Create a background job for culling expired 3PID validity tokens
         hs.get_clock().looping_call(
-            self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS,
+            self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS
         )
 
     @defer.inlineCallbacks
+    def _backgroud_update_set_deactivated_flag(self, progress, batch_size):
+        """Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
+        for each of them.
+        """
+
+        last_user = progress.get("user_id", "")
+
+        def _backgroud_update_set_deactivated_flag_txn(txn):
+            txn.execute(
+                """
+                SELECT
+                    users.name,
+                    COUNT(access_tokens.token) AS count_tokens,
+                    COUNT(user_threepids.address) AS count_threepids
+                FROM users
+                    LEFT JOIN access_tokens ON (access_tokens.user_id = users.name)
+                    LEFT JOIN user_threepids ON (user_threepids.user_id = users.name)
+                WHERE (users.password_hash IS NULL OR users.password_hash = '')
+                AND (users.appservice_id IS NULL OR users.appservice_id = '')
+                AND users.is_guest = 0
+                AND users.name > ?
+                GROUP BY users.name
+                ORDER BY users.name ASC
+                LIMIT ?;
+                """,
+                (last_user, batch_size),
+            )
+
+            rows = self.cursor_to_dict(txn)
+
+            if not rows:
+                return True
+
+            rows_processed_nb = 0
+
+            for user in rows:
+                if not user["count_tokens"] and not user["count_threepids"]:
+                    self.set_user_deactivated_status_txn(txn, user["name"], True)
+                    rows_processed_nb += 1
+
+            logger.info("Marked %d rows as deactivated", rows_processed_nb)
+
+            self._background_update_progress_txn(
+                txn, "users_set_deactivated_flag", {"user_id": rows[-1]["name"]}
+            )
+
+            if batch_size > len(rows):
+                return True
+            else:
+                return False
+
+        end = yield self.runInteraction(
+            "users_set_deactivated_flag", _backgroud_update_set_deactivated_flag_txn
+        )
+
+        if end:
+            yield self._end_background_update("users_set_deactivated_flag")
+
+        defer.returnValue(batch_size)
+
+    @defer.inlineCallbacks
     def add_access_token_to_user(self, user_id, token, device_id=None):
         """Adds an access token for the given user.
 
@@ -768,7 +849,7 @@ class RegistrationStore(
 
         def user_set_password_hash_txn(txn):
             self._simple_update_one_txn(
-                txn, 'users', {'name': user_id}, {'password_hash': password_hash}
+                txn, "users", {"name": user_id}, {"password_hash": password_hash}
             )
             self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
 
@@ -789,9 +870,9 @@ class RegistrationStore(
         def f(txn):
             self._simple_update_one_txn(
                 txn,
-                table='users',
-                keyvalues={'name': user_id},
-                updatevalues={'consent_version': consent_version},
+                table="users",
+                keyvalues={"name": user_id},
+                updatevalues={"consent_version": consent_version},
             )
             self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
 
@@ -813,9 +894,9 @@ class RegistrationStore(
         def f(txn):
             self._simple_update_one_txn(
                 txn,
-                table='users',
-                keyvalues={'name': user_id},
-                updatevalues={'consent_server_notice_sent': consent_version},
+                table="users",
+                keyvalues={"name": user_id},
+                updatevalues={"consent_server_notice_sent": consent_version},
             )
             self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
 
@@ -985,7 +1066,7 @@ class RegistrationStore(
 
         if id_servers:
             yield self.runInteraction(
-                "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn,
+                "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn
             )
 
         yield self._end_background_update("user_threepids_grandfather")
@@ -993,12 +1074,7 @@ class RegistrationStore(
         defer.returnValue(1)
 
     def get_threepid_validation_session(
-        self,
-        medium,
-        client_secret,
-        address=None,
-        sid=None,
-        validated=True,
+        self, medium, client_secret, address=None, sid=None, validated=True
     ):
         """Gets a session_id and last_send_attempt (if available) for a
         client_secret/medium/(address|session_id) combo
@@ -1018,23 +1094,22 @@ class RegistrationStore(
                 latest session_id and send_attempt count for this 3PID.
                 Otherwise None if there hasn't been a previous attempt
         """
-        keyvalues = {
-            "medium": medium,
-            "client_secret": client_secret,
-        }
+        keyvalues = {"medium": medium, "client_secret": client_secret}
         if address:
             keyvalues["address"] = address
         if sid:
             keyvalues["session_id"] = sid
 
-        assert(address or sid)
+        assert address or sid
 
         def get_threepid_validation_session_txn(txn):
             sql = """
                 SELECT address, session_id, medium, client_secret,
                 last_send_attempt, validated_at
                 FROM threepid_validation_session WHERE %s
-                """ % (" AND ".join("%s = ?" % k for k in iterkeys(keyvalues)),)
+                """ % (
+                " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)),
+            )
 
             if validated is not None:
                 sql += " AND validated_at IS " + ("NOT NULL" if validated else "NULL")
@@ -1049,17 +1124,10 @@ class RegistrationStore(
             return rows[0]
 
         return self.runInteraction(
-            "get_threepid_validation_session",
-            get_threepid_validation_session_txn,
+            "get_threepid_validation_session", get_threepid_validation_session_txn
         )
 
-    def validate_threepid_session(
-        self,
-        session_id,
-        client_secret,
-        token,
-        current_ts,
-    ):
+    def validate_threepid_session(self, session_id, client_secret, token, current_ts):
         """Attempt to validate a threepid session using a token
 
         Args:
@@ -1091,7 +1159,7 @@ class RegistrationStore(
 
             if retrieved_client_secret != client_secret:
                 raise ThreepidValidationError(
-                    400, "This client_secret does not match the provided session_id",
+                    400, "This client_secret does not match the provided session_id"
                 )
 
             row = self._simple_select_one_txn(
@@ -1104,7 +1172,7 @@ class RegistrationStore(
 
             if not row:
                 raise ThreepidValidationError(
-                    400, "Validation token not found or has expired",
+                    400, "Validation token not found or has expired"
                 )
             expires = row["expires"]
             next_link = row["next_link"]
@@ -1115,7 +1183,7 @@ class RegistrationStore(
 
             if expires <= current_ts:
                 raise ThreepidValidationError(
-                    400, "This token has expired. Please request a new one",
+                    400, "This token has expired. Please request a new one"
                 )
 
             # Looks good. Validate the session
@@ -1130,8 +1198,7 @@ class RegistrationStore(
 
         # Return next_link if it exists
         return self.runInteraction(
-            "validate_threepid_session_txn",
-            validate_threepid_session_txn,
+            "validate_threepid_session_txn", validate_threepid_session_txn
         )
 
     def upsert_threepid_validation_session(
@@ -1198,6 +1265,7 @@ class RegistrationStore(
             token_expires (int): The timestamp for which after the token
                 will no longer be valid
         """
+
         def start_or_continue_validation_session_txn(txn):
             # Create or update a validation session
             self._simple_upsert_txn(
@@ -1231,6 +1299,7 @@ class RegistrationStore(
 
     def cull_expired_threepid_validation_tokens(self):
         """Remove threepid validation tokens with expiry dates that have passed"""
+
         def cull_expired_threepid_validation_tokens_txn(txn, ts):
             sql = """
             DELETE FROM threepid_validation_token WHERE
@@ -1252,6 +1321,7 @@ class RegistrationStore(
         Args:
             session_id (str): The ID of the session to delete
         """
+
         def delete_threepid_session_txn(txn):
             self._simple_delete_txn(
                 txn,
@@ -1265,6 +1335,53 @@ class RegistrationStore(
             )
 
         return self.runInteraction(
-            "delete_threepid_session",
-            delete_threepid_session_txn,
+            "delete_threepid_session", delete_threepid_session_txn
+        )
+
+    def set_user_deactivated_status_txn(self, txn, user_id, deactivated):
+        self._simple_update_one_txn(
+            txn=txn,
+            table="users",
+            keyvalues={"name": user_id},
+            updatevalues={"deactivated": 1 if deactivated else 0},
+        )
+        self._invalidate_cache_and_stream(
+            txn, self.get_user_deactivated_status, (user_id,)
         )
+
+    @defer.inlineCallbacks
+    def set_user_deactivated_status(self, user_id, deactivated):
+        """Set the `deactivated` property for the provided user to the provided value.
+
+        Args:
+            user_id (str): The ID of the user to set the status for.
+            deactivated (bool): The value to set for `deactivated`.
+        """
+
+        yield self.runInteraction(
+            "set_user_deactivated_status",
+            self.set_user_deactivated_status_txn,
+            user_id,
+            deactivated,
+        )
+
+    @cachedInlineCallbacks()
+    def get_user_deactivated_status(self, user_id):
+        """Retrieve the value for the `deactivated` property for the provided user.
+
+        Args:
+            user_id (str): The ID of the user to retrieve the status for.
+
+        Returns:
+            defer.Deferred(bool): The requested value.
+        """
+
+        res = yield self._simple_select_one_onecol(
+            table="users",
+            keyvalues={"name": user_id},
+            retcol="deactivated",
+            desc="get_user_deactivated_status",
+        )
+
+        # Convert the integer into a boolean.
+        defer.returnValue(res == 1)