summary refs log tree commit diff
path: root/tests/rest/client/test_account.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest/client/test_account.py')
-rw-r--r--tests/rest/client/test_account.py69
1 files changed, 56 insertions, 13 deletions
diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index bd59bb50cf..f1e4bdea76 100644
--- a/tests/rest/client/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -1,16 +1,22 @@
-# Copyright 2022 The Matrix.org Foundation C.I.C.
 #
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
+# This file is licensed under the Affero General Public License (AGPL) version 3.
+#
+# Copyright (C) 2023 New Vector, Ltd
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# See the GNU Affero General Public License for more details:
+# <https://www.gnu.org/licenses/agpl-3.0.html>.
+#
+# Originally licensed under the Apache License, Version 2.0:
+# <http://www.apache.org/licenses/LICENSE-2.0>.
+#
+# [This file includes modifications made by New Vector Limited]
 #
-#     http://www.apache.org/licenses/LICENSE-2.0
 #
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
 import os
 import re
 from email.parser import Parser
@@ -322,16 +328,49 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
 
         self.assertIsNotNone(session_id)
 
+    def test_password_reset_redirection(self) -> None:
+        """Test basic password reset flow"""
+        old_password = "monkey"
+
+        user_id = self.register_user("kermit", old_password)
+        self.login("kermit", old_password)
+
+        email = "test@example.com"
+
+        # Add a threepid
+        self.get_success(
+            self.store.user_add_threepid(
+                user_id=user_id,
+                medium="email",
+                address=email,
+                validated_at=0,
+                added_at=0,
+            )
+        )
+
+        client_secret = "foobar"
+        next_link = "http://example.com"
+        self._request_token(email, client_secret, "127.0.0.1", next_link)
+
+        self.assertEqual(len(self.email_attempts), 1)
+        link = self._get_link_from_email()
+
+        self._validate_token(link, next_link)
+
     def _request_token(
         self,
         email: str,
         client_secret: str,
         ip: str = "127.0.0.1",
+        next_link: Optional[str] = None,
     ) -> str:
+        body = {"client_secret": client_secret, "email": email, "send_attempt": 1}
+        if next_link is not None:
+            body["next_link"] = next_link
         channel = self.make_request(
             "POST",
             b"account/password/email/requestToken",
-            {"client_secret": client_secret, "email": email, "send_attempt": 1},
+            body,
             client_ip=ip,
         )
 
@@ -344,7 +383,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
 
         return channel.json_body["sid"]
 
-    def _validate_token(self, link: str) -> None:
+    def _validate_token(self, link: str, next_link: Optional[str] = None) -> None:
         # Remove the host
         path = link.replace("https://example.com", "")
 
@@ -372,7 +411,11 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
             shorthand=False,
             content_is_form=True,
         )
-        self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
+        self.assertEqual(
+            HTTPStatus.OK if next_link is None else HTTPStatus.FOUND,
+            channel.code,
+            channel.result,
+        )
 
     def _get_link_from_email(self) -> str:
         assert self.email_attempts, "No emails have been sent"