summary refs log tree commit diff
path: root/synapse/storage/registration.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/registration.py')
-rw-r--r--synapse/storage/registration.py587
1 files changed, 566 insertions, 21 deletions
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 643f7a3808..d36917e4d6 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -1,5 +1,7 @@
 # -*- coding: utf-8 -*-
-# Copyright 2014 - 2016 OpenMarket Ltd
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,25 +15,32 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import logging
 import re
 
+from six import iterkeys
 from six.moves import range
 
 from twisted.internet import defer
 
 from synapse.api.constants import UserTypes
-from synapse.api.errors import Codes, StoreError
+from synapse.api.errors import Codes, StoreError, ThreepidValidationError
 from synapse.storage import background_updates
 from synapse.storage._base import SQLBaseStore
 from synapse.types import UserID
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
 
+THIRTY_MINUTES_IN_MS = 30 * 60 * 1000
+
+logger = logging.getLogger(__name__)
+
 
 class RegistrationWorkerStore(SQLBaseStore):
     def __init__(self, db_conn, hs):
         super(RegistrationWorkerStore, self).__init__(db_conn, hs)
 
         self.config = hs.config
+        self.clock = hs.get_clock()
 
     @cached()
     def get_user_by_id(self, user_id):
@@ -87,26 +96,176 @@ class RegistrationWorkerStore(SQLBaseStore):
         )
 
     @cachedInlineCallbacks()
-    def get_expiration_ts_for_user(self, user):
+    def get_expiration_ts_for_user(self, user_id):
         """Get the expiration timestamp for the account bearing a given user ID.
 
         Args:
-            user (str): The ID of the user.
+            user_id (str): The ID of the user.
         Returns:
             defer.Deferred: None, if the account has no expiration timestamp,
-            otherwise int representation of the timestamp (as a number of
-            milliseconds since epoch).
+                otherwise int representation of the timestamp (as a number of
+                milliseconds since epoch).
         """
         res = yield self._simple_select_one_onecol(
             table="account_validity",
-            keyvalues={"user_id": user.to_string()},
+            keyvalues={"user_id": user_id},
             retcol="expiration_ts_ms",
             allow_none=True,
-            desc="get_expiration_date_for_user",
+            desc="get_expiration_ts_for_user",
         )
         defer.returnValue(res)
 
     @defer.inlineCallbacks
