summary refs log tree commit diff
path: root/synapse/storage/registration.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/registration.py')
-rw-r--r--synapse/storage/registration.py104
1 files changed, 94 insertions, 10 deletions
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 09a05b08ef..967c732bda 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2014, 2015 OpenMarket Ltd
+# Copyright 2014 - 2016 OpenMarket Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,12 +13,14 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import re
+
 from twisted.internet import defer
 
 from synapse.api.errors import StoreError, Codes
 
 from ._base import SQLBaseStore
-from synapse.util.caches.descriptors import cached
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
 
 
 class RegistrationStore(SQLBaseStore):
@@ -73,30 +75,45 @@ class RegistrationStore(SQLBaseStore):
         )
 
     @defer.inlineCallbacks
-    def register(self, user_id, token, password_hash):
+    def register(self, user_id, token, password_hash,
+                 was_guest=False, make_guest=False):
         """Attempts to register an account.
 
         Args:
             user_id (str): The desired user ID to register.
             token (str): The desired access token to use for this user.
             password_hash (str): Optional. The password hash for this user.
+            was_guest (bool): Optional. Whether this is a guest account being
+                upgraded to a non-guest account.
+            make_guest (boolean): True if the the new user should be guest,
+                false to add a regular user account.
         Raises:
             StoreError if the user_id could not be registered.
         """
         yield self.runInteraction(
             "register",
-            self._register, user_id, token, password_hash
+            self._register, user_id, token, password_hash, was_guest, make_guest
         )
+        self.is_guest.invalidate((user_id,))
 
-    def _register(self, txn, user_id, token, password_hash):
+    def _register(self, txn, user_id, token, password_hash, was_guest, make_guest):
         now = int(self.clock.time())
 
         next_id = self._access_tokens_id_gen.get_next_txn(txn)
 
         try:
-            txn.execute("INSERT INTO users(name, password_hash, creation_ts) "
-                        "VALUES (?,?,?)",
-                        [user_id, password_hash, now])
+            if was_guest:
+                txn.execute("UPDATE users SET"
+                            " password_hash = ?,"
+                            " upgrade_ts = ?,"
+                            " is_guest = ?"
+                            " WHERE name = ?",
+                            [password_hash, now, 1 if make_guest else 0, user_id])
+            else:
+                txn.execute("INSERT INTO users "
+                            "(name, password_hash, creation_ts, is_guest) "
+                            "VALUES (?,?,?,?)",
+                            [user_id, password_hash, now, 1 if make_guest else 0])
         except self.database_engine.module.IntegrityError:
             raise StoreError(
                 400, "User ID already taken.", errcode=Codes.USER_IN_USE
@@ -117,8 +134,9 @@ class RegistrationStore(SQLBaseStore):
             keyvalues={
                 "name": user_id,
             },
-            retcols=["name", "password_hash"],
+            retcols=["name", "password_hash", "is_guest"],
             allow_none=True,
+            desc="get_user_by_id",
         )
 
     def get_users_by_id_case_insensitive(self, user_id):
@@ -240,9 +258,41 @@ class RegistrationStore(SQLBaseStore):
 
         defer.returnValue(res if res else False)
 
+    @cachedInlineCallbacks()
+    def is_guest(self, user_id):
+        res = yield self._simple_select_one_onecol(
+            table="users",
+            keyvalues={"name": user_id},
+            retcol="is_guest",
+            allow_none=True,
+            desc="is_guest",
+        )
+
+        defer.returnValue(res if res else False)
+
+    @cachedList(cache=is_guest.cache, list_name="user_ids", num_args=1,
+                inlineCallbacks=True)
+    def are_guests(self, user_ids):
+        sql = "SELECT name, is_guest FROM users WHERE name IN (%s)" % (
+            ",".join("?" for _ in user_ids),
+        )
+
+        rows = yield self._execute(
+            "are_guests", self.cursor_to_dict, sql, *user_ids
+        )
+
+        result = {user_id: False for user_id in user_ids}
+
+        result.update({
+            row["name"]: bool(row["is_guest"])
+            for row in rows
+        })
+
+        defer.returnValue(result)
+
     def _query_for_auth(self, txn, token):
         sql = (
-            "SELECT users.name, access_tokens.id as token_id"
+            "SELECT users.name, users.is_guest, access_tokens.id as token_id"
             " FROM users"
             " INNER JOIN access_tokens on users.name = access_tokens.user_id"
             " WHERE token = ?"
@@ -303,3 +353,37 @@ class RegistrationStore(SQLBaseStore):
 
         ret = yield self.runInteraction("count_users", _count_users)
         defer.returnValue(ret)
+
+    @defer.inlineCallbacks
+    def find_next_generated_user_id_localpart(self):
+        """
+        Gets the localpart of the next generated user ID.
+
+        Generated user IDs are integers, and we aim for them to be as small as
+        we can. Unfortunately, it's possible some of them are already taken by
+        existing users, and there may be gaps in the already taken range. This
+        function returns the start of the first allocatable gap. This is to
+        avoid the case of ID 10000000 being pre-allocated, so us wasting the
+        first (and shortest) many generated user IDs.
+        """
+        def _find_next_generated_user_id(txn):
+            txn.execute("SELECT name FROM users")
+            rows = self.cursor_to_dict(txn)
+
+            regex = re.compile("^@(\d+):")
+
+            found = set()
+
+            for r in rows:
+                user_id = r["name"]
+                match = regex.search(user_id)
+                if match:
+                    found.add(int(match.group(1)))
+            for i in xrange(len(found) + 1):
+                if i not in found:
+                    return i
+
+        defer.returnValue((yield self.runInteraction(
+            "find_next_generated_user_id",
+            _find_next_generated_user_id
+        )))