summary refs log tree commit diff
path: root/synapse/storage/registration.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/registration.py')
-rw-r--r--synapse/storage/registration.py368
1 files changed, 204 insertions, 164 deletions
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 80d76bf9d7..9b9572890b 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -19,9 +19,11 @@ from six.moves import range
 
 from twisted.internet import defer
 
+from synapse.api.constants import UserTypes
 from synapse.api.errors import Codes, StoreError
 from synapse.storage import background_updates
 from synapse.storage._base import SQLBaseStore
+from synapse.types import UserID
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
 
 
@@ -112,6 +114,187 @@ class RegistrationWorkerStore(SQLBaseStore):
 
         return None
 
+    @cachedInlineCallbacks()
+    def is_support_user(self, user_id):
+        """Determines if the user is of type UserTypes.SUPPORT
+
+        Args:
+            user_id (str): user id to test
+
+        Returns:
+            Deferred[bool]: True if user is of type UserTypes.SUPPORT
+        """
+        res = yield self.runInteraction(
+            "is_support_user", self.is_support_user_txn, user_id
+        )
+        defer.returnValue(res)
+
+    def is_support_user_txn(self, txn, user_id):
+        res = self._simple_select_one_onecol_txn(
+            txn=txn,
+            table="users",
+            keyvalues={"name": user_id},
+            retcol="user_type",
+            allow_none=True,
+        )
+        return True if res == UserTypes.SUPPORT else False
+
+    def get_users_by_id_case_insensitive(self, user_id):
+        """Gets users that match user_id case insensitively.
+        Returns a mapping of user_id -> password_hash.
+        """
+        def f(txn):
+            sql = (
+                "SELECT name, password_hash FROM users"
+                " WHERE lower(name) = lower(?)"
+            )
+            txn.execute(sql, (user_id,))
+            return dict(txn)
+
+        return self.runInteraction("get_users_by_id_case_insensitive", f)
+
+    @defer.inlineCallbacks
+    def count_all_users(self):
+        """Counts all users registered on the homeserver."""
+        def _count_users(txn):
+            txn.execute("SELECT COUNT(*) AS users FROM users")
+            rows = self.cursor_to_dict(txn)
+            if rows:
+                return rows[0]["users"]
+            return 0
+
+        ret = yield self.runInteraction("count_users", _count_users)
+        defer.returnValue(ret)
+
+    def count_daily_user_type(self):
+        """
+        Counts 1) native non guest users
+               2) native guests users
+               3) bridged users
+        who registered on the homeserver in the past 24 hours
+        """
+        def _count_daily_user_type(txn):
+            yesterday = int(self._clock.time()) - (60 * 60 * 24)
+
+            sql = """
+                SELECT user_type, COALESCE(count(*), 0) AS count FROM (
+                    SELECT
+                    CASE
+                        WHEN is_guest=0 AND appservice_id IS NULL THEN 'native'
+                        WHEN is_guest=1 AND appservice_id IS NULL THEN 'guest'
+                        WHEN is_guest=0 AND appservice_id IS NOT NULL THEN 'bridged'
+                    END AS user_type
+                    FROM users
+                    WHERE creation_ts > ?
+                ) AS t GROUP BY user_type
+            """
+            results = {'native': 0, 'guest': 0, 'bridged': 0}
+            txn.execute(sql, (yesterday,))
+            for row in txn:
+                results[row[0]] = row[1]
+            return results
+        return self.runInteraction("count_daily_user_type", _count_daily_user_type)
+
+    @defer.inlineCallbacks
+    def count_nonbridged_users(self):
+        def _count_users(txn):
+            txn.execute("""
+                SELECT COALESCE(COUNT(*), 0) FROM users
+                WHERE appservice_id IS NULL
+            """)
+            count, = txn.fetchone()
+            return count
+
+        ret = yield self.runInteraction("count_users", _count_users)
+        defer.returnValue(ret)
+
+    @defer.inlineCallbacks
+    def find_next_generated_user_id_localpart(self):
+        """
+        Gets the localpart of the next generated user ID.
+
+        Generated user IDs are integers, and we aim for them to be as small as
+        we can. Unfortunately, it's possible some of them are already taken by
+        existing users, and there may be gaps in the already taken range. This
+        function returns the start of the first allocatable gap. This is to
+        avoid the case of ID 10000000 being pre-allocated, so us wasting the
+        first (and shortest) many generated user IDs.
+        """
+        def _find_next_generated_user_id(txn):
+            txn.execute("SELECT name FROM users")
+
+            regex = re.compile(r"^@(\d+):")
+
+            found = set()
+
+            for user_id, in txn:
+                match = regex.search(user_id)
+                if match:
+                    found.add(int(match.group(1)))
+            for i in range(len(found) + 1):
+                if i not in found:
+                    return i
+
+        defer.returnValue((yield self.runInteraction(
+            "find_next_generated_user_id",
+            _find_next_generated_user_id
+        )))
+
+    @defer.inlineCallbacks
+    def get_3pid_guest_access_token(self, medium, address):
+        ret = yield self._simple_select_one(
+            "threepid_guest_access_tokens",
+            {
+                "medium": medium,
+                "address": address
+            },
+            ["guest_access_token"], True, 'get_3pid_guest_access_token'
+        )
+        if ret:
+            defer.returnValue(ret["guest_access_token"])
+        defer.returnValue(None)
+
+    @defer.inlineCallbacks
+    def get_user_id_by_threepid(self, medium, address):
+        """Returns user id from threepid
+
+        Args:
+            medium (str): threepid medium e.g. email
+            address (str): threepid address e.g. me@example.com
+
+        Returns:
+            Deferred[str|None]: user id or None if no user id/threepid mapping exists
+        """
+        user_id = yield self.runInteraction(
+            "get_user_id_by_threepid", self.get_user_id_by_threepid_txn,
+            medium, address
+        )
+        defer.returnValue(user_id)
+
+    def get_user_id_by_threepid_txn(self, txn, medium, address):
+        """Returns user id from threepid
+
+        Args:
+            txn (cursor):
+            medium (str): threepid medium e.g. email
+            address (str): threepid address e.g. me@example.com
+
+        Returns:
+            str|None: user id or None if no user id/threepid mapping exists
+        """
+        ret = self._simple_select_one_txn(
+            txn,
+            "user_threepids",
+            {
+                "medium": medium,
+                "address": address
+            },
+            ['user_id'], True
+        )
+        if ret:
+            return ret['user_id']
+        return None
+
 
 class RegistrationStore(RegistrationWorkerStore,
                         background_updates.BackgroundUpdateStore):
