summary refs log tree commit diff
path: root/synapse/storage/registration.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/registration.py')
-rw-r--r--synapse/storage/registration.py338
1 files changed, 182 insertions, 156 deletions
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index bd4eb88a92..983a8ec52b 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -18,25 +18,40 @@ import re
 from twisted.internet import defer
 
 from synapse.api.errors import StoreError, Codes
+from synapse.storage import background_updates
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
 
-from ._base import SQLBaseStore
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
 
-
-class RegistrationStore(SQLBaseStore):
+class RegistrationStore(background_updates.BackgroundUpdateStore):
 
     def __init__(self, hs):
         super(RegistrationStore, self).__init__(hs)
 
         self.clock = hs.get_clock()
 
+        self.register_background_index_update(
+            "access_tokens_device_index",
+            index_name="access_tokens_device_id",
+            table="access_tokens",
+            columns=["user_id", "device_id"],
+        )
+
+        self.register_background_index_update(
+            "refresh_tokens_device_index",
+            index_name="refresh_tokens_device_id",
+            table="refresh_tokens",
+            columns=["user_id", "device_id"],
+        )
+
     @defer.inlineCallbacks
-    def add_access_token_to_user(self, user_id, token):
+    def add_access_token_to_user(self, user_id, token, device_id=None):
         """Adds an access token for the given user.
 
         Args:
             user_id (str): The user ID.
             token (str): The new access token to add.
+            device_id (str): ID of the device to associate with the access
+               token
         Raises:
             StoreError if there was a problem adding this.
         """
@@ -47,51 +62,34 @@ class RegistrationStore(SQLBaseStore):
             {
                 "id": next_id,
                 "user_id": user_id,
-                "token": token
+                "token": token,
+                "device_id": device_id,
             },
             desc="add_access_token_to_user",
         )
 
-    @defer.inlineCallbacks
-    def add_refresh_token_to_user(self, user_id, token):
-        """Adds a refresh token for the given user.
-
-        Args:
-            user_id (str): The user ID.
-            token (str): The new refresh token to add.
-        Raises:
-            StoreError if there was a problem adding this.
-        """
-        next_id = self._refresh_tokens_id_gen.get_next()
-
-        yield self._simple_insert(
-            "refresh_tokens",
-            {
-                "id": next_id,
-                "user_id": user_id,
-                "token": token
-            },
-            desc="add_refresh_token_to_user",
-        )
-
-    @defer.inlineCallbacks
-    def register(self, user_id, token, password_hash,
-                 was_guest=False, make_guest=False, appservice_id=None):
+    def register(self, user_id, token=None, password_hash=None,
+                 was_guest=False, make_guest=False, appservice_id=None,
+                 create_profile_with_localpart=None, admin=False):
         """Attempts to register an account.
 
         Args:
             user_id (str): The desired user ID to register.
-            token (str): The desired access token to use for this user.
+            token (str): The desired access token to use for this user. If this
+                is not None, the given access token is associated with the user
+                id.
             password_hash (str): Optional. The password hash for this user.
             was_guest (bool): Optional. Whether this is a guest account being
                 upgraded to a non-guest account.
             make_guest (boolean): True if the the new user should be guest,
                 false to add a regular user account.
             appservice_id (str): The ID of the appservice registering the user.
+            create_profile_with_localpart (str): Optionally create a profile for
+                the given localpart.
         Raises:
             StoreError if the user_id could not be registered.
         """
