summary refs log tree commit diff
path: root/synapse/storage/registration.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/storage/registration.py26
1 files changed, 23 insertions, 3 deletions
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 999b710fbb..70cde0d04d 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -18,7 +18,7 @@ from twisted.internet import defer
 from synapse.api.errors import StoreError, Codes
 
 from ._base import SQLBaseStore
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
 
 
 class RegistrationStore(SQLBaseStore):
@@ -256,10 +256,10 @@ class RegistrationStore(SQLBaseStore):
         defer.returnValue(res if res else False)
 
     @cachedInlineCallbacks()
-    def is_guest(self, user):
+    def is_guest(self, user_id):
         res = yield self._simple_select_one_onecol(
             table="users",
-            keyvalues={"name": user.to_string()},
+            keyvalues={"name": user_id},
             retcol="is_guest",
             allow_none=True,
             desc="is_guest",
@@ -267,6 +267,26 @@ class RegistrationStore(SQLBaseStore):
 
         defer.returnValue(res if res else False)
 
+    @cachedList(cache=is_guest.cache, list_name="user_ids", num_args=1,
+                inlineCallbacks=True)
+    def are_guests(self, user_ids):
+        sql = "SELECT name, is_guest FROM users WHERE name IN (%s)" % (
+            ",".join("?" for _ in user_ids),
+        )
+
+        rows = yield self._execute(
+            "are_guests", self.cursor_to_dict, sql, *user_ids
+        )
+
+        result = {user_id: False for user_id in user_ids}
+
+        result.update({
+            row["name"]: bool(row["is_guest"])
+            for row in rows
+        })
+
+        defer.returnValue(result)
+
     def _query_for_auth(self, txn, token):
         sql = (
             "SELECT users.name, users.is_guest, access_tokens.id as token_id"