summary refs log tree commit diff
path: root/tests/server.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/server.py')
-rw-r--r--tests/server.py140
1 files changed, 96 insertions, 44 deletions
diff --git a/tests/server.py b/tests/server.py
index b404ad4e2a..a51ad0c14e 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -1,8 +1,11 @@
 import json
 import logging
+from collections import deque
 from io import SEEK_END, BytesIO
+from typing import Callable, Iterable, Optional, Tuple, Union
 
 import attr
+from typing_extensions import Deque
 from zope.interface import implementer
 
 from twisted.internet import address, threads, udp
@@ -16,8 +19,8 @@ from twisted.internet.interfaces import (
 )
 from twisted.python.failure import Failure
 from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
-from twisted.web.http import unquote
 from twisted.web.http_headers import Headers
+from twisted.web.resource import IResource
 from twisted.web.server import Site
 
 from synapse.http.site import SynapseRequest
@@ -43,7 +46,7 @@ class FakeChannel:
 
     site = attr.ib(type=Site)
     _reactor = attr.ib()
-    result = attr.ib(default=attr.Factory(dict))
+    result = attr.ib(type=dict, default=attr.Factory(dict))
     _producer = None
 
     @property
@@ -114,6 +117,25 @@ class FakeChannel:
     def transport(self):
         return self
 
+    def await_result(self, timeout: int = 100) -> None:
+        """
+        Wait until the request is finished.
+        """
+        self._reactor.run()
+        x = 0
+
+        while not self.result.get("done"):
+            # If there's a producer, tell it to resume producing so we get content
+            if self._producer:
+                self._producer.resumeProducing()
+
+            x += 1
+
+            if x > timeout:
+                raise TimedOutException("Timed out waiting for request to finish.")
+
+            self._reactor.advance(0.1)
+
 
 class FakeSite:
     """
@@ -125,9 +147,21 @@ class FakeSite:
     site_tag = "test"
     access_logger = logging.getLogger("synapse.access.http.fake")
 
+    def __init__(self, resource: IResource):
+        """
+
+        Args:
+            resource: the resource to be used for rendering all requests
+        """
+        self._resource = resource
+
+    def getResourceFor(self, request):
+        return self._resource
+
 
 def make_request(
     reactor,
+    site: Site,
     method,
     path,
     content=b"",
@@ -136,12 +170,19 @@ def make_request(
     shorthand=True,
     federation_auth_origin=None,
     content_is_form=False,
+    await_result: bool = True,
+    custom_headers: Optional[
+        Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
+    ] = None,
 ):
     """
-    Make a web request using the given method and path, feed it the
-    content, and return the Request and the Channel underneath.
+    Make a web request using the given method, path and content, and render it
+
+    Returns the Request and the Channel underneath.
 
     Args:
+        site: The twisted Site to use to render the request
+
         method (bytes/unicode): The HTTP request method ("verb").
         path (bytes/unicode): The HTTP path, suitably URL encoded (e.g.
         escaped UTF-8 & spaces and such).
