diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index a83df7759d..236d3cdbe3 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2017-2018 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019,2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,14 +14,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
import re
from typing import Any, Dict, List, Optional, Tuple
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
from synapse.storage.types import Cursor
@@ -48,6 +47,18 @@ class RegistrationWorkerStore(SQLBaseStore):
database.engine, find_max_generated_user_id_localpart, "user_id_seq",
)
+ self._account_validity = hs.config.account_validity
+ if hs.config.run_background_tasks and self._account_validity.enabled:
+ self._clock.call_later(
+ 0.0, self._set_expiration_date_when_missing,
+ )
+
+ # Create a background job for culling expired 3PID validity tokens
+ if hs.config.run_background_tasks:
+ self.clock.looping_call(
+ self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS
+ )
+
@cached()
async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
return await self.db_pool.simple_select_one(
@@ -778,6 +789,79 @@ class RegistrationWorkerStore(SQLBaseStore):
"delete_threepid_session", delete_threepid_session_txn
)
+ @wrap_as_background_process("cull_expired_threepid_validation_tokens")
+ async def cull_expired_threepid_validation_tokens(self) -> None:
+ """Remove threepid validation tokens with expiry dates that have passed"""
+
+ def cull_expired_threepid_validation_tokens_txn(txn, ts):
+ sql = """
+ DELETE FROM threepid_validation_token WHERE
+ expires < ?
+ """
+ txn.execute(sql, (ts,))
+
+ await self.db_pool.runInteraction(
+ "cull_expired_threepid_validation_tokens",
+ cull_expired_threepid_validation_tokens_txn,
+ self.clock.time_msec(),
+ )
+
+ @wrap_as_background_process("account_validity_set_expiration_dates")
+ async def _set_expiration_date_when_missing(self):
+ """
+ Retrieves the list of registered users that don't have an expiration date, and
+ adds an expiration date for each of them.
+ """
+
+ def select_users_with_no_expiration_date_txn(txn):
+ """Retrieves the list of registered users with no expiration date from the
+ database, filtering out deactivated users.
+ """
+ sql = (
+ "SELECT users.name FROM users"
+ " LEFT JOIN account_validity ON (users.name = account_validity.user_id)"
+ " WHERE account_validity.user_id is NULL AND users.deactivated = 0;"
+ )
+ txn.execute(sql, [])
+
+ res = self.db_pool.cursor_to_dict(txn)
+ if res:
+ for user in res:
+ self.set_expiration_date_for_user_txn(
+ txn, user["name"], use_delta=True
+ )
+
+ await self.db_pool.runInteraction(
+ "get_users_with_no_expiration_date",
+ select_users_with_no_expiration_date_txn,
+ )
+
+ def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False):
+ """Sets an expiration date to the account with the given user ID.
+
+ Args:
+ user_id (str): User ID to set an expiration date for.
+ use_delta (bool): If set to False, the expiration date for the user will be
+ now + validity period. If set to True, this expiration date will be a
+ random value in the [now + period - d ; now + period] range, d being a
+ delta equal to 10% of the validity period.
+ """
+ now_ms = self._clock.time_msec()
+ expiration_ts = now_ms + self._account_validity.period
+
+ if use_delta:
+ expiration_ts = self.rand.randrange(
+ expiration_ts - self._account_validity.startup_job_max_delta,
+ expiration_ts,
+ )
+
+ self.db_pool.simple_upsert_txn(
+ txn,
+ "account_validity",
+ keyvalues={"user_id": user_id},
+ values={"expiration_ts_ms": expiration_ts, "email_sent": False},
+ )
+
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
@@ -911,28 +995,8 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
- self._account_validity = hs.config.account_validity
self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors
- if self._account_validity.enabled:
- self._clock.call_later(
- 0.0,
- run_as_background_process,
- "account_validity_set_expiration_dates",
- self._set_expiration_date_when_missing,
- )
-
- # Create a background job for culling expired 3PID validity tokens
- def start_cull():
- # run as a background process to make sure that the database transactions
- # have a logcontext to report to
- return run_as_background_process(
- "cull_expired_threepid_validation_tokens",
- self.cull_expired_threepid_validation_tokens,
- )
-
- hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS)
-
async def add_access_token_to_user(
self,
user_id: str,
@@ -964,6 +1028,36 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
desc="add_access_token_to_user",
)
+ def _set_device_for_access_token_txn(self, txn, token: str, device_id: str) -> str:
+ old_device_id = self.db_pool.simple_select_one_onecol_txn(
+ txn, "access_tokens", {"token": token}, "device_id"
+ )
+
+ self.db_pool.simple_update_txn(
+ txn, "access_tokens", {"token": token}, {"device_id": device_id}
+ )
+
+ self._invalidate_cache_and_stream(txn, self.get_user_by_access_token, (token,))
+
+ return old_device_id
+
+ async def set_device_for_access_token(self, token: str, device_id: str) -> str:
+ """Sets the device ID associated with an access token.
+
+ Args:
+ token: The access token to modify.
+ device_id: The new device ID.
+ Returns:
+ The old device ID associated with the access token.
+ """
+
+ return await self.db_pool.runInteraction(
+ "set_device_for_access_token",
+ self._set_device_for_access_token_txn,
+ token,
+ device_id,
+ )
+
async def register_user(
self,
user_id: str,
@@ -1121,7 +1215,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
desc="record_user_external_id",
)
- async def user_set_password_hash(self, user_id: str, password_hash: str) -> None:
+ async def user_set_password_hash(
+ self, user_id: str, password_hash: Optional[str]
+ ) -> None:
"""
NB. This does *not* evict any cache because the one use for this
removes most of the entries subsequently anyway so it would be
@@ -1447,22 +1543,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
start_or_continue_validation_session_txn,
)
- async def cull_expired_threepid_validation_tokens(self) -> None:
- """Remove threepid validation tokens with expiry dates that have passed"""
-
- def cull_expired_threepid_validation_tokens_txn(txn, ts):
- sql = """
- DELETE FROM threepid_validation_token WHERE
- expires < ?
- """
- txn.execute(sql, (ts,))
-
- await self.db_pool.runInteraction(
- "cull_expired_threepid_validation_tokens",
- cull_expired_threepid_validation_tokens_txn,
- self.clock.time_msec(),
- )
-
async def set_user_deactivated_status(
self, user_id: str, deactivated: bool
) -> None:
@@ -1492,61 +1572,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
txn.call_after(self.is_guest.invalidate, (user_id,))
- async def _set_expiration_date_when_missing(self):
- """
- Retrieves the list of registered users that don't have an expiration date, and
- adds an expiration date for each of them.
- """
-
- def select_users_with_no_expiration_date_txn(txn):
- """Retrieves the list of registered users with no expiration date from the
- database, filtering out deactivated users.
- """
- sql = (
- "SELECT users.name FROM users"
- " LEFT JOIN account_validity ON (users.name = account_validity.user_id)"
- " WHERE account_validity.user_id is NULL AND users.deactivated = 0;"
- )
- txn.execute(sql, [])
-
- res = self.db_pool.cursor_to_dict(txn)
- if res:
- for user in res:
- self.set_expiration_date_for_user_txn(
- txn, user["name"], use_delta=True
- )
-
- await self.db_pool.runInteraction(
- "get_users_with_no_expiration_date",
- select_users_with_no_expiration_date_txn,
- )
-
- def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False):
- """Sets an expiration date to the account with the given user ID.
-
- Args:
- user_id (str): User ID to set an expiration date for.
- use_delta (bool): If set to False, the expiration date for the user will be
- now + validity period. If set to True, this expiration date will be a
- random value in the [now + period - d ; now + period] range, d being a
- delta equal to 10% of the validity period.
- """
- now_ms = self._clock.time_msec()
- expiration_ts = now_ms + self._account_validity.period
-
- if use_delta:
- expiration_ts = self.rand.randrange(
- expiration_ts - self._account_validity.startup_job_max_delta,
- expiration_ts,
- )
-
- self.db_pool.simple_upsert_txn(
- txn,
- "account_validity",
- keyvalues={"user_id": user_id},
- values={"expiration_ts_ms": expiration_ts, "email_sent": False},
- )
-
def find_max_generated_user_id_localpart(cur: Cursor) -> int:
"""
|