diff options
Diffstat (limited to 'tests/rest/client/test_account.py')
-rw-r--r-- | tests/rest/client/test_account.py | 69 |
1 files changed, 56 insertions, 13 deletions
diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index bd59bb50cf..f1e4bdea76 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -1,16 +1,22 @@ -# Copyright 2022 The Matrix.org Foundation C.I.C. # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2023 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# <https://www.gnu.org/licenses/agpl-3.0.html>. +# +# Originally licensed under the Apache License, Version 2.0: +# <http://www.apache.org/licenses/LICENSE-2.0>. +# +# [This file includes modifications made by New Vector Limited] # -# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. import os import re from email.parser import Parser @@ -322,16 +328,49 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): self.assertIsNotNone(session_id) + def test_password_reset_redirection(self) -> None: + """Test basic password reset flow""" + old_password = "monkey" + + user_id = self.register_user("kermit", old_password) + self.login("kermit", old_password) + + email = "test@example.com" + + # Add a threepid + self.get_success( + self.store.user_add_threepid( + user_id=user_id, + medium="email", + address=email, + validated_at=0, + added_at=0, + ) + ) + + client_secret = "foobar" + next_link = "http://example.com" + self._request_token(email, client_secret, "127.0.0.1", next_link) + + self.assertEqual(len(self.email_attempts), 1) + link = self._get_link_from_email() + + self._validate_token(link, next_link) + def _request_token( self, email: str, client_secret: str, ip: str = "127.0.0.1", + next_link: Optional[str] = None, ) -> str: + body = {"client_secret": client_secret, "email": email, "send_attempt": 1} + if next_link is not None: + body["next_link"] = next_link channel = self.make_request( "POST", b"account/password/email/requestToken", - {"client_secret": client_secret, "email": email, "send_attempt": 1}, + body, client_ip=ip, ) @@ -344,7 +383,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): return channel.json_body["sid"] - def _validate_token(self, link: str) -> None: + def _validate_token(self, link: str, next_link: Optional[str] = None) -> None: # Remove the host path = link.replace("https://example.com", "") @@ -372,7 +411,11 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): shorthand=False, content_is_form=True, ) - self.assertEqual(HTTPStatus.OK, channel.code, channel.result) + self.assertEqual( + HTTPStatus.OK if next_link is None else HTTPStatus.FOUND, + channel.code, + channel.result, + ) def _get_link_from_email(self) -> str: assert self.email_attempts, "No emails have been sent" |