summary refs log tree commit diff
path: root/tests/server.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/server.py')
-rw-r--r--tests/server.py137
1 files changed, 131 insertions, 6 deletions
diff --git a/tests/server.py b/tests/server.py
index 615bba1b59..819c854448 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -4,9 +4,14 @@ from io import BytesIO
 from six import text_type
 
 import attr
+from zope.interface import implementer
 
-from twisted.internet import address, threads
+from twisted.internet import address, threads, udp
+from twisted.internet._resolver import HostResolution
+from twisted.internet.address import IPv4Address
 from twisted.internet.defer import Deferred
+from twisted.internet.error import DNSLookupError
+from twisted.internet.interfaces import IReactorPluggableNameResolver
 from twisted.python.failure import Failure
 from twisted.test.proto_helpers import MemoryReactorClock
 
@@ -65,7 +70,7 @@ class FakeChannel(object):
     def getPeer(self):
         # We give an address so that getClientIP returns a non null entry,
         # causing us to record the MAU
-        return address.IPv4Address(b"TCP", "127.0.0.1", 3423)
+        return address.IPv4Address("TCP", "127.0.0.1", 3423)
 
     def getHost(self):
         return None
@@ -93,7 +98,7 @@ class FakeSite:
         return FakeLogger()
 
 
-def make_request(method, path, content=b"", access_token=None):
+def make_request(method, path, content=b"", access_token=None, request=SynapseRequest):
     """
     Make a web request using the given method and path, feed it the
     content, and return the Request and the Channel underneath.
@@ -115,14 +120,18 @@ def make_request(method, path, content=b"", access_token=None):
     site = FakeSite()
     channel = FakeChannel()
 
-    req = SynapseRequest(site, channel)
+    req = request(site, channel)
     req.process = lambda: b""
     req.content = BytesIO(content)
 
     if access_token:
-        req.requestHeaders.addRawHeader(b"Authorization", b"Bearer " + access_token)
+        req.requestHeaders.addRawHeader(
+            b"Authorization", b"Bearer " + access_token.encode('ascii')
+        )
+
+    if content:
+        req.requestHeaders.addRawHeader(b"Content-Type", b"application/json")
 
-    req.requestHeaders.addRawHeader(b"X-Forwarded-For", b"127.0.0.1")
     req.requestReceived(method, path, b"1.1")
 
     return req, channel
@@ -154,11 +163,46 @@ def render(request, resource, clock):
     wait_until_result(clock, request)
 
 
+@implementer(IReactorPluggableNameResolver)
 class ThreadedMemoryReactorClock(MemoryReactorClock):
     """
     A MemoryReactorClock that supports callFromThread.
     """
 
+    def __init__(self):
+        self._udp = []
+        self.lookups = {}
+
+        class Resolver(object):
+            def resolveHostName(
+                _self,
+                resolutionReceiver,
+                hostName,
+                portNumber=0,
+                addressTypes=None,
+                transportSemantics='TCP',
+            ):
+
+                resolution = HostResolution(hostName)
+                resolutionReceiver.resolutionBegan(resolution)
+                if hostName not in self.lookups:
+                    raise DNSLookupError("OH NO")
+
+                resolutionReceiver.addressResolved(
+                    IPv4Address('TCP', self.lookups[hostName], portNumber)
+                )
+                resolutionReceiver.resolutionComplete()
+                return resolution
+
+        self.nameResolver = Resolver()
+        super(ThreadedMemoryReactorClock, self).__init__()
+
+    def listenUDP(self, port, protocol, interface='', maxPacketSize=8196):
+        p = udp.Port(port, protocol, interface, maxPacketSize, self)
+        p.startListening()
+        self._udp.append(p)
+        return p
+
     def callFromThread(self, callback, *args, **kwargs):
         """
         Make the callback fire in the next reactor iteration.
@@ -240,3 +284,84 @@ def get_clock():
     clock = ThreadedMemoryReactorClock()
     hs_clock = Clock(clock)
     return (clock, hs_clock)
+
+
+@attr.s
+class FakeTransport(object):
+    """
+    A twisted.internet.interfaces.ITransport implementation which sends all its data
+    straight into an IProtocol object: it exists to connect two IProtocols together.
+
+    To use it, instantiate it with the receiving IProtocol, and then pass it to the
+    sending IProtocol's makeConnection method:
+
+        server = HTTPChannel()
+        client.makeConnection(FakeTransport(server, self.reactor))
+
+    If you want bidirectional communication, you'll need two instances.
+    """
+
+    other = attr.ib()
+    """The Protocol object which will receive any data written to this transport.
+
+    :type: twisted.internet.interfaces.IProtocol
+    """
+
+    _reactor = attr.ib()
+    """Test reactor
+
+    :type: twisted.internet.interfaces.IReactorTime
+    """
+
+    disconnecting = False
+    buffer = attr.ib(default=b'')
+    producer = attr.ib(default=None)
+
+    def getPeer(self):
+        return None
+
+    def getHost(self):
+        return None
+
+    def loseConnection(self):
+        self.disconnecting = True
+
+    def abortConnection(self):
+        self.disconnecting = True
+
+    def pauseProducing(self):
+        self.producer.pauseProducing()
+
+    def unregisterProducer(self):
+        if not self.producer:
+            return
+
+        self.producer = None
+
+    def registerProducer(self, producer, streaming):
+        self.producer = producer
+        self.producerStreaming = streaming
+
+        def _produce():
+            d = self.producer.resumeProducing()
+            d.addCallback(lambda x: self._reactor.callLater(0.1, _produce))
+
+        if not streaming:
+            self._reactor.callLater(0.0, _produce)
+
+    def write(self, byt):
+        self.buffer = self.buffer + byt
+
+        def _write():
+            if getattr(self.other, "transport") is not None:
+                self.other.dataReceived(self.buffer)
+                self.buffer = b""
+                return
+
+            self._reactor.callLater(0.0, _write)
+
+        _write()
+
+    def writeSequence(self, seq):
+        for x in seq:
+            self.write(x)