diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index f0fa0bd33c..ece71f2ee8 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -18,7 +18,7 @@ from twisted.internet import defer
from synapse.api.errors import StoreError, Codes
from ._base import SQLBaseStore
-from synapse.util.caches.descriptors import cached
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
class RegistrationStore(SQLBaseStore):
@@ -73,7 +73,8 @@ class RegistrationStore(SQLBaseStore):
)
@defer.inlineCallbacks
- def register(self, user_id, token, password_hash, was_guest=False):
+ def register(self, user_id, token, password_hash,
+ was_guest=False, make_guest=False):
"""Attempts to register an account.
Args:
@@ -82,15 +83,18 @@ class RegistrationStore(SQLBaseStore):
password_hash (str): Optional. The password hash for this user.
was_guest (bool): Optional. Whether this is a guest account being
upgraded to a non-guest account.
+ make_guest (boolean): True if the the new user should be guest,
+ false to add a regular user account.
Raises:
StoreError if the user_id could not be registered.
"""
yield self.runInteraction(
"register",
- self._register, user_id, token, password_hash, was_guest
+ self._register, user_id, token, password_hash, was_guest, make_guest
)
+ self.is_guest.invalidate((user_id,))
- def _register(self, txn, user_id, token, password_hash, was_guest):
+ def _register(self, txn, user_id, token, password_hash, was_guest, make_guest):
now = int(self.clock.time())
next_id = self._access_tokens_id_gen.get_next_txn(txn)
@@ -99,13 +103,15 @@ class RegistrationStore(SQLBaseStore):
if was_guest:
txn.execute("UPDATE users SET"
" password_hash = ?,"
- " upgrade_ts = ?"
+ " upgrade_ts = ?,"
+ " is_guest = ?"
" WHERE name = ?",
- [password_hash, now, user_id])
+ [password_hash, now, make_guest, user_id])
else:
- txn.execute("INSERT INTO users(name, password_hash, creation_ts) "
- "VALUES (?,?,?)",
- [user_id, password_hash, now])
+ txn.execute("INSERT INTO users "
+ "(name, password_hash, creation_ts, is_guest) "
+ "VALUES (?,?,?,?)",
+ [user_id, password_hash, now, make_guest])
except self.database_engine.module.IntegrityError:
raise StoreError(
400, "User ID already taken.", errcode=Codes.USER_IN_USE
@@ -126,7 +132,7 @@ class RegistrationStore(SQLBaseStore):
keyvalues={
"name": user_id,
},
- retcols=["name", "password_hash"],
+ retcols=["name", "password_hash", "is_guest"],
allow_none=True,
)
@@ -249,9 +255,21 @@ class RegistrationStore(SQLBaseStore):
defer.returnValue(res if res else False)
+ @cachedInlineCallbacks()
+ def is_guest(self, user):
+ res = yield self._simple_select_one_onecol(
+ table="users",
+ keyvalues={"name": user.to_string()},
+ retcol="is_guest",
+ allow_none=True,
+ desc="is_guest",
+ )
+
+ defer.returnValue(res if res else False)
+
def _query_for_auth(self, txn, token):
sql = (
- "SELECT users.name, access_tokens.id as token_id"
+ "SELECT users.name, users.is_guest, access_tokens.id as token_id"
" FROM users"
" INNER JOIN access_tokens on users.name = access_tokens.user_id"
" WHERE token = ?"
|