diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 9b6c28892c..03a06a83d6 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -32,18 +32,21 @@ class RegistrationWorkerStore(SQLBaseStore):
super(RegistrationWorkerStore, self).__init__(db_conn, hs)
self.config = hs.config
+ self.clock = hs.get_clock()
@cached()
def get_user_by_id(self, user_id):
return self._simple_select_one(
table="users",
- keyvalues={
- "name": user_id,
- },
+ keyvalues={"name": user_id},
retcols=[
- "name", "password_hash", "is_guest",
- "consent_version", "consent_server_notice_sent",
- "appservice_id", "creation_ts",
+ "name",
+ "password_hash",
+ "is_guest",
+ "consent_version",
+ "consent_server_notice_sent",
+ "appservice_id",
+ "creation_ts",
],
allow_none=True,
desc="get_user_by_id",
@@ -81,9 +84,163 @@ class RegistrationWorkerStore(SQLBaseStore):
including the keys `name`, `is_guest`, `device_id`, `token_id`.
"""
return self.runInteraction(
- "get_user_by_access_token",
- self._query_for_auth,
- token
+ "get_user_by_access_token", self._query_for_auth, token
+ )
+
+ @cachedInlineCallbacks()
+ def get_expiration_ts_for_user(self, user_id):
+ """Get the expiration timestamp for the account bearing a given user ID.
+
+ Args:
+ user_id (str): The ID of the user.
+ Returns:
+ defer.Deferred: None, if the account has no expiration timestamp,
+ otherwise int representation of the timestamp (as a number of
+ milliseconds since epoch).
+ """
+ res = yield self._simple_select_one_onecol(
+ table="account_validity",
+ keyvalues={"user_id": user_id},
+ retcol="expiration_ts_ms",
+ allow_none=True,
+ desc="get_expiration_ts_for_user",
+ )
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def set_account_validity_for_user(self, user_id, expiration_ts, email_sent,
+ renewal_token=None):
+ """Updates the account validity properties of the given account, with the
+ given values.
+
+ Args:
+ user_id (str): ID of the account to update properties for.
+ expiration_ts (int): New expiration date, as a timestamp in milliseconds
+ since epoch.
+ email_sent (bool): True means a renewal email has been sent for this
+ account and there's no need to send another one for the current validity
+ period.
+ renewal_token (str): Renewal token the user can use to extend the validity
+ of their account. Defaults to no token.
+ """
+ def set_account_validity_for_user_txn(txn):
+ self._simple_update_txn(
+ txn=txn,
+ table="account_validity",
+ keyvalues={"user_id": user_id},
+ updatevalues={
+ "expiration_ts_ms": expiration_ts,
+ "email_sent": email_sent,
+ "renewal_token": renewal_token,
+ },
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_expiration_ts_for_user, (user_id,),
+ )
+
+ yield self.runInteraction(
+ "set_account_validity_for_user",
+ set_account_validity_for_user_txn,
+ )
+
+ @defer.inlineCallbacks
+ def set_renewal_token_for_user(self, user_id, renewal_token):
+ """Defines a renewal token for a given user.
+
+ Args:
+ user_id (str): ID of the user to set the renewal token for.
+ renewal_token (str): Random unique string that will be used to renew the
+ user's account.
+
+ Raises:
+ StoreError: The provided token is already set for another user.
+ """
+ yield self._simple_update_one(
+ table="account_validity",
+ keyvalues={"user_id": user_id},
+ updatevalues={"renewal_token": renewal_token},
+ desc="set_renewal_token_for_user",
+ )
+
+ @defer.inlineCallbacks
+ def get_user_from_renewal_token(self, renewal_token):
+ """Get a user ID from a renewal token.
+
+ Args:
+ renewal_token (str): The renewal token to perform the lookup with.
+
+ Returns:
+ defer.Deferred[str]: The ID of the user to which the token belongs.
+ """
+ res = yield self._simple_select_one_onecol(
+ table="account_validity",
+ keyvalues={"renewal_token": renewal_token},
+ retcol="user_id",
+ desc="get_user_from_renewal_token",
+ )
+
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def get_renewal_token_for_user(self, user_id):
+ """Get the renewal token associated with a given user ID.
+
+ Args:
+ user_id (str): The user ID to lookup a token for.
+
+ Returns:
+ defer.Deferred[str]: The renewal token associated with this user ID.
+ """
+ res = yield self._simple_select_one_onecol(
+ table="account_validity",
+ keyvalues={"user_id": user_id},
+ retcol="renewal_token",
+ desc="get_renewal_token_for_user",
+ )
+
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def get_users_expiring_soon(self):
+ """Selects users whose account will expire in the [now, now + renew_at] time
+ window (see configuration for account_validity for information on what renew_at
+ refers to).
+
+ Returns:
+ Deferred: Resolves to a list[dict[user_id (str), expiration_ts_ms (int)]]
+ """
+ def select_users_txn(txn, now_ms, renew_at):
+ sql = (
+ "SELECT user_id, expiration_ts_ms FROM account_validity"
+ " WHERE email_sent = ? AND (expiration_ts_ms - ?) <= ?"
+ )
+ values = [False, now_ms, renew_at]
+ txn.execute(sql, values)
+ return self.cursor_to_dict(txn)
+
+ res = yield self.runInteraction(
+ "get_users_expiring_soon",
+ select_users_txn,
+ self.clock.time_msec(), self.config.account_validity.renew_at,
+ )
+
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def set_renewal_mail_status(self, user_id, email_sent):
+ """Sets or unsets the flag that indicates whether a renewal email has been sent
+ to the user (and the user hasn't renewed their account yet).
+
+ Args:
+ user_id (str): ID of the user to set/unset the flag for.
+ email_sent (bool): Flag which indicates whether a renewal email has been sent
+ to this user.
+ """
+ yield self._simple_update_one(
+ table="account_validity",
+ keyvalues={"user_id": user_id},
+ updatevalues={"email_sent": email_sent},
+ desc="set_renewal_mail_status",
)
@defer.inlineCallbacks
@@ -143,10 +300,10 @@ class RegistrationWorkerStore(SQLBaseStore):
"""Gets users that match user_id case insensitively.
Returns a mapping of user_id -> password_hash.
"""
+
def f(txn):
sql = (
- "SELECT name, password_hash FROM users"
- " WHERE lower(name) = lower(?)"
+ "SELECT name, password_hash FROM users" " WHERE lower(name) = lower(?)"
)
txn.execute(sql, (user_id,))
return dict(txn)
@@ -156,6 +313,7 @@ class RegistrationWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def count_all_users(self):
"""Counts all users registered on the homeserver."""
+
def _count_users(txn):
txn.execute("SELECT COUNT(*) AS users FROM users")
rows = self.cursor_to_dict(txn)
@@ -173,6 +331,7 @@ class RegistrationWorkerStore(SQLBaseStore):
3) bridged users
who registered on the homeserver in the past 24 hours
"""
+
def _count_daily_user_type(txn):
yesterday = int(self._clock.time()) - (60 * 60 * 24)
@@ -193,15 +352,18 @@ class RegistrationWorkerStore(SQLBaseStore):
for row in txn:
results[row[0]] = row[1]
return results
+
return self.runInteraction("count_daily_user_type", _count_daily_user_type)
@defer.inlineCallbacks
def count_nonbridged_users(self):
def _count_users(txn):
- txn.execute("""
+ txn.execute(
+ """
SELECT COALESCE(COUNT(*), 0) FROM users
WHERE appservice_id IS NULL
- """)
+ """
+ )
count, = txn.fetchone()
return count
@@ -220,6 +382,7 @@ class RegistrationWorkerStore(SQLBaseStore):
avoid the case of ID 10000000 being pre-allocated, so us wasting the
first (and shortest) many generated user IDs.
"""
+
def _find_next_generated_user_id(txn):
txn.execute("SELECT name FROM users")
@@ -227,7 +390,7 @@ class RegistrationWorkerStore(SQLBaseStore):
found = set()
- for user_id, in txn:
+ for (user_id,) in txn:
match = regex.search(user_id)
if match:
found.add(int(match.group(1)))
@@ -235,20 +398,22 @@ class RegistrationWorkerStore(SQLBaseStore):
if i not in found:
return i
- defer.returnValue((yield self.runInteraction(
- "find_next_generated_user_id",
- _find_next_generated_user_id
- )))
+ defer.returnValue(
+ (
+ yield self.runInteraction(
+ "find_next_generated_user_id", _find_next_generated_user_id
+ )
+ )
+ )
@defer.inlineCallbacks
def get_3pid_guest_access_token(self, medium, address):
ret = yield self._simple_select_one(
"threepid_guest_access_tokens",
- {
- "medium": medium,
- "address": address
- },
- ["guest_access_token"], True, 'get_3pid_guest_access_token'
+ {"medium": medium, "address": address},
+ ["guest_access_token"],
+ True,
+ 'get_3pid_guest_access_token',
)
if ret:
defer.returnValue(ret["guest_access_token"])
@@ -266,8 +431,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Deferred[str|None]: user id or None if no user id/threepid mapping exists
"""
user_id = yield self.runInteraction(
- "get_user_id_by_threepid", self.get_user_id_by_threepid_txn,
- medium, address
+ "get_user_id_by_threepid", self.get_user_id_by_threepid_txn, medium, address
)
defer.returnValue(user_id)
@@ -285,11 +449,9 @@ class RegistrationWorkerStore(SQLBaseStore):
ret = self._simple_select_one_txn(
txn,
"user_threepids",
- {
- "medium": medium,
- "address": address
- },
- ['user_id'], True
+ {"medium": medium, "address": address},
+ ['user_id'],
+ True,
)
if ret:
return ret['user_id']
@@ -297,41 +459,110 @@ class RegistrationWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
- yield self._simple_upsert("user_threepids", {
- "medium": medium,
- "address": address,
- }, {
- "user_id": user_id,
- "validated_at": validated_at,
- "added_at": added_at,
- })
+ yield self._simple_upsert(
+ "user_threepids",
+ {"medium": medium, "address": address},
+ {"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
+ )
@defer.inlineCallbacks
def user_get_threepids(self, user_id):
ret = yield self._simple_select_list(
- "user_threepids", {
- "user_id": user_id
- },
+ "user_threepids",
+ {"user_id": user_id},
['medium', 'address', 'validated_at', 'added_at'],
- 'user_get_threepids'
+ 'user_get_threepids',
)
defer.returnValue(ret)
def user_delete_threepid(self, user_id, medium, address):
return self._simple_delete(
"user_threepids",
+ keyvalues={"user_id": user_id, "medium": medium, "address": address},
+ desc="user_delete_threepids",
+ )
+
+ def add_user_bound_threepid(self, user_id, medium, address, id_server):
+ """The server proxied a bind request to the given identity server on
+ behalf of the given user. We need to remember this in case the user
+ asks us to unbind the threepid.
+
+ Args:
+ user_id (str)
+ medium (str)
+ address (str)
+ id_server (str)
+
+ Returns:
+ Deferred
+ """
+ # We need to use an upsert, in case they user had already bound the
+ # threepid
+ return self._simple_upsert(
+ table="user_threepid_id_server",
keyvalues={
"user_id": user_id,
"medium": medium,
"address": address,
+ "id_server": id_server,
},
- desc="user_delete_threepids",
+ values={},
+ insertion_values={},
+ desc="add_user_bound_threepid",
+ )
+
+ def remove_user_bound_threepid(self, user_id, medium, address, id_server):
+ """The server proxied an unbind request to the given identity server on
+ behalf of the given user, so we remove the mapping of threepid to
+ identity server.
+
+ Args:
+ user_id (str)
+ medium (str)
+ address (str)
+ id_server (str)
+
+ Returns:
+ Deferred
+ """
+ return self._simple_delete(
+ table="user_threepid_id_server",
+ keyvalues={
+ "user_id": user_id,
+ "medium": medium,
+ "address": address,
+ "id_server": id_server,
+ },
+ desc="remove_user_bound_threepid",
)
+ def get_id_servers_user_bound(self, user_id, medium, address):
+ """Get the list of identity servers that the server proxied bind
+ requests to for given user and threepid
+
+ Args:
+ user_id (str)
+ medium (str)
+ address (str)
+
+ Returns:
+ Deferred[list[str]]: Resolves to a list of identity servers
+ """
+ return self._simple_select_onecol(
+ table="user_threepid_id_server",
+ keyvalues={
+ "user_id": user_id,
+ "medium": medium,
+ "address": address,
+ },
+ retcol="id_server",
+ desc="get_id_servers_user_bound",
+ )
-class RegistrationStore(RegistrationWorkerStore,
- background_updates.BackgroundUpdateStore):
+class RegistrationStore(
+ RegistrationWorkerStore, background_updates.BackgroundUpdateStore
+):
def __init__(self, db_conn, hs):
super(RegistrationStore, self).__init__(db_conn, hs)
@@ -351,11 +582,17 @@ class RegistrationStore(RegistrationWorkerStore,
columns=["creation_ts"],
)
+ self._account_validity = hs.config.account_validity
+
# we no longer use refresh tokens, but it's possible that some people
# might have a background update queued to build this index. Just
# clear the background update.
self.register_noop_background_update("refresh_tokens_device_index")
+ self.register_background_update_handler(
+ "user_threepids_grandfather", self._bg_user_threepids_grandfather,
+ )
+
@defer.inlineCallbacks
def add_access_token_to_user(self, user_id, token, device_id=None):
"""Adds an access token for the given user.
@@ -372,18 +609,22 @@ class RegistrationStore(RegistrationWorkerStore,
yield self._simple_insert(
"access_tokens",
- {
- "id": next_id,
- "user_id": user_id,
- "token": token,
- "device_id": device_id,
- },
+ {"id": next_id, "user_id": user_id, "token": token, "device_id": device_id},
desc="add_access_token_to_user",
)
- def register(self, user_id, token=None, password_hash=None,
- was_guest=False, make_guest=False, appservice_id=None,
- create_profile_with_displayname=None, admin=False, user_type=None):
+ def register(
+ self,
+ user_id,
+ token=None,
+ password_hash=None,
+ was_guest=False,
+ make_guest=False,
+ appservice_id=None,
+ create_profile_with_displayname=None,
+ admin=False,
+ user_type=None,
+ ):
"""Attempts to register an account.
Args:
@@ -417,7 +658,7 @@ class RegistrationStore(RegistrationWorkerStore,
appservice_id,
create_profile_with_displayname,
admin,
- user_type
+ user_type,
)
def _register(
@@ -447,10 +688,7 @@ class RegistrationStore(RegistrationWorkerStore,
self._simple_select_one_txn(
txn,
"users",
- keyvalues={
- "name": user_id,
- "is_guest": 1,
- },
+ keyvalues={"name": user_id, "is_guest": 1},
retcols=("name",),
allow_none=False,
)
@@ -458,10 +696,7 @@ class RegistrationStore(RegistrationWorkerStore,
self._simple_update_one_txn(
txn,
"users",
- keyvalues={
- "name": user_id,
- "is_guest": 1,
- },
+ keyvalues={"name": user_id, "is_guest": 1},
updatevalues={
"password_hash": password_hash,
"upgrade_ts": now,
@@ -469,7 +704,7 @@ class RegistrationStore(RegistrationWorkerStore,
"appservice_id": appservice_id,
"admin": 1 if admin else 0,
"user_type": user_type,
- }
+ },
)
else:
self._simple_insert_txn(
@@ -483,20 +718,31 @@ class RegistrationStore(RegistrationWorkerStore,
"appservice_id": appservice_id,
"admin": 1 if admin else 0,
"user_type": user_type,
- }
+ },
)
+
except self.database_engine.module.IntegrityError:
- raise StoreError(
- 400, "User ID already taken.", errcode=Codes.USER_IN_USE
+ raise StoreError(400, "User ID already taken.", errcode=Codes.USER_IN_USE)
+
+ if self._account_validity.enabled:
+ now_ms = self.clock.time_msec()
+ expiration_ts = now_ms + self._account_validity.period
+ self._simple_insert_txn(
+ txn,
+ "account_validity",
+ values={
+ "user_id": user_id,
+ "expiration_ts_ms": expiration_ts,
+ "email_sent": False,
+ }
)
if token:
# it's possible for this to get a conflict, but only for a single user
# since tokens are namespaced based on their user ID
txn.execute(
- "INSERT INTO access_tokens(id, user_id, token)"
- " VALUES (?,?,?)",
- (next_id, user_id, token,)
+ "INSERT INTO access_tokens(id, user_id, token)" " VALUES (?,?,?)",
+ (next_id, user_id, token),
)
if create_profile_with_displayname:
@@ -507,12 +753,10 @@ class RegistrationStore(RegistrationWorkerStore,
# while everything else uses the full mxid.
txn.execute(
"INSERT INTO profiles(user_id, displayname) VALUES (?,?)",
- (user_id_obj.localpart, create_profile_with_displayname)
+ (user_id_obj.localpart, create_profile_with_displayname),
)
- self._invalidate_cache_and_stream(
- txn, self.get_user_by_id, (user_id,)
- )
+ self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
txn.call_after(self.is_guest.invalidate, (user_id,))
def user_set_password_hash(self, user_id, password_hash):
@@ -521,22 +765,14 @@ class RegistrationStore(RegistrationWorkerStore,
removes most of the entries subsequently anyway so it would be
pointless. Use flush_user separately.
"""
+
def user_set_password_hash_txn(txn):
self._simple_update_one_txn(
- txn,
- 'users', {
- 'name': user_id
- },
- {
- 'password_hash': password_hash
- }
- )
- self._invalidate_cache_and_stream(
- txn, self.get_user_by_id, (user_id,)
+ txn, 'users', {'name': user_id}, {'password_hash': password_hash}
)
- return self.runInteraction(
- "user_set_password_hash", user_set_password_hash_txn
- )
+ self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
+
+ return self.runInteraction("user_set_password_hash", user_set_password_hash_txn)
def user_set_consent_version(self, user_id, consent_version):
"""Updates the user table to record privacy policy consent
@@ -549,16 +785,16 @@ class RegistrationStore(RegistrationWorkerStore,
Raises:
StoreError(404) if user not found
"""
+
def f(txn):
self._simple_update_one_txn(
txn,
table='users',
- keyvalues={'name': user_id, },
- updatevalues={'consent_version': consent_version, },
- )
- self._invalidate_cache_and_stream(
- txn, self.get_user_by_id, (user_id,)
+ keyvalues={'name': user_id},
+ updatevalues={'consent_version': consent_version},
)
+ self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
+
return self.runInteraction("user_set_consent_version", f)
def user_set_consent_server_notice_sent(self, user_id, consent_version):
@@ -573,20 +809,19 @@ class RegistrationStore(RegistrationWorkerStore,
Raises:
StoreError(404) if user not found
"""
+
def f(txn):
self._simple_update_one_txn(
txn,
table='users',
- keyvalues={'name': user_id, },
- updatevalues={'consent_server_notice_sent': consent_version, },
- )
- self._invalidate_cache_and_stream(
- txn, self.get_user_by_id, (user_id,)
+ keyvalues={'name': user_id},
+ updatevalues={'consent_server_notice_sent': consent_version},
)
+ self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
+
return self.runInteraction("user_set_consent_server_notice_sent", f)
- def user_delete_access_tokens(self, user_id, except_token_id=None,
- device_id=None):
+ def user_delete_access_tokens(self, user_id, except_token_id=None, device_id=None):
"""
Invalidate access tokens belonging to a user
@@ -601,10 +836,9 @@ class RegistrationStore(RegistrationWorkerStore,
defer.Deferred[list[str, int, str|None, int]]: a list of
(token, token id, device id) for each of the deleted tokens
"""
+
def f(txn):
- keyvalues = {
- "user_id": user_id,
- }
+ keyvalues = {"user_id": user_id}
if device_id is not None:
keyvalues["device_id"] = device_id
@@ -616,8 +850,9 @@ class RegistrationStore(RegistrationWorkerStore,
values.append(except_token_id)
txn.execute(
- "SELECT token, id, device_id FROM access_tokens WHERE %s" % where_clause,
- values
+ "SELECT token, id, device_id FROM access_tokens WHERE %s"
+ % where_clause,
+ values,
)
tokens_and_devices = [(r[0], r[1], r[2]) for r in txn]
@@ -626,25 +861,16 @@ class RegistrationStore(RegistrationWorkerStore,
txn, self.get_user_by_access_token, (token,)
)
- txn.execute(
- "DELETE FROM access_tokens WHERE %s" % where_clause,
- values
- )
+ txn.execute("DELETE FROM access_tokens WHERE %s" % where_clause, values)
return tokens_and_devices
- return self.runInteraction(
- "user_delete_access_tokens", f,
- )
+ return self.runInteraction("user_delete_access_tokens", f)
def delete_access_token(self, access_token):
def f(txn):
self._simple_delete_one_txn(
- txn,
- table="access_tokens",
- keyvalues={
- "token": access_token
- },
+ txn, table="access_tokens", keyvalues={"token": access_token}
)
self._invalidate_cache_and_stream(
@@ -667,7 +893,7 @@ class RegistrationStore(RegistrationWorkerStore,
@defer.inlineCallbacks
def save_or_get_3pid_guest_access_token(
- self, medium, address, access_token, inviter_user_id
+ self, medium, address, access_token, inviter_user_id
):
"""
Gets the 3pid's guest access token if exists, else saves access_token.
@@ -683,12 +909,13 @@ class RegistrationStore(RegistrationWorkerStore,
deferred str: Whichever access token is persisted at the end
of this function call.
"""
+
def insert(txn):
txn.execute(
"INSERT INTO threepid_guest_access_tokens "
"(medium, address, guest_access_token, first_inviter) "
"VALUES (?, ?, ?, ?)",
- (medium, address, access_token, inviter_user_id)
+ (medium, address, access_token, inviter_user_id),
)
try:
@@ -705,9 +932,7 @@ class RegistrationStore(RegistrationWorkerStore,
"""
return self._simple_insert(
"users_pending_deactivation",
- values={
- "user_id": user_id,
- },
+ values={"user_id": user_id},
desc="add_user_pending_deactivation",
)
@@ -720,9 +945,7 @@ class RegistrationStore(RegistrationWorkerStore,
# the table, so somehow duplicate entries have ended up in it.
return self._simple_delete(
"users_pending_deactivation",
- keyvalues={
- "user_id": user_id,
- },
+ keyvalues={"user_id": user_id},
desc="del_user_pending_deactivation",
)
@@ -738,3 +961,34 @@ class RegistrationStore(RegistrationWorkerStore,
allow_none=True,
desc="get_users_pending_deactivation",
)
+
+ @defer.inlineCallbacks
+ def _bg_user_threepids_grandfather(self, progress, batch_size):
+ """We now track which identity servers a user binds their 3PID to, so
+ we need to handle the case of existing bindings where we didn't track
+ this.
+
+ We do this by grandfathering in existing user threepids assuming that
+ they used one of the server configured trusted identity servers.
+ """
+
+ id_servers = set(self.config.trusted_third_party_id_servers)
+
+ def _bg_user_threepids_grandfather_txn(txn):
+ sql = """
+ INSERT INTO user_threepid_id_server
+ (user_id, medium, address, id_server)
+ SELECT user_id, medium, address, ?
+ FROM user_threepids
+ """
+
+ txn.executemany(sql, [(id_server,) for id_server in id_servers])
+
+ if id_servers:
+ yield self.runInteraction(
+ "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn,
+ )
+
+ yield self._end_background_update("user_threepids_grandfather")
+
+ defer.returnValue(1)
|