diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index dfdb4e7e34..03a06a83d6 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -38,13 +38,15 @@ class RegistrationWorkerStore(SQLBaseStore):
def get_user_by_id(self, user_id):
return self._simple_select_one(
table="users",
- keyvalues={
- "name": user_id,
- },
+ keyvalues={"name": user_id},
retcols=[
- "name", "password_hash", "is_guest",
- "consent_version", "consent_server_notice_sent",
- "appservice_id", "creation_ts",
+ "name",
+ "password_hash",
+ "is_guest",
+ "consent_version",
+ "consent_server_notice_sent",
+ "appservice_id",
+ "creation_ts",
],
allow_none=True,
desc="get_user_by_id",
@@ -82,9 +84,7 @@ class RegistrationWorkerStore(SQLBaseStore):
including the keys `name`, `is_guest`, `device_id`, `token_id`.
"""
return self.runInteraction(
- "get_user_by_access_token",
- self._query_for_auth,
- token
+ "get_user_by_access_token", self._query_for_auth, token
)
@cachedInlineCallbacks()
@@ -300,10 +300,10 @@ class RegistrationWorkerStore(SQLBaseStore):
"""Gets users that match user_id case insensitively.
Returns a mapping of user_id -> password_hash.
"""
+
def f(txn):
sql = (
- "SELECT name, password_hash FROM users"
- " WHERE lower(name) = lower(?)"
+ "SELECT name, password_hash FROM users" " WHERE lower(name) = lower(?)"
)
txn.execute(sql, (user_id,))
return dict(txn)
@@ -313,6 +313,7 @@ class RegistrationWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def count_all_users(self):
"""Counts all users registered on the homeserver."""
+
def _count_users(txn):
txn.execute("SELECT COUNT(*) AS users FROM users")
rows = self.cursor_to_dict(txn)
@@ -330,6 +331,7 @@ class RegistrationWorkerStore(SQLBaseStore):
3) bridged users
who registered on the homeserver in the past 24 hours
"""
+
def _count_daily_user_type(txn):
yesterday = int(self._clock.time()) - (60 * 60 * 24)
@@ -350,15 +352,18 @@ class RegistrationWorkerStore(SQLBaseStore):
for row in txn:
results[row[0]] = row[1]
return results
+
return self.runInteraction("count_daily_user_type", _count_daily_user_type)
@defer.inlineCallbacks
def count_nonbridged_users(self):
def _count_users(txn):
- txn.execute("""
+ txn.execute(
+ """
SELECT COALESCE(COUNT(*), 0) FROM users
WHERE appservice_id IS NULL
- """)
+ """
+ )
count, = txn.fetchone()
return count
@@ -377,6 +382,7 @@ class RegistrationWorkerStore(SQLBaseStore):
avoid the case of ID 10000000 being pre-allocated, so us wasting the
first (and shortest) many generated user IDs.
"""
+
def _find_next_generated_user_id(txn):
txn.execute("SELECT name FROM users")
@@ -384,7 +390,7 @@ class RegistrationWorkerStore(SQLBaseStore):
found = set()
- for user_id, in txn:
+ for (user_id,) in txn:
match = regex.search(user_id)
if match:
found.add(int(match.group(1)))
@@ -392,20 +398,22 @@ class RegistrationWorkerStore(SQLBaseStore):
if i not in found:
return i
- defer.returnValue((yield self.runInteraction(
- "find_next_generated_user_id",
- _find_next_generated_user_id
- )))
+ defer.returnValue(
+ (
+ yield self.runInteraction(
+ "find_next_generated_user_id", _find_next_generated_user_id
+ )
+ )
+ )
@defer.inlineCallbacks
def get_3pid_guest_access_token(self, medium, address):
ret = yield self._simple_select_one(
"threepid_guest_access_tokens",
- {
- "medium": medium,
- "address": address
- },
- ["guest_access_token"], True, 'get_3pid_guest_access_token'
+ {"medium": medium, "address": address},
+ ["guest_access_token"],
+ True,
+ 'get_3pid_guest_access_token',
)
if ret:
defer.returnValue(ret["guest_access_token"])
@@ -423,8 +431,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Deferred[str|None]: user id or None if no user id/threepid mapping exists
"""
user_id = yield self.runInteraction(
- "get_user_id_by_threepid", self.get_user_id_by_threepid_txn,
- medium, address
+ "get_user_id_by_threepid", self.get_user_id_by_threepid_txn, medium, address
)
defer.returnValue(user_id)
@@ -442,11 +449,9 @@ class RegistrationWorkerStore(SQLBaseStore):
ret = self._simple_select_one_txn(
txn,
"user_threepids",
- {
- "medium": medium,
- "address": address
- },
- ['user_id'], True
+ {"medium": medium, "address": address},
+ ['user_id'],
+ True,
)
if ret:
return ret['user_id']
@@ -454,41 +459,110 @@ class RegistrationWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
- yield self._simple_upsert("user_threepids", {
- "medium": medium,
- "address": address,
- }, {
- "user_id": user_id,
- "validated_at": validated_at,
- "added_at": added_at,
- })
+ yield self._simple_upsert(
+ "user_threepids",
+ {"medium": medium, "address": address},
+ {"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
+ )
@defer.inlineCallbacks
def user_get_threepids(self, user_id):
ret = yield self._simple_select_list(
- "user_threepids", {
- "user_id": user_id
- },
+ "user_threepids",
+ {"user_id": user_id},
['medium', 'address', 'validated_at', 'added_at'],
- 'user_get_threepids'
+ 'user_get_threepids',
)
defer.returnValue(ret)
def user_delete_threepid(self, user_id, medium, address):
return self._simple_delete(
"user_threepids",
+ keyvalues={"user_id": user_id, "medium": medium, "address": address},
+ desc="user_delete_threepids",
+ )
+
+ def add_user_bound_threepid(self, user_id, medium, address, id_server):
+ """The server proxied a bind request to the given identity server on
+ behalf of the given user. We need to remember this in case the user
+ asks us to unbind the threepid.
+
+ Args:
+ user_id (str)
+ medium (str)
+ address (str)
+ id_server (str)
+
+ Returns:
+ Deferred
+ """
+ # We need to use an upsert, in case they user had already bound the
+ # threepid
+ return self._simple_upsert(
+ table="user_threepid_id_server",
keyvalues={
"user_id": user_id,
"medium": medium,
"address": address,
+ "id_server": id_server,
},
- desc="user_delete_threepids",
+ values={},
+ insertion_values={},
+ desc="add_user_bound_threepid",
)
+ def remove_user_bound_threepid(self, user_id, medium, address, id_server):
+ """The server proxied an unbind request to the given identity server on
+ behalf of the given user, so we remove the mapping of threepid to
+ identity server.
+
+ Args:
+ user_id (str)
+ medium (str)
+ address (str)
+ id_server (str)
+
+ Returns:
+ Deferred
+ """
+ return self._simple_delete(
+ table="user_threepid_id_server",
+ keyvalues={
+ "user_id": user_id,
+ "medium": medium,
+ "address": address,
+ "id_server": id_server,
+ },
+ desc="remove_user_bound_threepid",
+ )
-class RegistrationStore(RegistrationWorkerStore,
- background_updates.BackgroundUpdateStore):
+ def get_id_servers_user_bound(self, user_id, medium, address):
+ """Get the list of identity servers that the server proxied bind
+ requests to for given user and threepid
+ Args:
+ user_id (str)
+ medium (str)
+ address (str)
+
+ Returns:
+ Deferred[list[str]]: Resolves to a list of identity servers
+ """
+ return self._simple_select_onecol(
+ table="user_threepid_id_server",
+ keyvalues={
+ "user_id": user_id,
+ "medium": medium,
+ "address": address,
+ },
+ retcol="id_server",
+ desc="get_id_servers_user_bound",
+ )
+
+
+class RegistrationStore(
+ RegistrationWorkerStore, background_updates.BackgroundUpdateStore
+):
def __init__(self, db_conn, hs):
super(RegistrationStore, self).__init__(db_conn, hs)
@@ -515,6 +589,10 @@ class RegistrationStore(RegistrationWorkerStore,
# clear the background update.
self.register_noop_background_update("refresh_tokens_device_index")
+ self.register_background_update_handler(
+ "user_threepids_grandfather", self._bg_user_threepids_grandfather,
+ )
+
@defer.inlineCallbacks
def add_access_token_to_user(self, user_id, token, device_id=None):
"""Adds an access token for the given user.
@@ -531,18 +609,22 @@ class RegistrationStore(RegistrationWorkerStore,
yield self._simple_insert(
"access_tokens",
- {
- "id": next_id,
- "user_id": user_id,
- "token": token,
- "device_id": device_id,
- },
+ {"id": next_id, "user_id": user_id, "token": token, "device_id": device_id},
desc="add_access_token_to_user",
)
- def register(self, user_id, token=None, password_hash=None,
- was_guest=False, make_guest=False, appservice_id=None,
- create_profile_with_displayname=None, admin=False, user_type=None):
+ def register(
+ self,
+ user_id,
+ token=None,
+ password_hash=None,
+ was_guest=False,
+ make_guest=False,
+ appservice_id=None,
+ create_profile_with_displayname=None,
+ admin=False,
+ user_type=None,
+ ):
"""Attempts to register an account.
Args:
@@ -576,7 +658,7 @@ class RegistrationStore(RegistrationWorkerStore,
appservice_id,
create_profile_with_displayname,
admin,
- user_type
+ user_type,
)
def _register(
@@ -606,10 +688,7 @@ class RegistrationStore(RegistrationWorkerStore,
self._simple_select_one_txn(
txn,
"users",
- keyvalues={
- "name": user_id,
- "is_guest": 1,
- },
+ keyvalues={"name": user_id, "is_guest": 1},
retcols=("name",),
allow_none=False,
)
@@ -617,10 +696,7 @@ class RegistrationStore(RegistrationWorkerStore,
self._simple_update_one_txn(
txn,
"users",
- keyvalues={
- "name": user_id,
- "is_guest": 1,
- },
+ keyvalues={"name": user_id, "is_guest": 1},
updatevalues={
"password_hash": password_hash,
"upgrade_ts": now,
@@ -628,7 +704,7 @@ class RegistrationStore(RegistrationWorkerStore,
"appservice_id": appservice_id,
"admin": 1 if admin else 0,
"user_type": user_type,
- }
+ },
)
else:
self._simple_insert_txn(
@@ -642,13 +718,11 @@ class RegistrationStore(RegistrationWorkerStore,
"appservice_id": appservice_id,
"admin": 1 if admin else 0,
"user_type": user_type,
- }
+ },
)
except self.database_engine.module.IntegrityError:
- raise StoreError(
- 400, "User ID already taken.", errcode=Codes.USER_IN_USE
- )
+ raise StoreError(400, "User ID already taken.", errcode=Codes.USER_IN_USE)
if self._account_validity.enabled:
now_ms = self.clock.time_msec()
@@ -667,9 +741,8 @@ class RegistrationStore(RegistrationWorkerStore,
# it's possible for this to get a conflict, but only for a single user
# since tokens are namespaced based on their user ID
txn.execute(
- "INSERT INTO access_tokens(id, user_id, token)"
- " VALUES (?,?,?)",
- (next_id, user_id, token,)
+ "INSERT INTO access_tokens(id, user_id, token)" " VALUES (?,?,?)",
+ (next_id, user_id, token),
)
if create_profile_with_displayname:
@@ -680,12 +753,10 @@ class RegistrationStore(RegistrationWorkerStore,
# while everything else uses the full mxid.
txn.execute(
"INSERT INTO profiles(user_id, displayname) VALUES (?,?)",
- (user_id_obj.localpart, create_profile_with_displayname)
+ (user_id_obj.localpart, create_profile_with_displayname),
)
- self._invalidate_cache_and_stream(
- txn, self.get_user_by_id, (user_id,)
- )
+ self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
txn.call_after(self.is_guest.invalidate, (user_id,))
def user_set_password_hash(self, user_id, password_hash):
@@ -694,22 +765,14 @@ class RegistrationStore(RegistrationWorkerStore,
removes most of the entries subsequently anyway so it would be
pointless. Use flush_user separately.
"""
+
def user_set_password_hash_txn(txn):
self._simple_update_one_txn(
- txn,
- 'users', {
- 'name': user_id
- },
- {
- 'password_hash': password_hash
- }
- )
- self._invalidate_cache_and_stream(
- txn, self.get_user_by_id, (user_id,)
+ txn, 'users', {'name': user_id}, {'password_hash': password_hash}
)
- return self.runInteraction(
- "user_set_password_hash", user_set_password_hash_txn
- )
+ self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
+
+ return self.runInteraction("user_set_password_hash", user_set_password_hash_txn)
def user_set_consent_version(self, user_id, consent_version):
"""Updates the user table to record privacy policy consent
@@ -722,16 +785,16 @@ class RegistrationStore(RegistrationWorkerStore,
Raises:
StoreError(404) if user not found
"""
+
def f(txn):
self._simple_update_one_txn(
txn,
table='users',
- keyvalues={'name': user_id, },
- updatevalues={'consent_version': consent_version, },
- )
- self._invalidate_cache_and_stream(
- txn, self.get_user_by_id, (user_id,)
+ keyvalues={'name': user_id},
+ updatevalues={'consent_version': consent_version},
)
+ self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
+
return self.runInteraction("user_set_consent_version", f)
def user_set_consent_server_notice_sent(self, user_id, consent_version):
@@ -746,20 +809,19 @@ class RegistrationStore(RegistrationWorkerStore,
Raises:
StoreError(404) if user not found
"""
+
def f(txn):
self._simple_update_one_txn(
txn,
table='users',
- keyvalues={'name': user_id, },
- updatevalues={'consent_server_notice_sent': consent_version, },
- )
- self._invalidate_cache_and_stream(
- txn, self.get_user_by_id, (user_id,)
+ keyvalues={'name': user_id},
+ updatevalues={'consent_server_notice_sent': consent_version},
)
+ self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
+
return self.runInteraction("user_set_consent_server_notice_sent", f)
- def user_delete_access_tokens(self, user_id, except_token_id=None,
- device_id=None):
+ def user_delete_access_tokens(self, user_id, except_token_id=None, device_id=None):
"""
Invalidate access tokens belonging to a user
@@ -774,10 +836,9 @@ class RegistrationStore(RegistrationWorkerStore,
defer.Deferred[list[str, int, str|None, int]]: a list of
(token, token id, device id) for each of the deleted tokens
"""
+
def f(txn):
- keyvalues = {
- "user_id": user_id,
- }
+ keyvalues = {"user_id": user_id}
if device_id is not None:
keyvalues["device_id"] = device_id
@@ -789,8 +850,9 @@ class RegistrationStore(RegistrationWorkerStore,
values.append(except_token_id)
txn.execute(
- "SELECT token, id, device_id FROM access_tokens WHERE %s" % where_clause,
- values
+ "SELECT token, id, device_id FROM access_tokens WHERE %s"
+ % where_clause,
+ values,
)
tokens_and_devices = [(r[0], r[1], r[2]) for r in txn]
@@ -799,25 +861,16 @@ class RegistrationStore(RegistrationWorkerStore,
txn, self.get_user_by_access_token, (token,)
)
- txn.execute(
- "DELETE FROM access_tokens WHERE %s" % where_clause,
- values
- )
+ txn.execute("DELETE FROM access_tokens WHERE %s" % where_clause, values)
return tokens_and_devices
- return self.runInteraction(
- "user_delete_access_tokens", f,
- )
+ return self.runInteraction("user_delete_access_tokens", f)
def delete_access_token(self, access_token):
def f(txn):
self._simple_delete_one_txn(
- txn,
- table="access_tokens",
- keyvalues={
- "token": access_token
- },
+ txn, table="access_tokens", keyvalues={"token": access_token}
)
self._invalidate_cache_and_stream(
@@ -840,7 +893,7 @@ class RegistrationStore(RegistrationWorkerStore,
@defer.inlineCallbacks
def save_or_get_3pid_guest_access_token(
- self, medium, address, access_token, inviter_user_id
+ self, medium, address, access_token, inviter_user_id
):
"""
Gets the 3pid's guest access token if exists, else saves access_token.
@@ -856,12 +909,13 @@ class RegistrationStore(RegistrationWorkerStore,
deferred str: Whichever access token is persisted at the end
of this function call.
"""
+
def insert(txn):
txn.execute(
"INSERT INTO threepid_guest_access_tokens "
"(medium, address, guest_access_token, first_inviter) "
"VALUES (?, ?, ?, ?)",
- (medium, address, access_token, inviter_user_id)
+ (medium, address, access_token, inviter_user_id),
)
try:
@@ -878,9 +932,7 @@ class RegistrationStore(RegistrationWorkerStore,
"""
return self._simple_insert(
"users_pending_deactivation",
- values={
- "user_id": user_id,
- },
+ values={"user_id": user_id},
desc="add_user_pending_deactivation",
)
@@ -893,9 +945,7 @@ class RegistrationStore(RegistrationWorkerStore,
# the table, so somehow duplicate entries have ended up in it.
return self._simple_delete(
"users_pending_deactivation",
- keyvalues={
- "user_id": user_id,
- },
+ keyvalues={"user_id": user_id},
desc="del_user_pending_deactivation",
)
@@ -911,3 +961,34 @@ class RegistrationStore(RegistrationWorkerStore,
allow_none=True,
desc="get_users_pending_deactivation",
)
+
+ @defer.inlineCallbacks
+ def _bg_user_threepids_grandfather(self, progress, batch_size):
+ """We now track which identity servers a user binds their 3PID to, so
+ we need to handle the case of existing bindings where we didn't track
+ this.
+
+ We do this by grandfathering in existing user threepids assuming that
+ they used one of the server configured trusted identity servers.
+ """
+
+ id_servers = set(self.config.trusted_third_party_id_servers)
+
+ def _bg_user_threepids_grandfather_txn(txn):
+ sql = """
+ INSERT INTO user_threepid_id_server
+ (user_id, medium, address, id_server)
+ SELECT user_id, medium, address, ?
+ FROM user_threepids
+ """
+
+ txn.executemany(sql, [(id_server,) for id_server in id_servers])
+
+ if id_servers:
+ yield self.runInteraction(
+ "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn,
+ )
+
+ yield self._end_background_update("user_threepids_grandfather")
+
+ defer.returnValue(1)
|