diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index a83df7759d..fedb8a6c26 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2017-2018 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019,2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,32 +14,66 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
import re
-from typing import Any, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
+
+import attr
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
-from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage._base import SQLBaseStore
+from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage.database import DatabasePool
-from synapse.storage.types import Cursor
+from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
+from synapse.storage.databases.main.stats import StatsStore
+from synapse.storage.types import Connection, Cursor
+from synapse.storage.util.id_generators import IdGenerator
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import UserID
from synapse.util.caches.descriptors import cached
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
THIRTY_MINUTES_IN_MS = 30 * 60 * 1000
logger = logging.getLogger(__name__)
-class RegistrationWorkerStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+@attr.s(frozen=True, slots=True)
+class TokenLookupResult:
+ """Result of looking up an access token.
+
+ Attributes:
+ user_id: The user that this token authenticates as
+ is_guest
+ shadow_banned
+ token_id: The ID of the access token looked up
+ device_id: The device associated with the token, if any.
+ valid_until_ms: The timestamp the token expires, if any.
+ token_owner: The "owner" of the token. This is either the same as the
+ user, or a server admin who is logged in as the user.
+ """
+
+ user_id = attr.ib(type=str)
+ is_guest = attr.ib(type=bool, default=False)
+ shadow_banned = attr.ib(type=bool, default=False)
+ token_id = attr.ib(type=Optional[int], default=None)
+ device_id = attr.ib(type=Optional[str], default=None)
+ valid_until_ms = attr.ib(type=Optional[int], default=None)
+ token_owner = attr.ib(type=str)
+
+ # Make the token owner default to the user ID, which is the common case.
+ @token_owner.default
+ def _default_token_owner(self):
+ return self.user_id
+
+
+class RegistrationWorkerStore(CacheInvalidationWorkerStore):
+ def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self.config = hs.config
- self.clock = hs.get_clock()
# Note: we don't check this sequence for consistency as we'd have to
# call `find_max_generated_user_id_localpart` each time, which is
@@ -48,6 +82,18 @@ class RegistrationWorkerStore(SQLBaseStore):
database.engine, find_max_generated_user_id_localpart, "user_id_seq",
)
+ self._account_validity = hs.config.account_validity
+ if hs.config.run_background_tasks and self._account_validity.enabled:
+ self._clock.call_later(
+ 0.0, self._set_expiration_date_when_missing,
+ )
+
+ # Create a background job for culling expired 3PID validity tokens
+ if hs.config.run_background_tasks:
+ self._clock.looping_call(
+ self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS
+ )
+
@cached()
async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
return await self.db_pool.simple_select_one(
@@ -81,21 +127,19 @@ class RegistrationWorkerStore(SQLBaseStore):
if not info:
return False
- now = self.clock.time_msec()
+ now = self._clock.time_msec()
trial_duration_ms = self.config.mau_trial_days * 24 * 60 * 60 * 1000
is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms
return is_trial
@cached()
- async def get_user_by_access_token(self, token: str) -> Optional[dict]:
+ async def get_user_by_access_token(self, token: str) -> Optional[TokenLookupResult]:
"""Get a user from the given access token.
Args:
token: The access token of a user.
Returns:
- None, if the token did not match, otherwise dict
- including the keys `name`, `is_guest`, `device_id`, `token_id`,
- `valid_until_ms`.
+ None, if the token did not match, otherwise a `TokenLookupResult`
"""
return await self.db_pool.runInteraction(
"get_user_by_access_token", self._query_for_auth, token
@@ -225,13 +269,13 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="get_renewal_token_for_user",
)
- async def get_users_expiring_soon(self) -> List[Dict[str, int]]:
+ async def get_users_expiring_soon(self) -> List[Dict[str, Any]]:
"""Selects users whose account will expire in the [now, now + renew_at] time
window (see configuration for account_validity for information on what renew_at
refers to).
Returns:
- A list of dictionaries mapping user ID to expiration time (in milliseconds).
+ A list of dictionaries, each with a user ID and expiration time (in milliseconds).
"""
def select_users_txn(txn, now_ms, renew_at):
@@ -246,7 +290,7 @@ class RegistrationWorkerStore(SQLBaseStore):
return await self.db_pool.runInteraction(
"get_users_expiring_soon",
select_users_txn,
- self.clock.time_msec(),
+ self._clock.time_msec(),
self.config.account_validity.renew_at,
)
@@ -316,19 +360,24 @@ class RegistrationWorkerStore(SQLBaseStore):
await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
- def _query_for_auth(self, txn, token):
- sql = (
- "SELECT users.name, users.is_guest, users.shadow_banned, access_tokens.id as token_id,"
- " access_tokens.device_id, access_tokens.valid_until_ms"
- " FROM users"
- " INNER JOIN access_tokens on users.name = access_tokens.user_id"
- " WHERE token = ?"
- )
+ def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]:
+ sql = """
+ SELECT users.name as user_id,
+ users.is_guest,
+ users.shadow_banned,
+ access_tokens.id as token_id,
+ access_tokens.device_id,
+ access_tokens.valid_until_ms,
+ access_tokens.user_id as token_owner
+ FROM users
+ INNER JOIN access_tokens on users.name = COALESCE(puppets_user_id, access_tokens.user_id)
+ WHERE token = ?
+ """
txn.execute(sql, (token,))
rows = self.db_pool.cursor_to_dict(txn)
if rows:
- return rows[0]
+ return TokenLookupResult(**rows[0])
return None
@@ -778,12 +827,111 @@ class RegistrationWorkerStore(SQLBaseStore):
"delete_threepid_session", delete_threepid_session_txn
)
+ @wrap_as_background_process("cull_expired_threepid_validation_tokens")
+ async def cull_expired_threepid_validation_tokens(self) -> None:
+ """Remove threepid validation tokens with expiry dates that have passed"""
+
+ def cull_expired_threepid_validation_tokens_txn(txn, ts):
+ sql = """
+ DELETE FROM threepid_validation_token WHERE
+ expires < ?
+ """
+ txn.execute(sql, (ts,))
+
+ await self.db_pool.runInteraction(
+ "cull_expired_threepid_validation_tokens",
+ cull_expired_threepid_validation_tokens_txn,
+ self._clock.time_msec(),
+ )
+
+ @wrap_as_background_process("account_validity_set_expiration_dates")
+ async def _set_expiration_date_when_missing(self):
+ """
+ Retrieves the list of registered users that don't have an expiration date, and
+ adds an expiration date for each of them.
+ """
+
+ def select_users_with_no_expiration_date_txn(txn):
+ """Retrieves the list of registered users with no expiration date from the
+ database, filtering out deactivated users.
+ """
+ sql = (
+ "SELECT users.name FROM users"
+ " LEFT JOIN account_validity ON (users.name = account_validity.user_id)"
+ " WHERE account_validity.user_id is NULL AND users.deactivated = 0;"
+ )
+ txn.execute(sql, [])
+
+ res = self.db_pool.cursor_to_dict(txn)
+ if res:
+ for user in res:
+ self.set_expiration_date_for_user_txn(
+ txn, user["name"], use_delta=True
+ )
+
+ await self.db_pool.runInteraction(
+ "get_users_with_no_expiration_date",
+ select_users_with_no_expiration_date_txn,
+ )
+
+ def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False):
+ """Sets an expiration date to the account with the given user ID.
+
+ Args:
+ user_id (str): User ID to set an expiration date for.
+ use_delta (bool): If set to False, the expiration date for the user will be
+ now + validity period. If set to True, this expiration date will be a
+ random value in the [now + period - d ; now + period] range, d being a
+ delta equal to 10% of the validity period.
+ """
+ now_ms = self._clock.time_msec()
+ expiration_ts = now_ms + self._account_validity.period
+
+ if use_delta:
+ expiration_ts = self.rand.randrange(
+ expiration_ts - self._account_validity.startup_job_max_delta,
+ expiration_ts,
+ )
+
+ self.db_pool.simple_upsert_txn(
+ txn,
+ "account_validity",
+ keyvalues={"user_id": user_id},
+ values={"expiration_ts_ms": expiration_ts, "email_sent": False},
+ )
+
+ async def get_user_pending_deactivation(self) -> Optional[str]:
+ """
+ Gets one user from the table of users waiting to be parted from all the rooms
+ they're in.
+ """
+ return await self.db_pool.simple_select_one_onecol(
+ "users_pending_deactivation",
+ keyvalues={},
+ retcol="user_id",
+ allow_none=True,
+ desc="get_users_pending_deactivation",
+ )
+
+ async def del_user_pending_deactivation(self, user_id: str) -> None:
+ """
+ Removes the given user to the table of users who need to be parted from all the
+ rooms they're in, effectively marking that user as fully deactivated.
+ """
+ # XXX: This should be simple_delete_one but we failed to put a unique index on
+ # the table, so somehow duplicate entries have ended up in it.
+ await self.db_pool.simple_delete(
+ "users_pending_deactivation",
+ keyvalues={"user_id": user_id},
+ desc="del_user_pending_deactivation",
+ )
+
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
- self.clock = hs.get_clock()
+ self._clock = hs.get_clock()
self.config = hs.config
self.db_pool.updates.register_background_index_update(
@@ -906,32 +1054,55 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
return 1
+ async def set_user_deactivated_status(
+ self, user_id: str, deactivated: bool
+ ) -> None:
+ """Set the `deactivated` property for the provided user to the provided value.
-class RegistrationStore(RegistrationBackgroundUpdateStore):
- def __init__(self, database: DatabasePool, db_conn, hs):
- super().__init__(database, db_conn, hs)
+ Args:
+ user_id: The ID of the user to set the status for.
+ deactivated: The value to set for `deactivated`.
+ """
- self._account_validity = hs.config.account_validity
- self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors
+ await self.db_pool.runInteraction(
+ "set_user_deactivated_status",
+ self.set_user_deactivated_status_txn,
+ user_id,
+ deactivated,
+ )
- if self._account_validity.enabled:
- self._clock.call_later(
- 0.0,
- run_as_background_process,
- "account_validity_set_expiration_dates",
- self._set_expiration_date_when_missing,
- )
+ def set_user_deactivated_status_txn(self, txn, user_id: str, deactivated: bool):
+ self.db_pool.simple_update_one_txn(
+ txn=txn,
+ table="users",
+ keyvalues={"name": user_id},
+ updatevalues={"deactivated": 1 if deactivated else 0},
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_user_deactivated_status, (user_id,)
+ )
+ txn.call_after(self.is_guest.invalidate, (user_id,))
+
+ @cached()
+ async def is_guest(self, user_id: str) -> bool:
+ res = await self.db_pool.simple_select_one_onecol(
+ table="users",
+ keyvalues={"name": user_id},
+ retcol="is_guest",
+ allow_none=True,
+ desc="is_guest",
+ )
+
+ return res if res else False
- # Create a background job for culling expired 3PID validity tokens
- def start_cull():
- # run as a background process to make sure that the database transactions
- # have a logcontext to report to
- return run_as_background_process(
- "cull_expired_threepid_validation_tokens",
- self.cull_expired_threepid_validation_tokens,
- )
- hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS)
+class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
+ def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+ super().__init__(database, db_conn, hs)
+
+ self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors
+
+ self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
async def add_access_token_to_user(
self,
@@ -939,7 +1110,8 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
token: str,
device_id: Optional[str],
valid_until_ms: Optional[int],
- ) -> None:
+ puppets_user_id: Optional[str] = None,
+ ) -> int:
"""Adds an access token for the given user.
Args:
@@ -949,6 +1121,8 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
valid_until_ms: when the token is valid until. None for no expiry.
Raises:
StoreError if there was a problem adding this.
+ Returns:
+ The token ID
"""
next_id = self._access_tokens_id_gen.get_next()
@@ -960,10 +1134,43 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
"token": token,
"device_id": device_id,
"valid_until_ms": valid_until_ms,
+ "puppets_user_id": puppets_user_id,
},
desc="add_access_token_to_user",
)
+ return next_id
+
+ def _set_device_for_access_token_txn(self, txn, token: str, device_id: str) -> str:
+ old_device_id = self.db_pool.simple_select_one_onecol_txn(
+ txn, "access_tokens", {"token": token}, "device_id"
+ )
+
+ self.db_pool.simple_update_txn(
+ txn, "access_tokens", {"token": token}, {"device_id": device_id}
+ )
+
+ self._invalidate_cache_and_stream(txn, self.get_user_by_access_token, (token,))
+
+ return old_device_id
+
+ async def set_device_for_access_token(self, token: str, device_id: str) -> str:
+ """Sets the device ID associated with an access token.
+
+ Args:
+ token: The access token to modify.
+ device_id: The new device ID.
+ Returns:
+ The old device ID associated with the access token.
+ """
+
+ return await self.db_pool.runInteraction(
+ "set_device_for_access_token",
+ self._set_device_for_access_token_txn,
+ token,
+ device_id,
+ )
+
async def register_user(
self,
user_id: str,
@@ -1014,19 +1221,19 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
def _register_user(
self,
txn,
- user_id,
- password_hash,
- was_guest,
- make_guest,
- appservice_id,
- create_profile_with_displayname,
- admin,
- user_type,
- shadow_banned,
+ user_id: str,
+ password_hash: Optional[str],
+ was_guest: bool,
+ make_guest: bool,
+ appservice_id: Optional[str],
+ create_profile_with_displayname: Optional[str],
+ admin: bool,
+ user_type: Optional[str],
+ shadow_banned: bool,
):
user_id_obj = UserID.from_string(user_id)
- now = int(self.clock.time())
+ now = int(self._clock.time())
try:
if was_guest:
@@ -1121,7 +1328,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
desc="record_user_external_id",
)
- async def user_set_password_hash(self, user_id: str, password_hash: str) -> None:
+ async def user_set_password_hash(
+ self, user_id: str, password_hash: Optional[str]
+ ) -> None:
"""
NB. This does *not* evict any cache because the one use for this
removes most of the entries subsequently anyway so it would be
@@ -1248,18 +1457,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
await self.db_pool.runInteraction("delete_access_token", f)
- @cached()
- async def is_guest(self, user_id: str) -> bool:
- res = await self.db_pool.simple_select_one_onecol(
- table="users",
- keyvalues={"name": user_id},
- retcol="is_guest",
- allow_none=True,
- desc="is_guest",
- )
-
- return res if res else False
-
async def add_user_pending_deactivation(self, user_id: str) -> None:
"""
Adds a user to the table of users who need to be parted from all the rooms they're
@@ -1271,32 +1468,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
desc="add_user_pending_deactivation",
)
- async def del_user_pending_deactivation(self, user_id: str) -> None:
- """
- Removes the given user to the table of users who need to be parted from all the
- rooms they're in, effectively marking that user as fully deactivated.
- """
- # XXX: This should be simple_delete_one but we failed to put a unique index on
- # the table, so somehow duplicate entries have ended up in it.
- await self.db_pool.simple_delete(
- "users_pending_deactivation",
- keyvalues={"user_id": user_id},
- desc="del_user_pending_deactivation",
- )
-
- async def get_user_pending_deactivation(self) -> Optional[str]:
- """
- Gets one user from the table of users waiting to be parted from all the rooms
- they're in.
- """
- return await self.db_pool.simple_select_one_onecol(
- "users_pending_deactivation",
- keyvalues={},
- retcol="user_id",
- allow_none=True,
- desc="get_users_pending_deactivation",
- )
-
async def validate_threepid_session(
self, session_id: str, client_secret: str, token: str, current_ts: int
) -> Optional[str]:
@@ -1379,7 +1550,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
txn,
table="threepid_validation_session",
keyvalues={"session_id": session_id},
- updatevalues={"validated_at": self.clock.time_msec()},
+ updatevalues={"validated_at": self._clock.time_msec()},
)
return next_link
@@ -1447,106 +1618,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
start_or_continue_validation_session_txn,
)
- async def cull_expired_threepid_validation_tokens(self) -> None:
- """Remove threepid validation tokens with expiry dates that have passed"""
-
- def cull_expired_threepid_validation_tokens_txn(txn, ts):
- sql = """
- DELETE FROM threepid_validation_token WHERE
- expires < ?
- """
- txn.execute(sql, (ts,))
-
- await self.db_pool.runInteraction(
- "cull_expired_threepid_validation_tokens",
- cull_expired_threepid_validation_tokens_txn,
- self.clock.time_msec(),
- )
-
- async def set_user_deactivated_status(
- self, user_id: str, deactivated: bool
- ) -> None:
- """Set the `deactivated` property for the provided user to the provided value.
-
- Args:
- user_id: The ID of the user to set the status for.
- deactivated: The value to set for `deactivated`.
- """
-
- await self.db_pool.runInteraction(
- "set_user_deactivated_status",
- self.set_user_deactivated_status_txn,
- user_id,
- deactivated,
- )
-
- def set_user_deactivated_status_txn(self, txn, user_id, deactivated):
- self.db_pool.simple_update_one_txn(
- txn=txn,
- table="users",
- keyvalues={"name": user_id},
- updatevalues={"deactivated": 1 if deactivated else 0},
- )
- self._invalidate_cache_and_stream(
- txn, self.get_user_deactivated_status, (user_id,)
- )
- txn.call_after(self.is_guest.invalidate, (user_id,))
-
- async def _set_expiration_date_when_missing(self):
- """
- Retrieves the list of registered users that don't have an expiration date, and
- adds an expiration date for each of them.
- """
-
- def select_users_with_no_expiration_date_txn(txn):
- """Retrieves the list of registered users with no expiration date from the
- database, filtering out deactivated users.
- """
- sql = (
- "SELECT users.name FROM users"
- " LEFT JOIN account_validity ON (users.name = account_validity.user_id)"
- " WHERE account_validity.user_id is NULL AND users.deactivated = 0;"
- )
- txn.execute(sql, [])
-
- res = self.db_pool.cursor_to_dict(txn)
- if res:
- for user in res:
- self.set_expiration_date_for_user_txn(
- txn, user["name"], use_delta=True
- )
-
- await self.db_pool.runInteraction(
- "get_users_with_no_expiration_date",
- select_users_with_no_expiration_date_txn,
- )
-
- def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False):
- """Sets an expiration date to the account with the given user ID.
-
- Args:
- user_id (str): User ID to set an expiration date for.
- use_delta (bool): If set to False, the expiration date for the user will be
- now + validity period. If set to True, this expiration date will be a
- random value in the [now + period - d ; now + period] range, d being a
- delta equal to 10% of the validity period.
- """
- now_ms = self._clock.time_msec()
- expiration_ts = now_ms + self._account_validity.period
-
- if use_delta:
- expiration_ts = self.rand.randrange(
- expiration_ts - self._account_validity.startup_job_max_delta,
- expiration_ts,
- )
-
- self.db_pool.simple_upsert_txn(
- txn,
- "account_validity",
- keyvalues={"user_id": user_id},
- values={"expiration_ts_ms": expiration_ts, "email_sent": False},
- )
-
def find_max_generated_user_id_localpart(cur: Cursor) -> int:
"""
|