summary refs log tree commit diff
path: root/tests/server.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/server.py')
-rw-r--r--tests/server.py17
1 files changed, 14 insertions, 3 deletions
diff --git a/tests/server.py b/tests/server.py
index 48e45c6c8b..b404ad4e2a 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -1,6 +1,6 @@
 import json
 import logging
-from io import BytesIO
+from io import SEEK_END, BytesIO
 
 import attr
 from zope.interface import implementer
@@ -135,6 +135,7 @@ def make_request(
     request=SynapseRequest,
     shorthand=True,
     federation_auth_origin=None,
+    content_is_form=False,
 ):
     """
     Make a web request using the given method and path, feed it the
@@ -150,6 +151,8 @@ def make_request(
         with the usual REST API path, if it doesn't contain it.
         federation_auth_origin (bytes|None): if set to not-None, we will add a fake
             Authorization header pretenting to be the given server name.
+        content_is_form: Whether the content is URL encoded form data. Adds the
+            'Content-Type': 'application/x-www-form-urlencoded' header.
 
     Returns:
         Tuple[synapse.http.site.SynapseRequest, channel]
@@ -181,6 +184,8 @@ def make_request(
     req = request(channel)
     req.process = lambda: b""
     req.content = BytesIO(content)
+    # Twisted expects to be at the end of the content when parsing the request.
+    req.content.seek(SEEK_END)
     req.postpath = list(map(unquote, path[1:].split(b"/")))
 
     if access_token:
@@ -195,7 +200,13 @@ def make_request(
         )
 
     if content:
-        req.requestHeaders.addRawHeader(b"Content-Type", b"application/json")
+        if content_is_form:
+            req.requestHeaders.addRawHeader(
+                b"Content-Type", b"application/x-www-form-urlencoded"
+            )
+        else:
+            # Assume the body is JSON
+            req.requestHeaders.addRawHeader(b"Content-Type", b"application/json")
 
     req.requestReceived(method, path, b"1.1")
 
@@ -249,7 +260,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
                 return succeed(lookups[name])
 
         self.nameResolver = SimpleResolverComplexifier(FakeResolver())
-        super(ThreadedMemoryReactorClock, self).__init__()
+        super().__init__()
 
     def listenUDP(self, port, protocol, interface="", maxPacketSize=8196):
         p = udp.Port(port, protocol, interface, maxPacketSize, self)