diff --git a/tests/server.py b/tests/server.py
index b404ad4e2a..a51ad0c14e 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -1,8 +1,11 @@
import json
import logging
+from collections import deque
from io import SEEK_END, BytesIO
+from typing import Callable, Iterable, Optional, Tuple, Union
import attr
+from typing_extensions import Deque
from zope.interface import implementer
from twisted.internet import address, threads, udp
@@ -16,8 +19,8 @@ from twisted.internet.interfaces import (
)
from twisted.python.failure import Failure
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
-from twisted.web.http import unquote
from twisted.web.http_headers import Headers
+from twisted.web.resource import IResource
from twisted.web.server import Site
from synapse.http.site import SynapseRequest
@@ -43,7 +46,7 @@ class FakeChannel:
site = attr.ib(type=Site)
_reactor = attr.ib()
- result = attr.ib(default=attr.Factory(dict))
+ result = attr.ib(type=dict, default=attr.Factory(dict))
_producer = None
@property
@@ -114,6 +117,25 @@ class FakeChannel:
def transport(self):
return self
+ def await_result(self, timeout: int = 100) -> None:
+ """
+ Wait until the request is finished.
+ """
+ self._reactor.run()
+ x = 0
+
+ while not self.result.get("done"):
+ # If there's a producer, tell it to resume producing so we get content
+ if self._producer:
+ self._producer.resumeProducing()
+
+ x += 1
+
+ if x > timeout:
+ raise TimedOutException("Timed out waiting for request to finish.")
+
+ self._reactor.advance(0.1)
+
class FakeSite:
"""
@@ -125,9 +147,21 @@ class FakeSite:
site_tag = "test"
access_logger = logging.getLogger("synapse.access.http.fake")
+ def __init__(self, resource: IResource):
+ """
+
+ Args:
+ resource: the resource to be used for rendering all requests
+ """
+ self._resource = resource
+
+ def getResourceFor(self, request):
+ return self._resource
+
def make_request(
reactor,
+ site: Site,
method,
path,
content=b"",
@@ -136,12 +170,19 @@ def make_request(
shorthand=True,
federation_auth_origin=None,
content_is_form=False,
+ await_result: bool = True,
+ custom_headers: Optional[
+ Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
+ ] = None,
):
"""
- Make a web request using the given method and path, feed it the
- content, and return the Request and the Channel underneath.
+ Make a web request using the given method, path and content, and render it
+
+ Returns the Request and the Channel underneath.
Args:
+ site: The twisted Site to use to render the request
+
method (bytes/unicode): The HTTP request method ("verb").
path (bytes/unicode): The HTTP path, suitably URL encoded (e.g.
escaped UTF-8 & spaces and such).
@@ -154,6 +195,12 @@ def make_request(
content_is_form: Whether the content is URL encoded form data. Adds the
'Content-Type': 'application/x-www-form-urlencoded' header.
+ custom_headers: (name, value) pairs to add as request headers
+
+ await_result: whether to wait for the request to complete rendering. If true,
+ will pump the reactor until the the renderer tells the channel the request
+ is finished.
+
Returns:
Tuple[synapse.http.site.SynapseRequest, channel]
"""
@@ -175,18 +222,17 @@ def make_request(
if not path.startswith(b"/"):
path = b"/" + path
+ if isinstance(content, dict):
+ content = json.dumps(content).encode("utf8")
if isinstance(content, str):
content = content.encode("utf8")
- site = FakeSite()
channel = FakeChannel(site, reactor)
req = request(channel)
- req.process = lambda: b""
req.content = BytesIO(content)
# Twisted expects to be at the end of the content when parsing the request.
req.content.seek(SEEK_END)
- req.postpath = list(map(unquote, path[1:].split(b"/")))
if access_token:
req.requestHeaders.addRawHeader(
@@ -208,35 +254,16 @@ def make_request(
# Assume the body is JSON
req.requestHeaders.addRawHeader(b"Content-Type", b"application/json")
- req.requestReceived(method, path, b"1.1")
+ if custom_headers:
+ for k, v in custom_headers:
+ req.requestHeaders.addRawHeader(k, v)
- return req, channel
-
-
-def wait_until_result(clock, request, timeout=100):
- """
- Wait until the request is finished.
- """
- clock.run()
- x = 0
-
- while not request.finished:
-
- # If there's a producer, tell it to resume producing so we get content
- if request._channel._producer:
- request._channel._producer.resumeProducing()
-
- x += 1
-
- if x > timeout:
- raise TimedOutException("Timed out waiting for request to finish.")
-
- clock.advance(0.1)
+ req.requestReceived(method, path, b"1.1")
+ if await_result:
+ channel.await_result()
-def render(request, resource, clock):
- request.render(resource)
- wait_until_result(clock, request)
+ return req, channel
@implementer(IReactorPluggableNameResolver)
@@ -251,6 +278,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
self._tcp_callbacks = {}
self._udp = []
lookups = self.lookups = {}
+ self._thread_callbacks = deque() # type: Deque[Callable[[], None]]()
@implementer(IResolverSimple)
class FakeResolver:
@@ -272,10 +300,10 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
"""
Make the callback fire in the next reactor iteration.
"""
- d = Deferred()
- d.addCallback(lambda x: callback(*args, **kwargs))
- self.callLater(0, d.callback, True)
- return d
+ cb = lambda: callback(*args, **kwargs)
+ # it's not safe to call callLater() here, so we append the callback to a
+ # separate queue.
+ self._thread_callbacks.append(cb)
def getThreadPool(self):
return self.threadpool
@@ -303,6 +331,30 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
return conn
+ def advance(self, amount):
+ # first advance our reactor's time, and run any "callLater" callbacks that
+ # makes ready
+ super().advance(amount)
+
+ # now run any "callFromThread" callbacks
+ while True:
+ try:
+ callback = self._thread_callbacks.popleft()
+ except IndexError:
+ break
+ callback()
+
+ # check for more "callLater" callbacks added by the thread callback
+ # This isn't required in a regular reactor, but it ends up meaning that
+ # our database queries can complete in a single call to `advance` [1] which
+ # simplifies tests.
+ #
+ # [1]: we replace the threadpool backing the db connection pool with a
+ # mock ThreadPool which doesn't really use threads; but we still use
+ # reactor.callFromThread to feed results back from the db functions to the
+ # main thread.
+ super().advance(0)
+
class ThreadPool:
"""
@@ -339,8 +391,6 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
"""
server = _sth(cleanup_func, *args, **kwargs)
- database = server.config.database.get_single_database()
-
# Make the thread pool synchronous.
clock = server.get_clock()
@@ -354,7 +404,7 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
pool._runWithConnection,
func,
*args,
- **kwargs
+ **kwargs,
)
def runInteraction(interaction, *args, **kwargs):
@@ -364,7 +414,7 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
pool._runInteraction,
interaction,
*args,
- **kwargs
+ **kwargs,
)
pool.runWithConnection = runWithConnection
@@ -372,6 +422,10 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
pool.threadpool = ThreadPool(clock._reactor)
pool.running = True
+ # We've just changed the Databases to run DB transactions on the same
+ # thread, so we need to disable the dedicated thread behaviour.
+ server.get_datastores().main.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = False
+
return server
@@ -541,12 +595,10 @@ def connect_client(reactor: IReactorTCP, client_id: int) -> AccumulatingProtocol
reactor
factory: The connecting factory to build.
"""
- factory = reactor.tcpClients[client_id][2]
+ factory = reactor.tcpClients.pop(client_id)[2]
client = factory.buildProtocol(None)
server = AccumulatingProtocol()
server.makeConnection(FakeTransport(client, reactor))
client.makeConnection(FakeTransport(server, reactor))
- reactor.tcpClients.pop(client_id)
-
return client, server
|