+    def set_account_validity_for_user(self, user_id, expiration_ts, email_sent,
+                                      renewal_token=None):
+        """Updates the account validity properties of the given account, with the
+        given values.
+
+        Args:
+            user_id (str): ID of the account to update properties for.
+            expiration_ts (int): New expiration date, as a timestamp in milliseconds
+                since epoch.
+            email_sent (bool): True means a renewal email has been sent for this
+                account and there's no need to send another one for the current validity
+                period.
+            renewal_token (str): Renewal token the user can use to extend the validity
+                of their account. Defaults to no token.
+        """
+        def set_account_validity_for_user_txn(txn):
+            self._simple_update_txn(
+                txn=txn,
+                table="account_validity",
+                keyvalues={"user_id": user_id},
+                updatevalues={
+                    "expiration_ts_ms": expiration_ts,
+                    "email_sent": email_sent,
+                    "renewal_token": renewal_token,
+                },
+            )
+            self._invalidate_cache_and_stream(
+                txn, self.get_expiration_ts_for_user, (user_id,),
+            )
+
+        yield self.runInteraction(
+            "set_account_validity_for_user",
+            set_account_validity_for_user_txn,
+        )
+
+    @defer.inlineCallbacks
+    def set_renewal_token_for_user(self, user_id, renewal_token):
+        """Defines a renewal token for a given user.
+
+        Args:
+            user_id (str): ID of the user to set the renewal token for.
+            renewal_token (str): Random unique string that will be used to renew the
+                user's account.
+
+        Raises:
+            StoreError: The provided token is already set for another user.
+        """
+        yield self._simple_update_one(
+            table="account_validity",
+            keyvalues={"user_id": user_id},
+            updatevalues={"renewal_token": renewal_token},
+            desc="set_renewal_token_for_user",
+        )
+
+    @defer.inlineCallbacks
+    def get_user_from_renewal_token(self, renewal_token):
+        """Get a user ID from a renewal token.
+
+        Args:
+            renewal_token (str): The renewal token to perform the lookup with.
+
+        Returns:
+            defer.Deferred[str]: The ID of the user to which the token belongs.
+        """
+        res = yield self._simple_select_one_onecol(
+            table="account_validity",
+            keyvalues={"renewal_token": renewal_token},
+            retcol="user_id",
+            desc="get_user_from_renewal_token",
+        )
+
+        defer.returnValue(res)
+
+    @defer.inlineCallbacks
+    def get_renewal_token_for_user(self, user_id):
+        """Get the renewal token associated with a given user ID.
+
+        Args:
+            user_id (str): The user ID to lookup a token for.
+
+        Returns:
+            defer.Deferred[str]: The renewal token associated with this user ID.
+        """
+        res = yield self._simple_select_one_onecol(
+            table="account_validity",
+            keyvalues={"user_id": user_id},
+            retcol="renewal_token",
+            desc="get_renewal_token_for_user",
+        )
+
+        defer.returnValue(res)
+
+    @defer.inlineCallbacks
+    def get_users_expiring_soon(self):
+        """Selects users whose account will expire in the [now, now + renew_at] time
+        window (see configuration for account_validity for information on what renew_at
+        refers to).
+
+        Returns:
+            Deferred: Resolves to a list[dict[user_id (str), expiration_ts_ms (int)]]
+        """
+        def select_users_txn(txn, now_ms, renew_at):
+            sql = (
+                "SELECT user_id, expiration_ts_ms FROM account_validity"
+                " WHERE email_sent = ? AND (expiration_ts_ms - ?) <= ?"
+            )
+            values = [False, now_ms, renew_at]
+            txn.execute(sql, values)
+            return self.cursor_to_dict(txn)
+
+        res = yield self.runInteraction(
+            "get_users_expiring_soon",
+            select_users_txn,
+            self.clock.time_msec(), self.config.account_validity.renew_at,
+        )
+
+        defer.returnValue(res)
+
+    @defer.inlineCallbacks
+    def set_renewal_mail_status(self, user_id, email_sent):
+        """Sets or unsets the flag that indicates whether a renewal email has been sent
+        to the user (and the user hasn't renewed their account yet).
+
+        Args:
+            user_id (str): ID of the user to set/unset the flag for.
+            email_sent (bool): Flag which indicates whether a renewal email has been sent
+                to this user.
+        """
+        yield self._simple_update_one(
+            table="account_validity",
+            keyvalues={"user_id": user_id},
+            updatevalues={"email_sent": email_sent},
+            desc="set_renewal_mail_status",
+        )
+
+    @defer.inlineCallbacks
+    def delete_account_validity_for_user(self, user_id):
+        """Deletes the entry for the given user in the account validity table, removing
+        their expiration date and renewal token.
+
+        Args:
+            user_id (str): ID of the user to remove from the account validity table.
+        """
+        yield self._simple_delete_one(
+            table="account_validity",
+            keyvalues={"user_id": user_id},
+            desc="delete_account_validity_for_user",
+        )
+
+    @defer.inlineCallbacks
     def is_server_admin(self, user):
         res = yield self._simple_select_one_onecol(
             table="users",
@@ -283,7 +442,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         defer.returnValue(None)
 
     @defer.inlineCallbacks
-    def get_user_id_by_threepid(self, medium, address):
+    def get_user_id_by_threepid(self, medium, address, require_verified=False):
         """Returns user id from threepid
 
         Args:
@@ -456,6 +615,77 @@ class RegistrationStore(
             "user_threepids_grandfather", self._bg_user_threepids_grandfather,
         )
 
+        self.register_background_update_handler(
+            "users_set_deactivated_flag", self._backgroud_update_set_deactivated_flag,
+        )
+
+        # Create a background job for culling expired 3PID validity tokens
+        hs.get_clock().looping_call(
+            self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS,
+        )
+
+    @defer.inlineCallbacks
+    def _backgroud_update_set_deactivated_flag(self, progress, batch_size):
+        """Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
+        for each of them.
+        """
+
+        last_user = progress.get("user_id", "")
+
+        def _backgroud_update_set_deactivated_flag_txn(txn):
+            txn.execute(
+                """
+                SELECT
+                    users.name,
+                    COUNT(access_tokens.token) AS count_tokens,
+                    COUNT(user_threepids.address) AS count_threepids
+                FROM users
+                    LEFT JOIN access_tokens ON (access_tokens.user_id = users.name)
+                    LEFT JOIN user_threepids ON (user_threepids.user_id = users.name)
+                WHERE (users.password_hash IS NULL OR users.password_hash = '')
+                AND (users.appservice_id IS NULL OR users.appservice_id = '')
+                AND users.is_guest = 0
+                AND users.name > ?
+                GROUP BY users.name
+                ORDER BY users.name ASC
+                LIMIT ?;
+                """,
+                (last_user, batch_size),
+            )
+
+            rows = self.cursor_to_dict(txn)
+
+            if not rows:
+                return True
+
+            rows_processed_nb = 0
+
+            for user in rows:
+                if not user["count_tokens"] and not user["count_threepids"]:
+                    self.set_user_deactivated_status_txn(txn, user["user_id"], True)
+                    rows_processed_nb += 1
+
+            logger.info("Marked %d rows as deactivated", rows_processed_nb)
+
+            self._background_update_progress_txn(
+                txn, "users_set_deactivated_flag", {"user_id": rows[-1]["name"]}
+            )
+
+            if batch_size > len(rows):
+                return True
+            else:
+                return False
+
+        end = yield self.runInteraction(
+            "users_set_deactivated_flag",
+            _backgroud_update_set_deactivated_flag_txn,
+        )
+
+        if end:
+            yield self._end_background_update("users_set_deactivated_flag")
+
+        defer.returnValue(batch_size)
+
     @defer.inlineCallbacks
     def add_access_token_to_user(self, user_id, token, device_id=None):
         """Adds an access token for the given user.
@@ -584,20 +814,12 @@ class RegistrationStore(
                     },
                 )
 
