diff --git a/tests/server.py b/tests/server.py
index 7bee58dff1..fc1e76d146 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -1,4 +1,5 @@
import json
+import logging
from io import BytesIO
from six import text_type
@@ -7,19 +8,28 @@ import attr
from zope.interface import implementer
from twisted.internet import address, threads, udp
-from twisted.internet._resolver import HostResolution
-from twisted.internet.address import IPv4Address
-from twisted.internet.defer import Deferred
+from twisted.internet._resolver import SimpleResolverComplexifier
+from twisted.internet.defer import Deferred, fail, succeed
from twisted.internet.error import DNSLookupError
-from twisted.internet.interfaces import IReactorPluggableNameResolver
+from twisted.internet.interfaces import IReactorPluggableNameResolver, IResolverSimple
from twisted.python.failure import Failure
from twisted.test.proto_helpers import MemoryReactorClock
+from twisted.web.http import unquote
+from twisted.web.http_headers import Headers
from synapse.http.site import SynapseRequest
from synapse.util import Clock
from tests.utils import setup_test_homeserver as _sth
+logger = logging.getLogger(__name__)
+
+
+class TimedOutException(Exception):
+ """
+ A web query timed out.
+ """
+
@attr.s
class FakeChannel(object):
@@ -28,6 +38,7 @@ class FakeChannel(object):
wire).
"""
+ _reactor = attr.ib()
result = attr.ib(default=attr.Factory(dict))
_producer = None
@@ -43,6 +54,15 @@ class FakeChannel(object):
raise Exception("No result yet.")
return int(self.result["code"])
+ @property
+ def headers(self):
+ if not self.result:
+ raise Exception("No result yet.")
+ h = Headers()
+ for i in self.result["headers"]:
+ h.addRawHeader(*i)
+ return h
+
def writeHeaders(self, version, code, reason, headers):
self.result["version"] = version
self.result["code"] = code
@@ -50,6 +70,8 @@ class FakeChannel(object):
self.result["headers"] = headers
def write(self, content):
+ assert isinstance(content, bytes), "Should be bytes! " + repr(content)
+
if "body" not in self.result:
self.result["body"] = b""
@@ -57,6 +79,15 @@ class FakeChannel(object):
def registerProducer(self, producer, streaming):
self._producer = producer
+ self.producerStreaming = streaming
+
+ def _produce():
+ if self._producer:
+ self._producer.resumeProducing()
+ self._reactor.callLater(0.1, _produce)
+
+ if not streaming:
+ self._reactor.callLater(0.0, _produce)
def unregisterProducer(self):
if self._producer is None:
@@ -98,10 +129,30 @@ class FakeSite:
return FakeLogger()
-def make_request(method, path, content=b"", access_token=None, request=SynapseRequest):
+def make_request(
+ reactor,
+ method,
+ path,
+ content=b"",
+ access_token=None,
+ request=SynapseRequest,
+ shorthand=True,
+):
"""
Make a web request using the given method and path, feed it the
content, and return the Request and the Channel underneath.
+
+ Args:
+ method (bytes/unicode): The HTTP request method ("verb").
+ path (bytes/unicode): The HTTP path, suitably URL encoded (e.g.
+ escaped UTF-8 & spaces and such).
+ content (bytes or dict): The body of the request. JSON-encoded, if
+ a dict.
+ shorthand: Whether to try and be helpful and prefix the given URL
+ with the usual REST API path, if it doesn't contain it.
+
+ Returns:
+ A synapse.http.site.SynapseRequest.
"""
if not isinstance(method, bytes):
method = method.encode('ascii')
@@ -109,23 +160,29 @@ def make_request(method, path, content=b"", access_token=None, request=SynapseRe
if not isinstance(path, bytes):
path = path.encode('ascii')
- # Decorate it to be the full path
- if not path.startswith(b"/_matrix"):
+ # Decorate it to be the full path, if we're using shorthand
+ if shorthand and not path.startswith(b"/_matrix"):
path = b"/_matrix/client/r0/" + path
path = path.replace(b"//", b"/")
+ if not path.startswith(b"/"):
+ path = b"/" + path
+
if isinstance(content, text_type):
content = content.encode('utf8')
site = FakeSite()
- channel = FakeChannel()
+ channel = FakeChannel(reactor)
req = request(site, channel)
req.process = lambda: b""
req.content = BytesIO(content)
+ req.postpath = list(map(unquote, path[1:].split(b'/')))
if access_token:
- req.requestHeaders.addRawHeader(b"Authorization", b"Bearer " + access_token)
+ req.requestHeaders.addRawHeader(
+ b"Authorization", b"Bearer " + access_token.encode('ascii')
+ )
if content:
req.requestHeaders.addRawHeader(b"Content-Type", b"application/json")
@@ -151,7 +208,7 @@ def wait_until_result(clock, request, timeout=100):
x += 1
if x > timeout:
- raise Exception("Timed out waiting for request to finish.")
+ raise TimedOutException("Timed out waiting for request to finish.")
clock.advance(0.1)
@@ -169,30 +226,16 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
def __init__(self):
self._udp = []
- self.lookups = {}
-
- class Resolver(object):
- def resolveHostName(
- _self,
- resolutionReceiver,
- hostName,
- portNumber=0,
- addressTypes=None,
- transportSemantics='TCP',
- ):
-
- resolution = HostResolution(hostName)
- resolutionReceiver.resolutionBegan(resolution)
- if hostName not in self.lookups:
- raise DNSLookupError("OH NO")
-
- resolutionReceiver.addressResolved(
- IPv4Address('TCP', self.lookups[hostName], portNumber)
- )
- resolutionReceiver.resolutionComplete()
- return resolution
-
- self.nameResolver = Resolver()
+ lookups = self.lookups = {}
+
+ @implementer(IResolverSimple)
+ class FakeResolver(object):
+ def getHostByName(self, name, timeout=None):
+ if name not in lookups:
+ return fail(DNSLookupError("OH NO: unknown %s" % (name, )))
+ return succeed(lookups[name])
+
+ self.nameResolver = SimpleResolverComplexifier(FakeResolver())
super(ThreadedMemoryReactorClock, self).__init__()
def listenUDP(self, port, protocol, interface='', maxPacketSize=8196):
@@ -284,7 +327,7 @@ def get_clock():
return (clock, hs_clock)
-@attr.s
+@attr.s(cmp=False)
class FakeTransport(object):
"""
A twisted.internet.interfaces.ITransport implementation which sends all its data
@@ -311,7 +354,13 @@ class FakeTransport(object):
:type: twisted.internet.interfaces.IReactorTime
"""
+ _protocol = attr.ib(default=None)
+ """The Protocol which is producing data for this transport. Optional, but if set
+ will get called back for connectionLost() notifications etc.
+ """
+
disconnecting = False
+ disconnected = False
buffer = attr.ib(default=b'')
producer = attr.ib(default=None)
@@ -321,15 +370,29 @@ class FakeTransport(object):
def getHost(self):
return None
- def loseConnection(self):
- self.disconnecting = True
+ def loseConnection(self, reason=None):
+ if not self.disconnecting:
+ logger.info("FakeTransport: loseConnection(%s)", reason)
+ self.disconnecting = True
+ if self._protocol:
+ self._protocol.connectionLost(reason)
+ self.disconnected = True
def abortConnection(self):
- self.disconnecting = True
+ logger.info("FakeTransport: abortConnection()")
+ self.loseConnection()
def pauseProducing(self):
+ if not self.producer:
+ return
+
self.producer.pauseProducing()
+ def resumeProducing(self):
+ if not self.producer:
+ return
+ self.producer.resumeProducing()
+
def unregisterProducer(self):
if not self.producer:
return
@@ -351,14 +414,29 @@ class FakeTransport(object):
self.buffer = self.buffer + byt
def _write():
+ if not self.buffer:
+ # nothing to do. Don't write empty buffers: it upsets the
+ # TLSMemoryBIOProtocol
+ return
+
+ if self.disconnected:
+ return
+ logger.info("%s->%s: %s", self._protocol, self.other, self.buffer)
+
if getattr(self.other, "transport") is not None:
- self.other.dataReceived(self.buffer)
- self.buffer = b""
+ try:
+ self.other.dataReceived(self.buffer)
+ self.buffer = b""
+ except Exception as e:
+ logger.warning("Exception writing to protocol: %s", e)
return
self._reactor.callLater(0.0, _write)
- _write()
+ # always actually do the write asynchronously. Some protocols (notably the
+ # TLSMemoryBIOProtocol) get very confused if a read comes back while they are
+ # still doing a write. Doing a callLater here breaks the cycle.
+ self._reactor.callLater(0.0, _write)
def writeSequence(self, seq):
for x in seq:
|