diff --git a/tests/server.py b/tests/server.py
index 7dbdb7f8ea..7bee58dff1 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -4,9 +4,14 @@ from io import BytesIO
from six import text_type
import attr
+from zope.interface import implementer
-from twisted.internet import address, threads
+from twisted.internet import address, threads, udp
+from twisted.internet._resolver import HostResolution
+from twisted.internet.address import IPv4Address
from twisted.internet.defer import Deferred
+from twisted.internet.error import DNSLookupError
+from twisted.internet.interfaces import IReactorPluggableNameResolver
from twisted.python.failure import Failure
from twisted.test.proto_helpers import MemoryReactorClock
@@ -65,7 +70,7 @@ class FakeChannel(object):
def getPeer(self):
# We give an address so that getClientIP returns a non null entry,
# causing us to record the MAU
- return address.IPv4Address(b"TCP", "127.0.0.1", 3423)
+ return address.IPv4Address("TCP", "127.0.0.1", 3423)
def getHost(self):
return None
@@ -93,7 +98,7 @@ class FakeSite:
return FakeLogger()
-def make_request(method, path, content=b"", access_token=None):
+def make_request(method, path, content=b"", access_token=None, request=SynapseRequest):
"""
Make a web request using the given method and path, feed it the
content, and return the Request and the Channel underneath.
@@ -115,14 +120,16 @@ def make_request(method, path, content=b"", access_token=None):
site = FakeSite()
channel = FakeChannel()
- req = SynapseRequest(site, channel)
+ req = request(site, channel)
req.process = lambda: b""
req.content = BytesIO(content)
if access_token:
req.requestHeaders.addRawHeader(b"Authorization", b"Bearer " + access_token)
- req.requestHeaders.addRawHeader(b"X-Forwarded-For", b"127.0.0.1")
+ if content:
+ req.requestHeaders.addRawHeader(b"Content-Type", b"application/json")
+
req.requestReceived(method, path, b"1.1")
return req, channel
@@ -154,11 +161,46 @@ def render(request, resource, clock):
wait_until_result(clock, request)
+@implementer(IReactorPluggableNameResolver)
class ThreadedMemoryReactorClock(MemoryReactorClock):
"""
A MemoryReactorClock that supports callFromThread.
"""
+ def __init__(self):
+ self._udp = []
+ self.lookups = {}
+
+ class Resolver(object):
+ def resolveHostName(
+ _self,
+ resolutionReceiver,
+ hostName,
+ portNumber=0,
+ addressTypes=None,
+ transportSemantics='TCP',
+ ):
+
+ resolution = HostResolution(hostName)
+ resolutionReceiver.resolutionBegan(resolution)
+ if hostName not in self.lookups:
+ raise DNSLookupError("OH NO")
+
+ resolutionReceiver.addressResolved(
+ IPv4Address('TCP', self.lookups[hostName], portNumber)
+ )
+ resolutionReceiver.resolutionComplete()
+ return resolution
+
+ self.nameResolver = Resolver()
+ super(ThreadedMemoryReactorClock, self).__init__()
+
+ def listenUDP(self, port, protocol, interface='', maxPacketSize=8196):
+ p = udp.Port(port, protocol, interface, maxPacketSize, self)
+ p.startListening()
+ self._udp.append(p)
+ return p
+
def callFromThread(self, callback, *args, **kwargs):
"""
Make the callback fire in the next reactor iteration.
@@ -232,6 +274,7 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
clock.threadpool = ThreadPool()
pool.threadpool = ThreadPool()
+ pool.running = True
return d
@@ -239,3 +282,84 @@ def get_clock():
clock = ThreadedMemoryReactorClock()
hs_clock = Clock(clock)
return (clock, hs_clock)
+
+
+@attr.s
+class FakeTransport(object):
+ """
+ A twisted.internet.interfaces.ITransport implementation which sends all its data
+ straight into an IProtocol object: it exists to connect two IProtocols together.
+
+ To use it, instantiate it with the receiving IProtocol, and then pass it to the
+ sending IProtocol's makeConnection method:
+
+ server = HTTPChannel()
+ client.makeConnection(FakeTransport(server, self.reactor))
+
+ If you want bidirectional communication, you'll need two instances.
+ """
+
+ other = attr.ib()
+ """The Protocol object which will receive any data written to this transport.
+
+ :type: twisted.internet.interfaces.IProtocol
+ """
+
+ _reactor = attr.ib()
+ """Test reactor
+
+ :type: twisted.internet.interfaces.IReactorTime
+ """
+
+ disconnecting = False
+ buffer = attr.ib(default=b'')
+ producer = attr.ib(default=None)
+
+ def getPeer(self):
+ return None
+
+ def getHost(self):
+ return None
+
+ def loseConnection(self):
+ self.disconnecting = True
+
+ def abortConnection(self):
+ self.disconnecting = True
+
+ def pauseProducing(self):
+ self.producer.pauseProducing()
+
+ def unregisterProducer(self):
+ if not self.producer:
+ return
+
+ self.producer = None
+
+ def registerProducer(self, producer, streaming):
+ self.producer = producer
+ self.producerStreaming = streaming
+
+ def _produce():
+ d = self.producer.resumeProducing()
+ d.addCallback(lambda x: self._reactor.callLater(0.1, _produce))
+
+ if not streaming:
+ self._reactor.callLater(0.0, _produce)
+
+ def write(self, byt):
+ self.buffer = self.buffer + byt
+
+ def _write():
+ if getattr(self.other, "transport") is not None:
+ self.other.dataReceived(self.buffer)
+ self.buffer = b""
+ return
+
+ self._reactor.callLater(0.0, _write)
+
+ _write()
+
+ def writeSequence(self, seq):
+ for x in seq:
+ self.write(x)
|