diff options
Diffstat (limited to 'synapse/storage/registration.py')
-rw-r--r-- | synapse/storage/registration.py | 207 |
1 files changed, 154 insertions, 53 deletions
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index bda84a744a..7e7d32eb66 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -18,25 +18,40 @@ import re from twisted.internet import defer from synapse.api.errors import StoreError, Codes - -from ._base import SQLBaseStore +from synapse.storage import background_updates from synapse.util.caches.descriptors import cached, cachedInlineCallbacks -class RegistrationStore(SQLBaseStore): +class RegistrationStore(background_updates.BackgroundUpdateStore): def __init__(self, hs): super(RegistrationStore, self).__init__(hs) self.clock = hs.get_clock() + self.register_background_index_update( + "access_tokens_device_index", + index_name="access_tokens_device_id", + table="access_tokens", + columns=["user_id", "device_id"], + ) + + self.register_background_index_update( + "refresh_tokens_device_index", + index_name="refresh_tokens_device_id", + table="refresh_tokens", + columns=["user_id", "device_id"], + ) + @defer.inlineCallbacks - def add_access_token_to_user(self, user_id, token): + def add_access_token_to_user(self, user_id, token, device_id=None): """Adds an access token for the given user. Args: user_id (str): The user ID. token (str): The new access token to add. + device_id (str): ID of the device to associate with the access + token Raises: StoreError if there was a problem adding this. """ @@ -47,18 +62,21 @@ class RegistrationStore(SQLBaseStore): { "id": next_id, "user_id": user_id, - "token": token + "token": token, + "device_id": device_id, }, desc="add_access_token_to_user", ) @defer.inlineCallbacks - def add_refresh_token_to_user(self, user_id, token): + def add_refresh_token_to_user(self, user_id, token, device_id=None): """Adds a refresh token for the given user. Args: user_id (str): The user ID. token (str): The new refresh token to add. + device_id (str): ID of the device to associate with the access + token Raises: StoreError if there was a problem adding this. """ @@ -69,25 +87,31 @@ class RegistrationStore(SQLBaseStore): { "id": next_id, "user_id": user_id, - "token": token + "token": token, + "device_id": device_id, }, desc="add_refresh_token_to_user", ) @defer.inlineCallbacks - def register(self, user_id, token, password_hash, - was_guest=False, make_guest=False, appservice_id=None): + def register(self, user_id, token=None, password_hash=None, + was_guest=False, make_guest=False, appservice_id=None, + create_profile_with_localpart=None, admin=False): """Attempts to register an account. Args: user_id (str): The desired user ID to register. - token (str): The desired access token to use for this user. + token (str): The desired access token to use for this user. If this + is not None, the given access token is associated with the user + id. password_hash (str): Optional. The password hash for this user. was_guest (bool): Optional. Whether this is a guest account being upgraded to a non-guest account. make_guest (boolean): True if the the new user should be guest, false to add a regular user account. appservice_id (str): The ID of the appservice registering the user. + create_profile_with_localpart (str): Optionally create a profile for + the given localpart. Raises: StoreError if the user_id could not be registered. """ @@ -99,7 +123,9 @@ class RegistrationStore(SQLBaseStore): password_hash, was_guest, make_guest, - appservice_id + appservice_id, + create_profile_with_localpart, + admin ) self.get_user_by_id.invalidate((user_id,)) self.is_guest.invalidate((user_id,)) @@ -112,7 +138,9 @@ class RegistrationStore(SQLBaseStore): password_hash, was_guest, make_guest, - appservice_id + appservice_id, + create_profile_with_localpart, + admin, ): now = int(self.clock.time()) @@ -120,29 +148,48 @@ class RegistrationStore(SQLBaseStore): try: if was_guest: - txn.execute("UPDATE users SET" - " password_hash = ?," - " upgrade_ts = ?," - " is_guest = ?" - " WHERE name = ?", - [password_hash, now, 1 if make_guest else 0, user_id]) + # Ensure that the guest user actually exists + # ``allow_none=False`` makes this raise an exception + # if the row isn't in the database. + self._simple_select_one_txn( + txn, + "users", + keyvalues={ + "name": user_id, + "is_guest": 1, + }, + retcols=("name",), + allow_none=False, + ) + + self._simple_update_one_txn( + txn, + "users", + keyvalues={ + "name": user_id, + "is_guest": 1, + }, + updatevalues={ + "password_hash": password_hash, + "upgrade_ts": now, + "is_guest": 1 if make_guest else 0, + "appservice_id": appservice_id, + "admin": 1 if admin else 0, + } + ) else: - txn.execute("INSERT INTO users " - "(" - " name," - " password_hash," - " creation_ts," - " is_guest," - " appservice_id" - ") " - "VALUES (?,?,?,?,?)", - [ - user_id, - password_hash, - now, - 1 if make_guest else 0, - appservice_id, - ]) + self._simple_insert_txn( + txn, + "users", + values={ + "name": user_id, + "password_hash": password_hash, + "creation_ts": now, + "is_guest": 1 if make_guest else 0, + "appservice_id": appservice_id, + "admin": 1 if admin else 0, + } + ) except self.database_engine.module.IntegrityError: raise StoreError( 400, "User ID already taken.", errcode=Codes.USER_IN_USE @@ -157,6 +204,12 @@ class RegistrationStore(SQLBaseStore): (next_id, user_id, token,) ) + if create_profile_with_localpart: + txn.execute( + "INSERT INTO profiles(user_id) VALUES (?)", + (create_profile_with_localpart,) + ) + @cached() def get_user_by_id(self, user_id): return self._simple_select_one( @@ -198,16 +251,37 @@ class RegistrationStore(SQLBaseStore): self.get_user_by_id.invalidate((user_id,)) @defer.inlineCallbacks - def user_delete_access_tokens(self, user_id, except_token_ids=[]): - def f(txn): - sql = "SELECT token FROM access_tokens WHERE user_id = ?" + def user_delete_access_tokens(self, user_id, except_token_ids=[], + device_id=None, + delete_refresh_tokens=False): + """ + Invalidate access/refresh tokens belonging to a user + + Args: + user_id (str): ID of user the tokens belong to + except_token_ids (list[str]): list of access_tokens which should + *not* be deleted + device_id (str|None): ID of device the tokens are associated with. + If None, tokens associated with any device (or no device) will + be deleted + delete_refresh_tokens (bool): True to delete refresh tokens as + well as access tokens. + Returns: + defer.Deferred: + """ + def f(txn, table, except_tokens, call_after_delete): + sql = "SELECT token FROM %s WHERE user_id = ?" % table clauses = [user_id] - if except_token_ids: + if device_id is not None: + sql += " AND device_id = ?" + clauses.append(device_id) + + if except_tokens: sql += " AND id NOT IN (%s)" % ( - ",".join(["?" for _ in except_token_ids]), + ",".join(["?" for _ in except_tokens]), ) - clauses += except_token_ids + clauses += except_tokens txn.execute(sql, clauses) @@ -216,16 +290,33 @@ class RegistrationStore(SQLBaseStore): n = 100 chunks = [rows[i:i + n] for i in xrange(0, len(rows), n)] for chunk in chunks: - for row in chunk: - txn.call_after(self.get_user_by_access_token.invalidate, (row[0],)) + if call_after_delete: + for row in chunk: + txn.call_after(call_after_delete, (row[0],)) txn.execute( - "DELETE FROM access_tokens WHERE token in (%s)" % ( + "DELETE FROM %s WHERE token in (%s)" % ( + table, ",".join(["?" for _ in chunk]), ), [r[0] for r in chunk] ) - yield self.runInteraction("user_delete_access_tokens", f) + # delete refresh tokens first, to stop new access tokens being + # allocated while our backs are turned + if delete_refresh_tokens: + yield self.runInteraction( + "user_delete_access_tokens", f, + table="refresh_tokens", + except_tokens=[], + call_after_delete=None, + ) + + yield self.runInteraction( + "user_delete_access_tokens", f, + table="access_tokens", + except_tokens=except_token_ids, + call_after_delete=self.get_user_by_access_token.invalidate, + ) def delete_access_token(self, access_token): def f(txn): @@ -248,9 +339,8 @@ class RegistrationStore(SQLBaseStore): Args: token (str): The access token of a user. Returns: - dict: Including the name (user_id) and the ID of their access token. - Raises: - StoreError if no user was found. + defer.Deferred: None, if the token did not match, otherwise dict + including the keys `name`, `is_guest`, `device_id`, `token_id`. """ return self.runInteraction( "get_user_by_access_token", @@ -259,18 +349,18 @@ class RegistrationStore(SQLBaseStore): ) def exchange_refresh_token(self, refresh_token, token_generator): - """Exchange a refresh token for a new access token and refresh token. + """Exchange a refresh token for a new one. Doing so invalidates the old refresh token - refresh tokens are single use. Args: - token (str): The refresh token of a user. + refresh_token (str): The refresh token of a user. token_generator (fn: str -> str): Function which, when given a user ID, returns a unique refresh token for that user. This function must never return the same value twice. Returns: - tuple of (user_id, refresh_token) + tuple of (user_id, new_refresh_token, device_id) Raises: StoreError if no user was found with that refresh token. """ @@ -282,12 +372,13 @@ class RegistrationStore(SQLBaseStore): ) def _exchange_refresh_token(self, txn, old_token, token_generator): - sql = "SELECT user_id FROM refresh_tokens WHERE token = ?" + sql = "SELECT user_id, device_id FROM refresh_tokens WHERE token = ?" txn.execute(sql, (old_token,)) rows = self.cursor_to_dict(txn) if not rows: raise StoreError(403, "Did not recognize refresh token") user_id = rows[0]["user_id"] + device_id = rows[0]["device_id"] # TODO(danielwh): Maybe perform a validation on the macaroon that # macaroon.user_id == user_id. @@ -296,7 +387,7 @@ class RegistrationStore(SQLBaseStore): sql = "UPDATE refresh_tokens SET token = ? WHERE token = ?" txn.execute(sql, (new_token, old_token,)) - return user_id, new_token + return user_id, new_token, device_id @defer.inlineCallbacks def is_server_admin(self, user): @@ -324,7 +415,8 @@ class RegistrationStore(SQLBaseStore): def _query_for_auth(self, txn, token): sql = ( - "SELECT users.name, users.is_guest, access_tokens.id as token_id" + "SELECT users.name, users.is_guest, access_tokens.id as token_id," + " access_tokens.device_id" " FROM users" " INNER JOIN access_tokens on users.name = access_tokens.user_id" " WHERE token = ?" @@ -373,6 +465,15 @@ class RegistrationStore(SQLBaseStore): defer.returnValue(ret['user_id']) defer.returnValue(None) + def user_delete_threepids(self, user_id): + return self._simple_delete( + "user_threepids", + keyvalues={ + "user_id": user_id, + }, + desc="user_delete_threepids", + ) + @defer.inlineCallbacks def count_all_users(self): """Counts all users registered on the homeserver.""" |