-        yield self.runInteraction(
+        return self.runInteraction(
             "register",
             self._register,
             user_id,
@@ -99,9 +97,10 @@ class RegistrationStore(SQLBaseStore):
             password_hash,
             was_guest,
             make_guest,
-            appservice_id
+            appservice_id,
+            create_profile_with_localpart,
+            admin
         )
-        self.is_guest.invalidate((user_id,))
 
     def _register(
         self,
@@ -111,7 +110,9 @@ class RegistrationStore(SQLBaseStore):
         password_hash,
         was_guest,
         make_guest,
-        appservice_id
+        appservice_id,
+        create_profile_with_localpart,
+        admin,
     ):
         now = int(self.clock.time())
 
@@ -119,29 +120,48 @@ class RegistrationStore(SQLBaseStore):
 
         try:
             if was_guest:
-                txn.execute("UPDATE users SET"
-                            " password_hash = ?,"
-                            " upgrade_ts = ?,"
-                            " is_guest = ?"
-                            " WHERE name = ?",
-                            [password_hash, now, 1 if make_guest else 0, user_id])
+                # Ensure that the guest user actually exists
+                # ``allow_none=False`` makes this raise an exception
+                # if the row isn't in the database.
+                self._simple_select_one_txn(
+                    txn,
+                    "users",
+                    keyvalues={
+                        "name": user_id,
+                        "is_guest": 1,
+                    },
+                    retcols=("name",),
+                    allow_none=False,
+                )
+
+                self._simple_update_one_txn(
+                    txn,
+                    "users",
+                    keyvalues={
+                        "name": user_id,
+                        "is_guest": 1,
+                    },
+                    updatevalues={
+                        "password_hash": password_hash,
+                        "upgrade_ts": now,
+                        "is_guest": 1 if make_guest else 0,
+                        "appservice_id": appservice_id,
+                        "admin": 1 if admin else 0,
+                    }
+                )
             else:
-                txn.execute("INSERT INTO users "
-                            "("
-                            "   name,"
-                            "   password_hash,"
-                            "   creation_ts,"
-                            "   is_guest,"
-                            "   appservice_id"
-                            ") "
-                            "VALUES (?,?,?,?,?)",
-                            [
-                                user_id,
-                                password_hash,
-                                now,
-                                1 if make_guest else 0,
-                                appservice_id,
-                            ])
+                self._simple_insert_txn(
+                    txn,
+                    "users",
+                    values={
+                        "name": user_id,
+                        "password_hash": password_hash,
+                        "creation_ts": now,
+                        "is_guest": 1 if make_guest else 0,
+                        "appservice_id": appservice_id,
+                        "admin": 1 if admin else 0,
+                    }
+                )
         except self.database_engine.module.IntegrityError:
             raise StoreError(
                 400, "User ID already taken.", errcode=Codes.USER_IN_USE
@@ -156,6 +176,18 @@ class RegistrationStore(SQLBaseStore):
                 (next_id, user_id, token,)
             )
 
+        if create_profile_with_localpart:
+            txn.execute(
+                "INSERT INTO profiles(user_id) VALUES (?)",
+                (create_profile_with_localpart,)
+            )
+
+        self._invalidate_cache_and_stream(
+            txn, self.get_user_by_id, (user_id,)
+        )
+        txn.call_after(self.is_guest.invalidate, (user_id,))
+
+    @cached()
     def get_user_by_id(self, user_id):
         return self._simple_select_one(
             table="users",
@@ -181,48 +213,88 @@ class RegistrationStore(SQLBaseStore):
 
         return self.runInteraction("get_users_by_id_case_insensitive", f)
 
-    @defer.inlineCallbacks
     def user_set_password_hash(self, user_id, password_hash):
         """
         NB. This does *not* evict any cache because the one use for this
             removes most of the entries subsequently anyway so it would be
             pointless. Use flush_user separately.
         """
-        yield self._simple_update_one('users', {
-            'name': user_id
-        }, {
-            'password_hash': password_hash
-        })
+        def user_set_password_hash_txn(txn):
+            self._simple_update_one_txn(
+                txn,
+                'users', {
+                    'name': user_id
+                },
+                {
+                    'password_hash': password_hash
+                }
+            )
+            self._invalidate_cache_and_stream(
+                txn, self.get_user_by_id, (user_id,)
+            )
+        return self.runInteraction(
+            "user_set_password_hash", user_set_password_hash_txn
+        )
 
     @defer.inlineCallbacks
-    def user_delete_access_tokens(self, user_id, except_token_ids=[]):
-        def f(txn):
-            sql = "SELECT token FROM access_tokens WHERE user_id = ?"
-            clauses = [user_id]
+    def user_delete_access_tokens(self, user_id, except_token_id=None,
+                                  device_id=None,
+                                  delete_refresh_tokens=False):
+        """
+        Invalidate access/refresh tokens belonging to a user
 
-            if except_token_ids:
-                sql += " AND id NOT IN (%s)" % (
-                    ",".join(["?" for _ in except_token_ids]),
+        Args:
+            user_id (str):  ID of user the tokens belong to
+            except_token_id (str): list of access_tokens IDs which should
+                *not* be deleted
+            device_id (str|None):  ID of device the tokens are associated with.
+                If None, tokens associated with any device (or no device) will
+                be deleted
+            delete_refresh_tokens (bool):  True to delete refresh tokens as
+                well as access tokens.
+        Returns:
+            defer.Deferred:
+        """
+        def f(txn):
+            keyvalues = {
+                "user_id": user_id,
+            }
+            if device_id is not None:
+                keyvalues["device_id"] = device_id
+
+            if delete_refresh_tokens:
+                self._simple_delete_txn(
+                    txn,
+                    table="refresh_tokens",
+                    keyvalues=keyvalues,
                 )
-                clauses += except_token_ids
 
-            txn.execute(sql, clauses)
+            items = keyvalues.items()
+            where_clause = " AND ".join(k + " = ?" for k, _ in items)
+            values = [v for _, v in items]
+            if except_token_id:
+                where_clause += " AND id != ?"
+                values.append(except_token_id)
 
-            rows = txn.fetchall()
-
-            n = 100
-            chunks = [rows[i:i + n] for i in xrange(0, len(rows), n)]
-            for chunk in chunks:
-                for row in chunk:
-                    txn.call_after(self.get_user_by_access_token.invalidate, (row[0],))
+            txn.execute(
+                "SELECT token FROM access_tokens WHERE %s" % where_clause,
+                values
+            )
+            rows = self.cursor_to_dict(txn)
 
-                txn.execute(
-                    "DELETE FROM access_tokens WHERE token in (%s)" % (
-                        ",".join(["?" for _ in chunk]),
-                    ), [r[0] for r in chunk]
+            for row in rows:
+                self._invalidate_cache_and_stream(
+                    txn, self.get_user_by_access_token, (row["token"],)
                 )
 
-        yield self.runInteraction("user_delete_access_tokens", f)
+            txn.execute(
+                "DELETE FROM access_tokens WHERE %s" % where_clause,
+                values
+            )
+
+        yield self.runInteraction(
+            "user_delete_access_tokens", f,
+        )
 
     def delete_access_token(self, access_token):
         def f(txn):
@@ -234,7 +306,9 @@ class RegistrationStore(SQLBaseStore):
                 },
             )
 
-            txn.call_after(self.get_user_by_access_token.invalidate, (access_token,))
+            self._invalidate_cache_and_stream(
+                txn, self.get_user_by_access_token, (access_token,)
+            )
 
         return self.runInteraction("delete_access_token", f)
 
@@ -245,9 +319,8 @@ class RegistrationStore(SQLBaseStore):
         Args:
             token (str): The access token of a user.
         Returns:
-            dict: Including the name (user_id) and the ID of their access token.
-        Raises:
-            StoreError if no user was found.
+            defer.Deferred: None, if the token did not match, otherwise dict
+                including the keys `name`, `is_guest`, `device_id`, `token_id`.
         """
         return self.runInteraction(
             "get_user_by_access_token",
@@ -255,46 +328,6 @@ class RegistrationStore(SQLBaseStore):
             token
         )
 
-    def exchange_refresh_token(self, refresh_token, token_generator):
-        """Exchange a refresh token for a new access token and refresh token.
-
-        Doing so invalidates the old refresh token - refresh tokens are single
-        use.
-
-        Args:
-            token (str): The refresh token of a user.
-            token_generator (fn: str -> str): Function which, when given a
-                user ID, returns a unique refresh token for that user. This
-                function must never return the same value twice.
-        Returns:
-            tuple of (user_id, refresh_token)
-        Raises:
-            StoreError if no user was found with that refresh token.
-        """
-        return self.runInteraction(
-            "exchange_refresh_token",
-            self._exchange_refresh_token,
-            refresh_token,
-            token_generator
-        )
-
-    def _exchange_refresh_token(self, txn, old_token, token_generator):
-        sql = "SELECT user_id FROM refresh_tokens WHERE token = ?"
-        txn.execute(sql, (old_token,))
-        rows = self.cursor_to_dict(txn)
-        if not rows:
-            raise StoreError(403, "Did not recognize refresh token")
-        user_id = rows[0]["user_id"]
-
-        # TODO(danielwh): Maybe perform a validation on the macaroon that
-        # macaroon.user_id == user_id.
-
-        new_token = token_generator(user_id)
-        sql = "UPDATE refresh_tokens SET token = ? WHERE token = ?"
-        txn.execute(sql, (new_token, old_token,))
-
-        return user_id, new_token
-
     @defer.inlineCallbacks
     def is_server_admin(self, user):
         res = yield self._simple_select_one_onecol(
@@ -319,29 +352,10 @@ class RegistrationStore(SQLBaseStore):
 
         defer.returnValue(res if res else False)
 
-    @cachedList(cache=is_guest.cache, list_name="user_ids", num_args=1,
-                inlineCallbacks=True)
-    def are_guests(self, user_ids):
-        sql = "SELECT name, is_guest FROM users WHERE name IN (%s)" % (
-            ",".join("?" for _ in user_ids),
-        )
-
-        rows = yield self._execute(
-            "are_guests", self.cursor_to_dict, sql, *user_ids
-        )
-
-        result = {user_id: False for user_id in user_ids}
-
-        result.update({
-            row["name"]: bool(row["is_guest"])
-            for row in rows
-        })
-
-        defer.returnValue(result)
-
     def _query_for_auth(self, txn, token):
         sql = (
-            "SELECT users.name, users.is_guest, access_tokens.id as token_id"
+            "SELECT users.name, users.is_guest, access_tokens.id as token_id,"
+            " access_tokens.device_id"
             " FROM users"
             " INNER JOIN access_tokens on users.name = access_tokens.user_id"
             " WHERE token = ?"
@@ -390,6 +404,15 @@ class RegistrationStore(SQLBaseStore):
             defer.returnValue(ret['user_id'])
         defer.returnValue(None)
 
+    def user_delete_threepids(self, user_id):
+        return self._simple_delete(
+            "user_threepids",
+            keyvalues={
+                "user_id": user_id,
+            },
+            desc="user_delete_threepids",
+        )
+
     @defer.inlineCallbacks
     def count_all_users(self):
         """Counts all users registered on the homeserver."""
@@ -458,12 +481,15 @@ class RegistrationStore(SQLBaseStore):
         """
         Gets the 3pid's guest access token if exists, else saves access_token.
 
-        :param medium (str): Medium of the 3pid. Must be "email".
-        :param address (str): 3pid address.
-        :param access_token (str): The access token to persist if none is
-            already persisted.
-        :param inviter_user_id (str): User ID of the inviter.
-        :return (deferred str): Whichever access token is persisted at the end
+        Args:
+            medium (str): Medium of the 3pid. Must be "email".
+            address (str): 3pid address.
+            access_token (str): The access token to persist if none is
+                already persisted.
+            inviter_user_id (str): User ID of the inviter.
+
+        Returns:
+            deferred str: Whichever access token is persisted at the end
             of this function call.
         """
         def insert(txn):