diff options
Diffstat (limited to 'synapse/storage/_base.py')
-rw-r--r-- | synapse/storage/_base.py | 88 |
1 files changed, 85 insertions, 3 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 983ce026e1..52891bb9eb 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2017-2018 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,6 +16,7 @@ # limitations under the License. import itertools import logging +import random import sys import threading import time @@ -227,6 +230,8 @@ class SQLBaseStore(object): # A set of tables that are not safe to use native upserts in. self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys()) + self._account_validity = self.hs.config.account_validity + # We add the user_directory_search table to the blacklist on SQLite # because the existing search table does not have an index, making it # unsafe to use native upserts. @@ -243,6 +248,16 @@ class SQLBaseStore(object): self._check_safe_to_upsert, ) + self.rand = random.SystemRandom() + + if self._account_validity.enabled: + self._clock.call_later( + 0.0, + run_as_background_process, + "account_validity_set_expiration_dates", + self._set_expiration_date_when_missing, + ) + @defer.inlineCallbacks def _check_safe_to_upsert(self): """ @@ -275,6 +290,67 @@ class SQLBaseStore(object): self._check_safe_to_upsert, ) + @defer.inlineCallbacks + def _set_expiration_date_when_missing(self): + """ + Retrieves the list of registered users that don't have an expiration date, and + adds an expiration date for each of them. + """ + + def select_users_with_no_expiration_date_txn(txn): + """Retrieves the list of registered users with no expiration date from the + database. + """ + sql = ( + "SELECT users.name FROM users" + " LEFT JOIN account_validity ON (users.name = account_validity.user_id)" + " WHERE account_validity.user_id is NULL;" + ) + txn.execute(sql, []) + + res = self.cursor_to_dict(txn) + if res: + for user in res: + self.set_expiration_date_for_user_txn( + txn, + user["name"], + use_delta=True, + ) + + yield self.runInteraction( + "get_users_with_no_expiration_date", + select_users_with_no_expiration_date_txn, + ) + + def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False): + """Sets an expiration date to the account with the given user ID. + + Args: + user_id (str): User ID to set an expiration date for. + use_delta (bool): If set to False, the expiration date for the user will be + now + validity period. If set to True, this expiration date will be a + random value in the [now + period - d ; now + period] range, d being a + delta equal to 10% of the validity period. + """ + now_ms = self._clock.time_msec() + expiration_ts = now_ms + self._account_validity.period + + if use_delta: + expiration_ts = self.rand.randrange( + expiration_ts - self._account_validity.startup_job_max_delta, + expiration_ts, + ) + + self._simple_insert_txn( + txn, + "account_validity", + values={ + "user_id": user_id, + "expiration_ts_ms": expiration_ts, + "email_sent": False, + }, + ) + def start_profiling(self): self._previous_loop_ts = self._clock.time_msec() @@ -1203,7 +1279,8 @@ class SQLBaseStore(object): " AND ".join("%s = ?" % (k,) for k in keyvalues), ) - return txn.execute(sql, list(keyvalues.values())) + txn.execute(sql, list(keyvalues.values())) + return txn.rowcount def _simple_delete_many(self, table, column, iterable, keyvalues, desc): return self.runInteraction( @@ -1222,9 +1299,12 @@ class SQLBaseStore(object): column : column name to test for inclusion against `iterable` iterable : list keyvalues : dict of column names and values to select the rows with + + Returns: + int: Number rows deleted """ if not iterable: - return + return 0 sql = "DELETE FROM %s" % table @@ -1239,7 +1319,9 @@ class SQLBaseStore(object): if clauses: sql = "%s WHERE %s" % (sql, " AND ".join(clauses)) - return txn.execute(sql, values) + txn.execute(sql, values) + + return txn.rowcount def _get_cache_dict( self, db_conn, table, entity_column, stream_column, max_value, limit=100000 |