diff options
Diffstat (limited to 'synapse/storage/registration.py')
-rw-r--r-- | synapse/storage/registration.py | 26 |
1 files changed, 23 insertions, 3 deletions
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 999b710fbb..70cde0d04d 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -18,7 +18,7 @@ from twisted.internet import defer from synapse.api.errors import StoreError, Codes from ._base import SQLBaseStore -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList class RegistrationStore(SQLBaseStore): @@ -256,10 +256,10 @@ class RegistrationStore(SQLBaseStore): defer.returnValue(res if res else False) @cachedInlineCallbacks() - def is_guest(self, user): + def is_guest(self, user_id): res = yield self._simple_select_one_onecol( table="users", - keyvalues={"name": user.to_string()}, + keyvalues={"name": user_id}, retcol="is_guest", allow_none=True, desc="is_guest", @@ -267,6 +267,26 @@ class RegistrationStore(SQLBaseStore): defer.returnValue(res if res else False) + @cachedList(cache=is_guest.cache, list_name="user_ids", num_args=1, + inlineCallbacks=True) + def are_guests(self, user_ids): + sql = "SELECT name, is_guest FROM users WHERE name IN (%s)" % ( + ",".join("?" for _ in user_ids), + ) + + rows = yield self._execute( + "are_guests", self.cursor_to_dict, sql, *user_ids + ) + + result = {user_id: False for user_id in user_ids} + + result.update({ + row["name"]: bool(row["is_guest"]) + for row in rows + }) + + defer.returnValue(result) + def _query_for_auth(self, txn, token): sql = ( "SELECT users.name, users.is_guest, access_tokens.id as token_id" |