diff options
Diffstat (limited to 'synapse/storage/registration.py')
-rw-r--r-- | synapse/storage/registration.py | 40 |
1 files changed, 13 insertions, 27 deletions
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index bd4eb88a92..bda84a744a 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -20,7 +20,7 @@ from twisted.internet import defer from synapse.api.errors import StoreError, Codes from ._base import SQLBaseStore -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks class RegistrationStore(SQLBaseStore): @@ -101,6 +101,7 @@ class RegistrationStore(SQLBaseStore): make_guest, appservice_id ) + self.get_user_by_id.invalidate((user_id,)) self.is_guest.invalidate((user_id,)) def _register( @@ -156,6 +157,7 @@ class RegistrationStore(SQLBaseStore): (next_id, user_id, token,) ) + @cached() def get_user_by_id(self, user_id): return self._simple_select_one( table="users", @@ -193,6 +195,7 @@ class RegistrationStore(SQLBaseStore): }, { 'password_hash': password_hash }) + self.get_user_by_id.invalidate((user_id,)) @defer.inlineCallbacks def user_delete_access_tokens(self, user_id, except_token_ids=[]): @@ -319,26 +322,6 @@ class RegistrationStore(SQLBaseStore): defer.returnValue(res if res else False) - @cachedList(cache=is_guest.cache, list_name="user_ids", num_args=1, - inlineCallbacks=True) - def are_guests(self, user_ids): - sql = "SELECT name, is_guest FROM users WHERE name IN (%s)" % ( - ",".join("?" for _ in user_ids), - ) - - rows = yield self._execute( - "are_guests", self.cursor_to_dict, sql, *user_ids - ) - - result = {user_id: False for user_id in user_ids} - - result.update({ - row["name"]: bool(row["is_guest"]) - for row in rows - }) - - defer.returnValue(result) - def _query_for_auth(self, txn, token): sql = ( "SELECT users.name, users.is_guest, access_tokens.id as token_id" @@ -458,12 +441,15 @@ class RegistrationStore(SQLBaseStore): """ Gets the 3pid's guest access token if exists, else saves access_token. - :param medium (str): Medium of the 3pid. Must be "email". - :param address (str): 3pid address. - :param access_token (str): The access token to persist if none is - already persisted. - :param inviter_user_id (str): User ID of the inviter. - :return (deferred str): Whichever access token is persisted at the end + Args: + medium (str): Medium of the 3pid. Must be "email". + address (str): 3pid address. + access_token (str): The access token to persist if none is + already persisted. + inviter_user_id (str): User ID of the inviter. + + Returns: + deferred str: Whichever access token is persisted at the end of this function call. """ def insert(txn): |