-            if self._account_validity.enabled:
-                now_ms = self.clock.time_msec()
-                expiration_ts = now_ms + self._account_validity.period
-                self._simple_insert_txn(
-                    txn,
-                    "account_validity",
-                    values={
-                        "user_id": user_id,
-                        "expiration_ts_ms": expiration_ts,
-                    }
-                )
         except self.database_engine.module.IntegrityError:
             raise StoreError(400, "User ID already taken.", errcode=Codes.USER_IN_USE)
 
+        if self._account_validity.enabled:
+            self.set_expiration_date_for_user_txn(txn, user_id)
+
         if token:
             # it's possible for this to get a conflict, but only for a single user
             # since tokens are namespaced based on their user ID
@@ -832,7 +1054,6 @@ class RegistrationStore(
         We do this by grandfathering in existing user threepids assuming that
         they used one of the server configured trusted identity servers.
         """
-
         id_servers = set(self.config.trusted_third_party_id_servers)
 
         def _bg_user_threepids_grandfather_txn(txn):
@@ -853,3 +1074,327 @@ class RegistrationStore(
         yield self._end_background_update("user_threepids_grandfather")
 
         defer.returnValue(1)
+
+    def get_threepid_validation_session(
+        self,
+        medium,
+        client_secret,
+        address=None,
+        sid=None,
+        validated=True,
+    ):
+        """Gets a session_id and last_send_attempt (if available) for a
+        client_secret/medium/(address|session_id) combo
+
+        Args:
+            medium (str|None): The medium of the 3PID
+            address (str|None): The address of the 3PID
+            sid (str|None): The ID of the validation session
+            client_secret (str|None): A unique string provided by the client to
+                help identify this validation attempt
+            validated (bool|None): Whether sessions should be filtered by
+                whether they have been validated already or not. None to
+                perform no filtering
+
+        Returns:
+            deferred {str, int}|None: A dict containing the
+                latest session_id and send_attempt count for this 3PID.
+                Otherwise None if there hasn't been a previous attempt
+        """
+        keyvalues = {
+            "medium": medium,
+            "client_secret": client_secret,
+        }
+        if address:
+            keyvalues["address"] = address
+        if sid:
+            keyvalues["session_id"] = sid
+
+        assert(address or sid)
+
+        def get_threepid_validation_session_txn(txn):
+            sql = """
+                SELECT address, session_id, medium, client_secret,
+                last_send_attempt, validated_at
+                FROM threepid_validation_session WHERE %s
+                """ % (" AND ".join("%s = ?" % k for k in iterkeys(keyvalues)),)
+
+            if validated is not None:
+                sql += " AND validated_at IS " + ("NOT NULL" if validated else "NULL")
+
+            sql += " LIMIT 1"
+
+            txn.execute(sql, list(keyvalues.values()))
+            rows = self.cursor_to_dict(txn)
+            if not rows:
+                return None
+
+            return rows[0]
+
+        return self.runInteraction(
+            "get_threepid_validation_session",
+            get_threepid_validation_session_txn,
+        )
+
+    def validate_threepid_session(
+        self,
+        session_id,
+        client_secret,
+        token,
+        current_ts,
+    ):
+        """Attempt to validate a threepid session using a token
+
+        Args:
+            session_id (str): The id of a validation session
+            client_secret (str): A unique string provided by the client to
+                help identify this validation attempt
+            token (str): A validation token
+            current_ts (int): The current unix time in milliseconds. Used for
+                checking token expiry status
+
+        Returns:
+            deferred str|None: A str representing a link to redirect the user
+            to if there is one.
+        """
+        # Insert everything into a transaction in order to run atomically
+        def validate_threepid_session_txn(txn):
+            row = self._simple_select_one_txn(
+                txn,
+                table="threepid_validation_session",
+                keyvalues={"session_id": session_id},
+                retcols=["client_secret", "validated_at"],
+                allow_none=True,
+            )
+
+            if not row:
+                raise ThreepidValidationError(400, "Unknown session_id")
+            retrieved_client_secret = row["client_secret"]
+            validated_at = row["validated_at"]
+
+            if retrieved_client_secret != client_secret:
+                raise ThreepidValidationError(
+                    400, "This client_secret does not match the provided session_id",
+                )
+
+            row = self._simple_select_one_txn(
+                txn,
+                table="threepid_validation_token",
+                keyvalues={"session_id": session_id, "token": token},
+                retcols=["expires", "next_link"],
+                allow_none=True,
+            )
+
+            if not row:
+                raise ThreepidValidationError(
+                    400, "Validation token not found or has expired",
+                )
+            expires = row["expires"]
+            next_link = row["next_link"]
+
+            # If the session is already validated, no need to revalidate
+            if validated_at:
+                return next_link
+
+            if expires <= current_ts:
+                raise ThreepidValidationError(
+                    400, "This token has expired. Please request a new one",
+                )
+
+            # Looks good. Validate the session
+            self._simple_update_txn(
+                txn,
+                table="threepid_validation_session",
+                keyvalues={"session_id": session_id},
+                updatevalues={"validated_at": self.clock.time_msec()},
+            )
+
+            return next_link
+
+        # Return next_link if it exists
+        return self.runInteraction(
+            "validate_threepid_session_txn",
+            validate_threepid_session_txn,
+        )
+
+    def upsert_threepid_validation_session(
+        self,
+        medium,
+        address,
+        client_secret,
+        send_attempt,
+        session_id,
+        validated_at=None,
+    ):
+        """Upsert a threepid validation session
+        Args:
+            medium (str): The medium of the 3PID
+            address (str): The address of the 3PID
+            client_secret (str): A unique string provided by the client to
+                help identify this validation attempt
+            send_attempt (int): The latest send_attempt on this session
+            session_id (str): The id of this validation session
+            validated_at (int|None): The unix timestamp in milliseconds of
+                when the session was marked as valid
+        """
+        insertion_values = {
+            "medium": medium,
+            "address": address,
+            "client_secret": client_secret,
+        }
+
+        if validated_at:
+            insertion_values["validated_at"] = validated_at
+
+        return self._simple_upsert(
+            table="threepid_validation_session",
+            keyvalues={"session_id": session_id},
+            values={"last_send_attempt": send_attempt},
+            insertion_values=insertion_values,
+            desc="upsert_threepid_validation_session",
+        )
+
+    def start_or_continue_validation_session(
+        self,
+        medium,
+        address,
+        session_id,
+        client_secret,
+        send_attempt,
+        next_link,
+        token,
+        token_expires,
+    ):
+        """Creates a new threepid validation session if it does not already
+        exist and associates a new validation token with it
+
+        Args:
+            medium (str): The medium of the 3PID
+            address (str): The address of the 3PID
+            session_id (str): The id of this validation session
+            client_secret (str): A unique string provided by the client to
+                help identify this validation attempt
+            send_attempt (int): The latest send_attempt on this session
+            next_link (str|None): The link to redirect the user to upon
+                successful validation
+            token (str): The validation token
+            token_expires (int): The timestamp for which after the token
+                will no longer be valid
+        """
+        def start_or_continue_validation_session_txn(txn):
+            # Create or update a validation session
+            self._simple_upsert_txn(
+                txn,
+                table="threepid_validation_session",
+                keyvalues={"session_id": session_id},
+                values={"last_send_attempt": send_attempt},
+                insertion_values={
+                    "medium": medium,
+                    "address": address,
+                    "client_secret": client_secret,
+                },
+            )
+
+            # Create a new validation token with this session ID
+            self._simple_insert_txn(
+                txn,
+                table="threepid_validation_token",
+                values={
+                    "session_id": session_id,
+                    "token": token,
+                    "next_link": next_link,
+                    "expires": token_expires,
+                },
+            )
+
+        return self.runInteraction(
+            "start_or_continue_validation_session",
+            start_or_continue_validation_session_txn,
+        )
+
+    def cull_expired_threepid_validation_tokens(self):
+        """Remove threepid validation tokens with expiry dates that have passed"""
+        def cull_expired_threepid_validation_tokens_txn(txn, ts):
+            sql = """
+            DELETE FROM threepid_validation_token WHERE
+            expires < ?
+            """
+            return txn.execute(sql, (ts,))
+
+        return self.runInteraction(
+            "cull_expired_threepid_validation_tokens",
+            cull_expired_threepid_validation_tokens_txn,
+            self.clock.time_msec(),
+        )
+
+    def delete_threepid_session(self, session_id):
+        """Removes a threepid validation session from the database. This can
+        be done after validation has been performed and whatever action was
+        waiting on it has been carried out
+
+        Args:
+            session_id (str): The ID of the session to delete
+        """
+        def delete_threepid_session_txn(txn):
+            self._simple_delete_txn(
+                txn,
+                table="threepid_validation_token",
+                keyvalues={"session_id": session_id},
+            )
+            self._simple_delete_txn(
+                txn,
+                table="threepid_validation_session",
+                keyvalues={"session_id": session_id},
+            )
+
+        return self.runInteraction(
+            "delete_threepid_session",
+            delete_threepid_session_txn,
+        )
+
+    def set_user_deactivated_status_txn(self, txn, user_id, deactivated):
+        self._simple_update_one_txn(
+            txn=txn,
+            table="users",
+            keyvalues={"name": user_id},
+            updatevalues={"deactivated": 1 if deactivated else 0},
+        )
+        self._invalidate_cache_and_stream(
+            txn, self.get_user_deactivated_status, (user_id,),
+        )
+
+    @defer.inlineCallbacks
+    def set_user_deactivated_status(self, user_id, deactivated):
+        """Set the `deactivated` property for the provided user to the provided value.
+
+        Args:
+            user_id (str): The ID of the user to set the status for.
+            deactivated (bool): The value to set for `deactivated`.
+        """
+
+        yield self.runInteraction(
+            "set_user_deactivated_status",
+            self.set_user_deactivated_status_txn,
+            user_id, deactivated,
+        )
+
+    @cachedInlineCallbacks()
+    def get_user_deactivated_status(self, user_id):
+        """Retrieve the value for the `deactivated` property for the provided user.
+
+        Args:
+            user_id (str): The ID of the user to retrieve the status for.
+
+        Returns:
+            defer.Deferred(bool): The requested value.
+        """
+
+        res = yield self._simple_select_one_onecol(
+            table="users",
+            keyvalues={"name": user_id},
+            retcol="deactivated",
+            desc="get_user_deactivated_status",
+        )
+
+        # Convert the integer into a boolean.
+        defer.returnValue(res == 1)