summary refs log tree commit diff
path: root/tests/rest/client
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest/client')
-rw-r--r--tests/rest/client/v2_alpha/test_account.py175
1 files changed, 145 insertions, 30 deletions
diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py
index 3ab611f618..152a5182fa 100644
--- a/tests/rest/client/v2_alpha/test_account.py
+++ b/tests/rest/client/v2_alpha/test_account.py
@@ -108,6 +108,46 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
         # Assert we can't log in with the old password
         self.attempt_wrong_password_login("kermit", old_password)
 
+    def test_basic_password_reset_canonicalise_email(self):
+        """Test basic password reset flow
+        Request password reset with different spelling
+        """
+        old_password = "monkey"
+        new_password = "kangeroo"
+
+        user_id = self.register_user("kermit", old_password)
+        self.login("kermit", old_password)
+
+        email_profile = "test@example.com"
+        email_passwort_reset = "TEST@EXAMPLE.COM"
+
+        # Add a threepid
+        self.get_success(
+            self.store.user_add_threepid(
+                user_id=user_id,
+                medium="email",
+                address=email_profile,
+                validated_at=0,
+                added_at=0,
+            )
+        )
+
+        client_secret = "foobar"
+        session_id = self._request_token(email_passwort_reset, client_secret)
+
+        self.assertEquals(len(self.email_attempts), 1)
+        link = self._get_link_from_email()
+
+        self._validate_token(link)
+
+        self._reset_password(new_password, session_id, client_secret)
+
+        # Assert we can log in with the new password
+        self.login("kermit", new_password)
+
+        # Assert we can't log in with the old password
+        self.attempt_wrong_password_login("kermit", old_password)
+
     def test_cant_reset_password_without_clicking_link(self):
         """Test that we do actually need to click the link in the email
         """
@@ -386,44 +426,67 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
         self.email = "test@example.com"
         self.url_3pid = b"account/3pid"
 
-    def test_add_email(self):
-        """Test adding an email to profile
-        """
-        client_secret = "foobar"
-        session_id = self._request_token(self.email, client_secret)
+    def test_add_valid_email(self):
+        self.get_success(self._add_email(self.email, self.email))
 
-        self.assertEquals(len(self.email_attempts), 1)
-        link = self._get_link_from_email()
+    def test_add_valid_email_second_time(self):
+        self.get_success(self._add_email(self.email, self.email))
+        self.get_success(
+            self._request_token_invalid_email(
+                self.email,
+                expected_errcode=Codes.THREEPID_IN_USE,
+                expected_error="Email is already in use",
+            )
+        )
 
-        self._validate_token(link)
+    def test_add_valid_email_second_time_canonicalise(self):
+        self.get_success(self._add_email(self.email, self.email))
+        self.get_success(
+            self._request_token_invalid_email(
+                "TEST@EXAMPLE.COM",
+                expected_errcode=Codes.THREEPID_IN_USE,
+                expected_error="Email is already in use",
+            )
+        )
 
-        request, channel = self.make_request(
-            "POST",
-            b"/_matrix/client/unstable/account/3pid/add",
-            {
-                "client_secret": client_secret,
-                "sid": session_id,
-                "auth": {
-                    "type": "m.login.password",
-                    "user": self.user_id,
-                    "password": "test",
-                },
-            },
-            access_token=self.user_id_tok,
+    def test_add_email_no_at(self):
+        self.get_success(
+            self._request_token_invalid_email(
+                "address-without-at.bar",
+                expected_errcode=Codes.UNKNOWN,
+                expected_error="Unable to parse email address",
+            )
         )
 
-        self.render(request)
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+    def test_add_email_two_at(self):
+        self.get_success(
+            self._request_token_invalid_email(
+                "foo@foo@test.bar",
+                expected_errcode=Codes.UNKNOWN,
+                expected_error="Unable to parse email address",
+            )
+        )
 
-        # Get user
-        request, channel = self.make_request(
-            "GET", self.url_3pid, access_token=self.user_id_tok,
+    def test_add_email_bad_format(self):
+        self.get_success(
+            self._request_token_invalid_email(
+                "user@bad.example.net@good.example.com",
+                expected_errcode=Codes.UNKNOWN,
+                expected_error="Unable to parse email address",
+            )
         )
-        self.render(request)
 
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-        self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
-        self.assertEqual(self.email, channel.json_body["threepids"][0]["address"])
+    def test_add_email_domain_to_lower(self):
+        self.get_success(self._add_email("foo@TEST.BAR", "foo@test.bar"))
+
+    def test_add_email_domain_with_umlaut(self):
+        self.get_success(self._add_email("foo@Öumlaut.com", "foo@öumlaut.com"))
+
+    def test_add_email_address_casefold(self):
+        self.get_success(self._add_email("Strauß@Example.com", "strauss@example.com"))
+
+    def test_address_trim(self):
+        self.get_success(self._add_email(" foo@test.bar ", "foo@test.bar"))
 
     def test_add_email_if_disabled(self):
         """Test adding email to profile when doing so is disallowed
@@ -616,6 +679,19 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
 
         return channel.json_body["sid"]
 
+    def _request_token_invalid_email(
+        self, email, expected_errcode, expected_error, client_secret="foobar",
+    ):
+        request, channel = self.make_request(
+            "POST",
+            b"account/3pid/email/requestToken",
+            {"client_secret": client_secret, "email": email, "send_attempt": 1},
+        )
+        self.render(request)
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(expected_errcode, channel.json_body["errcode"])
+        self.assertEqual(expected_error, channel.json_body["error"])
+
     def _validate_token(self, link):
         # Remove the host
         path = link.replace("https://example.com", "")
@@ -643,3 +719,42 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
         assert match, "Could not find link in email"
 
         return match.group(0)
+
+    def _add_email(self, request_email, expected_email):
+        """Test adding an email to profile
+        """
+        client_secret = "foobar"
+        session_id = self._request_token(request_email, client_secret)
+
+        self.assertEquals(len(self.email_attempts), 1)
+        link = self._get_link_from_email()
+
+        self._validate_token(link)
+
+        request, channel = self.make_request(
+            "POST",
+            b"/_matrix/client/unstable/account/3pid/add",
+            {
+                "client_secret": client_secret,
+                "sid": session_id,
+                "auth": {
+                    "type": "m.login.password",
+                    "user": self.user_id,
+                    "password": "test",
+                },
+            },
+            access_token=self.user_id_tok,
+        )
+
+        self.render(request)
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+        # Get user
+        request, channel = self.make_request(
+            "GET", self.url_3pid, access_token=self.user_id_tok,
+        )
+        self.render(request)
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
+        self.assertEqual(expected_email, channel.json_body["threepids"][0]["address"])