diff options
Diffstat (limited to 'synapse/storage/registration.py')
-rw-r--r-- | synapse/storage/registration.py | 186 |
1 files changed, 158 insertions, 28 deletions
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 70cde0d04d..bd4eb88a92 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import re + from twisted.internet import defer from synapse.api.errors import StoreError, Codes @@ -38,7 +40,7 @@ class RegistrationStore(SQLBaseStore): Raises: StoreError if there was a problem adding this. """ - next_id = yield self._access_tokens_id_gen.get_next() + next_id = self._access_tokens_id_gen.get_next() yield self._simple_insert( "access_tokens", @@ -60,7 +62,7 @@ class RegistrationStore(SQLBaseStore): Raises: StoreError if there was a problem adding this. """ - next_id = yield self._refresh_tokens_id_gen.get_next() + next_id = self._refresh_tokens_id_gen.get_next() yield self._simple_insert( "refresh_tokens", @@ -74,7 +76,7 @@ class RegistrationStore(SQLBaseStore): @defer.inlineCallbacks def register(self, user_id, token, password_hash, - was_guest=False, make_guest=False): + was_guest=False, make_guest=False, appservice_id=None): """Attempts to register an account. Args: @@ -85,19 +87,35 @@ class RegistrationStore(SQLBaseStore): upgraded to a non-guest account. make_guest (boolean): True if the the new user should be guest, false to add a regular user account. + appservice_id (str): The ID of the appservice registering the user. Raises: StoreError if the user_id could not be registered. """ yield self.runInteraction( "register", - self._register, user_id, token, password_hash, was_guest, make_guest + self._register, + user_id, + token, + password_hash, + was_guest, + make_guest, + appservice_id ) self.is_guest.invalidate((user_id,)) - def _register(self, txn, user_id, token, password_hash, was_guest, make_guest): + def _register( + self, + txn, + user_id, + token, + password_hash, + was_guest, + make_guest, + appservice_id + ): now = int(self.clock.time()) - next_id = self._access_tokens_id_gen.get_next_txn(txn) + next_id = self._access_tokens_id_gen.get_next() try: if was_guest: @@ -109,9 +127,21 @@ class RegistrationStore(SQLBaseStore): [password_hash, now, 1 if make_guest else 0, user_id]) else: txn.execute("INSERT INTO users " - "(name, password_hash, creation_ts, is_guest) " - "VALUES (?,?,?,?)", - [user_id, password_hash, now, 1 if make_guest else 0]) + "(" + " name," + " password_hash," + " creation_ts," + " is_guest," + " appservice_id" + ") " + "VALUES (?,?,?,?,?)", + [ + user_id, + password_hash, + now, + 1 if make_guest else 0, + appservice_id, + ]) except self.database_engine.module.IntegrityError: raise StoreError( 400, "User ID already taken.", errcode=Codes.USER_IN_USE @@ -134,6 +164,7 @@ class RegistrationStore(SQLBaseStore): }, retcols=["name", "password_hash", "is_guest"], allow_none=True, + desc="get_user_by_id", ) def get_users_by_id_case_insensitive(self, user_id): @@ -164,27 +195,48 @@ class RegistrationStore(SQLBaseStore): }) @defer.inlineCallbacks - def user_delete_access_tokens(self, user_id): - yield self.runInteraction( - "user_delete_access_tokens", - self._user_delete_access_tokens, user_id - ) + def user_delete_access_tokens(self, user_id, except_token_ids=[]): + def f(txn): + sql = "SELECT token FROM access_tokens WHERE user_id = ?" + clauses = [user_id] - def _user_delete_access_tokens(self, txn, user_id): - txn.execute( - "DELETE FROM access_tokens WHERE user_id = ?", - (user_id, ) - ) + if except_token_ids: + sql += " AND id NOT IN (%s)" % ( + ",".join(["?" for _ in except_token_ids]), + ) + clauses += except_token_ids - @defer.inlineCallbacks - def flush_user(self, user_id): - rows = yield self._execute( - 'flush_user', None, - "SELECT token FROM access_tokens WHERE user_id = ?", - user_id - ) - for r in rows: - self.get_user_by_access_token.invalidate((r,)) + txn.execute(sql, clauses) + + rows = txn.fetchall() + + n = 100 + chunks = [rows[i:i + n] for i in xrange(0, len(rows), n)] + for chunk in chunks: + for row in chunk: + txn.call_after(self.get_user_by_access_token.invalidate, (row[0],)) + + txn.execute( + "DELETE FROM access_tokens WHERE token in (%s)" % ( + ",".join(["?" for _ in chunk]), + ), [r[0] for r in chunk] + ) + + yield self.runInteraction("user_delete_access_tokens", f) + + def delete_access_token(self, access_token): + def f(txn): + self._simple_delete_one_txn( + txn, + table="access_tokens", + keyvalues={ + "token": access_token + }, + ) + + txn.call_after(self.get_user_by_access_token.invalidate, (access_token,)) + + return self.runInteraction("delete_access_token", f) @cached() def get_user_by_access_token(self, token): @@ -350,3 +402,81 @@ class RegistrationStore(SQLBaseStore): ret = yield self.runInteraction("count_users", _count_users) defer.returnValue(ret) + + @defer.inlineCallbacks + def find_next_generated_user_id_localpart(self): + """ + Gets the localpart of the next generated user ID. + + Generated user IDs are integers, and we aim for them to be as small as + we can. Unfortunately, it's possible some of them are already taken by + existing users, and there may be gaps in the already taken range. This + function returns the start of the first allocatable gap. This is to + avoid the case of ID 10000000 being pre-allocated, so us wasting the + first (and shortest) many generated user IDs. + """ + def _find_next_generated_user_id(txn): + txn.execute("SELECT name FROM users") + rows = self.cursor_to_dict(txn) + + regex = re.compile("^@(\d+):") + + found = set() + + for r in rows: + user_id = r["name"] + match = regex.search(user_id) + if match: + found.add(int(match.group(1))) + for i in xrange(len(found) + 1): + if i not in found: + return i + + defer.returnValue((yield self.runInteraction( + "find_next_generated_user_id", + _find_next_generated_user_id + ))) + + @defer.inlineCallbacks + def get_3pid_guest_access_token(self, medium, address): + ret = yield self._simple_select_one( + "threepid_guest_access_tokens", + { + "medium": medium, + "address": address + }, + ["guest_access_token"], True, 'get_3pid_guest_access_token' + ) + if ret: + defer.returnValue(ret["guest_access_token"]) + defer.returnValue(None) + + @defer.inlineCallbacks + def save_or_get_3pid_guest_access_token( + self, medium, address, access_token, inviter_user_id + ): + """ + Gets the 3pid's guest access token if exists, else saves access_token. + + :param medium (str): Medium of the 3pid. Must be "email". + :param address (str): 3pid address. + :param access_token (str): The access token to persist if none is + already persisted. + :param inviter_user_id (str): User ID of the inviter. + :return (deferred str): Whichever access token is persisted at the end + of this function call. + """ + def insert(txn): + txn.execute( + "INSERT INTO threepid_guest_access_tokens " + "(medium, address, guest_access_token, first_inviter) " + "VALUES (?, ?, ?, ?)", + (medium, address, access_token, inviter_user_id) + ) + + try: + yield self.runInteraction("save_3pid_guest_access_token", insert) + defer.returnValue(access_token) + except self.database_engine.module.IntegrityError: + ret = yield self.get_3pid_guest_access_token(medium, address) + defer.returnValue(ret) |