diff options
Diffstat (limited to 'synapse/storage/registration.py')
-rw-r--r-- | synapse/storage/registration.py | 338 |
1 files changed, 182 insertions, 156 deletions
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index bd4eb88a92..983a8ec52b 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -18,25 +18,40 @@ import re from twisted.internet import defer from synapse.api.errors import StoreError, Codes +from synapse.storage import background_updates +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks -from ._base import SQLBaseStore -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList - -class RegistrationStore(SQLBaseStore): +class RegistrationStore(background_updates.BackgroundUpdateStore): def __init__(self, hs): super(RegistrationStore, self).__init__(hs) self.clock = hs.get_clock() + self.register_background_index_update( + "access_tokens_device_index", + index_name="access_tokens_device_id", + table="access_tokens", + columns=["user_id", "device_id"], + ) + + self.register_background_index_update( + "refresh_tokens_device_index", + index_name="refresh_tokens_device_id", + table="refresh_tokens", + columns=["user_id", "device_id"], + ) + @defer.inlineCallbacks - def add_access_token_to_user(self, user_id, token): + def add_access_token_to_user(self, user_id, token, device_id=None): """Adds an access token for the given user. Args: user_id (str): The user ID. token (str): The new access token to add. + device_id (str): ID of the device to associate with the access + token Raises: StoreError if there was a problem adding this. """ @@ -47,51 +62,34 @@ class RegistrationStore(SQLBaseStore): { "id": next_id, "user_id": user_id, - "token": token + "token": token, + "device_id": device_id, }, desc="add_access_token_to_user", ) - @defer.inlineCallbacks - def add_refresh_token_to_user(self, user_id, token): - """Adds a refresh token for the given user. - - Args: - user_id (str): The user ID. - token (str): The new refresh token to add. - Raises: - StoreError if there was a problem adding this. - """ - next_id = self._refresh_tokens_id_gen.get_next() - - yield self._simple_insert( - "refresh_tokens", - { - "id": next_id, - "user_id": user_id, - "token": token - }, - desc="add_refresh_token_to_user", - ) - - @defer.inlineCallbacks - def register(self, user_id, token, password_hash, - was_guest=False, make_guest=False, appservice_id=None): + def register(self, user_id, token=None, password_hash=None, + was_guest=False, make_guest=False, appservice_id=None, + create_profile_with_localpart=None, admin=False): """Attempts to register an account. Args: user_id (str): The desired user ID to register. - token (str): The desired access token to use for this user. + token (str): The desired access token to use for this user. If this + is not None, the given access token is associated with the user + id. password_hash (str): Optional. The password hash for this user. was_guest (bool): Optional. Whether this is a guest account being upgraded to a non-guest account. make_guest (boolean): True if the the new user should be guest, false to add a regular user account. appservice_id (str): The ID of the appservice registering the user. + create_profile_with_localpart (str): Optionally create a profile for + the given localpart. Raises: StoreError if the user_id could not be registered. """ - yield self.runInteraction( + return self.runInteraction( "register", self._register, user_id, @@ -99,9 +97,10 @@ class RegistrationStore(SQLBaseStore): password_hash, was_guest, make_guest, - appservice_id + appservice_id, + create_profile_with_localpart, + admin ) - self.is_guest.invalidate((user_id,)) def _register( self, @@ -111,7 +110,9 @@ class RegistrationStore(SQLBaseStore): password_hash, was_guest, make_guest, - appservice_id + appservice_id, + create_profile_with_localpart, + admin, ): now = int(self.clock.time()) @@ -119,29 +120,48 @@ class RegistrationStore(SQLBaseStore): try: if was_guest: - txn.execute("UPDATE users SET" - " password_hash = ?," - " upgrade_ts = ?," - " is_guest = ?" - " WHERE name = ?", - [password_hash, now, 1 if make_guest else 0, user_id]) + # Ensure that the guest user actually exists + # ``allow_none=False`` makes this raise an exception + # if the row isn't in the database. + self._simple_select_one_txn( + txn, + "users", + keyvalues={ + "name": user_id, + "is_guest": 1, + }, + retcols=("name",), + allow_none=False, + ) + + self._simple_update_one_txn( + txn, + "users", + keyvalues={ + "name": user_id, + "is_guest": 1, + }, + updatevalues={ + "password_hash": password_hash, + "upgrade_ts": now, + "is_guest": 1 if make_guest else 0, + "appservice_id": appservice_id, + "admin": 1 if admin else 0, + } + ) else: - txn.execute("INSERT INTO users " - "(" - " name," - " password_hash," - " creation_ts," - " is_guest," - " appservice_id" - ") " - "VALUES (?,?,?,?,?)", - [ - user_id, - password_hash, - now, - 1 if make_guest else 0, - appservice_id, - ]) + self._simple_insert_txn( + txn, + "users", + values={ + "name": user_id, + "password_hash": password_hash, + "creation_ts": now, + "is_guest": 1 if make_guest else 0, + "appservice_id": appservice_id, + "admin": 1 if admin else 0, + } + ) except self.database_engine.module.IntegrityError: raise StoreError( 400, "User ID already taken.", errcode=Codes.USER_IN_USE @@ -156,6 +176,18 @@ class RegistrationStore(SQLBaseStore): (next_id, user_id, token,) ) + if create_profile_with_localpart: + txn.execute( + "INSERT INTO profiles(user_id) VALUES (?)", + (create_profile_with_localpart,) + ) + + self._invalidate_cache_and_stream( + txn, self.get_user_by_id, (user_id,) + ) + txn.call_after(self.is_guest.invalidate, (user_id,)) + + @cached() def get_user_by_id(self, user_id): return self._simple_select_one( table="users", @@ -181,48 +213,88 @@ class RegistrationStore(SQLBaseStore): return self.runInteraction("get_users_by_id_case_insensitive", f) - @defer.inlineCallbacks def user_set_password_hash(self, user_id, password_hash): """ NB. This does *not* evict any cache because the one use for this removes most of the entries subsequently anyway so it would be pointless. Use flush_user separately. """ - yield self._simple_update_one('users', { - 'name': user_id - }, { - 'password_hash': password_hash - }) + def user_set_password_hash_txn(txn): + self._simple_update_one_txn( + txn, + 'users', { + 'name': user_id + }, + { + 'password_hash': password_hash + } + ) + self._invalidate_cache_and_stream( + txn, self.get_user_by_id, (user_id,) + ) + return self.runInteraction( + "user_set_password_hash", user_set_password_hash_txn + ) @defer.inlineCallbacks - def user_delete_access_tokens(self, user_id, except_token_ids=[]): - def f(txn): - sql = "SELECT token FROM access_tokens WHERE user_id = ?" - clauses = [user_id] + def user_delete_access_tokens(self, user_id, except_token_id=None, + device_id=None, + delete_refresh_tokens=False): + """ + Invalidate access/refresh tokens belonging to a user - if except_token_ids: - sql += " AND id NOT IN (%s)" % ( - ",".join(["?" for _ in except_token_ids]), + Args: + user_id (str): ID of user the tokens belong to + except_token_id (str): list of access_tokens IDs which should + *not* be deleted + device_id (str|None): ID of device the tokens are associated with. + If None, tokens associated with any device (or no device) will + be deleted + delete_refresh_tokens (bool): True to delete refresh tokens as + well as access tokens. + Returns: + defer.Deferred: + """ + def f(txn): + keyvalues = { + "user_id": user_id, + } + if device_id is not None: + keyvalues["device_id"] = device_id + + if delete_refresh_tokens: + self._simple_delete_txn( + txn, + table="refresh_tokens", + keyvalues=keyvalues, ) - clauses += except_token_ids - txn.execute(sql, clauses) + items = keyvalues.items() + where_clause = " AND ".join(k + " = ?" for k, _ in items) + values = [v for _, v in items] + if except_token_id: + where_clause += " AND id != ?" + values.append(except_token_id) - rows = txn.fetchall() - - n = 100 - chunks = [rows[i:i + n] for i in xrange(0, len(rows), n)] - for chunk in chunks: - for row in chunk: - txn.call_after(self.get_user_by_access_token.invalidate, (row[0],)) + txn.execute( + "SELECT token FROM access_tokens WHERE %s" % where_clause, + values + ) + rows = self.cursor_to_dict(txn) - txn.execute( - "DELETE FROM access_tokens WHERE token in (%s)" % ( - ",".join(["?" for _ in chunk]), - ), [r[0] for r in chunk] + for row in rows: + self._invalidate_cache_and_stream( + txn, self.get_user_by_access_token, (row["token"],) ) - yield self.runInteraction("user_delete_access_tokens", f) + txn.execute( + "DELETE FROM access_tokens WHERE %s" % where_clause, + values + ) + + yield self.runInteraction( + "user_delete_access_tokens", f, + ) def delete_access_token(self, access_token): def f(txn): @@ -234,7 +306,9 @@ class RegistrationStore(SQLBaseStore): }, ) - txn.call_after(self.get_user_by_access_token.invalidate, (access_token,)) + self._invalidate_cache_and_stream( + txn, self.get_user_by_access_token, (access_token,) + ) return self.runInteraction("delete_access_token", f) @@ -245,9 +319,8 @@ class RegistrationStore(SQLBaseStore): Args: token (str): The access token of a user. Returns: - dict: Including the name (user_id) and the ID of their access token. - Raises: - StoreError if no user was found. + defer.Deferred: None, if the token did not match, otherwise dict + including the keys `name`, `is_guest`, `device_id`, `token_id`. """ return self.runInteraction( "get_user_by_access_token", @@ -255,46 +328,6 @@ class RegistrationStore(SQLBaseStore): token ) - def exchange_refresh_token(self, refresh_token, token_generator): - """Exchange a refresh token for a new access token and refresh token. - - Doing so invalidates the old refresh token - refresh tokens are single - use. - - Args: - token (str): The refresh token of a user. - token_generator (fn: str -> str): Function which, when given a - user ID, returns a unique refresh token for that user. This - function must never return the same value twice. - Returns: - tuple of (user_id, refresh_token) - Raises: - StoreError if no user was found with that refresh token. - """ - return self.runInteraction( - "exchange_refresh_token", - self._exchange_refresh_token, - refresh_token, - token_generator - ) - - def _exchange_refresh_token(self, txn, old_token, token_generator): - sql = "SELECT user_id FROM refresh_tokens WHERE token = ?" - txn.execute(sql, (old_token,)) - rows = self.cursor_to_dict(txn) - if not rows: - raise StoreError(403, "Did not recognize refresh token") - user_id = rows[0]["user_id"] - - # TODO(danielwh): Maybe perform a validation on the macaroon that - # macaroon.user_id == user_id. - - new_token = token_generator(user_id) - sql = "UPDATE refresh_tokens SET token = ? WHERE token = ?" - txn.execute(sql, (new_token, old_token,)) - - return user_id, new_token - @defer.inlineCallbacks def is_server_admin(self, user): res = yield self._simple_select_one_onecol( @@ -319,29 +352,10 @@ class RegistrationStore(SQLBaseStore): defer.returnValue(res if res else False) - @cachedList(cache=is_guest.cache, list_name="user_ids", num_args=1, - inlineCallbacks=True) - def are_guests(self, user_ids): - sql = "SELECT name, is_guest FROM users WHERE name IN (%s)" % ( - ",".join("?" for _ in user_ids), - ) - - rows = yield self._execute( - "are_guests", self.cursor_to_dict, sql, *user_ids - ) - - result = {user_id: False for user_id in user_ids} - - result.update({ - row["name"]: bool(row["is_guest"]) - for row in rows - }) - - defer.returnValue(result) - def _query_for_auth(self, txn, token): sql = ( - "SELECT users.name, users.is_guest, access_tokens.id as token_id" + "SELECT users.name, users.is_guest, access_tokens.id as token_id," + " access_tokens.device_id" " FROM users" " INNER JOIN access_tokens on users.name = access_tokens.user_id" " WHERE token = ?" @@ -390,6 +404,15 @@ class RegistrationStore(SQLBaseStore): defer.returnValue(ret['user_id']) defer.returnValue(None) + def user_delete_threepids(self, user_id): + return self._simple_delete( + "user_threepids", + keyvalues={ + "user_id": user_id, + }, + desc="user_delete_threepids", + ) + @defer.inlineCallbacks def count_all_users(self): """Counts all users registered on the homeserver.""" @@ -458,12 +481,15 @@ class RegistrationStore(SQLBaseStore): """ Gets the 3pid's guest access token if exists, else saves access_token. - :param medium (str): Medium of the 3pid. Must be "email". - :param address (str): 3pid address. - :param access_token (str): The access token to persist if none is - already persisted. - :param inviter_user_id (str): User ID of the inviter. - :return (deferred str): Whichever access token is persisted at the end + Args: + medium (str): Medium of the 3pid. Must be "email". + address (str): 3pid address. + access_token (str): The access token to persist if none is + already persisted. + inviter_user_id (str): User ID of the inviter. + + Returns: + deferred str: Whichever access token is persisted at the end of this function call. """ def insert(txn): |