summary refs log tree commit diff
path: root/tests/server.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/server.py')
-rw-r--r--tests/server.py90
1 files changed, 57 insertions, 33 deletions
diff --git a/tests/server.py b/tests/server.py

index 3dd2cfc072..a51ad0c14e 100644 --- a/tests/server.py +++ b/tests/server.py
@@ -2,7 +2,7 @@ import json import logging from collections import deque from io import SEEK_END, BytesIO -from typing import Callable +from typing import Callable, Iterable, Optional, Tuple, Union import attr from typing_extensions import Deque @@ -19,8 +19,8 @@ from twisted.internet.interfaces import ( ) from twisted.python.failure import Failure from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock -from twisted.web.http import unquote from twisted.web.http_headers import Headers +from twisted.web.resource import IResource from twisted.web.server import Site from synapse.http.site import SynapseRequest @@ -117,6 +117,25 @@ class FakeChannel: def transport(self): return self + def await_result(self, timeout: int = 100) -> None: + """ + Wait until the request is finished. + """ + self._reactor.run() + x = 0 + + while not self.result.get("done"): + # If there's a producer, tell it to resume producing so we get content + if self._producer: + self._producer.resumeProducing() + + x += 1 + + if x > timeout: + raise TimedOutException("Timed out waiting for request to finish.") + + self._reactor.advance(0.1) + class FakeSite: """ @@ -128,9 +147,21 @@ class FakeSite: site_tag = "test" access_logger = logging.getLogger("synapse.access.http.fake") + def __init__(self, resource: IResource): + """ + + Args: + resource: the resource to be used for rendering all requests + """ + self._resource = resource + + def getResourceFor(self, request): + return self._resource + def make_request( reactor, + site: Site, method, path, content=b"", @@ -139,12 +170,19 @@ def make_request( shorthand=True, federation_auth_origin=None, content_is_form=False, + await_result: bool = True, + custom_headers: Optional[ + Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] + ] = None, ): """ - Make a web request using the given method and path, feed it the - content, and return the Request and the Channel underneath. + Make a web request using the given method, path and content, and render it + + Returns the Request and the Channel underneath. Args: + site: The twisted Site to use to render the request + method (bytes/unicode): The HTTP request method ("verb"). path (bytes/unicode): The HTTP path, suitably URL encoded (e.g. escaped UTF-8 & spaces and such). @@ -157,6 +195,12 @@ def make_request( content_is_form: Whether the content is URL encoded form data. Adds the 'Content-Type': 'application/x-www-form-urlencoded' header. + custom_headers: (name, value) pairs to add as request headers + + await_result: whether to wait for the request to complete rendering. If true, + will pump the reactor until the the renderer tells the channel the request + is finished. + Returns: Tuple[synapse.http.site.SynapseRequest, channel] """ @@ -178,18 +222,17 @@ def make_request( if not path.startswith(b"/"): path = b"/" + path + if isinstance(content, dict): + content = json.dumps(content).encode("utf8") if isinstance(content, str): content = content.encode("utf8") - site = FakeSite() channel = FakeChannel(site, reactor) req = request(channel) - req.process = lambda: b"" req.content = BytesIO(content) # Twisted expects to be at the end of the content when parsing the request. req.content.seek(SEEK_END) - req.postpath = list(map(unquote, path[1:].split(b"/"))) if access_token: req.requestHeaders.addRawHeader( @@ -211,35 +254,16 @@ def make_request( # Assume the body is JSON req.requestHeaders.addRawHeader(b"Content-Type", b"application/json") - req.requestReceived(method, path, b"1.1") - - return req, channel - + if custom_headers: + for k, v in custom_headers: + req.requestHeaders.addRawHeader(k, v) -def wait_until_result(clock, request, timeout=100): - """ - Wait until the request is finished. - """ - clock.run() - x = 0 - - while not request.finished: - - # If there's a producer, tell it to resume producing so we get content - if request._channel._producer: - request._channel._producer.resumeProducing() - - x += 1 - - if x > timeout: - raise TimedOutException("Timed out waiting for request to finish.") - - clock.advance(0.1) + req.requestReceived(method, path, b"1.1") + if await_result: + channel.await_result() -def render(request, resource, clock): - request.render(resource) - wait_until_result(clock, request) + return req, channel @implementer(IReactorPluggableNameResolver)