diff --git a/tests/server.py b/tests/server.py
index 7bee58dff1..7919a1f124 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -21,6 +21,12 @@ from synapse.util import Clock
from tests.utils import setup_test_homeserver as _sth
+class TimedOutException(Exception):
+ """
+ A web query timed out.
+ """
+
+
@attr.s
class FakeChannel(object):
"""
@@ -28,6 +34,7 @@ class FakeChannel(object):
wire).
"""
+ _reactor = attr.ib()
result = attr.ib(default=attr.Factory(dict))
_producer = None
@@ -50,6 +57,8 @@ class FakeChannel(object):
self.result["headers"] = headers
def write(self, content):
+ assert isinstance(content, bytes), "Should be bytes! " + repr(content)
+
if "body" not in self.result:
self.result["body"] = b""
@@ -57,6 +66,15 @@ class FakeChannel(object):
def registerProducer(self, producer, streaming):
self._producer = producer
+ self.producerStreaming = streaming
+
+ def _produce():
+ if self._producer:
+ self._producer.resumeProducing()
+ self._reactor.callLater(0.1, _produce)
+
+ if not streaming:
+ self._reactor.callLater(0.0, _produce)
def unregisterProducer(self):
if self._producer is None:
@@ -98,10 +116,30 @@ class FakeSite:
return FakeLogger()
-def make_request(method, path, content=b"", access_token=None, request=SynapseRequest):
+def make_request(
+ reactor,
+ method,
+ path,
+ content=b"",
+ access_token=None,
+ request=SynapseRequest,
+ shorthand=True,
+):
"""
Make a web request using the given method and path, feed it the
content, and return the Request and the Channel underneath.
+
+ Args:
+ method (bytes/unicode): The HTTP request method ("verb").
+ path (bytes/unicode): The HTTP path, suitably URL encoded (e.g.
+ escaped UTF-8 & spaces and such).
+ content (bytes or dict): The body of the request. JSON-encoded, if
+ a dict.
+ shorthand: Whether to try and be helpful and prefix the given URL
+ with the usual REST API path, if it doesn't contain it.
+
+ Returns:
+ A synapse.http.site.SynapseRequest.
"""
if not isinstance(method, bytes):
method = method.encode('ascii')
@@ -109,8 +147,8 @@ def make_request(method, path, content=b"", access_token=None, request=SynapseRe
if not isinstance(path, bytes):
path = path.encode('ascii')
- # Decorate it to be the full path
- if not path.startswith(b"/_matrix"):
+ # Decorate it to be the full path, if we're using shorthand
+ if shorthand and not path.startswith(b"/_matrix"):
path = b"/_matrix/client/r0/" + path
path = path.replace(b"//", b"/")
@@ -118,14 +156,16 @@ def make_request(method, path, content=b"", access_token=None, request=SynapseRe
content = content.encode('utf8')
site = FakeSite()
- channel = FakeChannel()
+ channel = FakeChannel(reactor)
req = request(site, channel)
req.process = lambda: b""
req.content = BytesIO(content)
if access_token:
- req.requestHeaders.addRawHeader(b"Authorization", b"Bearer " + access_token)
+ req.requestHeaders.addRawHeader(
+ b"Authorization", b"Bearer " + access_token.encode('ascii')
+ )
if content:
req.requestHeaders.addRawHeader(b"Content-Type", b"application/json")
@@ -151,7 +191,7 @@ def wait_until_result(clock, request, timeout=100):
x += 1
if x > timeout:
- raise Exception("Timed out waiting for request to finish.")
+ raise TimedOutException("Timed out waiting for request to finish.")
clock.advance(0.1)
|