diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index 5cef9c5c17..f1e4bdea76 100644
--- a/tests/rest/client/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -328,16 +328,49 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
self.assertIsNotNone(session_id)
+ def test_password_reset_redirection(self) -> None:
+ """Test basic password reset flow"""
+ old_password = "monkey"
+
+ user_id = self.register_user("kermit", old_password)
+ self.login("kermit", old_password)
+
+ email = "test@example.com"
+
+ # Add a threepid
+ self.get_success(
+ self.store.user_add_threepid(
+ user_id=user_id,
+ medium="email",
+ address=email,
+ validated_at=0,
+ added_at=0,
+ )
+ )
+
+ client_secret = "foobar"
+ next_link = "http://example.com"
+ self._request_token(email, client_secret, "127.0.0.1", next_link)
+
+ self.assertEqual(len(self.email_attempts), 1)
+ link = self._get_link_from_email()
+
+ self._validate_token(link, next_link)
+
def _request_token(
self,
email: str,
client_secret: str,
ip: str = "127.0.0.1",
+ next_link: Optional[str] = None,
) -> str:
+ body = {"client_secret": client_secret, "email": email, "send_attempt": 1}
+ if next_link is not None:
+ body["next_link"] = next_link
channel = self.make_request(
"POST",
b"account/password/email/requestToken",
- {"client_secret": client_secret, "email": email, "send_attempt": 1},
+ body,
client_ip=ip,
)
@@ -350,7 +383,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
return channel.json_body["sid"]
- def _validate_token(self, link: str) -> None:
+ def _validate_token(self, link: str, next_link: Optional[str] = None) -> None:
# Remove the host
path = link.replace("https://example.com", "")
@@ -378,7 +411,11 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
shorthand=False,
content_is_form=True,
)
- self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+ self.assertEqual(
+ HTTPStatus.OK if next_link is None else HTTPStatus.FOUND,
+ channel.code,
+ channel.result,
+ )
def _get_link_from_email(self) -> str:
assert self.email_attempts, "No emails have been sent"
|