diff options
Diffstat (limited to 'synapse/storage/databases/main')
-rw-r--r-- | synapse/storage/databases/main/registration.py | 207 |
1 files changed, 203 insertions, 4 deletions
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index e5c5cf8ff0..e31c5864ac 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -53,6 +53,9 @@ class TokenLookupResult: valid_until_ms: The timestamp the token expires, if any. token_owner: The "owner" of the token. This is either the same as the user, or a server admin who is logged in as the user. + token_used: True if this token was used at least once in a request. + This field can be out of date since `get_user_by_access_token` is + cached. """ user_id = attr.ib(type=str) @@ -62,6 +65,7 @@ class TokenLookupResult: device_id = attr.ib(type=Optional[str], default=None) valid_until_ms = attr.ib(type=Optional[int], default=None) token_owner = attr.ib(type=str) + token_used = attr.ib(type=bool, default=False) # Make the token owner default to the user ID, which is the common case. @token_owner.default @@ -69,6 +73,29 @@ class TokenLookupResult: return self.user_id +@attr.s(frozen=True, slots=True) +class RefreshTokenLookupResult: + """Result of looking up a refresh token.""" + + user_id = attr.ib(type=str) + """The user this token belongs to.""" + + device_id = attr.ib(type=str) + """The device associated with this refresh token.""" + + token_id = attr.ib(type=int) + """The ID of this refresh token.""" + + next_token_id = attr.ib(type=Optional[int]) + """The ID of the refresh token which replaced this one.""" + + has_next_refresh_token_been_refreshed = attr.ib(type=bool) + """True if the next refresh token was used for another refresh.""" + + has_next_access_token_been_used = attr.ib(type=bool) + """True if the next access token was already used at least once.""" + + class RegistrationWorkerStore(CacheInvalidationWorkerStore): def __init__( self, @@ -441,7 +468,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): access_tokens.id as token_id, access_tokens.device_id, access_tokens.valid_until_ms, - access_tokens.user_id as token_owner + access_tokens.user_id as token_owner, + access_tokens.used as token_used FROM users INNER JOIN access_tokens on users.name = COALESCE(puppets_user_id, access_tokens.user_id) WHERE token = ? @@ -449,8 +477,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): txn.execute(sql, (token,)) rows = self.db_pool.cursor_to_dict(txn) + if rows: - return TokenLookupResult(**rows[0]) + row = rows[0] + + # This field is nullable, ensure it comes out as a boolean + if row["token_used"] is None: + row["token_used"] = False + + return TokenLookupResult(**row) return None @@ -1072,6 +1107,111 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): desc="update_access_token_last_validated", ) + @cached() + async def mark_access_token_as_used(self, token_id: int) -> None: + """ + Mark the access token as used, which invalidates the refresh token used + to obtain it. + + Because get_user_by_access_token is cached, this function might be + called multiple times for the same token, effectively doing unnecessary + SQL updates. Because updating the `used` field only goes one way (from + False to True) it is safe to cache this function as well to avoid this + issue. + + Args: + token_id: The ID of the access token to update. + Raises: + StoreError if there was a problem updating this. + """ + await self.db_pool.simple_update_one( + "access_tokens", + {"id": token_id}, + {"used": True}, + desc="mark_access_token_as_used", + ) + + async def lookup_refresh_token( + self, token: str + ) -> Optional[RefreshTokenLookupResult]: + """Lookup a refresh token with hints about its validity.""" + + def _lookup_refresh_token_txn(txn) -> Optional[RefreshTokenLookupResult]: + txn.execute( + """ + SELECT + rt.id token_id, + rt.user_id, + rt.device_id, + rt.next_token_id, + (nrt.next_token_id IS NOT NULL) has_next_refresh_token_been_refreshed, + at.used has_next_access_token_been_used + FROM refresh_tokens rt + LEFT JOIN refresh_tokens nrt ON rt.next_token_id = nrt.id + LEFT JOIN access_tokens at ON at.refresh_token_id = nrt.id + WHERE rt.token = ? + """, + (token,), + ) + row = txn.fetchone() + + if row is None: + return None + + return RefreshTokenLookupResult( + token_id=row[0], + user_id=row[1], + device_id=row[2], + next_token_id=row[3], + has_next_refresh_token_been_refreshed=row[4], + # This column is nullable, ensure it's a boolean + has_next_access_token_been_used=(row[5] or False), + ) + + return await self.db_pool.runInteraction( + "lookup_refresh_token", _lookup_refresh_token_txn + ) + + async def replace_refresh_token(self, token_id: int, next_token_id: int) -> None: + """ + Set the successor of a refresh token, removing the existing successor + if any. + + Args: + token_id: ID of the refresh token to update. + next_token_id: ID of its successor. + """ + + def _replace_refresh_token_txn(txn) -> None: + # First check if there was an existing refresh token + old_next_token_id = self.db_pool.simple_select_one_onecol_txn( + txn, + "refresh_tokens", + {"id": token_id}, + "next_token_id", + allow_none=True, + ) + + self.db_pool.simple_update_one_txn( + txn, + "refresh_tokens", + {"id": token_id}, + {"next_token_id": next_token_id}, + ) + + # Delete the old "next" token if it exists. This should cascade and + # delete the associated access_token + if old_next_token_id is not None: + self.db_pool.simple_delete_one_txn( + txn, + "refresh_tokens", + {"id": old_next_token_id}, + ) + + await self.db_pool.runInteraction( + "replace_refresh_token", _replace_refresh_token_txn + ) + class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): def __init__( @@ -1263,6 +1403,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") + self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id") async def add_access_token_to_user( self, @@ -1271,14 +1412,18 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): device_id: Optional[str], valid_until_ms: Optional[int], puppets_user_id: Optional[str] = None, + refresh_token_id: Optional[int] = None, ) -> int: """Adds an access token for the given user. Args: user_id: The user ID. token: The new access token to add. - device_id: ID of the device to associate with the access token + device_id: ID of the device to associate with the access token. valid_until_ms: when the token is valid until. None for no expiry. + puppets_user_id + refresh_token_id: ID of the refresh token generated alongside this + access token. Raises: StoreError if there was a problem adding this. Returns: @@ -1297,12 +1442,47 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): "valid_until_ms": valid_until_ms, "puppets_user_id": puppets_user_id, "last_validated": now, + "refresh_token_id": refresh_token_id, + "used": False, }, desc="add_access_token_to_user", ) return next_id + async def add_refresh_token_to_user( + self, + user_id: str, + token: str, + device_id: Optional[str], + ) -> int: + """Adds a refresh token for the given user. + + Args: + user_id: The user ID. + token: The new access token to add. + device_id: ID of the device to associate with the refresh token. + Raises: + StoreError if there was a problem adding this. + Returns: + The token ID + """ + next_id = self._refresh_tokens_id_gen.get_next() + + await self.db_pool.simple_insert( + "refresh_tokens", + { + "id": next_id, + "user_id": user_id, + "device_id": device_id, + "token": token, + "next_token_id": None, + }, + desc="add_refresh_token_to_user", + ) + + return next_id + def _set_device_for_access_token_txn(self, txn, token: str, device_id: str) -> str: old_device_id = self.db_pool.simple_select_one_onecol_txn( txn, "access_tokens", {"token": token}, "device_id" @@ -1545,7 +1725,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): device_id: Optional[str] = None, ) -> List[Tuple[str, int, Optional[str]]]: """ - Invalidate access tokens belonging to a user + Invalidate access and refresh tokens belonging to a user Args: user_id: ID of user the tokens belong to @@ -1565,7 +1745,13 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): items = keyvalues.items() where_clause = " AND ".join(k + " = ?" for k, _ in items) values = [v for _, v in items] # type: List[Union[str, int]] + # Conveniently, refresh_tokens and access_tokens both use the user_id and device_id fields. Only caveat + # is the `except_token_id` param that is tricky to get right, so for now we're just using the same where + # clause and values before we handle that. This seems to be only used in the "set password" handler. + refresh_where_clause = where_clause + refresh_values = values.copy() if except_token_id: + # TODO: support that for refresh tokens where_clause += " AND id != ?" values.append(except_token_id) @@ -1583,6 +1769,11 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): txn.execute("DELETE FROM access_tokens WHERE %s" % where_clause, values) + txn.execute( + "DELETE FROM refresh_tokens WHERE %s" % refresh_where_clause, + refresh_values, + ) + return tokens_and_devices return await self.db_pool.runInteraction("user_delete_access_tokens", f) @@ -1599,6 +1790,14 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): await self.db_pool.runInteraction("delete_access_token", f) + async def delete_refresh_token(self, refresh_token: str) -> None: + def f(txn): + self.db_pool.simple_delete_one_txn( + txn, table="refresh_tokens", keyvalues={"token": refresh_token} + ) + + await self.db_pool.runInteraction("delete_refresh_token", f) + async def add_user_pending_deactivation(self, user_id: str) -> None: """ Adds a user to the table of users who need to be parted from all the rooms they're |