diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index bd4eb88a92..983a8ec52b 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -18,25 +18,40 @@ import re
from twisted.internet import defer
from synapse.api.errors import StoreError, Codes
+from synapse.storage import background_updates
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
-from ._base import SQLBaseStore
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
-
-class RegistrationStore(SQLBaseStore):
+class RegistrationStore(background_updates.BackgroundUpdateStore):
def __init__(self, hs):
super(RegistrationStore, self).__init__(hs)
self.clock = hs.get_clock()
+ self.register_background_index_update(
+ "access_tokens_device_index",
+ index_name="access_tokens_device_id",
+ table="access_tokens",
+ columns=["user_id", "device_id"],
+ )
+
+ self.register_background_index_update(
+ "refresh_tokens_device_index",
+ index_name="refresh_tokens_device_id",
+ table="refresh_tokens",
+ columns=["user_id", "device_id"],
+ )
+
@defer.inlineCallbacks
- def add_access_token_to_user(self, user_id, token):
+ def add_access_token_to_user(self, user_id, token, device_id=None):
"""Adds an access token for the given user.
Args:
user_id (str): The user ID.
token (str): The new access token to add.
+ device_id (str): ID of the device to associate with the access
+ token
Raises:
StoreError if there was a problem adding this.
"""
@@ -47,51 +62,34 @@ class RegistrationStore(SQLBaseStore):
{
"id": next_id,
"user_id": user_id,
- "token": token
+ "token": token,
+ "device_id": device_id,
},
desc="add_access_token_to_user",
)
- @defer.inlineCallbacks
- def add_refresh_token_to_user(self, user_id, token):
- """Adds a refresh token for the given user.
-
- Args:
- user_id (str): The user ID.
- token (str): The new refresh token to add.
- Raises:
- StoreError if there was a problem adding this.
- """
- next_id = self._refresh_tokens_id_gen.get_next()
-
- yield self._simple_insert(
- "refresh_tokens",
- {
- "id": next_id,
- "user_id": user_id,
- "token": token
- },
- desc="add_refresh_token_to_user",
- )
-
- @defer.inlineCallbacks
- def register(self, user_id, token, password_hash,
- was_guest=False, make_guest=False, appservice_id=None):
+ def register(self, user_id, token=None, password_hash=None,
+ was_guest=False, make_guest=False, appservice_id=None,
+ create_profile_with_localpart=None, admin=False):
"""Attempts to register an account.
Args:
user_id (str): The desired user ID to register.
- token (str): The desired access token to use for this user.
+ token (str): The desired access token to use for this user. If this
+ is not None, the given access token is associated with the user
+ id.
password_hash (str): Optional. The password hash for this user.
was_guest (bool): Optional. Whether this is a guest account being
upgraded to a non-guest account.
make_guest (boolean): True if the the new user should be guest,
false to add a regular user account.
appservice_id (str): The ID of the appservice registering the user.
+ create_profile_with_localpart (str): Optionally create a profile for
+ the given localpart.
Raises:
StoreError if the user_id could not be registered.
"""
- yield self.runInteraction(
+ return self.runInteraction(
"register",
self._register,
user_id,
@@ -99,9 +97,10 @@ class RegistrationStore(SQLBaseStore):
password_hash,
was_guest,
make_guest,
- appservice_id
+ appservice_id,
+ create_profile_with_localpart,
+ admin
)
- self.is_guest.invalidate((user_id,))
def _register(
self,
@@ -111,7 +110,9 @@ class RegistrationStore(SQLBaseStore):
password_hash,
was_guest,
make_guest,
- appservice_id
+ appservice_id,
+ create_profile_with_localpart,
+ admin,
):
now = int(self.clock.time())
@@ -119,29 +120,48 @@ class RegistrationStore(SQLBaseStore):
try:
if was_guest:
- txn.execute("UPDATE users SET"
- " password_hash = ?,"
- " upgrade_ts = ?,"
- " is_guest = ?"
- " WHERE name = ?",
- [password_hash, now, 1 if make_guest else 0, user_id])
+ # Ensure that the guest user actually exists
+ # ``allow_none=False`` makes this raise an exception
+ # if the row isn't in the database.
+ self._simple_select_one_txn(
+ txn,
+ "users",
+ keyvalues={
+ "name": user_id,
+ "is_guest": 1,
+ },
+ retcols=("name",),
+ allow_none=False,
+ )
+
+ self._simple_update_one_txn(
+ txn,
+ "users",
+ keyvalues={
+ "name": user_id,
+ "is_guest": 1,
+ },
+ updatevalues={
+ "password_hash": password_hash,
+ "upgrade_ts": now,
+ "is_guest": 1 if make_guest else 0,
+ "appservice_id": appservice_id,
+ "admin": 1 if admin else 0,
+ }
+ )
else:
- txn.execute("INSERT INTO users "
- "("
- " name,"
- " password_hash,"
- " creation_ts,"
- " is_guest,"
- " appservice_id"
- ") "
- "VALUES (?,?,?,?,?)",
- [
- user_id,
- password_hash,
- now,
- 1 if make_guest else 0,
- appservice_id,
- ])
+ self._simple_insert_txn(
+ txn,
+ "users",
+ values={
+ "name": user_id,
+ "password_hash": password_hash,
+ "creation_ts": now,
+ "is_guest": 1 if make_guest else 0,
+ "appservice_id": appservice_id,
+ "admin": 1 if admin else 0,
+ }
+ )
except self.database_engine.module.IntegrityError:
raise StoreError(
400, "User ID already taken.", errcode=Codes.USER_IN_USE
@@ -156,6 +176,18 @@ class RegistrationStore(SQLBaseStore):
(next_id, user_id, token,)
)
+ if create_profile_with_localpart:
+ txn.execute(
+ "INSERT INTO profiles(user_id) VALUES (?)",
+ (create_profile_with_localpart,)
+ )
+
+ self._invalidate_cache_and_stream(
+ txn, self.get_user_by_id, (user_id,)
+ )
+ txn.call_after(self.is_guest.invalidate, (user_id,))
+
+ @cached()
def get_user_by_id(self, user_id):
return self._simple_select_one(
table="users",
@@ -181,48 +213,88 @@ class RegistrationStore(SQLBaseStore):
return self.runInteraction("get_users_by_id_case_insensitive", f)
- @defer.inlineCallbacks
def user_set_password_hash(self, user_id, password_hash):
"""
NB. This does *not* evict any cache because the one use for this
removes most of the entries subsequently anyway so it would be
pointless. Use flush_user separately.
"""
- yield self._simple_update_one('users', {
- 'name': user_id
- }, {
- 'password_hash': password_hash
- })
+ def user_set_password_hash_txn(txn):
+ self._simple_update_one_txn(
+ txn,
+ 'users', {
+ 'name': user_id
+ },
+ {
+ 'password_hash': password_hash
+ }
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_user_by_id, (user_id,)
+ )
+ return self.runInteraction(
+ "user_set_password_hash", user_set_password_hash_txn
+ )
@defer.inlineCallbacks
- def user_delete_access_tokens(self, user_id, except_token_ids=[]):
- def f(txn):
- sql = "SELECT token FROM access_tokens WHERE user_id = ?"
- clauses = [user_id]
+ def user_delete_access_tokens(self, user_id, except_token_id=None,
+ device_id=None,
+ delete_refresh_tokens=False):
+ """
+ Invalidate access/refresh tokens belonging to a user
- if except_token_ids:
- sql += " AND id NOT IN (%s)" % (
- ",".join(["?" for _ in except_token_ids]),
+ Args:
+ user_id (str): ID of user the tokens belong to
+ except_token_id (str): list of access_tokens IDs which should
+ *not* be deleted
+ device_id (str|None): ID of device the tokens are associated with.
+ If None, tokens associated with any device (or no device) will
+ be deleted
+ delete_refresh_tokens (bool): True to delete refresh tokens as
+ well as access tokens.
+ Returns:
+ defer.Deferred:
+ """
+ def f(txn):
+ keyvalues = {
+ "user_id": user_id,
+ }
+ if device_id is not None:
+ keyvalues["device_id"] = device_id
+
+ if delete_refresh_tokens:
+ self._simple_delete_txn(
+ txn,
+ table="refresh_tokens",
+ keyvalues=keyvalues,
)
- clauses += except_token_ids
- txn.execute(sql, clauses)
+ items = keyvalues.items()
+ where_clause = " AND ".join(k + " = ?" for k, _ in items)
+ values = [v for _, v in items]
+ if except_token_id:
+ where_clause += " AND id != ?"
+ values.append(except_token_id)
- rows = txn.fetchall()
-
- n = 100
- chunks = [rows[i:i + n] for i in xrange(0, len(rows), n)]
- for chunk in chunks:
- for row in chunk:
- txn.call_after(self.get_user_by_access_token.invalidate, (row[0],))
+ txn.execute(
+ "SELECT token FROM access_tokens WHERE %s" % where_clause,
+ values
+ )
+ rows = self.cursor_to_dict(txn)
- txn.execute(
- "DELETE FROM access_tokens WHERE token in (%s)" % (
- ",".join(["?" for _ in chunk]),
- ), [r[0] for r in chunk]
+ for row in rows:
+ self._invalidate_cache_and_stream(
+ txn, self.get_user_by_access_token, (row["token"],)
)
- yield self.runInteraction("user_delete_access_tokens", f)
+ txn.execute(
+ "DELETE FROM access_tokens WHERE %s" % where_clause,
+ values
+ )
+
+ yield self.runInteraction(
+ "user_delete_access_tokens", f,
+ )
def delete_access_token(self, access_token):
def f(txn):
@@ -234,7 +306,9 @@ class RegistrationStore(SQLBaseStore):
},
)
- txn.call_after(self.get_user_by_access_token.invalidate, (access_token,))
+ self._invalidate_cache_and_stream(
+ txn, self.get_user_by_access_token, (access_token,)
+ )
return self.runInteraction("delete_access_token", f)
@@ -245,9 +319,8 @@ class RegistrationStore(SQLBaseStore):
Args:
token (str): The access token of a user.
Returns:
- dict: Including the name (user_id) and the ID of their access token.
- Raises:
- StoreError if no user was found.
+ defer.Deferred: None, if the token did not match, otherwise dict
+ including the keys `name`, `is_guest`, `device_id`, `token_id`.
"""
return self.runInteraction(
"get_user_by_access_token",
@@ -255,46 +328,6 @@ class RegistrationStore(SQLBaseStore):
token
)
- def exchange_refresh_token(self, refresh_token, token_generator):
- """Exchange a refresh token for a new access token and refresh token.
-
- Doing so invalidates the old refresh token - refresh tokens are single
- use.
-
- Args:
- token (str): The refresh token of a user.
- token_generator (fn: str -> str): Function which, when given a
- user ID, returns a unique refresh token for that user. This
- function must never return the same value twice.
- Returns:
- tuple of (user_id, refresh_token)
- Raises:
- StoreError if no user was found with that refresh token.
- """
- return self.runInteraction(
- "exchange_refresh_token",
- self._exchange_refresh_token,
- refresh_token,
- token_generator
- )
-
- def _exchange_refresh_token(self, txn, old_token, token_generator):
- sql = "SELECT user_id FROM refresh_tokens WHERE token = ?"
- txn.execute(sql, (old_token,))
- rows = self.cursor_to_dict(txn)
- if not rows:
- raise StoreError(403, "Did not recognize refresh token")
- user_id = rows[0]["user_id"]
-
- # TODO(danielwh): Maybe perform a validation on the macaroon that
- # macaroon.user_id == user_id.
-
- new_token = token_generator(user_id)
- sql = "UPDATE refresh_tokens SET token = ? WHERE token = ?"
- txn.execute(sql, (new_token, old_token,))
-
- return user_id, new_token
-
@defer.inlineCallbacks
def is_server_admin(self, user):
res = yield self._simple_select_one_onecol(
@@ -319,29 +352,10 @@ class RegistrationStore(SQLBaseStore):
defer.returnValue(res if res else False)
- @cachedList(cache=is_guest.cache, list_name="user_ids", num_args=1,
- inlineCallbacks=True)
- def are_guests(self, user_ids):
- sql = "SELECT name, is_guest FROM users WHERE name IN (%s)" % (
- ",".join("?" for _ in user_ids),
- )
-
- rows = yield self._execute(
- "are_guests", self.cursor_to_dict, sql, *user_ids
- )
-
- result = {user_id: False for user_id in user_ids}
-
- result.update({
- row["name"]: bool(row["is_guest"])
- for row in rows
- })
-
- defer.returnValue(result)
-
def _query_for_auth(self, txn, token):
sql = (
- "SELECT users.name, users.is_guest, access_tokens.id as token_id"
+ "SELECT users.name, users.is_guest, access_tokens.id as token_id,"
+ " access_tokens.device_id"
" FROM users"
" INNER JOIN access_tokens on users.name = access_tokens.user_id"
" WHERE token = ?"
@@ -390,6 +404,15 @@ class RegistrationStore(SQLBaseStore):
defer.returnValue(ret['user_id'])
defer.returnValue(None)
+ def user_delete_threepids(self, user_id):
+ return self._simple_delete(
+ "user_threepids",
+ keyvalues={
+ "user_id": user_id,
+ },
+ desc="user_delete_threepids",
+ )
+
@defer.inlineCallbacks
def count_all_users(self):
"""Counts all users registered on the homeserver."""
@@ -458,12 +481,15 @@ class RegistrationStore(SQLBaseStore):
"""
Gets the 3pid's guest access token if exists, else saves access_token.
- :param medium (str): Medium of the 3pid. Must be "email".
- :param address (str): 3pid address.
- :param access_token (str): The access token to persist if none is
- already persisted.
- :param inviter_user_id (str): User ID of the inviter.
- :return (deferred str): Whichever access token is persisted at the end
+ Args:
+ medium (str): Medium of the 3pid. Must be "email".
+ address (str): 3pid address.
+ access_token (str): The access token to persist if none is
+ already persisted.
+ inviter_user_id (str): User ID of the inviter.
+
+ Returns:
+ deferred str: Whichever access token is persisted at the end
of this function call.
"""
def insert(txn):
|