diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/rest/client/test_account.py | 43 |
1 files changed, 40 insertions, 3 deletions
diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index 5cef9c5c17..f1e4bdea76 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -328,16 +328,49 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): self.assertIsNotNone(session_id) + def test_password_reset_redirection(self) -> None: + """Test basic password reset flow""" + old_password = "monkey" + + user_id = self.register_user("kermit", old_password) + self.login("kermit", old_password) + + email = "test@example.com" + + # Add a threepid + self.get_success( + self.store.user_add_threepid( + user_id=user_id, + medium="email", + address=email, + validated_at=0, + added_at=0, + ) + ) + + client_secret = "foobar" + next_link = "http://example.com" + self._request_token(email, client_secret, "127.0.0.1", next_link) + + self.assertEqual(len(self.email_attempts), 1) + link = self._get_link_from_email() + + self._validate_token(link, next_link) + def _request_token( self, email: str, client_secret: str, ip: str = "127.0.0.1", + next_link: Optional[str] = None, ) -> str: + body = {"client_secret": client_secret, "email": email, "send_attempt": 1} + if next_link is not None: + body["next_link"] = next_link channel = self.make_request( "POST", b"account/password/email/requestToken", - {"client_secret": client_secret, "email": email, "send_attempt": 1}, + body, client_ip=ip, ) @@ -350,7 +383,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): return channel.json_body["sid"] - def _validate_token(self, link: str) -> None: + def _validate_token(self, link: str, next_link: Optional[str] = None) -> None: # Remove the host path = link.replace("https://example.com", "") @@ -378,7 +411,11 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): shorthand=False, content_is_form=True, ) - self.assertEqual(HTTPStatus.OK, channel.code, channel.result) + self.assertEqual( + HTTPStatus.OK if next_link is None else HTTPStatus.FOUND, + channel.code, + channel.result, + ) def _get_link_from_email(self) -> str: assert self.email_attempts, "No emails have been sent" |