summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/rest/client/v2_alpha/account.py6
-rw-r--r--tests/rest/client/v2_alpha/test_account.py17
-rw-r--r--tests/server.py15
-rw-r--r--tests/test_utils/http.py37
4 files changed, 67 insertions, 8 deletions
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index a309cf532d..412a6eaec9 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -19,6 +19,7 @@ import random
 from http import HTTPStatus
 from typing import TYPE_CHECKING
 
+from twisted.web.server import Request
 from synapse.api.constants import LoginType
 from synapse.api.errors import (
     Codes,
@@ -240,7 +241,7 @@ class PasswordResetConfirmationSubmitTokenServlet(RestServlet):
                 hs.config.email_password_reset_template_failure_html
             )
 
-    async def on_POST(self, request):
+    async def on_POST(self, request: Request):
         if self._threepid_behaviour_email == ThreepidBehaviour.OFF:
             if self._local_threepid_handling_disabled_due_to_email_config:
                 logger.warning(
@@ -255,6 +256,9 @@ class PasswordResetConfirmationSubmitTokenServlet(RestServlet):
                 "Password resets for this homeserver are handled by a separate program",
             )
 
+        logger.info("ARGS: %s, CONTENT: %s, HEADERS: %s", request.args, request.content,
+                    request.getAllHeaders())
+
         sid = parse_string(request, "sid", required=True)
         token = parse_string(request, "token", required=True)
         client_secret = parse_string(request, "client_secret", required=True)
diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py
index f336671e2c..0223152295 100644
--- a/tests/rest/client/v2_alpha/test_account.py
+++ b/tests/rest/client/v2_alpha/test_account.py
@@ -16,9 +16,11 @@
 # limitations under the License.
 
 import json
+from urllib.parse import urlencode
 import os
 import re
 from email.parser import Parser
+from tests.test_utils.http import convert_request_args_to_form_data
 
 import pkg_resources
 
@@ -70,6 +72,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
     def prepare(self, reactor, clock, hs):
         self.store = hs.get_datastore()
 
+    @unittest.INFO
     def test_basic_password_reset(self):
         """Test basic password reset flow
         """
@@ -256,15 +259,17 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
         self.assertEquals(200, channel.code, channel.result)
 
         # Replace the path with the confirmation path
-        path = re.sub(
-            "^/_matrix.*submit_token",
-            "/_matrix/client/unstable/password_reset/email/submit_token_confirm",
-            path,
-        )
+        path = "/_matrix/client/unstable/password_reset/email/submit_token_confirm"
 
         # Confirm the password reset
-        request, channel = self.make_request("POST", path, shorthand=False)
+        request, channel = self.make_request(
+            "POST",
+            path,
+            content=urlencode(request.args).encode("utf8"),
+            shorthand=False,
+        )
         self.render(request)
+        print(channel.json_body)
         self.assertEquals(200, channel.code, channel.result)
 
     def _get_link_from_email(self):
diff --git a/tests/server.py b/tests/server.py
index b6e0b14e78..b1e5fd84fe 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -1,6 +1,7 @@
 import json
 import logging
 from io import BytesIO
+from json.decoder import JSONDecodeError
 
 import attr
 from zope.interface import implementer
@@ -195,7 +196,19 @@ def make_request(
         )
 
     if content:
-        req.requestHeaders.addRawHeader(b"Content-Type", b"application/json")
+        content_is_json = True
+        try:
+            json.loads(content)
+        except JSONDecodeError:
+            content_is_json = False
+
+        print("Content is json?", content_is_json, path)
+        if content_is_json:
+            req.requestHeaders.addRawHeader(b"Content-Type", b"application/json")
+        else:
+            req.requestHeaders.addRawHeader(
+                b"Content-Type", b"application/x-www-form-urlencoded"
+            )
 
     req.requestReceived(method, path, b"1.1")
 
diff --git a/tests/test_utils/http.py b/tests/test_utils/http.py
new file mode 100644
index 0000000000..c2808e1ff5
--- /dev/null
+++ b/tests/test_utils/http.py
@@ -0,0 +1,37 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.web.server import Request
+
+
+def convert_request_args_to_form_data(request: Request) -> bytes:
+    """Converts query arguments from a request to formatted HTML form data
+
+    Ref: https://developer.mozilla.org/en-US/docs/Learn/Forms/Sending_and_retrieving_form_data
+
+    Args:
+        The request to pull arguments from
+
+    Returns:
+        The HTML form body data representation of the request's arguments
+    """
+    body = b""
+    for key, value in request.args.items():
+        arg = b"%s=%s&" % (key, value[0])
+        body += arg
+
+    # Remove the last '&' sign
+    return body[:-1]