summary refs log tree commit diff
path: root/tests/server.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/server.py')
-rw-r--r--tests/server.py55
1 files changed, 42 insertions, 13 deletions
diff --git a/tests/server.py b/tests/server.py
index a51ad0c14e..6419c445ec 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -2,7 +2,7 @@ import json
 import logging
 from collections import deque
 from io import SEEK_END, BytesIO
-from typing import Callable, Iterable, Optional, Tuple, Union
+from typing import Callable, Iterable, MutableMapping, Optional, Tuple, Union
 
 import attr
 from typing_extensions import Deque
@@ -47,13 +47,26 @@ class FakeChannel:
     site = attr.ib(type=Site)
     _reactor = attr.ib()
     result = attr.ib(type=dict, default=attr.Factory(dict))
+    _ip = attr.ib(type=str, default="127.0.0.1")
     _producer = None
 
     @property
     def json_body(self):
-        if not self.result:
-            raise Exception("No result yet.")
-        return json.loads(self.result["body"].decode("utf8"))
+        return json.loads(self.text_body)
+
+    @property
+    def text_body(self) -> str:
+        """The body of the result, utf-8-decoded.
+
+        Raises an exception if the request has not yet completed.
+        """
+        if not self.is_finished:
+            raise Exception("Request not yet completed")
+        return self.result["body"].decode("utf8")
+
+    def is_finished(self) -> bool:
+        """check if the response has been completely received"""
+        return self.result.get("done", False)
 
     @property
     def code(self):
@@ -62,7 +75,7 @@ class FakeChannel:
         return int(self.result["code"])
 
     @property
-    def headers(self):
+    def headers(self) -> Headers:
         if not self.result:
             raise Exception("No result yet.")
         h = Headers()
@@ -108,7 +121,7 @@ class FakeChannel:
     def getPeer(self):
         # We give an address so that getClientIP returns a non null entry,
         # causing us to record the MAU
-        return address.IPv4Address("TCP", "127.0.0.1", 3423)
+        return address.IPv4Address("TCP", self._ip, 3423)
 
     def getHost(self):
         return None
@@ -124,7 +137,7 @@ class FakeChannel:
         self._reactor.run()
         x = 0
 
-        while not self.result.get("done"):
+        while not self.is_finished():
             # If there's a producer, tell it to resume producing so we get content
             if self._producer:
                 self._producer.resumeProducing()
@@ -136,6 +149,16 @@ class FakeChannel:
 
             self._reactor.advance(0.1)
 
+    def extract_cookies(self, cookies: MutableMapping[str, str]) -> None:
+        """Process the contents of any Set-Cookie headers in the response
+
+        Any cookines found are added to the given dict
+        """
+        for h in self.headers.getRawHeaders("Set-Cookie"):
+            parts = h.split(";")
+            k, v = parts[0].split("=", maxsplit=1)
+            cookies[k] = v
+
 
 class FakeSite:
     """
@@ -174,11 +197,12 @@ def make_request(
     custom_headers: Optional[
         Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
     ] = None,
-):
+    client_ip: str = "127.0.0.1",
+) -> FakeChannel:
     """
     Make a web request using the given method, path and content, and render it
 
-    Returns the Request and the Channel underneath.
+    Returns the fake Channel object which records the response to the request.
 
     Args:
         site: The twisted Site to use to render the request
@@ -201,8 +225,11 @@ def make_request(
              will pump the reactor until the the renderer tells the channel the request
              is finished.
 
+        client_ip: The IP to use as the requesting IP. Useful for testing
+            ratelimiting.
+
     Returns:
-        Tuple[synapse.http.site.SynapseRequest, channel]
+        channel
     """
     if not isinstance(method, bytes):
         method = method.encode("ascii")
@@ -216,8 +243,9 @@ def make_request(
         and not path.startswith(b"/_matrix")
         and not path.startswith(b"/_synapse")
     ):
+        if path.startswith(b"/"):
+            path = path[1:]
         path = b"/_matrix/client/r0/" + path
-        path = path.replace(b"//", b"/")
 
     if not path.startswith(b"/"):
         path = b"/" + path
@@ -227,7 +255,7 @@ def make_request(
     if isinstance(content, str):
         content = content.encode("utf8")
 
-    channel = FakeChannel(site, reactor)
+    channel = FakeChannel(site, reactor, ip=client_ip)
 
     req = request(channel)
     req.content = BytesIO(content)
@@ -258,12 +286,13 @@ def make_request(
         for k, v in custom_headers:
             req.requestHeaders.addRawHeader(k, v)
 
+    req.parseCookies()
     req.requestReceived(method, path, b"1.1")
 
     if await_result:
         channel.await_result()
 
-    return req, channel
+    return channel
 
 
 @implementer(IReactorPluggableNameResolver)