diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 983ce026e1..941c07fce5 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -1,5 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,6 +16,7 @@
# limitations under the License.
import itertools
import logging
+import random
import sys
import threading
import time
@@ -227,6 +230,8 @@ class SQLBaseStore(object):
# A set of tables that are not safe to use native upserts in.
self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys())
+ self._account_validity = self.hs.config.account_validity
+
# We add the user_directory_search table to the blacklist on SQLite
# because the existing search table does not have an index, making it
# unsafe to use native upserts.
@@ -243,6 +248,16 @@ class SQLBaseStore(object):
self._check_safe_to_upsert,
)
+ self.rand = random.SystemRandom()
+
+ if self._account_validity.enabled:
+ self._clock.call_later(
+ 0.0,
+ run_as_background_process,
+ "account_validity_set_expiration_dates",
+ self._set_expiration_date_when_missing,
+ )
+
@defer.inlineCallbacks
def _check_safe_to_upsert(self):
"""
@@ -275,6 +290,67 @@ class SQLBaseStore(object):
self._check_safe_to_upsert,
)
+ @defer.inlineCallbacks
+ def _set_expiration_date_when_missing(self):
+ """
+ Retrieves the list of registered users that don't have an expiration date, and
+ adds an expiration date for each of them.
+ """
+
+ def select_users_with_no_expiration_date_txn(txn):
+ """Retrieves the list of registered users with no expiration date from the
+ database, filtering out deactivated users.
+ """
+ sql = (
+ "SELECT users.name FROM users"
+ " LEFT JOIN account_validity ON (users.name = account_validity.user_id)"
+ " WHERE account_validity.user_id is NULL AND users.deactivated = 0;"
+ )
+ txn.execute(sql, [])
+
+ res = self.cursor_to_dict(txn)
+ if res:
+ for user in res:
+ self.set_expiration_date_for_user_txn(
+ txn,
+ user["name"],
+ use_delta=True,
+ )
+
+ yield self.runInteraction(
+ "get_users_with_no_expiration_date",
+ select_users_with_no_expiration_date_txn,
+ )
+
+ def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False):
+ """Sets an expiration date to the account with the given user ID.
+
+ Args:
+ user_id (str): User ID to set an expiration date for.
+ use_delta (bool): If set to False, the expiration date for the user will be
+ now + validity period. If set to True, this expiration date will be a
+ random value in the [now + period - d ; now + period] range, d being a
+ delta equal to 10% of the validity period.
+ """
+ now_ms = self._clock.time_msec()
+ expiration_ts = now_ms + self._account_validity.period
+
+ if use_delta:
+ expiration_ts = self.rand.randrange(
+ expiration_ts - self._account_validity.startup_job_max_delta,
+ expiration_ts,
+ )
+
+ self._simple_insert_txn(
+ txn,
+ "account_validity",
+ values={
+ "user_id": user_id,
+ "expiration_ts_ms": expiration_ts,
+ "email_sent": False,
+ },
+ )
+
def start_profiling(self):
self._previous_loop_ts = self._clock.time_msec()
@@ -512,6 +588,10 @@ class SQLBaseStore(object):
Args:
table : string giving the table name
values : dict of new column names and values for them
+ or_ignore : bool stating whether an exception should be raised
+ when a conflicting row already exists. If True, False will be
+ returned by the function instead
+ desc : string giving a description of the transaction
Returns:
bool: Whether the row was inserted or not. Only useful when
@@ -1152,8 +1232,8 @@ class SQLBaseStore(object):
)
txn.execute(select_sql, list(keyvalues.values()))
-
row = txn.fetchone()
+
if not row:
if allow_none:
return None
@@ -1203,7 +1283,8 @@ class SQLBaseStore(object):
" AND ".join("%s = ?" % (k,) for k in keyvalues),
)
- return txn.execute(sql, list(keyvalues.values()))
+ txn.execute(sql, list(keyvalues.values()))
+ return txn.rowcount
def _simple_delete_many(self, table, column, iterable, keyvalues, desc):
return self.runInteraction(
@@ -1222,9 +1303,12 @@ class SQLBaseStore(object):
column : column name to test for inclusion against `iterable`
iterable : list
keyvalues : dict of column names and values to select the rows with
+
+ Returns:
+ int: Number rows deleted
"""
if not iterable:
- return
+ return 0
sql = "DELETE FROM %s" % table
@@ -1239,7 +1323,9 @@ class SQLBaseStore(object):
if clauses:
sql = "%s WHERE %s" % (sql, " AND ".join(clauses))
- return txn.execute(sql, values)
+ txn.execute(sql, values)
+
+ return txn.rowcount
def _get_cache_dict(
self, db_conn, table, entity_column, stream_column, max_value, limit=100000
|