@@ -154,6 +195,12 @@ def make_request(
         content_is_form: Whether the content is URL encoded form data. Adds the
             'Content-Type': 'application/x-www-form-urlencoded' header.
 
+        custom_headers: (name, value) pairs to add as request headers
+
+        await_result: whether to wait for the request to complete rendering. If true,
+             will pump the reactor until the the renderer tells the channel the request
+             is finished.
+
     Returns:
         Tuple[synapse.http.site.SynapseRequest, channel]
     """
@@ -175,18 +222,17 @@ def make_request(
     if not path.startswith(b"/"):
         path = b"/" + path
 
+    if isinstance(content, dict):
+        content = json.dumps(content).encode("utf8")
     if isinstance(content, str):
         content = content.encode("utf8")
 
-    site = FakeSite()
     channel = FakeChannel(site, reactor)
 
     req = request(channel)
-    req.process = lambda: b""
     req.content = BytesIO(content)
     # Twisted expects to be at the end of the content when parsing the request.
     req.content.seek(SEEK_END)
-    req.postpath = list(map(unquote, path[1:].split(b"/")))
 
     if access_token:
         req.requestHeaders.addRawHeader(
@@ -208,35 +254,16 @@ def make_request(
             # Assume the body is JSON
             req.requestHeaders.addRawHeader(b"Content-Type", b"application/json")
 
-    req.requestReceived(method, path, b"1.1")
+    if custom_headers:
+        for k, v in custom_headers:
+            req.requestHeaders.addRawHeader(k, v)
 
-    return req, channel
-
-
-def wait_until_result(clock, request, timeout=100):
-    """
-    Wait until the request is finished.
-    """
-    clock.run()
-    x = 0
-
-    while not request.finished:
-
-        # If there's a producer, tell it to resume producing so we get content
-        if request._channel._producer:
-            request._channel._producer.resumeProducing()
-
-        x += 1
-
-        if x > timeout:
-            raise TimedOutException("Timed out waiting for request to finish.")
-
-        clock.advance(0.1)
+    req.requestReceived(method, path, b"1.1")
 
+    if await_result:
+        channel.await_result()
 
-def render(request, resource, clock):
-    request.render(resource)
-    wait_until_result(clock, request)
+    return req, channel
 
 
 @implementer(IReactorPluggableNameResolver)
@@ -251,6 +278,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
         self._tcp_callbacks = {}
         self._udp = []
         lookups = self.lookups = {}
+        self._thread_callbacks = deque()  # type: Deque[Callable[[], None]]()
 
         @implementer(IResolverSimple)
         class FakeResolver:
@@ -272,10 +300,10 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
         """
         Make the callback fire in the next reactor iteration.
         """
-        d = Deferred()
-        d.addCallback(lambda x: callback(*args, **kwargs))
-        self.callLater(0, d.callback, True)
-        return d
+        cb = lambda: callback(*args, **kwargs)
+        # it's not safe to call callLater() here, so we append the callback to a
+        # separate queue.
+        self._thread_callbacks.append(cb)
 
     def getThreadPool(self):
         return self.threadpool
@@ -303,6 +331,30 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
 
         return conn
 
+    def advance(self, amount):
+        # first advance our reactor's time, and run any "callLater" callbacks that
+        # makes ready
+        super().advance(amount)
+
+        # now run any "callFromThread" callbacks
+        while True:
+            try:
+                callback = self._thread_callbacks.popleft()
+            except IndexError:
+                break
+            callback()
+
+            # check for more "callLater" callbacks added by the thread callback
+            # This isn't required in a regular reactor, but it ends up meaning that
+            # our database queries can complete in a single call to `advance` [1] which
+            # simplifies tests.
+            #
+            # [1]: we replace the threadpool backing the db connection pool with a
+            # mock ThreadPool which doesn't really use threads; but we still use
+            # reactor.callFromThread to feed results back from the db functions to the
+            # main thread.
+            super().advance(0)
+
 
 class ThreadPool:
     """
@@ -339,8 +391,6 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
     """
     server = _sth(cleanup_func, *args, **kwargs)
 
-    database = server.config.database.get_single_database()
-
     # Make the thread pool synchronous.
     clock = server.get_clock()
 
@@ -354,7 +404,7 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
                 pool._runWithConnection,
                 func,
                 *args,
-                **kwargs
+                **kwargs,
             )
 
         def runInteraction(interaction, *args, **kwargs):
@@ -364,7 +414,7 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
                 pool._runInteraction,
                 interaction,
                 *args,
-                **kwargs
+                **kwargs,
             )
 
         pool.runWithConnection = runWithConnection
@@ -372,6 +422,10 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
         pool.threadpool = ThreadPool(clock._reactor)
         pool.running = True
 
+    # We've just changed the Databases to run DB transactions on the same
+    # thread, so we need to disable the dedicated thread behaviour.
+    server.get_datastores().main.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = False
+
     return server
 
 
@@ -541,12 +595,10 @@ def connect_client(reactor: IReactorTCP, client_id: int) -> AccumulatingProtocol
         reactor
         factory: The connecting factory to build.
     """
-    factory = reactor.tcpClients[client_id][2]
+    factory = reactor.tcpClients.pop(client_id)[2]
     client = factory.buildProtocol(None)
     server = AccumulatingProtocol()
     server.makeConnection(FakeTransport(client, reactor))
     client.makeConnection(FakeTransport(server, reactor))
 
-    reactor.tcpClients.pop(client_id)
-
     return client, server