summary refs log tree commit diff
path: root/tests/rest/client/test_account.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest/client/test_account.py')
-rw-r--r--tests/rest/client/test_account.py43
1 files changed, 40 insertions, 3 deletions
diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index 5cef9c5c17..f1e4bdea76 100644
--- a/tests/rest/client/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -328,16 +328,49 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
 
         self.assertIsNotNone(session_id)
 
+    def test_password_reset_redirection(self) -> None:
+        """Test basic password reset flow"""
+        old_password = "monkey"
+
+        user_id = self.register_user("kermit", old_password)
+        self.login("kermit", old_password)
+
+        email = "test@example.com"
+
+        # Add a threepid
+        self.get_success(
+            self.store.user_add_threepid(
+                user_id=user_id,
+                medium="email",
+                address=email,
+                validated_at=0,
+                added_at=0,
+            )
+        )
+
+        client_secret = "foobar"
+        next_link = "http://example.com"
+        self._request_token(email, client_secret, "127.0.0.1", next_link)
+
+        self.assertEqual(len(self.email_attempts), 1)
+        link = self._get_link_from_email()
+
+        self._validate_token(link, next_link)
+
     def _request_token(
         self,
         email: str,
         client_secret: str,
         ip: str = "127.0.0.1",
+        next_link: Optional[str] = None,
     ) -> str:
+        body = {"client_secret": client_secret, "email": email, "send_attempt": 1}
+        if next_link is not None:
+            body["next_link"] = next_link
         channel = self.make_request(
             "POST",
             b"account/password/email/requestToken",
-            {"client_secret": client_secret, "email": email, "send_attempt": 1},
+            body,
             client_ip=ip,
         )
 
@@ -350,7 +383,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
 
         return channel.json_body["sid"]
 
-    def _validate_token(self, link: str) -> None:
+    def _validate_token(self, link: str, next_link: Optional[str] = None) -> None:
         # Remove the host
         path = link.replace("https://example.com", "")
 
@@ -378,7 +411,11 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
             shorthand=False,
             content_is_form=True,
         )
-        self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+        self.assertEqual(
+            HTTPStatus.OK if next_link is None else HTTPStatus.FOUND,
+            channel.code,
+            channel.result,
+        )
 
     def _get_link_from_email(self) -> str:
         assert self.email_attempts, "No emails have been sent"