@@ -167,7 +350,7 @@ class RegistrationStore(RegistrationWorkerStore,
 
     def register(self, user_id, token=None, password_hash=None,
                  was_guest=False, make_guest=False, appservice_id=None,
-                 create_profile_with_localpart=None, admin=False):
+                 create_profile_with_displayname=None, admin=False, user_type=None):
         """Attempts to register an account.
 
         Args:
@@ -181,8 +364,12 @@ class RegistrationStore(RegistrationWorkerStore,
             make_guest (boolean): True if the the new user should be guest,
                 false to add a regular user account.
             appservice_id (str): The ID of the appservice registering the user.
-            create_profile_with_localpart (str): Optionally create a profile for
-                the given localpart.
+            create_profile_with_displayname (unicode): Optionally create a profile for
+                the user, setting their displayname to the given value
+            admin (boolean): is an admin user?
+            user_type (str|None): type of user. One of the values from
+                api.constants.UserTypes, or None for a normal user.
+
         Raises:
             StoreError if the user_id could not be registered.
         """
@@ -195,8 +382,9 @@ class RegistrationStore(RegistrationWorkerStore,
             was_guest,
             make_guest,
             appservice_id,
-            create_profile_with_localpart,
-            admin
+            create_profile_with_displayname,
+            admin,
+            user_type
         )
 
     def _register(
@@ -208,9 +396,12 @@ class RegistrationStore(RegistrationWorkerStore,
         was_guest,
         make_guest,
         appservice_id,
-        create_profile_with_localpart,
+        create_profile_with_displayname,
         admin,
+        user_type,
     ):
+        user_id_obj = UserID.from_string(user_id)
+
         now = int(self.clock.time())
 
         next_id = self._access_tokens_id_gen.get_next()
@@ -244,6 +435,7 @@ class RegistrationStore(RegistrationWorkerStore,
                         "is_guest": 1 if make_guest else 0,
                         "appservice_id": appservice_id,
                         "admin": 1 if admin else 0,
+                        "user_type": user_type,
                     }
                 )
             else:
