summary refs log tree commit diff
path: root/synapse/storage/registration.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/registration.py')
-rw-r--r--synapse/storage/registration.py40
1 files changed, 29 insertions, 11 deletions
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index f0fa0bd33c..c79066f774 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -18,7 +18,7 @@ from twisted.internet import defer
 from synapse.api.errors import StoreError, Codes
 
 from ._base import SQLBaseStore
-from synapse.util.caches.descriptors import cached
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
 
 
 class RegistrationStore(SQLBaseStore):
@@ -73,7 +73,8 @@ class RegistrationStore(SQLBaseStore):
         )
 
     @defer.inlineCallbacks
-    def register(self, user_id, token, password_hash, was_guest=False):
+    def register(self, user_id, token, password_hash,
+                 was_guest=False, make_guest=False):
         """Attempts to register an account.
 
         Args:
@@ -82,15 +83,18 @@ class RegistrationStore(SQLBaseStore):
             password_hash (str): Optional. The password hash for this user.
             was_guest (bool): Optional. Whether this is a guest account being
                 upgraded to a non-guest account.
+            make_guest (boolean): True if the the new user should be guest,
+                false to add a regular user account.
         Raises:
             StoreError if the user_id could not be registered.
         """
         yield self.runInteraction(
             "register",
-            self._register, user_id, token, password_hash, was_guest
+            self._register, user_id, token, password_hash, was_guest, make_guest
         )
+        self.is_guest.invalidate((user_id,))
 
-    def _register(self, txn, user_id, token, password_hash, was_guest):
+    def _register(self, txn, user_id, token, password_hash, was_guest, make_guest):
         now = int(self.clock.time())
 
         next_id = self._access_tokens_id_gen.get_next_txn(txn)
@@ -100,12 +104,14 @@ class RegistrationStore(SQLBaseStore):
                 txn.execute("UPDATE users SET"
                             " password_hash = ?,"
                             " upgrade_ts = ?"
+                            " is_guest = ?"
                             " WHERE name = ?",
-                            [password_hash, now, user_id])
+                            [password_hash, now, make_guest, user_id])
             else:
-                txn.execute("INSERT INTO users(name, password_hash, creation_ts) "
-                            "VALUES (?,?,?)",
-                            [user_id, password_hash, now])
+                txn.execute("INSERT INTO users "
+                            "(name, password_hash, creation_ts, is_guest) "
+                            "VALUES (?,?,?,?)",
+                            [user_id, password_hash, now, make_guest])
         except self.database_engine.module.IntegrityError:
             raise StoreError(
                 400, "User ID already taken.", errcode=Codes.USER_IN_USE
@@ -126,7 +132,7 @@ class RegistrationStore(SQLBaseStore):
             keyvalues={
                 "name": user_id,
             },
-            retcols=["name", "password_hash"],
+            retcols=["name", "password_hash", "is_guest"],
             allow_none=True,
         )
 
@@ -136,7 +142,7 @@ class RegistrationStore(SQLBaseStore):
         """
         def f(txn):
             sql = (
-                "SELECT name, password_hash FROM users"
+                "SELECT name, password_hash, is_guest FROM users"
                 " WHERE lower(name) = lower(?)"
             )
             txn.execute(sql, (user_id,))
@@ -249,9 +255,21 @@ class RegistrationStore(SQLBaseStore):
 
         defer.returnValue(res if res else False)
 
+    @cachedInlineCallbacks()
+    def is_guest(self, user):
+        res = yield self._simple_select_one_onecol(
+            table="users",
+            keyvalues={"name": user.to_string()},
+            retcol="is_guest",
+            allow_none=True,
+            desc="is_guest",
+        )
+
+        defer.returnValue(res if res else False)
+
     def _query_for_auth(self, txn, token):
         sql = (
-            "SELECT users.name, access_tokens.id as token_id"
+            "SELECT users.name, users.is_guest, access_tokens.id as token_id"
             " FROM users"
             " INNER JOIN access_tokens on users.name = access_tokens.user_id"
             " WHERE token = ?"