summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
authorQuentin Gliech <quentingliech@gmail.com>2021-06-24 15:33:20 +0200
committerGitHub <noreply@github.com>2021-06-24 14:33:20 +0100
commitbd4919fb72b2a75f1c0a7f0c78bd619fd2ae30e8 (patch)
tree04a988e47720e9c58c99f05b74121e03ebe1f5f4 /synapse/storage/databases
parentMerge tag 'v1.37.0rc1' into develop (diff)
downloadsynapse-bd4919fb72b2a75f1c0a7f0c78bd619fd2ae30e8.tar.xz
MSC2918 Refresh tokens implementation (#9450)
This implements refresh tokens, as defined by MSC2918

This MSC has been implemented client side in Hydrogen Web: vector-im/hydrogen-web#235

The basics of the MSC works: requesting refresh tokens on login, having the access tokens expire, and using the refresh token to get a new one.

Signed-off-by: Quentin Gliech <quentingliech@gmail.com>
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/registration.py207
1 files changed, 203 insertions, 4 deletions
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index e5c5cf8ff0..e31c5864ac 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -53,6 +53,9 @@ class TokenLookupResult:
         valid_until_ms: The timestamp the token expires, if any.
         token_owner: The "owner" of the token. This is either the same as the
             user, or a server admin who is logged in as the user.
+        token_used: True if this token was used at least once in a request.
+            This field can be out of date since `get_user_by_access_token` is
+            cached.
     """
 
     user_id = attr.ib(type=str)
@@ -62,6 +65,7 @@ class TokenLookupResult:
     device_id = attr.ib(type=Optional[str], default=None)
     valid_until_ms = attr.ib(type=Optional[int], default=None)
     token_owner = attr.ib(type=str)
+    token_used = attr.ib(type=bool, default=False)
 
     # Make the token owner default to the user ID, which is the common case.
     @token_owner.default
@@ -69,6 +73,29 @@ class TokenLookupResult:
         return self.user_id
 
 
+@attr.s(frozen=True, slots=True)
+class RefreshTokenLookupResult:
+    """Result of looking up a refresh token."""
+
+    user_id = attr.ib(type=str)
+    """The user this token belongs to."""
+
+    device_id = attr.ib(type=str)
+    """The device associated with this refresh token."""
+
+    token_id = attr.ib(type=int)
+    """The ID of this refresh token."""
+
+    next_token_id = attr.ib(type=Optional[int])
+    """The ID of the refresh token which replaced this one."""
+
+    has_next_refresh_token_been_refreshed = attr.ib(type=bool)
+    """True if the next refresh token was used for another refresh."""
+
+    has_next_access_token_been_used = attr.ib(type=bool)
+    """True if the next access token was already used at least once."""
+
+
 class RegistrationWorkerStore(CacheInvalidationWorkerStore):
     def __init__(
         self,
@@ -441,7 +468,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
                 access_tokens.id as token_id,
                 access_tokens.device_id,
                 access_tokens.valid_until_ms,
-                access_tokens.user_id as token_owner
+                access_tokens.user_id as token_owner,
+                access_tokens.used as token_used
             FROM users
             INNER JOIN access_tokens on users.name = COALESCE(puppets_user_id, access_tokens.user_id)
             WHERE token = ?
@@ -449,8 +477,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
 
         txn.execute(sql, (token,))
         rows = self.db_pool.cursor_to_dict(txn)
+
         if rows:
-            return TokenLookupResult(**rows[0])
+            row = rows[0]
+
+            # This field is nullable, ensure it comes out as a boolean
+            if row["token_used"] is None:
+                row["token_used"] = False
+
+            return TokenLookupResult(**row)
 
         return None
 
@@ -1072,6 +1107,111 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             desc="update_access_token_last_validated",
         )
 
+    @cached()
+    async def mark_access_token_as_used(self, token_id: int) -> None:
+        """
+        Mark the access token as used, which invalidates the refresh token used
+        to obtain it.
+
+        Because get_user_by_access_token is cached, this function might be
+        called multiple times for the same token, effectively doing unnecessary
+        SQL updates. Because updating the `used` field only goes one way (from
+        False to True) it is safe to cache this function as well to avoid this
+        issue.
+
+        Args:
+            token_id: The ID of the access token to update.
+        Raises:
+            StoreError if there was a problem updating this.
+        """
+        await self.db_pool.simple_update_one(
+            "access_tokens",
+            {"id": token_id},
+            {"used": True},
+            desc="mark_access_token_as_used",
+        )
+
+    async def lookup_refresh_token(
+        self, token: str
+    ) -> Optional[RefreshTokenLookupResult]:
+        """Lookup a refresh token with hints about its validity."""
+
+        def _lookup_refresh_token_txn(txn) -> Optional[RefreshTokenLookupResult]:
+            txn.execute(
+                """
+                SELECT
+                    rt.id token_id,
+                    rt.user_id,
+                    rt.device_id,
+                    rt.next_token_id,
+                    (nrt.next_token_id IS NOT NULL) has_next_refresh_token_been_refreshed,
+                    at.used has_next_access_token_been_used
+                FROM refresh_tokens rt
+                LEFT JOIN refresh_tokens nrt ON rt.next_token_id = nrt.id
+                LEFT JOIN access_tokens at ON at.refresh_token_id = nrt.id
+                WHERE rt.token = ?
+            """,
+                (token,),
+            )
+            row = txn.fetchone()
+
+            if row is None:
+                return None
+
+            return RefreshTokenLookupResult(
+                token_id=row[0],
+                user_id=row[1],
+                device_id=row[2],
+                next_token_id=row[3],
+                has_next_refresh_token_been_refreshed=row[4],
+                # This column is nullable, ensure it's a boolean
+                has_next_access_token_been_used=(row[5] or False),
+            )
+
+        return await self.db_pool.runInteraction(
+            "lookup_refresh_token", _lookup_refresh_token_txn
+        )
+
+    async def replace_refresh_token(self, token_id: int, next_token_id: int) -> None:
+        """
+        Set the successor of a refresh token, removing the existing successor
+        if any.
+
+        Args:
+            token_id: ID of the refresh token to update.
+            next_token_id: ID of its successor.
+        """
+
+        def _replace_refresh_token_txn(txn) -> None:
+            # First check if there was an existing refresh token
+            old_next_token_id = self.db_pool.simple_select_one_onecol_txn(
+                txn,
+                "refresh_tokens",
+                {"id": token_id},
+                "next_token_id",
+                allow_none=True,
+            )
+
+            self.db_pool.simple_update_one_txn(
+                txn,
+                "refresh_tokens",
+                {"id": token_id},
+                {"next_token_id": next_token_id},
+            )
+
+            # Delete the old "next" token if it exists. This should cascade and
+            # delete the associated access_token
+            if old_next_token_id is not None:
+                self.db_pool.simple_delete_one_txn(
+                    txn,
+                    "refresh_tokens",
+                    {"id": old_next_token_id},
+                )
+
+        await self.db_pool.runInteraction(
+            "replace_refresh_token", _replace_refresh_token_txn
+        )
+
 
 class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
     def __init__(
@@ -1263,6 +1403,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
         self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors
 
         self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
+        self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id")
 
     async def add_access_token_to_user(
         self,
@@ -1271,14 +1412,18 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
         device_id: Optional[str],
         valid_until_ms: Optional[int],
         puppets_user_id: Optional[str] = None,
+        refresh_token_id: Optional[int] = None,
     ) -> int:
         """Adds an access token for the given user.
 
         Args:
             user_id: The user ID.
             token: The new access token to add.
-            device_id: ID of the device to associate with the access token
+            device_id: ID of the device to associate with the access token.
             valid_until_ms: when the token is valid until. None for no expiry.
+            puppets_user_id
+            refresh_token_id: ID of the refresh token generated alongside this
+                access token.
         Raises:
             StoreError if there was a problem adding this.
         Returns:
@@ -1297,12 +1442,47 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
                 "valid_until_ms": valid_until_ms,
                 "puppets_user_id": puppets_user_id,
                 "last_validated": now,
+                "refresh_token_id": refresh_token_id,
+                "used": False,
             },
             desc="add_access_token_to_user",
         )
 
         return next_id
 
