diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py
index fed4ea3610..9b0a99cb49 100644
--- a/synapse/storage/deviceinbox.py
+++ b/synapse/storage/deviceinbox.py
@@ -118,7 +118,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
defer.returnValue(count)
def get_new_device_msgs_for_remote(
- self, destination, last_stream_id, current_stream_id, limit=100
+ self, destination, last_stream_id, current_stream_id, limit
):
"""
Args:
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 643f7a3808..03a06a83d6 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -32,6 +32,7 @@ class RegistrationWorkerStore(SQLBaseStore):
super(RegistrationWorkerStore, self).__init__(db_conn, hs)
self.config = hs.config
+ self.clock = hs.get_clock()
@cached()
def get_user_by_id(self, user_id):
@@ -87,26 +88,162 @@ class RegistrationWorkerStore(SQLBaseStore):
)
@cachedInlineCallbacks()
- def get_expiration_ts_for_user(self, user):
+ def get_expiration_ts_for_user(self, user_id):
"""Get the expiration timestamp for the account bearing a given user ID.
Args:
- user (str): The ID of the user.
+ user_id (str): The ID of the user.
Returns:
defer.Deferred: None, if the account has no expiration timestamp,
- otherwise int representation of the timestamp (as a number of
- milliseconds since epoch).
+ otherwise int representation of the timestamp (as a number of
+ milliseconds since epoch).
"""
res = yield self._simple_select_one_onecol(
table="account_validity",
- keyvalues={"user_id": user.to_string()},
+ keyvalues={"user_id": user_id},
retcol="expiration_ts_ms",
allow_none=True,
- desc="get_expiration_date_for_user",
+ desc="get_expiration_ts_for_user",
+ )
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def set_account_validity_for_user(self, user_id, expiration_ts, email_sent,
+ renewal_token=None):
+ """Updates the account validity properties of the given account, with the
+ given values.
+
+ Args:
+ user_id (str): ID of the account to update properties for.
+ expiration_ts (int): New expiration date, as a timestamp in milliseconds
+ since epoch.
+ email_sent (bool): True means a renewal email has been sent for this
+ account and there's no need to send another one for the current validity
+ period.
+ renewal_token (str): Renewal token the user can use to extend the validity
+ of their account. Defaults to no token.
+ """
+ def set_account_validity_for_user_txn(txn):
+ self._simple_update_txn(
+ txn=txn,
+ table="account_validity",
+ keyvalues={"user_id": user_id},
+ updatevalues={
+ "expiration_ts_ms": expiration_ts,
+ "email_sent": email_sent,
+ "renewal_token": renewal_token,
+ },
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_expiration_ts_for_user, (user_id,),
+ )
+
+ yield self.runInteraction(
+ "set_account_validity_for_user",
+ set_account_validity_for_user_txn,
+ )
+
+ @defer.inlineCallbacks
+ def set_renewal_token_for_user(self, user_id, renewal_token):
+ """Defines a renewal token for a given user.
+
+ Args:
+ user_id (str): ID of the user to set the renewal token for.
+ renewal_token (str): Random unique string that will be used to renew the
+ user's account.
+
+ Raises:
+ StoreError: The provided token is already set for another user.
+ """
+ yield self._simple_update_one(
+ table="account_validity",
+ keyvalues={"user_id": user_id},
+ updatevalues={"renewal_token": renewal_token},
+ desc="set_renewal_token_for_user",
)
+
+ @defer.inlineCallbacks
+ def get_user_from_renewal_token(self, renewal_token):
+ """Get a user ID from a renewal token.
+
+ Args:
+ renewal_token (str): The renewal token to perform the lookup with.
+
+ Returns:
+ defer.Deferred[str]: The ID of the user to which the token belongs.
+ """
+ res = yield self._simple_select_one_onecol(
+ table="account_validity",
+ keyvalues={"renewal_token": renewal_token},
+ retcol="user_id",
+ desc="get_user_from_renewal_token",
+ )
+
defer.returnValue(res)
@defer.inlineCallbacks
+ def get_renewal_token_for_user(self, user_id):
+ """Get the renewal token associated with a given user ID.
+
+ Args:
+ user_id (str): The user ID to lookup a token for.
+
+ Returns:
+ defer.Deferred[str]: The renewal token associated with this user ID.
+ """
+ res = yield self._simple_select_one_onecol(
+ table="account_validity",
+ keyvalues={"user_id": user_id},
+ retcol="renewal_token",
+ desc="get_renewal_token_for_user",
+ )
+
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def get_users_expiring_soon(self):
+ """Selects users whose account will expire in the [now, now + renew_at] time
+ window (see configuration for account_validity for information on what renew_at
+ refers to).
+
+ Returns:
+ Deferred: Resolves to a list[dict[user_id (str), expiration_ts_ms (int)]]
+ """
+ def select_users_txn(txn, now_ms, renew_at):
+ sql = (
+ "SELECT user_id, expiration_ts_ms FROM account_validity"
+ " WHERE email_sent = ? AND (expiration_ts_ms - ?) <= ?"
+ )
+ values = [False, now_ms, renew_at]
+ txn.execute(sql, values)
+ return self.cursor_to_dict(txn)
+
+ res = yield self.runInteraction(
+ "get_users_expiring_soon",
+ select_users_txn,
+ self.clock.time_msec(), self.config.account_validity.renew_at,
+ )
+
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def set_renewal_mail_status(self, user_id, email_sent):
+ """Sets or unsets the flag that indicates whether a renewal email has been sent
+ to the user (and the user hasn't renewed their account yet).
+
+ Args:
+ user_id (str): ID of the user to set/unset the flag for.
+ email_sent (bool): Flag which indicates whether a renewal email has been sent
+ to this user.
+ """
+ yield self._simple_update_one(
+ table="account_validity",
+ keyvalues={"user_id": user_id},
+ updatevalues={"email_sent": email_sent},
+ desc="set_renewal_mail_status",
+ )
+
+ @defer.inlineCallbacks
def is_server_admin(self, user):
res = yield self._simple_select_one_onecol(
table="users",
@@ -584,20 +721,22 @@ class RegistrationStore(
},
)
- if self._account_validity.enabled:
- now_ms = self.clock.time_msec()
- expiration_ts = now_ms + self._account_validity.period
- self._simple_insert_txn(
- txn,
- "account_validity",
- values={
- "user_id": user_id,
- "expiration_ts_ms": expiration_ts,
- }
- )
except self.database_engine.module.IntegrityError:
raise StoreError(400, "User ID already taken.", errcode=Codes.USER_IN_USE)
+ if self._account_validity.enabled:
+ now_ms = self.clock.time_msec()
+ expiration_ts = now_ms + self._account_validity.period
+ self._simple_insert_txn(
+ txn,
+ "account_validity",
+ values={
+ "user_id": user_id,
+ "expiration_ts_ms": expiration_ts,
+ "email_sent": False,
+ }
+ )
+
if token:
# it's possible for this to get a conflict, but only for a single user
# since tokens are namespaced based on their user ID
diff --git a/synapse/storage/schema/delta/54/account_validity.sql b/synapse/storage/schema/delta/54/account_validity.sql
index 57249262d7..2357626000 100644
--- a/synapse/storage/schema/delta/54/account_validity.sql
+++ b/synapse/storage/schema/delta/54/account_validity.sql
@@ -13,8 +13,15 @@
* limitations under the License.
*/
+DROP TABLE IF EXISTS account_validity;
+
-- Track what users are in public rooms.
CREATE TABLE IF NOT EXISTS account_validity (
user_id TEXT PRIMARY KEY,
- expiration_ts_ms BIGINT NOT NULL
+ expiration_ts_ms BIGINT NOT NULL,
+ email_sent BOOLEAN NOT NULL,
+ renewal_token TEXT
);
+
+CREATE INDEX account_validity_email_sent_idx ON account_validity(email_sent, expiration_ts_ms)
+CREATE UNIQUE INDEX account_validity_renewal_string_idx ON account_validity(renewal_token)
diff --git a/synapse/storage/state_deltas.py b/synapse/storage/state_deltas.py
index 56e42f583d..31a0279b18 100644
--- a/synapse/storage/state_deltas.py
+++ b/synapse/storage/state_deltas.py
@@ -22,6 +22,24 @@ logger = logging.getLogger(__name__)
class StateDeltasStore(SQLBaseStore):
def get_current_state_deltas(self, prev_stream_id):
+ """Fetch a list of room state changes since the given stream id
+
+ Each entry in the result contains the following fields:
+ - stream_id (int)
+ - room_id (str)
+ - type (str): event type
+ - state_key (str):
+ - event_id (str|None): new event_id for this state key. None if the
+ state has been deleted.
+ - prev_event_id (str|None): previous event_id for this state key. None
+ if it's new state.
+
+ Args:
+ prev_stream_id (int): point to get changes since (exclusive)
+
+ Returns:
+ Deferred[list[dict]]: results
+ """
prev_stream_id = int(prev_stream_id)
if not self._curr_state_delta_stream_cache.has_any_entity_changed(
prev_stream_id
|