summary refs log tree commit diff
path: root/synapse/storage/registration.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/registration.py')
-rw-r--r--synapse/storage/registration.py102
1 files changed, 88 insertions, 14 deletions
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 586628579d..2e5eddd259 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -51,6 +51,28 @@ class RegistrationStore(SQLBaseStore):
         )
 
     @defer.inlineCallbacks
+    def add_refresh_token_to_user(self, user_id, token):
+        """Adds a refresh token for the given user.
+
+        Args:
+            user_id (str): The user ID.
+            token (str): The new refresh token to add.
+        Raises:
+            StoreError if there was a problem adding this.
+        """
+        next_id = yield self._refresh_tokens_id_gen.get_next()
+
+        yield self._simple_insert(
+            "refresh_tokens",
+            {
+                "id": next_id,
+                "user_id": user_id,
+                "token": token
+            },
+            desc="add_refresh_token_to_user",
+        )
+
+    @defer.inlineCallbacks
     def register(self, user_id, token, password_hash):
         """Attempts to register an account.
 
@@ -80,13 +102,14 @@ class RegistrationStore(SQLBaseStore):
                 400, "User ID already taken.", errcode=Codes.USER_IN_USE
             )
 
-        # it's possible for this to get a conflict, but only for a single user
-        # since tokens are namespaced based on their user ID
-        txn.execute(
-            "INSERT INTO access_tokens(id, user_id, token)"
-            " VALUES (?,?,?)",
-            (next_id, user_id, token,)
-        )
+        if token:
+            # it's possible for this to get a conflict, but only for a single user
+            # since tokens are namespaced based on their user ID
+            txn.execute(
+                "INSERT INTO access_tokens(id, user_id, token)"
+                " VALUES (?,?,?)",
+                (next_id, user_id, token,)
+            )
 
     def get_user_by_id(self, user_id):
         return self._simple_select_one(
@@ -146,26 +169,65 @@ class RegistrationStore(SQLBaseStore):
             user_id
         )
         for r in rows:
-            self.get_user_by_token.invalidate((r,))
+            self.get_user_by_access_token.invalidate((r,))
 
     @cached()
-    def get_user_by_token(self, token):
+    def get_user_by_access_token(self, token):
         """Get a user from the given access token.
 
         Args:
             token (str): The access token of a user.
         Returns:
-            dict: Including the name (user_id), device_id and whether they are
-                an admin.
+            dict: Including the name (user_id) and the ID of their access token.
         Raises:
             StoreError if no user was found.
         """
         return self.runInteraction(
-            "get_user_by_token",
+            "get_user_by_access_token",
             self._query_for_auth,
             token
         )
 
+    def exchange_refresh_token(self, refresh_token, token_generator):
+        """Exchange a refresh token for a new access token and refresh token.
+
+        Doing so invalidates the old refresh token - refresh tokens are single
+        use.
+
+        Args:
+            token (str): The refresh token of a user.
+            token_generator (fn: str -> str): Function which, when given a
+                user ID, returns a unique refresh token for that user. This
+                function must never return the same value twice.
+        Returns:
+            tuple of (user_id, refresh_token)
+        Raises:
+            StoreError if no user was found with that refresh token.
+        """
+        return self.runInteraction(
+            "exchange_refresh_token",
+            self._exchange_refresh_token,
+            refresh_token,
+            token_generator
+        )
+
+    def _exchange_refresh_token(self, txn, old_token, token_generator):
+        sql = "SELECT user_id FROM refresh_tokens WHERE token = ?"
+        txn.execute(sql, (old_token,))
+        rows = self.cursor_to_dict(txn)
+        if not rows:
+            raise StoreError(403, "Did not recognize refresh token")
+        user_id = rows[0]["user_id"]
+
+        # TODO(danielwh): Maybe perform a validation on the macaroon that
+        # macaroon.user_id == user_id.
+
+        new_token = token_generator(user_id)
+        sql = "UPDATE refresh_tokens SET token = ? WHERE token = ?"
+        txn.execute(sql, (new_token, old_token,))
+
+        return user_id, new_token
+
     @defer.inlineCallbacks
     def is_server_admin(self, user):
         res = yield self._simple_select_one_onecol(
@@ -180,8 +242,7 @@ class RegistrationStore(SQLBaseStore):
 
     def _query_for_auth(self, txn, token):
         sql = (
-            "SELECT users.name, users.admin,"
-            " access_tokens.device_id, access_tokens.id as token_id"
+            "SELECT users.name, access_tokens.id as token_id"
             " FROM users"
             " INNER JOIN access_tokens on users.name = access_tokens.user_id"
             " WHERE token = ?"
@@ -229,3 +290,16 @@ class RegistrationStore(SQLBaseStore):
         if ret:
             defer.returnValue(ret['user_id'])
         defer.returnValue(None)
+
+    @defer.inlineCallbacks
+    def count_all_users(self):
+        """Counts all users registered on the homeserver."""
+        def _count_users(txn):
+            txn.execute("SELECT COUNT(*) AS users FROM users")
+            rows = self.cursor_to_dict(txn)
+            if rows:
+                return rows[0]["users"]
+            return 0
+
+        ret = yield self.runInteraction("count_users", _count_users)
+        defer.returnValue(ret)