summary refs log tree commit diff
path: root/tests/server.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/server.py')
-rw-r--r--tests/server.py160
1 files changed, 119 insertions, 41 deletions
diff --git a/tests/server.py b/tests/server.py
index 7bee58dff1..fc1e76d146 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -1,4 +1,5 @@
 import json
+import logging
 from io import BytesIO
 
 from six import text_type
@@ -7,19 +8,28 @@ import attr
 from zope.interface import implementer
 
 from twisted.internet import address, threads, udp
-from twisted.internet._resolver import HostResolution
-from twisted.internet.address import IPv4Address
-from twisted.internet.defer import Deferred
+from twisted.internet._resolver import SimpleResolverComplexifier
+from twisted.internet.defer import Deferred, fail, succeed
 from twisted.internet.error import DNSLookupError
-from twisted.internet.interfaces import IReactorPluggableNameResolver
+from twisted.internet.interfaces import IReactorPluggableNameResolver, IResolverSimple
 from twisted.python.failure import Failure
 from twisted.test.proto_helpers import MemoryReactorClock
+from twisted.web.http import unquote
+from twisted.web.http_headers import Headers
 
 from synapse.http.site import SynapseRequest
 from synapse.util import Clock
 
 from tests.utils import setup_test_homeserver as _sth
 
+logger = logging.getLogger(__name__)
+
+
+class TimedOutException(Exception):
+    """
+    A web query timed out.
+    """
+
 
 @attr.s
 class FakeChannel(object):
@@ -28,6 +38,7 @@ class FakeChannel(object):
     wire).
     """
 
+    _reactor = attr.ib()
     result = attr.ib(default=attr.Factory(dict))
     _producer = None
 
@@ -43,6 +54,15 @@ class FakeChannel(object):
             raise Exception("No result yet.")
         return int(self.result["code"])
 
+    @property
+    def headers(self):
+        if not self.result:
+            raise Exception("No result yet.")
+        h = Headers()
+        for i in self.result["headers"]:
+            h.addRawHeader(*i)
+        return h
+
     def writeHeaders(self, version, code, reason, headers):
         self.result["version"] = version
         self.result["code"] = code
@@ -50,6 +70,8 @@ class FakeChannel(object):
         self.result["headers"] = headers
 
     def write(self, content):
+        assert isinstance(content, bytes), "Should be bytes! " + repr(content)
+
         if "body" not in self.result:
             self.result["body"] = b""
 
@@ -57,6 +79,15 @@ class FakeChannel(object):
 
     def registerProducer(self, producer, streaming):
         self._producer = producer
+        self.producerStreaming = streaming
+
+        def _produce():
+            if self._producer:
+                self._producer.resumeProducing()
+                self._reactor.callLater(0.1, _produce)
+
+        if not streaming:
+            self._reactor.callLater(0.0, _produce)
 
     def unregisterProducer(self):
         if self._producer is None:
@@ -98,10 +129,30 @@ class FakeSite:
         return FakeLogger()
 
 
-def make_request(method, path, content=b"", access_token=None, request=SynapseRequest):
+def make_request(
+    reactor,
+    method,
+    path,
+    content=b"",
+    access_token=None,
+    request=SynapseRequest,
+    shorthand=True,
+):
     """
     Make a web request using the given method and path, feed it the
     content, and return the Request and the Channel underneath.