+    async def add_refresh_token_to_user(
+        self,
+        user_id: str,
+        token: str,
+        device_id: Optional[str],
+    ) -> int:
+        """Adds a refresh token for the given user.
+
+        Args:
+            user_id: The user ID.
+            token: The new access token to add.
+            device_id: ID of the device to associate with the refresh token.
+        Raises:
+            StoreError if there was a problem adding this.
+        Returns:
+            The token ID
+        """
+        next_id = self._refresh_tokens_id_gen.get_next()
+
+        await self.db_pool.simple_insert(
+            "refresh_tokens",
+            {
+                "id": next_id,
+                "user_id": user_id,
+                "device_id": device_id,
+                "token": token,
+                "next_token_id": None,
+            },
+            desc="add_refresh_token_to_user",
+        )
+
+        return next_id
+
     def _set_device_for_access_token_txn(self, txn, token: str, device_id: str) -> str:
         old_device_id = self.db_pool.simple_select_one_onecol_txn(
             txn, "access_tokens", {"token": token}, "device_id"
@@ -1545,7 +1725,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
         device_id: Optional[str] = None,
     ) -> List[Tuple[str, int, Optional[str]]]:
         """
-        Invalidate access tokens belonging to a user
+        Invalidate access and refresh tokens belonging to a user
 
         Args:
             user_id: ID of user the tokens belong to
@@ -1565,7 +1745,13 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
             items = keyvalues.items()
             where_clause = " AND ".join(k + " = ?" for k, _ in items)
             values = [v for _, v in items]  # type: List[Union[str, int]]
+            # Conveniently, refresh_tokens and access_tokens both use the user_id and device_id fields. Only caveat
+            # is the `except_token_id` param that is tricky to get right, so for now we're just using the same where
+            # clause and values before we handle that. This seems to be only used in the "set password" handler.
+            refresh_where_clause = where_clause
+            refresh_values = values.copy()
             if except_token_id:
+                # TODO: support that for refresh tokens
                 where_clause += " AND id != ?"
                 values.append(except_token_id)
 
@@ -1583,6 +1769,11 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
 
             txn.execute("DELETE FROM access_tokens WHERE %s" % where_clause, values)
 
+            txn.execute(
+                "DELETE FROM refresh_tokens WHERE %s" % refresh_where_clause,
+                refresh_values,
+            )
+
             return tokens_and_devices
 
         return await self.db_pool.runInteraction("user_delete_access_tokens", f)
@@ -1599,6 +1790,14 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
 
         await self.db_pool.runInteraction("delete_access_token", f)
 
+    async def delete_refresh_token(self, refresh_token: str) -> None:
+        def f(txn):
+            self.db_pool.simple_delete_one_txn(
+                txn, table="refresh_tokens", keyvalues={"token": refresh_token}
+            )
+
+        await self.db_pool.runInteraction("delete_refresh_token", f)
+
     async def add_user_pending_deactivation(self, user_id: str) -> None:
         """
         Adds a user to the table of users who need to be parted from all the rooms they're