@@ -257,6 +449,7 @@ class RegistrationStore(RegistrationWorkerStore,
                         "is_guest": 1 if make_guest else 0,
                         "appservice_id": appservice_id,
                         "admin": 1 if admin else 0,
+                        "user_type": user_type,
                     }
                 )
         except self.database_engine.module.IntegrityError:
@@ -273,12 +466,15 @@ class RegistrationStore(RegistrationWorkerStore,
                 (next_id, user_id, token,)
             )
 
-        if create_profile_with_localpart:
+        if create_profile_with_displayname:
             # set a default displayname serverside to avoid ugly race
             # between auto-joins and clients trying to set displaynames
+            #
+            # *obviously* the 'profiles' table uses localpart for user_id
+            # while everything else uses the full mxid.
             txn.execute(
                 "INSERT INTO profiles(user_id, displayname) VALUES (?,?)",
-                (create_profile_with_localpart, create_profile_with_localpart)
+                (user_id_obj.localpart, create_profile_with_displayname)
             )
 
         self._invalidate_cache_and_stream(
@@ -286,20 +482,6 @@ class RegistrationStore(RegistrationWorkerStore,
         )
         txn.call_after(self.is_guest.invalidate, (user_id,))
 
-    def get_users_by_id_case_insensitive(self, user_id):
-        """Gets users that match user_id case insensitively.
-        Returns a mapping of user_id -> password_hash.
-        """
-        def f(txn):
-            sql = (
-                "SELECT name, password_hash FROM users"
-                " WHERE lower(name) = lower(?)"
-            )
-            txn.execute(sql, (user_id,))
-            return dict(txn)
-
-        return self.runInteraction("get_users_by_id_case_insensitive", f)
-
     def user_set_password_hash(self, user_id, password_hash):
         """
         NB. This does *not* evict any cache because the one use for this
@@ -472,47 +654,6 @@ class RegistrationStore(RegistrationWorkerStore,
         )
         defer.returnValue(ret)
 
-    @defer.inlineCallbacks
-    def get_user_id_by_threepid(self, medium, address):
-        """Returns user id from threepid
-
-        Args:
-            medium (str): threepid medium e.g. email
-            address (str): threepid address e.g. me@example.com
-
-        Returns:
-            Deferred[str|None]: user id or None if no user id/threepid mapping exists
-        """
-        user_id = yield self.runInteraction(
-            "get_user_id_by_threepid", self.get_user_id_by_threepid_txn,
-            medium, address
-        )
-        defer.returnValue(user_id)
-
-    def get_user_id_by_threepid_txn(self, txn, medium, address):
-        """Returns user id from threepid
-
-        Args:
-            txn (cursor):
-            medium (str): threepid medium e.g. email
-            address (str): threepid address e.g. me@example.com
-
-        Returns:
-            str|None: user id or None if no user id/threepid mapping exists
-        """
-        ret = self._simple_select_one_txn(
-            txn,
-            "user_threepids",
-            {
-                "medium": medium,
-                "address": address
-            },
-            ['user_id'], True
-        )
-        if ret:
-            return ret['user_id']
-        return None
-
     def user_delete_threepid(self, user_id, medium, address):
         return self._simple_delete(
             "user_threepids",
@@ -525,107 +666,6 @@ class RegistrationStore(RegistrationWorkerStore,
         )
 
     @defer.inlineCallbacks
-    def count_all_users(self):
-        """Counts all users registered on the homeserver."""
-        def _count_users(txn):
-            txn.execute("SELECT COUNT(*) AS users FROM users")
-            rows = self.cursor_to_dict(txn)
-            if rows:
-                return rows[0]["users"]
-            return 0
-
-        ret = yield self.runInteraction("count_users", _count_users)
-        defer.returnValue(ret)
-
-    def count_daily_user_type(self):
-        """
-        Counts 1) native non guest users
-               2) native guests users
-               3) bridged users
-        who registered on the homeserver in the past 24 hours
-        """
-        def _count_daily_user_type(txn):
-            yesterday = int(self._clock.time()) - (60 * 60 * 24)
-
-            sql = """
-                SELECT user_type, COALESCE(count(*), 0) AS count FROM (
-                    SELECT
-                    CASE
-                        WHEN is_guest=0 AND appservice_id IS NULL THEN 'native'
-                        WHEN is_guest=1 AND appservice_id IS NULL THEN 'guest'
-                        WHEN is_guest=0 AND appservice_id IS NOT NULL THEN 'bridged'
-                    END AS user_type
-                    FROM users
-                    WHERE creation_ts > ?
-                ) AS t GROUP BY user_type
-            """
-            results = {'native': 0, 'guest': 0, 'bridged': 0}
-            txn.execute(sql, (yesterday,))
-            for row in txn:
-                results[row[0]] = row[1]
-            return results
-        return self.runInteraction("count_daily_user_type", _count_daily_user_type)
-
-    @defer.inlineCallbacks
-    def count_nonbridged_users(self):
-        def _count_users(txn):
-            txn.execute("""
-                SELECT COALESCE(COUNT(*), 0) FROM users
-                WHERE appservice_id IS NULL
-            """)
-            count, = txn.fetchone()
-            return count
-
-        ret = yield self.runInteraction("count_users", _count_users)
-        defer.returnValue(ret)
-
-    @defer.inlineCallbacks
-    def find_next_generated_user_id_localpart(self):
-        """
-        Gets the localpart of the next generated user ID.
-
-        Generated user IDs are integers, and we aim for them to be as small as
-        we can. Unfortunately, it's possible some of them are already taken by
-        existing users, and there may be gaps in the already taken range. This
-        function returns the start of the first allocatable gap. This is to
-        avoid the case of ID 10000000 being pre-allocated, so us wasting the
-        first (and shortest) many generated user IDs.
-        """
-        def _find_next_generated_user_id(txn):
-            txn.execute("SELECT name FROM users")
-
-            regex = re.compile(r"^@(\d+):")
-
-            found = set()
-
-            for user_id, in txn:
-                match = regex.search(user_id)
-                if match:
-                    found.add(int(match.group(1)))
-            for i in range(len(found) + 1):
-                if i not in found:
-                    return i
-
-        defer.returnValue((yield self.runInteraction(
-            "find_next_generated_user_id",
-            _find_next_generated_user_id
-        )))
-
-    @defer.inlineCallbacks
-    def get_3pid_guest_access_token(self, medium, address):
-        ret = yield self._simple_select_one(
-            "threepid_guest_access_tokens",
-            {
-                "medium": medium,
-                "address": address
-            },
-            ["guest_access_token"], True, 'get_3pid_guest_access_token'
-        )
-        if ret:
-            defer.returnValue(ret["guest_access_token"])
-        defer.returnValue(None)
-
-    @defer.inlineCallbacks
     def save_or_get_3pid_guest_access_token(
             self, medium, address, access_token, inviter_user_id
     ):