diff options
Diffstat (limited to 'tests/rest/client')
-rw-r--r-- | tests/rest/client/v2_alpha/test_account.py | 103 |
1 files changed, 96 insertions, 7 deletions
diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index 152a5182fa..0a51aeff92 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -14,11 +14,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import json import os import re from email.parser import Parser +from typing import Optional import pkg_resources @@ -29,6 +29,7 @@ from synapse.rest.client.v1 import login, room from synapse.rest.client.v2_alpha import account, register from tests import unittest +from tests.unittest import override_config class PasswordResetTestCase(unittest.HomeserverTestCase): @@ -668,16 +669,104 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertFalse(channel.json_body["threepids"]) - def _request_token(self, email, client_secret): + @override_config({"next_link_domain_whitelist": None}) + def test_next_link(self): + """Tests a valid next_link parameter value with no whitelist (good case)""" + self._request_token( + "something@example.com", + "some_secret", + next_link="https://example.com/a/good/site", + expect_code=200, + ) + + @override_config({"next_link_domain_whitelist": None}) + def test_next_link_exotic_protocol(self): + """Tests using a esoteric protocol as a next_link parameter value. + Someone may be hosting a client on IPFS etc. + """ + self._request_token( + "something@example.com", + "some_secret", + next_link="some-protocol://abcdefghijklmopqrstuvwxyz", + expect_code=200, + ) + + @override_config({"next_link_domain_whitelist": None}) + def test_next_link_file_uri(self): + """Tests next_link parameters cannot be file URI""" + # Attempt to use a next_link value that points to the local disk + self._request_token( + "something@example.com", + "some_secret", + next_link="file:///host/path", + expect_code=400, + ) + + @override_config({"next_link_domain_whitelist": ["example.com", "example.org"]}) + def test_next_link_domain_whitelist(self): + """Tests next_link parameters must fit the whitelist if provided""" + self._request_token( + "something@example.com", + "some_secret", + next_link="https://example.com/some/good/page", + expect_code=200, + ) + + self._request_token( + "something@example.com", + "some_secret", + next_link="https://example.org/some/also/good/page", + expect_code=200, + ) + + self._request_token( + "something@example.com", + "some_secret", + next_link="https://bad.example.org/some/bad/page", + expect_code=400, + ) + + @override_config({"next_link_domain_whitelist": []}) + def test_empty_next_link_domain_whitelist(self): + """Tests an empty next_lint_domain_whitelist value, meaning next_link is essentially + disallowed + """ + self._request_token( + "something@example.com", + "some_secret", + next_link="https://example.com/a/page", + expect_code=400, + ) + + def _request_token( + self, + email: str, + client_secret: str, + next_link: Optional[str] = None, + expect_code: int = 200, + ) -> str: + """Request a validation token to add an email address to a user's account + + Args: + email: The email address to validate + client_secret: A secret string + next_link: A link to redirect the user to after validation + expect_code: Expected return code of the call + + Returns: + The ID of the new threepid validation session + """ + body = {"client_secret": client_secret, "email": email, "send_attempt": 1} + if next_link: + body["next_link"] = next_link + request, channel = self.make_request( - "POST", - b"account/3pid/email/requestToken", - {"client_secret": client_secret, "email": email, "send_attempt": 1}, + "POST", b"account/3pid/email/requestToken", body, ) self.render(request) - self.assertEquals(200, channel.code, channel.result) + self.assertEquals(expect_code, channel.code, channel.result) - return channel.json_body["sid"] + return channel.json_body.get("sid") def _request_token_invalid_email( self, email, expected_errcode, expected_error, client_secret="foobar", |