diff options
Diffstat (limited to 'tests/server.py')
-rw-r--r-- | tests/server.py | 134 |
1 files changed, 129 insertions, 5 deletions
diff --git a/tests/server.py b/tests/server.py index 7dbdb7f8ea..7bee58dff1 100644 --- a/tests/server.py +++ b/tests/server.py @@ -4,9 +4,14 @@ from io import BytesIO from six import text_type import attr +from zope.interface import implementer -from twisted.internet import address, threads +from twisted.internet import address, threads, udp +from twisted.internet._resolver import HostResolution +from twisted.internet.address import IPv4Address from twisted.internet.defer import Deferred +from twisted.internet.error import DNSLookupError +from twisted.internet.interfaces import IReactorPluggableNameResolver from twisted.python.failure import Failure from twisted.test.proto_helpers import MemoryReactorClock @@ -65,7 +70,7 @@ class FakeChannel(object): def getPeer(self): # We give an address so that getClientIP returns a non null entry, # causing us to record the MAU - return address.IPv4Address(b"TCP", "127.0.0.1", 3423) + return address.IPv4Address("TCP", "127.0.0.1", 3423) def getHost(self): return None @@ -93,7 +98,7 @@ class FakeSite: return FakeLogger() -def make_request(method, path, content=b"", access_token=None): +def make_request(method, path, content=b"", access_token=None, request=SynapseRequest): """ Make a web request using the given method and path, feed it the content, and return the Request and the Channel underneath. @@ -115,14 +120,16 @@ def make_request(method, path, content=b"", access_token=None): site = FakeSite() channel = FakeChannel() - req = SynapseRequest(site, channel) + req = request(site, channel) req.process = lambda: b"" req.content = BytesIO(content) if access_token: req.requestHeaders.addRawHeader(b"Authorization", b"Bearer " + access_token) - req.requestHeaders.addRawHeader(b"X-Forwarded-For", b"127.0.0.1") + if content: + req.requestHeaders.addRawHeader(b"Content-Type", b"application/json") + req.requestReceived(method, path, b"1.1") return req, channel @@ -154,11 +161,46 @@ def render(request, resource, clock): wait_until_result(clock, request) +@implementer(IReactorPluggableNameResolver) class ThreadedMemoryReactorClock(MemoryReactorClock): """ A MemoryReactorClock that supports callFromThread. """ + def __init__(self): + self._udp = [] + self.lookups = {} + + class Resolver(object): + def resolveHostName( + _self, + resolutionReceiver, + hostName, + portNumber=0, + addressTypes=None, + transportSemantics='TCP', + ): + + resolution = HostResolution(hostName) + resolutionReceiver.resolutionBegan(resolution) + if hostName not in self.lookups: + raise DNSLookupError("OH NO") + + resolutionReceiver.addressResolved( + IPv4Address('TCP', self.lookups[hostName], portNumber) + ) + resolutionReceiver.resolutionComplete() + return resolution + + self.nameResolver = Resolver() + super(ThreadedMemoryReactorClock, self).__init__() + + def listenUDP(self, port, protocol, interface='', maxPacketSize=8196): + p = udp.Port(port, protocol, interface, maxPacketSize, self) + p.startListening() + self._udp.append(p) + return p + def callFromThread(self, callback, *args, **kwargs): """ Make the callback fire in the next reactor iteration. @@ -232,6 +274,7 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs): clock.threadpool = ThreadPool() pool.threadpool = ThreadPool() + pool.running = True return d @@ -239,3 +282,84 @@ def get_clock(): clock = ThreadedMemoryReactorClock() hs_clock = Clock(clock) return (clock, hs_clock) + + +@attr.s +class FakeTransport(object): + """ + A twisted.internet.interfaces.ITransport implementation which sends all its data + straight into an IProtocol object: it exists to connect two IProtocols together. + + To use it, instantiate it with the receiving IProtocol, and then pass it to the + sending IProtocol's makeConnection method: + + server = HTTPChannel() + client.makeConnection(FakeTransport(server, self.reactor)) + + If you want bidirectional communication, you'll need two instances. + """ + + other = attr.ib() + """The Protocol object which will receive any data written to this transport. + + :type: twisted.internet.interfaces.IProtocol + """ + + _reactor = attr.ib() + """Test reactor + + :type: twisted.internet.interfaces.IReactorTime + """ + + disconnecting = False + buffer = attr.ib(default=b'') + producer = attr.ib(default=None) + + def getPeer(self): + return None + + def getHost(self): + return None + + def loseConnection(self): + self.disconnecting = True + + def abortConnection(self): + self.disconnecting = True + + def pauseProducing(self): + self.producer.pauseProducing() + + def unregisterProducer(self): + if not self.producer: + return + + self.producer = None + + def registerProducer(self, producer, streaming): + self.producer = producer + self.producerStreaming = streaming + + def _produce(): + d = self.producer.resumeProducing() + d.addCallback(lambda x: self._reactor.callLater(0.1, _produce)) + + if not streaming: + self._reactor.callLater(0.0, _produce) + + def write(self, byt): + self.buffer = self.buffer + byt + + def _write(): + if getattr(self.other, "transport") is not None: + self.other.dataReceived(self.buffer) + self.buffer = b"" + return + + self._reactor.callLater(0.0, _write) + + _write() + + def writeSequence(self, seq): + for x in seq: + self.write(x) |