diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 643f7a3808..d36917e4d6 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -1,5 +1,7 @@
# -*- coding: utf-8 -*-
-# Copyright 2014 - 2016 OpenMarket Ltd
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,25 +15,32 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
import re
+from six import iterkeys
from six.moves import range
from twisted.internet import defer
from synapse.api.constants import UserTypes
-from synapse.api.errors import Codes, StoreError
+from synapse.api.errors import Codes, StoreError, ThreepidValidationError
from synapse.storage import background_updates
from synapse.storage._base import SQLBaseStore
from synapse.types import UserID
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+THIRTY_MINUTES_IN_MS = 30 * 60 * 1000
+
+logger = logging.getLogger(__name__)
+
class RegistrationWorkerStore(SQLBaseStore):
def __init__(self, db_conn, hs):
super(RegistrationWorkerStore, self).__init__(db_conn, hs)
self.config = hs.config
+ self.clock = hs.get_clock()
@cached()
def get_user_by_id(self, user_id):
@@ -87,26 +96,176 @@ class RegistrationWorkerStore(SQLBaseStore):
)
@cachedInlineCallbacks()
- def get_expiration_ts_for_user(self, user):
+ def get_expiration_ts_for_user(self, user_id):
"""Get the expiration timestamp for the account bearing a given user ID.
Args:
- user (str): The ID of the user.
+ user_id (str): The ID of the user.
Returns:
defer.Deferred: None, if the account has no expiration timestamp,
- otherwise int representation of the timestamp (as a number of
- milliseconds since epoch).
+ otherwise int representation of the timestamp (as a number of
+ milliseconds since epoch).
"""
res = yield self._simple_select_one_onecol(
table="account_validity",
- keyvalues={"user_id": user.to_string()},
+ keyvalues={"user_id": user_id},
retcol="expiration_ts_ms",
allow_none=True,
- desc="get_expiration_date_for_user",
+ desc="get_expiration_ts_for_user",
)
defer.returnValue(res)
@defer.inlineCallbacks
+ def set_account_validity_for_user(self, user_id, expiration_ts, email_sent,
+ renewal_token=None):
+ """Updates the account validity properties of the given account, with the
+ given values.
+
+ Args:
+ user_id (str): ID of the account to update properties for.
+ expiration_ts (int): New expiration date, as a timestamp in milliseconds
+ since epoch.
+ email_sent (bool): True means a renewal email has been sent for this
+ account and there's no need to send another one for the current validity
+ period.
+ renewal_token (str): Renewal token the user can use to extend the validity
+ of their account. Defaults to no token.
+ """
+ def set_account_validity_for_user_txn(txn):
+ self._simple_update_txn(
+ txn=txn,
+ table="account_validity",
+ keyvalues={"user_id": user_id},
+ updatevalues={
+ "expiration_ts_ms": expiration_ts,
+ "email_sent": email_sent,
+ "renewal_token": renewal_token,
+ },
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_expiration_ts_for_user, (user_id,),
+ )
+
+ yield self.runInteraction(
+ "set_account_validity_for_user",
+ set_account_validity_for_user_txn,
+ )
+
+ @defer.inlineCallbacks
+ def set_renewal_token_for_user(self, user_id, renewal_token):
+ """Defines a renewal token for a given user.
+
+ Args:
+ user_id (str): ID of the user to set the renewal token for.
+ renewal_token (str): Random unique string that will be used to renew the
+ user's account.
+
+ Raises:
+ StoreError: The provided token is already set for another user.
+ """
+ yield self._simple_update_one(
+ table="account_validity",
+ keyvalues={"user_id": user_id},
+ updatevalues={"renewal_token": renewal_token},
+ desc="set_renewal_token_for_user",
+ )
+
+ @defer.inlineCallbacks
+ def get_user_from_renewal_token(self, renewal_token):
+ """Get a user ID from a renewal token.
+
+ Args:
+ renewal_token (str): The renewal token to perform the lookup with.
+
+ Returns:
+ defer.Deferred[str]: The ID of the user to which the token belongs.
+ """
+ res = yield self._simple_select_one_onecol(
+ table="account_validity",
+ keyvalues={"renewal_token": renewal_token},
+ retcol="user_id",
+ desc="get_user_from_renewal_token",
+ )
+
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def get_renewal_token_for_user(self, user_id):
+ """Get the renewal token associated with a given user ID.
+
+ Args:
+ user_id (str): The user ID to lookup a token for.
+
+ Returns:
+ defer.Deferred[str]: The renewal token associated with this user ID.
+ """
+ res = yield self._simple_select_one_onecol(
+ table="account_validity",
+ keyvalues={"user_id": user_id},
+ retcol="renewal_token",
+ desc="get_renewal_token_for_user",
+ )
+
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def get_users_expiring_soon(self):
+ """Selects users whose account will expire in the [now, now + renew_at] time
+ window (see configuration for account_validity for information on what renew_at
+ refers to).
+
+ Returns:
+ Deferred: Resolves to a list[dict[user_id (str), expiration_ts_ms (int)]]
+ """
+ def select_users_txn(txn, now_ms, renew_at):
+ sql = (
+ "SELECT user_id, expiration_ts_ms FROM account_validity"
+ " WHERE email_sent = ? AND (expiration_ts_ms - ?) <= ?"
+ )
+ values = [False, now_ms, renew_at]
+ txn.execute(sql, values)
+ return self.cursor_to_dict(txn)
+
+ res = yield self.runInteraction(
+ "get_users_expiring_soon",
+ select_users_txn,
+ self.clock.time_msec(), self.config.account_validity.renew_at,
+ )
+
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def set_renewal_mail_status(self, user_id, email_sent):
+ """Sets or unsets the flag that indicates whether a renewal email has been sent
+ to the user (and the user hasn't renewed their account yet).
+
+ Args:
+ user_id (str): ID of the user to set/unset the flag for.
+ email_sent (bool): Flag which indicates whether a renewal email has been sent
+ to this user.
+ """
+ yield self._simple_update_one(
+ table="account_validity",
+ keyvalues={"user_id": user_id},
+ updatevalues={"email_sent": email_sent},
+ desc="set_renewal_mail_status",
+ )
+
+ @defer.inlineCallbacks
+ def delete_account_validity_for_user(self, user_id):
+ """Deletes the entry for the given user in the account validity table, removing
+ their expiration date and renewal token.
+
+ Args:
+ user_id (str): ID of the user to remove from the account validity table.
+ """
+ yield self._simple_delete_one(
+ table="account_validity",
+ keyvalues={"user_id": user_id},
+ desc="delete_account_validity_for_user",
+ )
+
+ @defer.inlineCallbacks
def is_server_admin(self, user):
res = yield self._simple_select_one_onecol(
table="users",
@@ -283,7 +442,7 @@ class RegistrationWorkerStore(SQLBaseStore):
defer.returnValue(None)
@defer.inlineCallbacks
- def get_user_id_by_threepid(self, medium, address):
+ def get_user_id_by_threepid(self, medium, address, require_verified=False):
"""Returns user id from threepid
Args:
@@ -456,6 +615,77 @@ class RegistrationStore(
"user_threepids_grandfather", self._bg_user_threepids_grandfather,
)
+ self.register_background_update_handler(
+ "users_set_deactivated_flag", self._backgroud_update_set_deactivated_flag,
+ )
+
+ # Create a background job for culling expired 3PID validity tokens
+ hs.get_clock().looping_call(
+ self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS,
+ )
+
+ @defer.inlineCallbacks
+ def _backgroud_update_set_deactivated_flag(self, progress, batch_size):
+ """Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
+ for each of them.
+ """
+
+ last_user = progress.get("user_id", "")
+
+ def _backgroud_update_set_deactivated_flag_txn(txn):
+ txn.execute(
+ """
+ SELECT
+ users.name,
+ COUNT(access_tokens.token) AS count_tokens,
+ COUNT(user_threepids.address) AS count_threepids
+ FROM users
+ LEFT JOIN access_tokens ON (access_tokens.user_id = users.name)
+ LEFT JOIN user_threepids ON (user_threepids.user_id = users.name)
+ WHERE (users.password_hash IS NULL OR users.password_hash = '')
+ AND (users.appservice_id IS NULL OR users.appservice_id = '')
+ AND users.is_guest = 0
+ AND users.name > ?
+ GROUP BY users.name
+ ORDER BY users.name ASC
+ LIMIT ?;
+ """,
+ (last_user, batch_size),
+ )
+
+ rows = self.cursor_to_dict(txn)
+
+ if not rows:
+ return True
+
+ rows_processed_nb = 0
+
+ for user in rows:
+ if not user["count_tokens"] and not user["count_threepids"]:
+ self.set_user_deactivated_status_txn(txn, user["user_id"], True)
+ rows_processed_nb += 1
+
+ logger.info("Marked %d rows as deactivated", rows_processed_nb)
+
+ self._background_update_progress_txn(
+ txn, "users_set_deactivated_flag", {"user_id": rows[-1]["name"]}
+ )
+
+ if batch_size > len(rows):
+ return True
+ else:
+ return False
+
+ end = yield self.runInteraction(
+ "users_set_deactivated_flag",
+ _backgroud_update_set_deactivated_flag_txn,
+ )
+
+ if end:
+ yield self._end_background_update("users_set_deactivated_flag")
+
+ defer.returnValue(batch_size)
+
@defer.inlineCallbacks
def add_access_token_to_user(self, user_id, token, device_id=None):
"""Adds an access token for the given user.
@@ -584,20 +814,12 @@ class RegistrationStore(
},
)
- if self._account_validity.enabled:
- now_ms = self.clock.time_msec()
- expiration_ts = now_ms + self._account_validity.period
- self._simple_insert_txn(
- txn,
- "account_validity",
- values={
- "user_id": user_id,
- "expiration_ts_ms": expiration_ts,
- }
- )
except self.database_engine.module.IntegrityError:
raise StoreError(400, "User ID already taken.", errcode=Codes.USER_IN_USE)
+ if self._account_validity.enabled:
+ self.set_expiration_date_for_user_txn(txn, user_id)
+
if token:
# it's possible for this to get a conflict, but only for a single user
# since tokens are namespaced based on their user ID
@@ -832,7 +1054,6 @@ class RegistrationStore(
We do this by grandfathering in existing user threepids assuming that
they used one of the server configured trusted identity servers.
"""
-
id_servers = set(self.config.trusted_third_party_id_servers)
def _bg_user_threepids_grandfather_txn(txn):
@@ -853,3 +1074,327 @@ class RegistrationStore(
yield self._end_background_update("user_threepids_grandfather")
defer.returnValue(1)
+
+ def get_threepid_validation_session(
+ self,
+ medium,
+ client_secret,
+ address=None,
+ sid=None,
+ validated=True,
+ ):
+ """Gets a session_id and last_send_attempt (if available) for a
+ client_secret/medium/(address|session_id) combo
+
+ Args:
+ medium (str|None): The medium of the 3PID
+ address (str|None): The address of the 3PID
+ sid (str|None): The ID of the validation session
+ client_secret (str|None): A unique string provided by the client to
+ help identify this validation attempt
+ validated (bool|None): Whether sessions should be filtered by
+ whether they have been validated already or not. None to
+ perform no filtering
+
+ Returns:
+ deferred {str, int}|None: A dict containing the
+ latest session_id and send_attempt count for this 3PID.
+ Otherwise None if there hasn't been a previous attempt
+ """
+ keyvalues = {
+ "medium": medium,
+ "client_secret": client_secret,
+ }
+ if address:
+ keyvalues["address"] = address
+ if sid:
+ keyvalues["session_id"] = sid
+
+ assert(address or sid)
+
+ def get_threepid_validation_session_txn(txn):
+ sql = """
+ SELECT address, session_id, medium, client_secret,
+ last_send_attempt, validated_at
+ FROM threepid_validation_session WHERE %s
+ """ % (" AND ".join("%s = ?" % k for k in iterkeys(keyvalues)),)
+
+ if validated is not None:
+ sql += " AND validated_at IS " + ("NOT NULL" if validated else "NULL")
+
+ sql += " LIMIT 1"
+
+ txn.execute(sql, list(keyvalues.values()))
+ rows = self.cursor_to_dict(txn)
+ if not rows:
+ return None
+
+ return rows[0]
+
+ return self.runInteraction(
+ "get_threepid_validation_session",
+ get_threepid_validation_session_txn,
+ )
+
+ def validate_threepid_session(
+ self,
+ session_id,
+ client_secret,
+ token,
+ current_ts,
+ ):
+ """Attempt to validate a threepid session using a token
+
+ Args:
+ session_id (str): The id of a validation session
+ client_secret (str): A unique string provided by the client to
+ help identify this validation attempt
+ token (str): A validation token
+ current_ts (int): The current unix time in milliseconds. Used for
+ checking token expiry status
+
+ Returns:
+ deferred str|None: A str representing a link to redirect the user
+ to if there is one.
+ """
+ # Insert everything into a transaction in order to run atomically
+ def validate_threepid_session_txn(txn):
+ row = self._simple_select_one_txn(
+ txn,
+ table="threepid_validation_session",
+ keyvalues={"session_id": session_id},
+ retcols=["client_secret", "validated_at"],
+ allow_none=True,
+ )
+
+ if not row:
+ raise ThreepidValidationError(400, "Unknown session_id")
+ retrieved_client_secret = row["client_secret"]
+ validated_at = row["validated_at"]
+
+ if retrieved_client_secret != client_secret:
+ raise ThreepidValidationError(
+ 400, "This client_secret does not match the provided session_id",
+ )
+
+ row = self._simple_select_one_txn(
+ txn,
+ table="threepid_validation_token",
+ keyvalues={"session_id": session_id, "token": token},
+ retcols=["expires", "next_link"],
+ allow_none=True,
+ )
+
+ if not row:
+ raise ThreepidValidationError(
+ 400, "Validation token not found or has expired",
+ )
+ expires = row["expires"]
+ next_link = row["next_link"]
+
+ # If the session is already validated, no need to revalidate
+ if validated_at:
+ return next_link
+
+ if expires <= current_ts:
+ raise ThreepidValidationError(
+ 400, "This token has expired. Please request a new one",
+ )
+
+ # Looks good. Validate the session
+ self._simple_update_txn(
+ txn,
+ table="threepid_validation_session",
+ keyvalues={"session_id": session_id},
+ updatevalues={"validated_at": self.clock.time_msec()},
+ )
+
+ return next_link
+
+ # Return next_link if it exists
+ return self.runInteraction(
+ "validate_threepid_session_txn",
+ validate_threepid_session_txn,
+ )
+
+ def upsert_threepid_validation_session(
+ self,
+ medium,
+ address,
+ client_secret,
+ send_attempt,
+ session_id,
+ validated_at=None,
+ ):
+ """Upsert a threepid validation session
+ Args:
+ medium (str): The medium of the 3PID
+ address (str): The address of the 3PID
+ client_secret (str): A unique string provided by the client to
+ help identify this validation attempt
+ send_attempt (int): The latest send_attempt on this session
+ session_id (str): The id of this validation session
+ validated_at (int|None): The unix timestamp in milliseconds of
+ when the session was marked as valid
+ """
+ insertion_values = {
+ "medium": medium,
+ "address": address,
+ "client_secret": client_secret,
+ }
+
+ if validated_at:
+ insertion_values["validated_at"] = validated_at
+
+ return self._simple_upsert(
+ table="threepid_validation_session",
+ keyvalues={"session_id": session_id},
+ values={"last_send_attempt": send_attempt},
+ insertion_values=insertion_values,
+ desc="upsert_threepid_validation_session",
+ )
+
+ def start_or_continue_validation_session(
+ self,
+ medium,
+ address,
+ session_id,
+ client_secret,
+ send_attempt,
+ next_link,
+ token,
+ token_expires,
+ ):
+ """Creates a new threepid validation session if it does not already
+ exist and associates a new validation token with it
+
+ Args:
+ medium (str): The medium of the 3PID
+ address (str): The address of the 3PID
+ session_id (str): The id of this validation session
+ client_secret (str): A unique string provided by the client to
+ help identify this validation attempt
+ send_attempt (int): The latest send_attempt on this session
+ next_link (str|None): The link to redirect the user to upon
+ successful validation
+ token (str): The validation token
+ token_expires (int): The timestamp for which after the token
+ will no longer be valid
+ """
+ def start_or_continue_validation_session_txn(txn):
+ # Create or update a validation session
+ self._simple_upsert_txn(
+ txn,
+ table="threepid_validation_session",
+ keyvalues={"session_id": session_id},
+ values={"last_send_attempt": send_attempt},
+ insertion_values={
+ "medium": medium,
+ "address": address,
+ "client_secret": client_secret,
+ },
+ )
+
+ # Create a new validation token with this session ID
+ self._simple_insert_txn(
+ txn,
+ table="threepid_validation_token",
+ values={
+ "session_id": session_id,
+ "token": token,
+ "next_link": next_link,
+ "expires": token_expires,
+ },
+ )
+
+ return self.runInteraction(
+ "start_or_continue_validation_session",
+ start_or_continue_validation_session_txn,
+ )
+
+ def cull_expired_threepid_validation_tokens(self):
+ """Remove threepid validation tokens with expiry dates that have passed"""
+ def cull_expired_threepid_validation_tokens_txn(txn, ts):
+ sql = """
+ DELETE FROM threepid_validation_token WHERE
+ expires < ?
+ """
+ return txn.execute(sql, (ts,))
+
+ return self.runInteraction(
+ "cull_expired_threepid_validation_tokens",
+ cull_expired_threepid_validation_tokens_txn,
+ self.clock.time_msec(),
+ )
+
+ def delete_threepid_session(self, session_id):
+ """Removes a threepid validation session from the database. This can
+ be done after validation has been performed and whatever action was
+ waiting on it has been carried out
+
+ Args:
+ session_id (str): The ID of the session to delete
+ """
+ def delete_threepid_session_txn(txn):
+ self._simple_delete_txn(
+ txn,
+ table="threepid_validation_token",
+ keyvalues={"session_id": session_id},
+ )
+ self._simple_delete_txn(
+ txn,
+ table="threepid_validation_session",
+ keyvalues={"session_id": session_id},
+ )
+
+ return self.runInteraction(
+ "delete_threepid_session",
+ delete_threepid_session_txn,
+ )
+
+ def set_user_deactivated_status_txn(self, txn, user_id, deactivated):
+ self._simple_update_one_txn(
+ txn=txn,
+ table="users",
+ keyvalues={"name": user_id},
+ updatevalues={"deactivated": 1 if deactivated else 0},
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_user_deactivated_status, (user_id,),
+ )
+
+ @defer.inlineCallbacks
+ def set_user_deactivated_status(self, user_id, deactivated):
+ """Set the `deactivated` property for the provided user to the provided value.
+
+ Args:
+ user_id (str): The ID of the user to set the status for.
+ deactivated (bool): The value to set for `deactivated`.
+ """
+
+ yield self.runInteraction(
+ "set_user_deactivated_status",
+ self.set_user_deactivated_status_txn,
+ user_id, deactivated,
+ )
+
+ @cachedInlineCallbacks()
+ def get_user_deactivated_status(self, user_id):
+ """Retrieve the value for the `deactivated` property for the provided user.
+
+ Args:
+ user_id (str): The ID of the user to retrieve the status for.
+
+ Returns:
+ defer.Deferred(bool): The requested value.
+ """
+
+ res = yield self._simple_select_one_onecol(
+ table="users",
+ keyvalues={"name": user_id},
+ retcol="deactivated",
+ desc="get_user_deactivated_status",
+ )
+
+ # Convert the integer into a boolean.
+ defer.returnValue(res == 1)
|