diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 80d76bf9d7..9b9572890b 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -19,9 +19,11 @@ from six.moves import range
from twisted.internet import defer
+from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError
from synapse.storage import background_updates
from synapse.storage._base import SQLBaseStore
+from synapse.types import UserID
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
@@ -112,6 +114,187 @@ class RegistrationWorkerStore(SQLBaseStore):
return None
+ @cachedInlineCallbacks()
+ def is_support_user(self, user_id):
+ """Determines if the user is of type UserTypes.SUPPORT
+
+ Args:
+ user_id (str): user id to test
+
+ Returns:
+ Deferred[bool]: True if user is of type UserTypes.SUPPORT
+ """
+ res = yield self.runInteraction(
+ "is_support_user", self.is_support_user_txn, user_id
+ )
+ defer.returnValue(res)
+
+ def is_support_user_txn(self, txn, user_id):
+ res = self._simple_select_one_onecol_txn(
+ txn=txn,
+ table="users",
+ keyvalues={"name": user_id},
+ retcol="user_type",
+ allow_none=True,
+ )
+ return True if res == UserTypes.SUPPORT else False
+
+ def get_users_by_id_case_insensitive(self, user_id):
+ """Gets users that match user_id case insensitively.
+ Returns a mapping of user_id -> password_hash.
+ """
+ def f(txn):
+ sql = (
+ "SELECT name, password_hash FROM users"
+ " WHERE lower(name) = lower(?)"
+ )
+ txn.execute(sql, (user_id,))
+ return dict(txn)
+
+ return self.runInteraction("get_users_by_id_case_insensitive", f)
+
+ @defer.inlineCallbacks
+ def count_all_users(self):
+ """Counts all users registered on the homeserver."""
+ def _count_users(txn):
+ txn.execute("SELECT COUNT(*) AS users FROM users")
+ rows = self.cursor_to_dict(txn)
+ if rows:
+ return rows[0]["users"]
+ return 0
+
+ ret = yield self.runInteraction("count_users", _count_users)
+ defer.returnValue(ret)
+
+ def count_daily_user_type(self):
+ """
+ Counts 1) native non guest users
+ 2) native guests users
+ 3) bridged users
+ who registered on the homeserver in the past 24 hours
+ """
+ def _count_daily_user_type(txn):
+ yesterday = int(self._clock.time()) - (60 * 60 * 24)
+
+ sql = """
+ SELECT user_type, COALESCE(count(*), 0) AS count FROM (
+ SELECT
+ CASE
+ WHEN is_guest=0 AND appservice_id IS NULL THEN 'native'
+ WHEN is_guest=1 AND appservice_id IS NULL THEN 'guest'
+ WHEN is_guest=0 AND appservice_id IS NOT NULL THEN 'bridged'
+ END AS user_type
+ FROM users
+ WHERE creation_ts > ?
+ ) AS t GROUP BY user_type
+ """
+ results = {'native': 0, 'guest': 0, 'bridged': 0}
+ txn.execute(sql, (yesterday,))
+ for row in txn:
+ results[row[0]] = row[1]
+ return results
+ return self.runInteraction("count_daily_user_type", _count_daily_user_type)
+
+ @defer.inlineCallbacks
+ def count_nonbridged_users(self):
+ def _count_users(txn):
+ txn.execute("""
+ SELECT COALESCE(COUNT(*), 0) FROM users
+ WHERE appservice_id IS NULL
+ """)
+ count, = txn.fetchone()
+ return count
+
+ ret = yield self.runInteraction("count_users", _count_users)
+ defer.returnValue(ret)
+
+ @defer.inlineCallbacks
+ def find_next_generated_user_id_localpart(self):
+ """
+ Gets the localpart of the next generated user ID.
+
+ Generated user IDs are integers, and we aim for them to be as small as
+ we can. Unfortunately, it's possible some of them are already taken by
+ existing users, and there may be gaps in the already taken range. This
+ function returns the start of the first allocatable gap. This is to
+ avoid the case of ID 10000000 being pre-allocated, so us wasting the
+ first (and shortest) many generated user IDs.
+ """
+ def _find_next_generated_user_id(txn):
+ txn.execute("SELECT name FROM users")
+
+ regex = re.compile(r"^@(\d+):")
+
+ found = set()
+
+ for user_id, in txn:
+ match = regex.search(user_id)
+ if match:
+ found.add(int(match.group(1)))
+ for i in range(len(found) + 1):
+ if i not in found:
+ return i
+
+ defer.returnValue((yield self.runInteraction(
+ "find_next_generated_user_id",
+ _find_next_generated_user_id
+ )))
+
+ @defer.inlineCallbacks
+ def get_3pid_guest_access_token(self, medium, address):
+ ret = yield self._simple_select_one(
+ "threepid_guest_access_tokens",
+ {
+ "medium": medium,
+ "address": address
+ },
+ ["guest_access_token"], True, 'get_3pid_guest_access_token'
+ )
+ if ret:
+ defer.returnValue(ret["guest_access_token"])
+ defer.returnValue(None)
+
+ @defer.inlineCallbacks
+ def get_user_id_by_threepid(self, medium, address):
+ """Returns user id from threepid
+
+ Args:
+ medium (str): threepid medium e.g. email
+ address (str): threepid address e.g. me@example.com
+
+ Returns:
+ Deferred[str|None]: user id or None if no user id/threepid mapping exists
+ """
+ user_id = yield self.runInteraction(
+ "get_user_id_by_threepid", self.get_user_id_by_threepid_txn,
+ medium, address
+ )
+ defer.returnValue(user_id)
+
+ def get_user_id_by_threepid_txn(self, txn, medium, address):
+ """Returns user id from threepid
+
+ Args:
+ txn (cursor):
+ medium (str): threepid medium e.g. email
+ address (str): threepid address e.g. me@example.com
+
+ Returns:
+ str|None: user id or None if no user id/threepid mapping exists
+ """
+ ret = self._simple_select_one_txn(
+ txn,
+ "user_threepids",
+ {
+ "medium": medium,
+ "address": address
+ },
+ ['user_id'], True
+ )
+ if ret:
+ return ret['user_id']
+ return None
+
class RegistrationStore(RegistrationWorkerStore,
background_updates.BackgroundUpdateStore):
@@ -167,7 +350,7 @@ class RegistrationStore(RegistrationWorkerStore,
def register(self, user_id, token=None, password_hash=None,
was_guest=False, make_guest=False, appservice_id=None,
- create_profile_with_localpart=None, admin=False):
+ create_profile_with_displayname=None, admin=False, user_type=None):
"""Attempts to register an account.
Args:
@@ -181,8 +364,12 @@ class RegistrationStore(RegistrationWorkerStore,
make_guest (boolean): True if the the new user should be guest,
false to add a regular user account.
appservice_id (str): The ID of the appservice registering the user.
- create_profile_with_localpart (str): Optionally create a profile for
- the given localpart.
+ create_profile_with_displayname (unicode): Optionally create a profile for
+ the user, setting their displayname to the given value
+ admin (boolean): is an admin user?
+ user_type (str|None): type of user. One of the values from
+ api.constants.UserTypes, or None for a normal user.
+
Raises:
StoreError if the user_id could not be registered.
"""
@@ -195,8 +382,9 @@ class RegistrationStore(RegistrationWorkerStore,
was_guest,
make_guest,
appservice_id,
- create_profile_with_localpart,
- admin
+ create_profile_with_displayname,
+ admin,
+ user_type
)
def _register(
@@ -208,9 +396,12 @@ class RegistrationStore(RegistrationWorkerStore,
was_guest,
make_guest,
appservice_id,
- create_profile_with_localpart,
+ create_profile_with_displayname,
admin,
+ user_type,
):
+ user_id_obj = UserID.from_string(user_id)
+
now = int(self.clock.time())
next_id = self._access_tokens_id_gen.get_next()
@@ -244,6 +435,7 @@ class RegistrationStore(RegistrationWorkerStore,
"is_guest": 1 if make_guest else 0,
"appservice_id": appservice_id,
"admin": 1 if admin else 0,
+ "user_type": user_type,
}
)
else:
@@ -257,6 +449,7 @@ class RegistrationStore(RegistrationWorkerStore,
"is_guest": 1 if make_guest else 0,
"appservice_id": appservice_id,
"admin": 1 if admin else 0,
+ "user_type": user_type,
}
)
except self.database_engine.module.IntegrityError:
@@ -273,12 +466,15 @@ class RegistrationStore(RegistrationWorkerStore,
(next_id, user_id, token,)
)
- if create_profile_with_localpart:
+ if create_profile_with_displayname:
# set a default displayname serverside to avoid ugly race
# between auto-joins and clients trying to set displaynames
+ #
+ # *obviously* the 'profiles' table uses localpart for user_id
+ # while everything else uses the full mxid.
txn.execute(
"INSERT INTO profiles(user_id, displayname) VALUES (?,?)",
- (create_profile_with_localpart, create_profile_with_localpart)
+ (user_id_obj.localpart, create_profile_with_displayname)
)
self._invalidate_cache_and_stream(
@@ -286,20 +482,6 @@ class RegistrationStore(RegistrationWorkerStore,
)
txn.call_after(self.is_guest.invalidate, (user_id,))
- def get_users_by_id_case_insensitive(self, user_id):
- """Gets users that match user_id case insensitively.
- Returns a mapping of user_id -> password_hash.
- """
- def f(txn):
- sql = (
- "SELECT name, password_hash FROM users"
- " WHERE lower(name) = lower(?)"
- )
- txn.execute(sql, (user_id,))
- return dict(txn)
-
- return self.runInteraction("get_users_by_id_case_insensitive", f)
-
def user_set_password_hash(self, user_id, password_hash):
"""
NB. This does *not* evict any cache because the one use for this
@@ -472,47 +654,6 @@ class RegistrationStore(RegistrationWorkerStore,
)
defer.returnValue(ret)
- @defer.inlineCallbacks
- def get_user_id_by_threepid(self, medium, address):
- """Returns user id from threepid
-
- Args:
- medium (str): threepid medium e.g. email
- address (str): threepid address e.g. me@example.com
-
- Returns:
- Deferred[str|None]: user id or None if no user id/threepid mapping exists
- """
- user_id = yield self.runInteraction(
- "get_user_id_by_threepid", self.get_user_id_by_threepid_txn,
- medium, address
- )
- defer.returnValue(user_id)
-
- def get_user_id_by_threepid_txn(self, txn, medium, address):
- """Returns user id from threepid
-
- Args:
- txn (cursor):
- medium (str): threepid medium e.g. email
- address (str): threepid address e.g. me@example.com
-
- Returns:
- str|None: user id or None if no user id/threepid mapping exists
- """
- ret = self._simple_select_one_txn(
- txn,
- "user_threepids",
- {
- "medium": medium,
- "address": address
- },
- ['user_id'], True
- )
- if ret:
- return ret['user_id']
- return None
-
def user_delete_threepid(self, user_id, medium, address):
return self._simple_delete(
"user_threepids",
@@ -525,107 +666,6 @@ class RegistrationStore(RegistrationWorkerStore,
)
@defer.inlineCallbacks
- def count_all_users(self):
- """Counts all users registered on the homeserver."""
- def _count_users(txn):
- txn.execute("SELECT COUNT(*) AS users FROM users")
- rows = self.cursor_to_dict(txn)
- if rows:
- return rows[0]["users"]
- return 0
-
- ret = yield self.runInteraction("count_users", _count_users)
- defer.returnValue(ret)
-
- def count_daily_user_type(self):
- """
- Counts 1) native non guest users
- 2) native guests users
- 3) bridged users
- who registered on the homeserver in the past 24 hours
- """
- def _count_daily_user_type(txn):
- yesterday = int(self._clock.time()) - (60 * 60 * 24)
-
- sql = """
- SELECT user_type, COALESCE(count(*), 0) AS count FROM (
- SELECT
- CASE
- WHEN is_guest=0 AND appservice_id IS NULL THEN 'native'
- WHEN is_guest=1 AND appservice_id IS NULL THEN 'guest'
- WHEN is_guest=0 AND appservice_id IS NOT NULL THEN 'bridged'
- END AS user_type
- FROM users
- WHERE creation_ts > ?
- ) AS t GROUP BY user_type
- """
- results = {'native': 0, 'guest': 0, 'bridged': 0}
- txn.execute(sql, (yesterday,))
- for row in txn:
- results[row[0]] = row[1]
- return results
- return self.runInteraction("count_daily_user_type", _count_daily_user_type)
-
- @defer.inlineCallbacks
- def count_nonbridged_users(self):
- def _count_users(txn):
- txn.execute("""
- SELECT COALESCE(COUNT(*), 0) FROM users
- WHERE appservice_id IS NULL
- """)
- count, = txn.fetchone()
- return count
-
- ret = yield self.runInteraction("count_users", _count_users)
- defer.returnValue(ret)
-
- @defer.inlineCallbacks
- def find_next_generated_user_id_localpart(self):
- """
- Gets the localpart of the next generated user ID.
-
- Generated user IDs are integers, and we aim for them to be as small as
- we can. Unfortunately, it's possible some of them are already taken by
- existing users, and there may be gaps in the already taken range. This
- function returns the start of the first allocatable gap. This is to
- avoid the case of ID 10000000 being pre-allocated, so us wasting the
- first (and shortest) many generated user IDs.
- """
- def _find_next_generated_user_id(txn):
- txn.execute("SELECT name FROM users")
-
- regex = re.compile(r"^@(\d+):")
-
- found = set()
-
- for user_id, in txn:
- match = regex.search(user_id)
- if match:
- found.add(int(match.group(1)))
- for i in range(len(found) + 1):
- if i not in found:
- return i
-
- defer.returnValue((yield self.runInteraction(
- "find_next_generated_user_id",
- _find_next_generated_user_id
- )))
-
- @defer.inlineCallbacks
- def get_3pid_guest_access_token(self, medium, address):
- ret = yield self._simple_select_one(
- "threepid_guest_access_tokens",
- {
- "medium": medium,
- "address": address
- },
- ["guest_access_token"], True, 'get_3pid_guest_access_token'
- )
- if ret:
- defer.returnValue(ret["guest_access_token"])
- defer.returnValue(None)
-
- @defer.inlineCallbacks
def save_or_get_3pid_guest_access_token(
self, medium, address, access_token, inviter_user_id
):
|