+
+    Args:
+        method (bytes/unicode): The HTTP request method ("verb").
+        path (bytes/unicode): The HTTP path, suitably URL encoded (e.g.
+        escaped UTF-8 & spaces and such).
+        content (bytes or dict): The body of the request. JSON-encoded, if
+        a dict.
+        shorthand: Whether to try and be helpful and prefix the given URL
+        with the usual REST API path, if it doesn't contain it.
+
+    Returns:
+        A synapse.http.site.SynapseRequest.
     """
     if not isinstance(method, bytes):
         method = method.encode('ascii')
@@ -109,23 +160,29 @@ def make_request(method, path, content=b"", access_token=None, request=SynapseRe
     if not isinstance(path, bytes):
         path = path.encode('ascii')
 
-    # Decorate it to be the full path
-    if not path.startswith(b"/_matrix"):
+    # Decorate it to be the full path, if we're using shorthand
+    if shorthand and not path.startswith(b"/_matrix"):
         path = b"/_matrix/client/r0/" + path
         path = path.replace(b"//", b"/")
 
+    if not path.startswith(b"/"):
+        path = b"/" + path
+
     if isinstance(content, text_type):
         content = content.encode('utf8')
 
     site = FakeSite()
-    channel = FakeChannel()
+    channel = FakeChannel(reactor)
 
     req = request(site, channel)
     req.process = lambda: b""
     req.content = BytesIO(content)
+    req.postpath = list(map(unquote, path[1:].split(b'/')))
 
     if access_token:
-        req.requestHeaders.addRawHeader(b"Authorization", b"Bearer " + access_token)
+        req.requestHeaders.addRawHeader(
+            b"Authorization", b"Bearer " + access_token.encode('ascii')
+        )
 
     if content:
         req.requestHeaders.addRawHeader(b"Content-Type", b"application/json")
@@ -151,7 +208,7 @@ def wait_until_result(clock, request, timeout=100):
         x += 1
 
         if x > timeout:
-            raise Exception("Timed out waiting for request to finish.")
+            raise TimedOutException("Timed out waiting for request to finish.")
 
         clock.advance(0.1)
 
@@ -169,30 +226,16 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
 
     def __init__(self):
         self._udp = []
-        self.lookups = {}
-
-        class Resolver(object):
-            def resolveHostName(
-                _self,
-                resolutionReceiver,
-                hostName,
-                portNumber=0,
-                addressTypes=None,
-                transportSemantics='TCP',
-            ):
-
-                resolution = HostResolution(hostName)
-                resolutionReceiver.resolutionBegan(resolution)
-                if hostName not in self.lookups:
-                    raise DNSLookupError("OH NO")
-
-                resolutionReceiver.addressResolved(
-                    IPv4Address('TCP', self.lookups[hostName], portNumber)
-                )
-                resolutionReceiver.resolutionComplete()
-                return resolution
-
-        self.nameResolver = Resolver()
+        lookups = self.lookups = {}
+
+        @implementer(IResolverSimple)
+        class FakeResolver(object):
+            def getHostByName(self, name, timeout=None):
+                if name not in lookups:
+                    return fail(DNSLookupError("OH NO: unknown %s" % (name, )))
+                return succeed(lookups[name])
+
+        self.nameResolver = SimpleResolverComplexifier(FakeResolver())
         super(ThreadedMemoryReactorClock, self).__init__()
 
     def listenUDP(self, port, protocol, interface='', maxPacketSize=8196):
@@ -284,7 +327,7 @@ def get_clock():
     return (clock, hs_clock)
 
 
-@attr.s
+@attr.s(cmp=False)
 class FakeTransport(object):
     """
     A twisted.internet.interfaces.ITransport implementation which sends all its data
@@ -311,7 +354,13 @@ class FakeTransport(object):
     :type: twisted.internet.interfaces.IReactorTime
     """
 
+    _protocol = attr.ib(default=None)
+    """The Protocol which is producing data for this transport. Optional, but if set
+    will get called back for connectionLost() notifications etc.
+    """
+
     disconnecting = False
+    disconnected = False
     buffer = attr.ib(default=b'')
     producer = attr.ib(default=None)
 
@@ -321,15 +370,29 @@ class FakeTransport(object):
     def getHost(self):
         return None
 
-    def loseConnection(self):
-        self.disconnecting = True
+    def loseConnection(self, reason=None):
+        if not self.disconnecting:
+            logger.info("FakeTransport: loseConnection(%s)", reason)
+            self.disconnecting = True
+            if self._protocol:
+                self._protocol.connectionLost(reason)
+            self.disconnected = True
 
     def abortConnection(self):
-        self.disconnecting = True
+        logger.info("FakeTransport: abortConnection()")
+        self.loseConnection()
 
     def pauseProducing(self):
+        if not self.producer:
+            return
+
         self.producer.pauseProducing()
 
+    def resumeProducing(self):
+        if not self.producer:
+            return
+        self.producer.resumeProducing()
+
     def unregisterProducer(self):
         if not self.producer:
             return
@@ -351,14 +414,29 @@ class FakeTransport(object):
         self.buffer = self.buffer + byt
 
         def _write():
+            if not self.buffer:
+                # nothing to do. Don't write empty buffers: it upsets the
+                # TLSMemoryBIOProtocol
+                return
+
+            if self.disconnected:
+                return
+            logger.info("%s->%s: %s", self._protocol, self.other, self.buffer)
+
             if getattr(self.other, "transport") is not None:
-                self.other.dataReceived(self.buffer)
-                self.buffer = b""
+                try:
+                    self.other.dataReceived(self.buffer)
+                    self.buffer = b""
+                except Exception as e:
+                    logger.warning("Exception writing to protocol: %s", e)
                 return
 
             self._reactor.callLater(0.0, _write)
 
-        _write()
+        # always actually do the write asynchronously. Some protocols (notably the
+        # TLSMemoryBIOProtocol) get very confused if a read comes back while they are
+        # still doing a write. Doing a callLater here breaks the cycle.
+        self._reactor.callLater(0.0, _write)
 
     def writeSequence(self, seq):
         for x in seq: