diff options
author | Erik Johnston <erik@matrix.org> | 2019-04-17 19:44:40 +0100 |
---|---|---|
committer | Erik Johnston <erik@matrix.org> | 2019-04-17 19:44:40 +0100 |
commit | ca90336a6935b36b5761244005b0f68b496d5d79 (patch) | |
tree | 6bbce5eafc0db3b24ccc3b59b051da850382ae09 /synapse/storage/registration.py | |
parent | Add management endpoints for account validity (diff) | |
parent | Merge pull request #5047 from matrix-org/babolivier/account_expiration (diff) | |
download | synapse-ca90336a6935b36b5761244005b0f68b496d5d79.tar.xz |
Merge branch 'develop' of github.com:matrix-org/synapse into babolivier/account_expiration
Diffstat (limited to 'synapse/storage/registration.py')
-rw-r--r-- | synapse/storage/registration.py | 331 |
1 files changed, 206 insertions, 125 deletions
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index dfdb4e7e34..03a06a83d6 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -38,13 +38,15 @@ class RegistrationWorkerStore(SQLBaseStore): def get_user_by_id(self, user_id): return self._simple_select_one( table="users", - keyvalues={ - "name": user_id, - }, + keyvalues={"name": user_id}, retcols=[ - "name", "password_hash", "is_guest", - "consent_version", "consent_server_notice_sent", - "appservice_id", "creation_ts", + "name", + "password_hash", + "is_guest", + "consent_version", + "consent_server_notice_sent", + "appservice_id", + "creation_ts", ], allow_none=True, desc="get_user_by_id", @@ -82,9 +84,7 @@ class RegistrationWorkerStore(SQLBaseStore): including the keys `name`, `is_guest`, `device_id`, `token_id`. """ return self.runInteraction( - "get_user_by_access_token", - self._query_for_auth, - token + "get_user_by_access_token", self._query_for_auth, token ) @cachedInlineCallbacks() @@ -300,10 +300,10 @@ class RegistrationWorkerStore(SQLBaseStore): """Gets users that match user_id case insensitively. Returns a mapping of user_id -> password_hash. """ + def f(txn): sql = ( - "SELECT name, password_hash FROM users" - " WHERE lower(name) = lower(?)" + "SELECT name, password_hash FROM users" " WHERE lower(name) = lower(?)" ) txn.execute(sql, (user_id,)) return dict(txn) @@ -313,6 +313,7 @@ class RegistrationWorkerStore(SQLBaseStore): @defer.inlineCallbacks def count_all_users(self): """Counts all users registered on the homeserver.""" + def _count_users(txn): txn.execute("SELECT COUNT(*) AS users FROM users") rows = self.cursor_to_dict(txn) @@ -330,6 +331,7 @@ class RegistrationWorkerStore(SQLBaseStore): 3) bridged users who registered on the homeserver in the past 24 hours """ + def _count_daily_user_type(txn): yesterday = int(self._clock.time()) - (60 * 60 * 24) @@ -350,15 +352,18 @@ class RegistrationWorkerStore(SQLBaseStore): for row in txn: results[row[0]] = row[1] return results + return self.runInteraction("count_daily_user_type", _count_daily_user_type) @defer.inlineCallbacks def count_nonbridged_users(self): def _count_users(txn): - txn.execute(""" + txn.execute( + """ SELECT COALESCE(COUNT(*), 0) FROM users WHERE appservice_id IS NULL - """) + """ + ) count, = txn.fetchone() return count @@ -377,6 +382,7 @@ class RegistrationWorkerStore(SQLBaseStore): avoid the case of ID 10000000 being pre-allocated, so us wasting the first (and shortest) many generated user IDs. """ + def _find_next_generated_user_id(txn): txn.execute("SELECT name FROM users") @@ -384,7 +390,7 @@ class RegistrationWorkerStore(SQLBaseStore): found = set() - for user_id, in txn: + for (user_id,) in txn: match = regex.search(user_id) if match: found.add(int(match.group(1))) @@ -392,20 +398,22 @@ class RegistrationWorkerStore(SQLBaseStore): if i not in found: return i - defer.returnValue((yield self.runInteraction( - "find_next_generated_user_id", - _find_next_generated_user_id - ))) + defer.returnValue( + ( + yield self.runInteraction( + "find_next_generated_user_id", _find_next_generated_user_id + ) + ) + ) @defer.inlineCallbacks def get_3pid_guest_access_token(self, medium, address): ret = yield self._simple_select_one( "threepid_guest_access_tokens", - { - "medium": medium, - "address": address - }, - ["guest_access_token"], True, 'get_3pid_guest_access_token' + {"medium": medium, "address": address}, + ["guest_access_token"], + True, + 'get_3pid_guest_access_token', ) if ret: defer.returnValue(ret["guest_access_token"]) @@ -423,8 +431,7 @@ class RegistrationWorkerStore(SQLBaseStore): Deferred[str|None]: user id or None if no user id/threepid mapping exists """ user_id = yield self.runInteraction( - "get_user_id_by_threepid", self.get_user_id_by_threepid_txn, - medium, address + "get_user_id_by_threepid", self.get_user_id_by_threepid_txn, medium, address ) defer.returnValue(user_id) @@ -442,11 +449,9 @@ class RegistrationWorkerStore(SQLBaseStore): ret = self._simple_select_one_txn( txn, "user_threepids", - { - "medium": medium, - "address": address - }, - ['user_id'], True + {"medium": medium, "address": address}, + ['user_id'], + True, ) if ret: return ret['user_id'] @@ -454,41 +459,110 @@ class RegistrationWorkerStore(SQLBaseStore): @defer.inlineCallbacks def user_add_threepid(self, user_id, medium, address, validated_at, added_at): - yield self._simple_upsert("user_threepids", { - "medium": medium, - "address": address, - }, { - "user_id": user_id, - "validated_at": validated_at, - "added_at": added_at, - }) + yield self._simple_upsert( + "user_threepids", + {"medium": medium, "address": address}, + {"user_id": user_id, "validated_at": validated_at, "added_at": added_at}, + ) @defer.inlineCallbacks def user_get_threepids(self, user_id): ret = yield self._simple_select_list( - "user_threepids", { - "user_id": user_id - }, + "user_threepids", + {"user_id": user_id}, ['medium', 'address', 'validated_at', 'added_at'], - 'user_get_threepids' + 'user_get_threepids', ) defer.returnValue(ret) def user_delete_threepid(self, user_id, medium, address): return self._simple_delete( "user_threepids", + keyvalues={"user_id": user_id, "medium": medium, "address": address}, + desc="user_delete_threepids", + ) + + def add_user_bound_threepid(self, user_id, medium, address, id_server): + """The server proxied a bind request to the given identity server on + behalf of the given user. We need to remember this in case the user + asks us to unbind the threepid. + + Args: + user_id (str) + medium (str) + address (str) + id_server (str) + + Returns: + Deferred + """ + # We need to use an upsert, in case they user had already bound the + # threepid + return self._simple_upsert( + table="user_threepid_id_server", keyvalues={ "user_id": user_id, "medium": medium, "address": address, + "id_server": id_server, }, - desc="user_delete_threepids", + values={}, + insertion_values={}, + desc="add_user_bound_threepid", ) + def remove_user_bound_threepid(self, user_id, medium, address, id_server): + """The server proxied an unbind request to the given identity server on + behalf of the given user, so we remove the mapping of threepid to + identity server. + + Args: + user_id (str) + medium (str) + address (str) + id_server (str) + + Returns: + Deferred + """ + return self._simple_delete( + table="user_threepid_id_server", + keyvalues={ + "user_id": user_id, + "medium": medium, + "address": address, + "id_server": id_server, + }, + desc="remove_user_bound_threepid", + ) -class RegistrationStore(RegistrationWorkerStore, - background_updates.BackgroundUpdateStore): + def get_id_servers_user_bound(self, user_id, medium, address): + """Get the list of identity servers that the server proxied bind + requests to for given user and threepid + Args: + user_id (str) + medium (str) + address (str) + + Returns: + Deferred[list[str]]: Resolves to a list of identity servers + """ + return self._simple_select_onecol( + table="user_threepid_id_server", + keyvalues={ + "user_id": user_id, + "medium": medium, + "address": address, + }, + retcol="id_server", + desc="get_id_servers_user_bound", + ) + + +class RegistrationStore( + RegistrationWorkerStore, background_updates.BackgroundUpdateStore +): def __init__(self, db_conn, hs): super(RegistrationStore, self).__init__(db_conn, hs) @@ -515,6 +589,10 @@ class RegistrationStore(RegistrationWorkerStore, # clear the background update. self.register_noop_background_update("refresh_tokens_device_index") + self.register_background_update_handler( + "user_threepids_grandfather", self._bg_user_threepids_grandfather, + ) + @defer.inlineCallbacks def add_access_token_to_user(self, user_id, token, device_id=None): """Adds an access token for the given user. @@ -531,18 +609,22 @@ class RegistrationStore(RegistrationWorkerStore, yield self._simple_insert( "access_tokens", - { - "id": next_id, - "user_id": user_id, - "token": token, - "device_id": device_id, - }, + {"id": next_id, "user_id": user_id, "token": token, "device_id": device_id}, desc="add_access_token_to_user", ) - def register(self, user_id, token=None, password_hash=None, - was_guest=False, make_guest=False, appservice_id=None, - create_profile_with_displayname=None, admin=False, user_type=None): + def register( + self, + user_id, + token=None, + password_hash=None, + was_guest=False, + make_guest=False, + appservice_id=None, + create_profile_with_displayname=None, + admin=False, + user_type=None, + ): """Attempts to register an account. Args: @@ -576,7 +658,7 @@ class RegistrationStore(RegistrationWorkerStore, appservice_id, create_profile_with_displayname, admin, - user_type + user_type, ) def _register( @@ -606,10 +688,7 @@ class RegistrationStore(RegistrationWorkerStore, self._simple_select_one_txn( txn, "users", - keyvalues={ - "name": user_id, - "is_guest": 1, - }, + keyvalues={"name": user_id, "is_guest": 1}, retcols=("name",), allow_none=False, ) @@ -617,10 +696,7 @@ class RegistrationStore(RegistrationWorkerStore, self._simple_update_one_txn( txn, "users", - keyvalues={ - "name": user_id, - "is_guest": 1, - }, + keyvalues={"name": user_id, "is_guest": 1}, updatevalues={ "password_hash": password_hash, "upgrade_ts": now, @@ -628,7 +704,7 @@ class RegistrationStore(RegistrationWorkerStore, "appservice_id": appservice_id, "admin": 1 if admin else 0, "user_type": user_type, - } + }, ) else: self._simple_insert_txn( @@ -642,13 +718,11 @@ class RegistrationStore(RegistrationWorkerStore, "appservice_id": appservice_id, "admin": 1 if admin else 0, "user_type": user_type, - } + }, ) except self.database_engine.module.IntegrityError: - raise StoreError( - 400, "User ID already taken.", errcode=Codes.USER_IN_USE - ) + raise StoreError(400, "User ID already taken.", errcode=Codes.USER_IN_USE) if self._account_validity.enabled: now_ms = self.clock.time_msec() @@ -667,9 +741,8 @@ class RegistrationStore(RegistrationWorkerStore, # it's possible for this to get a conflict, but only for a single user # since tokens are namespaced based on their user ID txn.execute( - "INSERT INTO access_tokens(id, user_id, token)" - " VALUES (?,?,?)", - (next_id, user_id, token,) + "INSERT INTO access_tokens(id, user_id, token)" " VALUES (?,?,?)", + (next_id, user_id, token), ) if create_profile_with_displayname: @@ -680,12 +753,10 @@ class RegistrationStore(RegistrationWorkerStore, # while everything else uses the full mxid. txn.execute( "INSERT INTO profiles(user_id, displayname) VALUES (?,?)", - (user_id_obj.localpart, create_profile_with_displayname) + (user_id_obj.localpart, create_profile_with_displayname), ) - self._invalidate_cache_and_stream( - txn, self.get_user_by_id, (user_id,) - ) + self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) txn.call_after(self.is_guest.invalidate, (user_id,)) def user_set_password_hash(self, user_id, password_hash): @@ -694,22 +765,14 @@ class RegistrationStore(RegistrationWorkerStore, removes most of the entries subsequently anyway so it would be pointless. Use flush_user separately. """ + def user_set_password_hash_txn(txn): self._simple_update_one_txn( - txn, - 'users', { - 'name': user_id - }, - { - 'password_hash': password_hash - } - ) - self._invalidate_cache_and_stream( - txn, self.get_user_by_id, (user_id,) + txn, 'users', {'name': user_id}, {'password_hash': password_hash} ) - return self.runInteraction( - "user_set_password_hash", user_set_password_hash_txn - ) + self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) + + return self.runInteraction("user_set_password_hash", user_set_password_hash_txn) def user_set_consent_version(self, user_id, consent_version): """Updates the user table to record privacy policy consent @@ -722,16 +785,16 @@ class RegistrationStore(RegistrationWorkerStore, Raises: StoreError(404) if user not found """ + def f(txn): self._simple_update_one_txn( txn, table='users', - keyvalues={'name': user_id, }, - updatevalues={'consent_version': consent_version, }, - ) - self._invalidate_cache_and_stream( - txn, self.get_user_by_id, (user_id,) + keyvalues={'name': user_id}, + updatevalues={'consent_version': consent_version}, ) + self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) + return self.runInteraction("user_set_consent_version", f) def user_set_consent_server_notice_sent(self, user_id, consent_version): @@ -746,20 +809,19 @@ class RegistrationStore(RegistrationWorkerStore, Raises: StoreError(404) if user not found """ + def f(txn): self._simple_update_one_txn( txn, table='users', - keyvalues={'name': user_id, }, - updatevalues={'consent_server_notice_sent': consent_version, }, - ) - self._invalidate_cache_and_stream( - txn, self.get_user_by_id, (user_id,) + keyvalues={'name': user_id}, + updatevalues={'consent_server_notice_sent': consent_version}, ) + self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) + return self.runInteraction("user_set_consent_server_notice_sent", f) - def user_delete_access_tokens(self, user_id, except_token_id=None, - device_id=None): + def user_delete_access_tokens(self, user_id, except_token_id=None, device_id=None): """ Invalidate access tokens belonging to a user @@ -774,10 +836,9 @@ class RegistrationStore(RegistrationWorkerStore, defer.Deferred[list[str, int, str|None, int]]: a list of (token, token id, device id) for each of the deleted tokens """ + def f(txn): - keyvalues = { - "user_id": user_id, - } + keyvalues = {"user_id": user_id} if device_id is not None: keyvalues["device_id"] = device_id @@ -789,8 +850,9 @@ class RegistrationStore(RegistrationWorkerStore, values.append(except_token_id) txn.execute( - "SELECT token, id, device_id FROM access_tokens WHERE %s" % where_clause, - values + "SELECT token, id, device_id FROM access_tokens WHERE %s" + % where_clause, + values, ) tokens_and_devices = [(r[0], r[1], r[2]) for r in txn] @@ -799,25 +861,16 @@ class RegistrationStore(RegistrationWorkerStore, txn, self.get_user_by_access_token, (token,) ) - txn.execute( - "DELETE FROM access_tokens WHERE %s" % where_clause, - values - ) + txn.execute("DELETE FROM access_tokens WHERE %s" % where_clause, values) return tokens_and_devices - return self.runInteraction( - "user_delete_access_tokens", f, - ) + return self.runInteraction("user_delete_access_tokens", f) def delete_access_token(self, access_token): def f(txn): self._simple_delete_one_txn( - txn, - table="access_tokens", - keyvalues={ - "token": access_token - }, + txn, table="access_tokens", keyvalues={"token": access_token} ) self._invalidate_cache_and_stream( @@ -840,7 +893,7 @@ class RegistrationStore(RegistrationWorkerStore, @defer.inlineCallbacks def save_or_get_3pid_guest_access_token( - self, medium, address, access_token, inviter_user_id + self, medium, address, access_token, inviter_user_id ): """ Gets the 3pid's guest access token if exists, else saves access_token. @@ -856,12 +909,13 @@ class RegistrationStore(RegistrationWorkerStore, deferred str: Whichever access token is persisted at the end of this function call. """ + def insert(txn): txn.execute( "INSERT INTO threepid_guest_access_tokens " "(medium, address, guest_access_token, first_inviter) " "VALUES (?, ?, ?, ?)", - (medium, address, access_token, inviter_user_id) + (medium, address, access_token, inviter_user_id), ) try: @@ -878,9 +932,7 @@ class RegistrationStore(RegistrationWorkerStore, """ return self._simple_insert( "users_pending_deactivation", - values={ - "user_id": user_id, - }, + values={"user_id": user_id}, desc="add_user_pending_deactivation", ) @@ -893,9 +945,7 @@ class RegistrationStore(RegistrationWorkerStore, # the table, so somehow duplicate entries have ended up in it. return self._simple_delete( "users_pending_deactivation", - keyvalues={ - "user_id": user_id, - }, + keyvalues={"user_id": user_id}, desc="del_user_pending_deactivation", ) @@ -911,3 +961,34 @@ class RegistrationStore(RegistrationWorkerStore, allow_none=True, desc="get_users_pending_deactivation", ) + + @defer.inlineCallbacks + def _bg_user_threepids_grandfather(self, progress, batch_size): + """We now track which identity servers a user binds their 3PID to, so + we need to handle the case of existing bindings where we didn't track + this. + + We do this by grandfathering in existing user threepids assuming that + they used one of the server configured trusted identity servers. + """ + + id_servers = set(self.config.trusted_third_party_id_servers) + + def _bg_user_threepids_grandfather_txn(txn): + sql = """ + INSERT INTO user_threepid_id_server + (user_id, medium, address, id_server) + SELECT user_id, medium, address, ? + FROM user_threepids + """ + + txn.executemany(sql, [(id_server,) for id_server in id_servers]) + + if id_servers: + yield self.runInteraction( + "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn, + ) + + yield self._end_background_update("user_threepids_grandfather") + + defer.returnValue(1) |