diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 09a05b08ef..967c732bda 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2014, 2015 OpenMarket Ltd
+# Copyright 2014 - 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,12 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import re
+
from twisted.internet import defer
from synapse.api.errors import StoreError, Codes
from ._base import SQLBaseStore
-from synapse.util.caches.descriptors import cached
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
class RegistrationStore(SQLBaseStore):
@@ -73,30 +75,45 @@ class RegistrationStore(SQLBaseStore):
)
@defer.inlineCallbacks
- def register(self, user_id, token, password_hash):
+ def register(self, user_id, token, password_hash,
+ was_guest=False, make_guest=False):
"""Attempts to register an account.
Args:
user_id (str): The desired user ID to register.
token (str): The desired access token to use for this user.
password_hash (str): Optional. The password hash for this user.
+ was_guest (bool): Optional. Whether this is a guest account being
+ upgraded to a non-guest account.
+ make_guest (boolean): True if the the new user should be guest,
+ false to add a regular user account.
Raises:
StoreError if the user_id could not be registered.
"""
yield self.runInteraction(
"register",
- self._register, user_id, token, password_hash
+ self._register, user_id, token, password_hash, was_guest, make_guest
)
+ self.is_guest.invalidate((user_id,))
- def _register(self, txn, user_id, token, password_hash):
+ def _register(self, txn, user_id, token, password_hash, was_guest, make_guest):
now = int(self.clock.time())
next_id = self._access_tokens_id_gen.get_next_txn(txn)
try:
- txn.execute("INSERT INTO users(name, password_hash, creation_ts) "
- "VALUES (?,?,?)",
- [user_id, password_hash, now])
+ if was_guest:
+ txn.execute("UPDATE users SET"
+ " password_hash = ?,"
+ " upgrade_ts = ?,"
+ " is_guest = ?"
+ " WHERE name = ?",
+ [password_hash, now, 1 if make_guest else 0, user_id])
+ else:
+ txn.execute("INSERT INTO users "
+ "(name, password_hash, creation_ts, is_guest) "
+ "VALUES (?,?,?,?)",
+ [user_id, password_hash, now, 1 if make_guest else 0])
except self.database_engine.module.IntegrityError:
raise StoreError(
400, "User ID already taken.", errcode=Codes.USER_IN_USE
@@ -117,8 +134,9 @@ class RegistrationStore(SQLBaseStore):
keyvalues={
"name": user_id,
},
- retcols=["name", "password_hash"],
+ retcols=["name", "password_hash", "is_guest"],
allow_none=True,
+ desc="get_user_by_id",
)
def get_users_by_id_case_insensitive(self, user_id):
@@ -240,9 +258,41 @@ class RegistrationStore(SQLBaseStore):
defer.returnValue(res if res else False)
+ @cachedInlineCallbacks()
+ def is_guest(self, user_id):
+ res = yield self._simple_select_one_onecol(
+ table="users",
+ keyvalues={"name": user_id},
+ retcol="is_guest",
+ allow_none=True,
+ desc="is_guest",
+ )
+
+ defer.returnValue(res if res else False)
+
+ @cachedList(cache=is_guest.cache, list_name="user_ids", num_args=1,
+ inlineCallbacks=True)
+ def are_guests(self, user_ids):
+ sql = "SELECT name, is_guest FROM users WHERE name IN (%s)" % (
+ ",".join("?" for _ in user_ids),
+ )
+
+ rows = yield self._execute(
+ "are_guests", self.cursor_to_dict, sql, *user_ids
+ )
+
+ result = {user_id: False for user_id in user_ids}
+
+ result.update({
+ row["name"]: bool(row["is_guest"])
+ for row in rows
+ })
+
+ defer.returnValue(result)
+
def _query_for_auth(self, txn, token):
sql = (
- "SELECT users.name, access_tokens.id as token_id"
+ "SELECT users.name, users.is_guest, access_tokens.id as token_id"
" FROM users"
" INNER JOIN access_tokens on users.name = access_tokens.user_id"
" WHERE token = ?"
@@ -303,3 +353,37 @@ class RegistrationStore(SQLBaseStore):
ret = yield self.runInteraction("count_users", _count_users)
defer.returnValue(ret)
+
+ @defer.inlineCallbacks
+ def find_next_generated_user_id_localpart(self):
+ """
+ Gets the localpart of the next generated user ID.
+
+ Generated user IDs are integers, and we aim for them to be as small as
+ we can. Unfortunately, it's possible some of them are already taken by
+ existing users, and there may be gaps in the already taken range. This
+ function returns the start of the first allocatable gap. This is to
+ avoid the case of ID 10000000 being pre-allocated, so us wasting the
+ first (and shortest) many generated user IDs.
+ """
+ def _find_next_generated_user_id(txn):
+ txn.execute("SELECT name FROM users")
+ rows = self.cursor_to_dict(txn)
+
+ regex = re.compile("^@(\d+):")
+
+ found = set()
+
+ for r in rows:
+ user_id = r["name"]
+ match = regex.search(user_id)
+ if match:
+ found.add(int(match.group(1)))
+ for i in xrange(len(found) + 1):
+ if i not in found:
+ return i
+
+ defer.returnValue((yield self.runInteraction(
+ "find_next_generated_user_id",
+ _find_next_generated_user_id
+ )))
|