summary refs log tree commit diff
path: root/tests/server.py
diff options
context:
space:
mode:
authorDavid Baker <dave@matrix.org>2018-11-09 18:35:02 +0000
committerDavid Baker <dave@matrix.org>2018-11-09 18:35:02 +0000
commitbca3b91c2dfeb63b43c3bfbb6700a38d4903f1eb (patch)
treeec18566349f9e3cae83699cffd2a714ae9dce932 /tests/server.py
parentpep8 (diff)
parentMerge pull request #4168 from matrix-org/babolivier/federation-client-content... (diff)
downloadsynapse-bca3b91c2dfeb63b43c3bfbb6700a38d4903f1eb.tar.xz
Merge remote-tracking branch 'origin/develop' into dbkr/e2e_backup_versions_are_numbers
Diffstat (limited to 'tests/server.py')
-rw-r--r--tests/server.py52
1 files changed, 46 insertions, 6 deletions
diff --git a/tests/server.py b/tests/server.py
index 7bee58dff1..7919a1f124 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -21,6 +21,12 @@ from synapse.util import Clock
 from tests.utils import setup_test_homeserver as _sth
 
 
+class TimedOutException(Exception):
+    """
+    A web query timed out.
+    """
+
+
 @attr.s
 class FakeChannel(object):
     """
@@ -28,6 +34,7 @@ class FakeChannel(object):
     wire).
     """
 
+    _reactor = attr.ib()
     result = attr.ib(default=attr.Factory(dict))
     _producer = None
 
@@ -50,6 +57,8 @@ class FakeChannel(object):
         self.result["headers"] = headers
 
     def write(self, content):
+        assert isinstance(content, bytes), "Should be bytes! " + repr(content)
+
         if "body" not in self.result:
             self.result["body"] = b""
 
@@ -57,6 +66,15 @@ class FakeChannel(object):
 
     def registerProducer(self, producer, streaming):
         self._producer = producer
+        self.producerStreaming = streaming
+
+        def _produce():
+            if self._producer:
+                self._producer.resumeProducing()
+                self._reactor.callLater(0.1, _produce)
+
+        if not streaming:
+            self._reactor.callLater(0.0, _produce)
 
     def unregisterProducer(self):
         if self._producer is None:
@@ -98,10 +116,30 @@ class FakeSite:
         return FakeLogger()
 
 
-def make_request(method, path, content=b"", access_token=None, request=SynapseRequest):
+def make_request(
+    reactor,
+    method,
+    path,
+    content=b"",
+    access_token=None,
+    request=SynapseRequest,
+    shorthand=True,
+):
     """
     Make a web request using the given method and path, feed it the
     content, and return the Request and the Channel underneath.
+
+    Args:
+        method (bytes/unicode): The HTTP request method ("verb").
+        path (bytes/unicode): The HTTP path, suitably URL encoded (e.g.
+        escaped UTF-8 & spaces and such).
+        content (bytes or dict): The body of the request. JSON-encoded, if
+        a dict.
+        shorthand: Whether to try and be helpful and prefix the given URL
+        with the usual REST API path, if it doesn't contain it.
+
+    Returns:
+        A synapse.http.site.SynapseRequest.
     """
     if not isinstance(method, bytes):
         method = method.encode('ascii')
@@ -109,8 +147,8 @@ def make_request(method, path, content=b"", access_token=None, request=SynapseRe
     if not isinstance(path, bytes):
         path = path.encode('ascii')
 
-    # Decorate it to be the full path
-    if not path.startswith(b"/_matrix"):
+    # Decorate it to be the full path, if we're using shorthand
+    if shorthand and not path.startswith(b"/_matrix"):
         path = b"/_matrix/client/r0/" + path
         path = path.replace(b"//", b"/")
 
@@ -118,14 +156,16 @@ def make_request(method, path, content=b"", access_token=None, request=SynapseRe
         content = content.encode('utf8')
 
     site = FakeSite()
-    channel = FakeChannel()
+    channel = FakeChannel(reactor)
 
     req = request(site, channel)
     req.process = lambda: b""
     req.content = BytesIO(content)
 
     if access_token:
-        req.requestHeaders.addRawHeader(b"Authorization", b"Bearer " + access_token)
+        req.requestHeaders.addRawHeader(
+            b"Authorization", b"Bearer " + access_token.encode('ascii')
+        )
 
     if content:
         req.requestHeaders.addRawHeader(b"Content-Type", b"application/json")
@@ -151,7 +191,7 @@ def wait_until_result(clock, request, timeout=100):
         x += 1
 
         if x > timeout:
-            raise Exception("Timed out waiting for request to finish.")
+            raise TimedOutException("Timed out waiting for request to finish.")
 
         clock.advance(0.1)