diff options
-rw-r--r-- | synapse/rest/client/v2_alpha/account.py | 6 | ||||
-rw-r--r-- | tests/rest/client/v2_alpha/test_account.py | 17 | ||||
-rw-r--r-- | tests/server.py | 15 | ||||
-rw-r--r-- | tests/test_utils/http.py | 37 |
4 files changed, 67 insertions, 8 deletions
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index a309cf532d..412a6eaec9 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -19,6 +19,7 @@ import random from http import HTTPStatus from typing import TYPE_CHECKING +from twisted.web.server import Request from synapse.api.constants import LoginType from synapse.api.errors import ( Codes, @@ -240,7 +241,7 @@ class PasswordResetConfirmationSubmitTokenServlet(RestServlet): hs.config.email_password_reset_template_failure_html ) - async def on_POST(self, request): + async def on_POST(self, request: Request): if self._threepid_behaviour_email == ThreepidBehaviour.OFF: if self._local_threepid_handling_disabled_due_to_email_config: logger.warning( @@ -255,6 +256,9 @@ class PasswordResetConfirmationSubmitTokenServlet(RestServlet): "Password resets for this homeserver are handled by a separate program", ) + logger.info("ARGS: %s, CONTENT: %s, HEADERS: %s", request.args, request.content, + request.getAllHeaders()) + sid = parse_string(request, "sid", required=True) token = parse_string(request, "token", required=True) client_secret = parse_string(request, "client_secret", required=True) diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index f336671e2c..0223152295 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -16,9 +16,11 @@ # limitations under the License. import json +from urllib.parse import urlencode import os import re from email.parser import Parser +from tests.test_utils.http import convert_request_args_to_form_data import pkg_resources @@ -70,6 +72,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() + @unittest.INFO def test_basic_password_reset(self): """Test basic password reset flow """ @@ -256,15 +259,17 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): self.assertEquals(200, channel.code, channel.result) # Replace the path with the confirmation path - path = re.sub( - "^/_matrix.*submit_token", - "/_matrix/client/unstable/password_reset/email/submit_token_confirm", - path, - ) + path = "/_matrix/client/unstable/password_reset/email/submit_token_confirm" # Confirm the password reset - request, channel = self.make_request("POST", path, shorthand=False) + request, channel = self.make_request( + "POST", + path, + content=urlencode(request.args).encode("utf8"), + shorthand=False, + ) self.render(request) + print(channel.json_body) self.assertEquals(200, channel.code, channel.result) def _get_link_from_email(self): diff --git a/tests/server.py b/tests/server.py index b6e0b14e78..b1e5fd84fe 100644 --- a/tests/server.py +++ b/tests/server.py @@ -1,6 +1,7 @@ import json import logging from io import BytesIO +from json.decoder import JSONDecodeError import attr from zope.interface import implementer @@ -195,7 +196,19 @@ def make_request( ) if content: - req.requestHeaders.addRawHeader(b"Content-Type", b"application/json") + content_is_json = True + try: + json.loads(content) + except JSONDecodeError: + content_is_json = False + + print("Content is json?", content_is_json, path) + if content_is_json: + req.requestHeaders.addRawHeader(b"Content-Type", b"application/json") + else: + req.requestHeaders.addRawHeader( + b"Content-Type", b"application/x-www-form-urlencoded" + ) req.requestReceived(method, path, b"1.1") diff --git a/tests/test_utils/http.py b/tests/test_utils/http.py new file mode 100644 index 0000000000..c2808e1ff5 --- /dev/null +++ b/tests/test_utils/http.py @@ -0,0 +1,37 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.web.server import Request + + +def convert_request_args_to_form_data(request: Request) -> bytes: + """Converts query arguments from a request to formatted HTML form data + + Ref: https://developer.mozilla.org/en-US/docs/Learn/Forms/Sending_and_retrieving_form_data + + Args: + The request to pull arguments from + + Returns: + The HTML form body data representation of the request's arguments + """ + body = b"" + for key, value in request.args.items(): + arg = b"%s=%s&" % (key, value[0]) + body += arg + + # Remove the last '&' sign + return body[:-1] |